THRIFT-5007 Implement MAX_MESSAGE_SIZE and remaining read bytes control
Client: Delphi
Patch: Jens Geyer

This closes #1932
diff --git a/lib/delphi/src/Thrift.Transport.pas b/lib/delphi/src/Thrift.Transport.pas
index bede57c..a3476bf 100644
--- a/lib/delphi/src/Thrift.Transport.pas
+++ b/lib/delphi/src/Thrift.Transport.pas
@@ -44,14 +44,39 @@
   Thrift.WinHTTP,
   Thrift.Stream;
 
+const
+  DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024; // 100 MB
+  DEFAULT_THRIFT_TIMEOUT = 5 * 1000; // ms
+
 type
+  ITransportControl = interface
+    ['{CDA35E2C-F1D2-4BE3-9927-7F1540923265}']
+    function  MaxAllowedMessageSize : Integer;
+    procedure ConsumeReadBytes( const count : Integer);
+    procedure ResetConsumedMessageSize;
+  end;
+
+  TTransportControlImpl = class( TInterfacedObject, ITransportControl)
+  strict private
+    FMaxAllowedMsgSize : Integer;
+    FRemainingMsgSize : Integer;
+  strict protected
+    // ITransportControl
+    function  MaxAllowedMessageSize : Integer;
+    procedure ConsumeReadBytes( const count : Integer);
+    procedure ResetConsumedMessageSize;
+  public
+    constructor Create( const aMaxMessageSize : Integer = DEFAULT_MAX_MESSAGE_SIZE);  reintroduce;
+  end;
+
   ITransport = interface
-    ['{DB84961E-8BB3-4532-99E1-A8C7AC2300F7}']
+    ['{938F6EB5-1848-43D5-8AC4-07633C55B229}']
     function GetIsOpen: Boolean;
     property IsOpen: Boolean read GetIsOpen;
     function Peek: Boolean;
     procedure Open;
     procedure Close;
+
     function Read(var buf: TBytes; off: Integer; len: Integer): Integer; overload;
     function Read(const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer; overload;
     function ReadAll(var buf: TBytes; off: Integer; len: Integer): Integer; overload;
@@ -61,15 +86,22 @@
     procedure Write( const pBuf : Pointer; off, len : Integer); overload;
     procedure Write( const pBuf : Pointer; len : Integer); overload;
     procedure Flush;
+
+    function  TransportControl : ITransportControl;
+    procedure CheckReadBytesAvailable( const value : Integer);
   end;
 
   TTransportImpl = class( TInterfacedObject, ITransport)
+  strict private
+    FTransportControl : ITransportControl;
+
   strict protected
     function GetIsOpen: Boolean; virtual; abstract;
     property IsOpen: Boolean read GetIsOpen;
     function Peek: Boolean; virtual;
     procedure Open(); virtual; abstract;
     procedure Close(); virtual; abstract;
+
     function Read(var buf: TBytes; off: Integer; len: Integer): Integer; overload; inline;
     function Read(const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer; overload; virtual; abstract;
     function ReadAll(var buf: TBytes; off: Integer; len: Integer): Integer;  overload; inline;
@@ -79,6 +111,13 @@
     procedure Write( const pBuf : Pointer; len : Integer); overload; inline;
     procedure Write( const pBuf : Pointer; off, len : Integer); overload; virtual; abstract;
     procedure Flush; virtual;
+
+    function  TransportControl : ITransportControl;  inline;
+    procedure ConsumeReadBytes( const count : Integer);  inline;
+    procedure CheckReadBytesAvailable( const value : Integer); virtual; abstract;
+
+  public
+    constructor Create( const aTransportCtl : ITransportControl);  reintroduce;
   end;
 
   TTransportException = class abstract( TException)
@@ -98,9 +137,9 @@
     constructor HiddenCreate(const Msg: string);
     class function GetType: TExceptionType;  virtual; abstract;
   public
-    class function Create( AType: TExceptionType): TTransportException; overload; deprecated 'Use specialized TTransportException types (or regenerate from IDL)';
+    class function Create( aType: TExceptionType): TTransportException; overload; deprecated 'Use specialized TTransportException types (or regenerate from IDL)';
     class function Create( const msg: string): TTransportException; reintroduce; overload; deprecated 'Use specialized TTransportException types (or regenerate from IDL)';
-    class function Create( AType: TExceptionType; const msg: string): TTransportException; overload; deprecated 'Use specialized TTransportException types (or regenerate from IDL)';
+    class function Create( aType: TExceptionType; const msg: string): TTransportException; overload; deprecated 'Use specialized TTransportException types (or regenerate from IDL)';
     property Type_: TExceptionType read GetType;
   end;
 
@@ -196,7 +235,7 @@
 
   ITransportFactory = interface
     ['{DD809446-000F-49E1-9BFF-E0D0DC76A9D7}']
-    function GetTransport( const ATrans: ITransport): ITransport;
+    function GetTransport( const aTransport: ITransport): ITransport;
   end;
 
   TTransportFactoryImpl = class( TInterfacedObject, ITransportFactory)
@@ -222,6 +261,7 @@
   strict protected
     procedure Write( const pBuf : Pointer; offset, count: Integer); override;
     function Read( const pBuf : Pointer; const buflen : Integer; offset: Integer; count: Integer): Integer; override;
+    procedure CheckReadBytesAvailable( const value : Integer);  override;
     procedure Open; override;
     procedure Close; override;
     procedure Flush; override;
@@ -230,9 +270,9 @@
     function ToArray: TBytes; override;
   public
 {$IFDEF OLD_SOCKETS}
-    constructor Create( const ATcpClient: TCustomIpClient; const aTimeout : Integer = 0);
+    constructor Create( const aTcpClient: TCustomIpClient; const aTimeout : Integer = 0);
 {$ELSE}
-    constructor Create( const ATcpClient: TSocket; const aTimeout : Longword = 0);
+    constructor Create( const aTcpClient: TSocket; const aTimeout : Longword = 0);
 {$ENDIF}
   end;
 
@@ -254,14 +294,15 @@
     function GetInputStream: IThriftStream;
     function GetOutputStream: IThriftStream;
 
-  protected
+    procedure CheckReadBytesAvailable( const value : Integer); override;
+  strict protected
     procedure Open; override;
     procedure Close; override;
     procedure Flush; override;
     function  Read( const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer; override;
     procedure Write( const pBuf : Pointer; off, len : Integer); override;
   public
-    constructor Create( const aInputStream, aOutputStream : IThriftStream);
+    constructor Create( const aInputStream, aOutputStream : IThriftStream; const aTransportCtl : ITransportControl = nil);
     destructor Destroy; override;
 
     property InputStream : IThriftStream read GetInputStream;
@@ -277,6 +318,7 @@
   strict protected
     procedure Write( const pBuf : Pointer; offset: Integer; count: Integer); override;
     function Read( const pBuf : Pointer; const buflen : Integer; offset: Integer; count: Integer): Integer; override;
+    procedure CheckReadBytesAvailable( const value : Integer); override;
     procedure Open;  override;
     procedure Close; override;
     procedure Flush; override;
@@ -298,15 +340,19 @@
 {$ENDIF}
     FUseBufferedSocket : Boolean;
     FOwnsServer : Boolean;
+    FTransportControl : ITransportControl;
+
   strict protected
     function Accept( const fnAccepting: TProc) : ITransport; override;
+    property TransportControl : ITransportControl read FTransportControl;
+
   public
 {$IFDEF OLD_SOCKETS}
-    constructor Create( const aServer: TTcpServer; const aClientTimeout: Integer = 0); overload;
-    constructor Create( const aPort: Integer; const aClientTimeout: Integer = 0; const aUseBufferedSockets: Boolean = FALSE); overload;
+    constructor Create( const aServer: TTcpServer; const aClientTimeout: Integer = DEFAULT_THRIFT_TIMEOUT; const aTransportCtl : ITransportControl = nil); overload;
+    constructor Create( const aPort: Integer; const aClientTimeout: Integer = DEFAULT_THRIFT_TIMEOUT; aUseBufferedSockets: Boolean = FALSE; const aTransportCtl : ITransportControl = nil); overload;
 {$ELSE}
-    constructor Create( const aServer: TServerSocket; const aClientTimeout: Longword = 0); overload;
-    constructor Create( const aPort: Integer; const aClientTimeout: Longword = 0; const aUseBufferedSockets: Boolean = FALSE); overload;
+    constructor Create( const aServer: TServerSocket; const aClientTimeout: Longword = DEFAULT_THRIFT_TIMEOUT; const aTransportCtl : ITransportControl = nil); overload;
+    constructor Create( const aPort: Integer; const aClientTimeout: Longword = DEFAULT_THRIFT_TIMEOUT; aUseBufferedSockets: Boolean = FALSE; const aTransportCtl : ITransportControl = nil); overload;
 {$ENDIF}
     destructor Destroy; override;
     procedure Listen; override;
@@ -325,6 +371,7 @@
   strict protected
     function GetIsOpen: Boolean; override;
     procedure Flush; override;
+    procedure CheckReadBytesAvailable( const value : Integer);  override;
   public
     type
       TFactory = class( TTransportFactoryImpl )
@@ -363,11 +410,11 @@
   public
     procedure Open; override;
 {$IFDEF OLD_SOCKETS}
-    constructor Create( const aClient : TCustomIpClient; const aOwnsClient : Boolean; const aTimeout: Integer = 0); overload;
-    constructor Create( const aHost: string; const aPort: Integer; const aTimeout: Integer = 0); overload;
+    constructor Create( const aClient : TCustomIpClient; const aOwnsClient : Boolean; const aTimeout: Integer = DEFAULT_THRIFT_TIMEOUT; const aTransportCtl : ITransportControl = nil); overload;
+    constructor Create( const aHost: string; const aPort: Integer; const aTimeout: Integer = DEFAULT_THRIFT_TIMEOUT; const aTransportCtl : ITransportControl = nil); overload;
 {$ELSE}
-    constructor Create( const aClient: TSocket; const aOwnsClient: Boolean); overload;
-    constructor Create( const aHost: string; const aPort: Integer; const aTimeout: Longword = 0); overload;
+    constructor Create(const aClient: TSocket; const aOwnsClient: Boolean; const aTransportCtl : ITransportControl = nil); overload;
+    constructor Create( const aHost: string; const aPort: Integer; const aTimeout: Longword = DEFAULT_THRIFT_TIMEOUT; const aTransportCtl : ITransportControl = nil); overload;
 {$ENDIF}
     destructor Destroy; override;
     procedure Close; override;
@@ -402,6 +449,7 @@
     function  Read( const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer; override;
     procedure Write( const pBuf : Pointer; off, len : Integer); override;
     procedure Flush; override;
+    procedure CheckReadBytesAvailable( const value : Integer); override;
   public
     type
       TFactory = class( TTransportFactoryImpl )
@@ -409,20 +457,67 @@
         function GetTransport( const aTransport: ITransport): ITransport; override;
       end;
 
+    constructor Create( const aTransportCtl : ITransportControl); overload;
     constructor Create( const aTransport: ITransport); overload;
     destructor Destroy; override;
   end;
 
 
 const
-  DEFAULT_THRIFT_TIMEOUT = 5 * 1000; // ms
   DEFAULT_THRIFT_SECUREPROTOCOLS = [ TSecureProtocol.TLS_1_1, TSecureProtocol.TLS_1_2];
 
 implementation
 
 
+{ TTransportControlImpl }
+
+constructor TTransportControlImpl.Create( const aMaxMessageSize : Integer);
+begin
+  inherited Create;
+
+  if aMaxMessageSize > 0
+  then FMaxAllowedMsgSize := aMaxMessageSize
+  else FMaxAllowedMsgSize := DEFAULT_MAX_MESSAGE_SIZE;
+
+  ResetConsumedMessageSize;
+end;
+
+function TTransportControlImpl.MaxAllowedMessageSize : Integer;
+begin
+  result := FMaxAllowedMsgSize;
+end;
+
+procedure TTransportControlImpl.ResetConsumedMessageSize;
+begin
+  FRemainingMsgSize := MaxAllowedMessageSize;
+end;
+
+
+procedure TTransportControlImpl.ConsumeReadBytes( const count : Integer);
+begin
+  if FRemainingMsgSize >= count
+  then Dec( FRemainingMsgSize, count)
+  else begin
+    FRemainingMsgSize := 0;
+    if FRemainingMsgSize < count
+    then raise TTransportExceptionEndOfFile.Create('Maximum message size reached');
+  end;
+end;
+
+
 { TTransportImpl }
 
+constructor TTransportImpl.Create( const aTransportCtl : ITransportControl);
+begin
+  inherited Create;
+
+  if aTransportCtl <> nil
+  then FTransportControl := aTransportCtl
+  else FTransportControl := TTransportControlImpl.Create;
+  ASSERT( FTransportControl <> nil);
+end;
+
+
 procedure TTransportImpl.Flush;
 begin
   // nothing to do
@@ -477,6 +572,19 @@
 end;
 
 
+function TTransportImpl.TransportControl : ITransportControl;
+begin
+  result := FTransportControl;
+end;
+
+
+procedure TTransportImpl.ConsumeReadBytes( const count : Integer);
+begin
+  if FTransportControl <> nil
+  then FTransportControl.ConsumeReadBytes( count);
+end;
+
+
 { TTransportException }
 
 constructor TTransportException.HiddenCreate(const Msg: string);
@@ -571,13 +679,15 @@
 { TServerSocket }
 
 {$IFDEF OLD_SOCKETS}
-constructor TServerSocketImpl.Create( const aServer: TTcpServer; const aClientTimeout : Integer);
+constructor TServerSocketImpl.Create( const aServer: TTcpServer; const aClientTimeout : Integer; const aTransportCtl : ITransportControl);
 {$ELSE}
-constructor TServerSocketImpl.Create( const aServer: TServerSocket; const aClientTimeout: Longword);
+constructor TServerSocketImpl.Create( const aServer: TServerSocket; const aClientTimeout: Longword; const aTransportCtl : ITransportControl);
 {$ENDIF}
 begin
   inherited Create;
   FServer := aServer;
+  FTransportControl := aTransportCtl;
+  ASSERT( FTransportControl <> nil);
 
 {$IFDEF OLD_SOCKETS}
   FClientTimeout := aClientTimeout;
@@ -589,13 +699,18 @@
 
 
 {$IFDEF OLD_SOCKETS}
-constructor TServerSocketImpl.Create( const aPort: Integer; const aClientTimeout: Integer; const aUseBufferedSockets: Boolean);
+constructor TServerSocketImpl.Create( const aPort: Integer; const aClientTimeout: Integer; aUseBufferedSockets: Boolean; const aTransportCtl : ITransportControl);
 {$ELSE}
-constructor TServerSocketImpl.Create( const aPort: Integer; const aClientTimeout: Longword; const aUseBufferedSockets: Boolean);
+constructor TServerSocketImpl.Create( const aPort: Integer; const aClientTimeout: Longword; aUseBufferedSockets: Boolean; const aTransportCtl : ITransportControl);
 {$ENDIF}
 begin
   inherited Create;
 
+  if aTransportCtl <> nil
+  then FTransportControl := aTransportCtl
+  else FTransportControl := TTransportControlImpl.Create;
+  ASSERT( FTransportControl <> nil);
+
 {$IFDEF OLD_SOCKETS}
   FPort := aPort;
   FClientTimeout := aClientTimeout;
@@ -657,7 +772,7 @@
       Exit;
     end;
 
-    trans := TSocketImpl.Create( client, TRUE, FClientTimeout);
+    trans := TSocketImpl.Create( client, TRUE, FClientTimeout, TransportControl);
     client := nil;  // trans owns it now
 
     if FUseBufferedSocket
@@ -676,7 +791,7 @@
 
   client := FServer.Accept;
   try
-    trans := TSocketImpl.Create(client, MaxMessageSize, True);
+    trans := TSocketImpl.Create(client, True, TransportControl);
     client := nil;
 
     if FUseBufferedSocket then
@@ -725,9 +840,9 @@
 { TSocket }
 
 {$IFDEF OLD_SOCKETS}
-constructor TSocketImpl.Create( const aClient : TCustomIpClient; const aOwnsClient : Boolean; const aTimeout: Integer);
+constructor TSocketImpl.Create( const aClient : TCustomIpClient; const aOwnsClient : Boolean; const aTimeout: Integer; const aTransportCtl : ITransportControl);
 {$ELSE}
-constructor TSocketImpl.Create(const aClient: TSocket; const aOwnsClient: Boolean);
+constructor TSocketImpl.Create(const aClient: TSocket; const aOwnsClient: Boolean; const aTransportCtl : ITransportControl);
 {$ENDIF}
 var stream : IThriftStream;
 begin
@@ -741,16 +856,16 @@
 {$ENDIF}
 
   stream := TTcpSocketStreamImpl.Create( FClient, FTimeout);
-  inherited Create( stream, stream);
+  inherited Create( stream, stream, aTransportCtl);
 end;
 
 {$IFDEF OLD_SOCKETS}
-constructor TSocketImpl.Create(const aHost: string; const aPort, aTimeout: Integer);
+constructor TSocketImpl.Create(const aHost: string; const aPort, aTimeout: Integer; const aTransportCtl : ITransportControl);
 {$ELSE}
-constructor TSocketImpl.Create(const aHost: string; const aPort : Integer; const aTimeout: Longword);
+constructor TSocketImpl.Create(const aHost: string; const aPort : Integer; const aTimeout: Longword; const aTransportCtl : ITransportControl);
 {$ENDIF}
 begin
-  inherited Create(nil,nil);
+  inherited Create(nil,nil, aTransportCtl);
   FHost := aHost;
   FPort := aPort;
   FTimeout := aTimeout;
@@ -928,6 +1043,22 @@
 end;
 
 
+procedure TBufferedStreamImpl.CheckReadBytesAvailable( const value : Integer);
+var nRequired : Integer;
+begin
+  nRequired := value;
+
+  if FReadBuffer <> nil then begin
+    Dec( nRequired, (FReadBuffer.Position - FReadBuffer.Size));
+    if nRequired <= 0 then Exit;
+  end;
+
+  if FStream <> nil
+  then FStream.CheckReadBytesAvailable( nRequired)
+  else raise TTransportExceptionEndOfFile.Create('Not enough input data');
+end;
+
+
 function TBufferedStreamImpl.ToArray: TBytes;
 var len : Integer;
 begin
@@ -963,9 +1094,9 @@
 
 { TStreamTransportImpl }
 
-constructor TStreamTransportImpl.Create( const aInputStream, aOutputStream : IThriftStream);
+constructor TStreamTransportImpl.Create( const aInputStream, aOutputStream : IThriftStream; const aTransportCtl : ITransportControl);
 begin
-  inherited Create;
+  inherited Create( aTransportCtl);
   FInputStream := aInputStream;
   FOutputStream := aOutputStream;
 end;
@@ -1018,6 +1149,7 @@
   then raise TTransportExceptionNotOpen.Create('Cannot read from null inputstream' );
 
   Result := FInputStream.Read( pBuf,buflen, off, len );
+  ConsumeReadBytes( result);
 end;
 
 procedure TStreamTransportImpl.Write( const pBuf : Pointer; off, len : Integer);
@@ -1028,13 +1160,20 @@
   FOutputStream.Write( pBuf, off, len );
 end;
 
+procedure TStreamTransportImpl.CheckReadBytesAvailable( const value : Integer);
+begin
+  if FInputStream <> nil
+  then FInputStream.CheckReadBytesAvailable( value)
+  else raise TTransportExceptionNotOpen.Create('Cannot read from null inputstream' );
+end;
+
 
 { TBufferedTransportImpl }
 
 constructor TBufferedTransportImpl.Create( const aTransport : IStreamTransport; const aBufSize: Integer);
 begin
   ASSERT( aTransport <> nil);
-  inherited Create;
+  inherited Create( aTransport.TransportControl);
   FTransport := aTransport;
   FBufSize := aBufSize;
   InitBuffers;
@@ -1094,6 +1233,23 @@
   end;
 end;
 
+procedure TBufferedTransportImpl.CheckReadBytesAvailable( const value : Integer);
+var stm2 : IThriftStream2;
+    need : Integer;
+begin
+  need := value;
+
+  // buffered bytes
+  if Supports( FInputBuffer, IThriftStream2, stm2) then begin
+    Dec( need, stm2.Size - stm2.Position);
+    if need <= 0 then Exit;
+  end;
+
+  if FInputBuffer <> nil
+  then FInputBuffer.CheckReadBytesAvailable( need)
+  else raise TTransportExceptionNotOpen.Create('Cannot read from null inputstream' );
+end;
+
 { TBufferedTransportImpl.TFactory }
 
 function TBufferedTransportImpl.TFactory.GetTransport( const aTransport: ITransport): ITransport;
@@ -1104,10 +1260,18 @@
 
 { TFramedTransportImpl }
 
+constructor TFramedTransportImpl.Create( const aTransportCtl : ITransportControl);
+begin
+  inherited Create( aTransportCtl);
+
+  InitMaxFrameSize;
+  InitWriteBuffer;
+end;
+
 constructor TFramedTransportImpl.Create( const aTransport: ITransport);
 begin
   ASSERT( aTransport <> nil);
-  inherited Create;
+  inherited Create( aTransport.TransportControl);
 
   InitMaxFrameSize;
   InitWriteBuffer;
@@ -1122,8 +1286,15 @@
 end;
 
 procedure TFramedTransportImpl.InitMaxFrameSize;
+var maxLen : Integer;
 begin
   FMaxFrameSize := DEFAULT_MAX_LENGTH;
+
+  // MaxAllowedMessageSize may be smaller, but not larger
+  if TransportControl <> nil then begin
+    maxLen := TransportControl.MaxAllowedMessageSize - SizeOf(TFramedHeader);
+    FMaxFrameSize := Min( FMaxFrameSize, maxLen);
+  end;
 end;
 
 procedure TFramedTransportImpl.Close;
@@ -1225,6 +1396,7 @@
     raise TTransportExceptionCorruptedData.Create('Frame size ('+IntToStr(size)+') larger than allowed maximum ('+IntToStr(FMaxFrameSize)+')');
   end;
 
+  FTransport.CheckReadBytesAvailable( size);
   SetLength( buff, size );
   FTransport.ReadAll( buff, 0, size );
 
@@ -1247,6 +1419,18 @@
 end;
 
 
+procedure TFramedTransportImpl.CheckReadBytesAvailable( const value : Integer);
+var nRemaining : Int64;
+begin
+  if FReadBuffer = nil
+  then raise TTransportExceptionEndOfFile.Create('Cannot read from null inputstream');
+
+  nRemaining := FReadBuffer.Size - FReadBuffer.Position;
+  if value > nRemaining
+  then raise TTransportExceptionEndOfFile.Create('Not enough input data');
+end;
+
+
 { TFramedTransport.TFactory }
 
 function TFramedTransportImpl.TFactory.GetTransport( const aTransport: ITransport): ITransport;
@@ -1577,6 +1761,10 @@
 
 {$ENDIF}
 
+procedure TTcpSocketStreamImpl.CheckReadBytesAvailable( const value : Integer);
+begin
+  // we can't really tell, no further checks possible
+end;
 
 
 end.