THRIFT-3234 Delphi: Limit recursion depth to 64
Client: Delphi
Patch: Jens Geyer
diff --git a/compiler/cpp/src/generate/t_delphi_generator.cc b/compiler/cpp/src/generate/t_delphi_generator.cc
index 71c49d3..cdf49c6 100644
--- a/compiler/cpp/src/generate/t_delphi_generator.cc
+++ b/compiler/cpp/src/generate/t_delphi_generator.cc
@@ -2452,7 +2452,7 @@
if(events_) {
indent_impl(s_service_impl) << "if events <> nil then events.PostWrite;" << endl;
}
- indent_impl(s_service_impl) << "Exit;" << endl;
+ indent_impl(s_service_impl) << "Exit;" << endl;
indent_down_impl();
indent_impl(s_service_impl) << "finally" << endl;
indent_up_impl();
@@ -3481,6 +3481,9 @@
indent_impl(code_block) << "begin" << endl;
indent_up_impl();
+ indent_impl(local_vars) << "tracker : IProtocolRecursionTracker;" << endl;
+ indent_impl(code_block) << "tracker := iprot.NextRecursionLevel;" << endl;
+
// local bools for required fields
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
if ((*f_iter)->get_req() == t_field::T_REQUIRED) {
@@ -3620,8 +3623,10 @@
indent_impl(code_block) << "begin" << endl;
indent_up_impl();
+ indent_impl(local_vars) << "tracker : IProtocolRecursionTracker;" << endl;
+ indent_impl(code_block) << "tracker := oprot.NextRecursionLevel;" << endl;
+
indent_impl(code_block) << "struc := TStructImpl.Create('" << name << "');" << endl;
-
indent_impl(code_block) << "oprot.WriteStructBegin(struc);" << endl;
if (fields.size() > 0) {
@@ -3682,8 +3687,10 @@
indent_impl(code_block) << "begin" << endl;
indent_up_impl();
- indent_impl(code_block) << "struc := TStructImpl.Create('" << name << "');" << endl;
+ indent_impl(local_vars) << "tracker : IProtocolRecursionTracker;" << endl;
+ indent_impl(code_block) << "tracker := oprot.NextRecursionLevel;" << endl;
+ indent_impl(code_block) << "struc := TStructImpl.Create('" << name << "');" << endl;
indent_impl(code_block) << "oprot.WriteStructBegin(struc);" << endl;
if (fields.size() > 0) {
diff --git a/lib/delphi/src/Thrift.Protocol.pas b/lib/delphi/src/Thrift.Protocol.pas
index 606823d..01b11a8 100644
--- a/lib/delphi/src/Thrift.Protocol.pas
+++ b/lib/delphi/src/Thrift.Protocol.pas
@@ -65,6 +65,9 @@
VALID_MESSAGETYPES = [Low(TMessageType)..High(TMessageType)];
+const
+ DEFAULT_RECURSION_LIMIT = 64;
+
type
IProtocol = interface;
IStruct = interface;
@@ -244,8 +247,21 @@
class procedure Skip( prot: IProtocol; type_: TType);
end;
+ IProtocolRecursionTracker = interface
+ ['{29CA033F-BB56-49B1-9EE3-31B1E82FC7A5}']
+ // no members yet
+ end;
+
+ TProtocolRecursionTrackerImpl = class abstract( TInterfacedObject, IProtocolRecursionTracker)
+ protected
+ FProtocol : IProtocol;
+ public
+ constructor Create( prot : IProtocol);
+ destructor Destroy; override;
+ end;
+
IProtocol = interface
- ['{FD95C151-1527-4C96-8134-B902BFC4B4FC}']
+ ['{602A7FFB-0D9E-4CD8-8D7F-E5076660588A}']
function GetTransport: ITransport;
procedure WriteMessageBegin( const msg: IMessage);
procedure WriteMessageEnd;
@@ -291,12 +307,29 @@
function ReadBinary: TBytes;
function ReadString: string;
function ReadAnsiString: AnsiString;
+
+ procedure SetRecursionLimit( value : Integer);
+ function GetRecursionLimit : Integer;
+ function NextRecursionLevel : IProtocolRecursionTracker;
+ procedure IncrementRecursionDepth;
+ procedure DecrementRecursionDepth;
+
property Transport: ITransport read GetTransport;
+ property RecursionLimit : Integer read GetRecursionLimit write SetRecursionLimit;
end;
TProtocolImpl = class abstract( TInterfacedObject, IProtocol)
protected
FTrans : ITransport;
+ FRecursionLimit : Integer;
+ FRecursionDepth : Integer;
+
+ procedure SetRecursionLimit( value : Integer);
+ function GetRecursionLimit : Integer;
+ function NextRecursionLevel : IProtocolRecursionTracker;
+ procedure IncrementRecursionDepth;
+ procedure DecrementRecursionDepth;
+
function GetTransport: ITransport;
public
procedure WriteMessageBegin( const msg: IMessage); virtual; abstract;
@@ -609,12 +642,65 @@
FType := Value;
end;
+{ TProtocolRecursionTrackerImpl }
+
+constructor TProtocolRecursionTrackerImpl.Create( prot : IProtocol);
+begin
+ inherited Create;
+
+ // storing the pointer *after* the (successful) increment is important here
+ prot.IncrementRecursionDepth;
+ FProtocol := prot;
+end;
+
+destructor TProtocolRecursionTrackerImpl.Destroy;
+begin
+ try
+ // we have to release the reference iff the pointer has been stored
+ if FProtocol <> nil then begin
+ FProtocol.DecrementRecursionDepth;
+ FProtocol := nil;
+ end;
+ finally
+ inherited Destroy;
+ end;
+end;
+
{ TProtocolImpl }
constructor TProtocolImpl.Create(trans: ITransport);
begin
inherited Create;
FTrans := trans;
+ FRecursionLimit := DEFAULT_RECURSION_LIMIT;
+ FRecursionDepth := 0;
+end;
+
+procedure TProtocolImpl.SetRecursionLimit( value : Integer);
+begin
+ FRecursionLimit := value;
+end;
+
+function TProtocolImpl.GetRecursionLimit : Integer;
+begin
+ result := FRecursionLimit;
+end;
+
+function TProtocolImpl.NextRecursionLevel : IProtocolRecursionTracker;
+begin
+ result := TProtocolRecursionTrackerImpl.Create(Self);
+end;
+
+procedure TProtocolImpl.IncrementRecursionDepth;
+begin
+ if FRecursionDepth < FRecursionLimit
+ then Inc(FRecursionDepth)
+ else raise TProtocolException.Create( TProtocolException.DEPTH_LIMIT, 'Depth limit exceeded');
+end;
+
+procedure TProtocolImpl.DecrementRecursionDepth;
+begin
+ Dec(FRecursionDepth)
end;
function TProtocolImpl.GetTransport: ITransport;
@@ -672,7 +758,9 @@
set_ : ISet;
list : IList;
i : Integer;
+ tracker : IProtocolRecursionTracker;
begin
+ tracker := prot.NextRecursionLevel;
case type_ of
// simple types
TType.Bool_ : prot.ReadBool();