THRIFT-4252: Close sockets when shut down server (#1905)
* THRIFT-4252: Close sockets when shut down server
In TThreadPoolServer, threads are blocking in io with open sockets,
as long as clients don't close the connection, server threads are
never stopped even after a shutdown is called on server (because
they are blocked waiting for io).
To be able to stop all server threads properly, server should
proactively close sockets once a shutdown is initiated.
* Fix indentation
Use white space for indentation instead of tabulation.
diff --git a/.gitignore b/.gitignore
index d10f769..4e2f427 100644
--- a/.gitignore
+++ b/.gitignore
@@ -205,6 +205,7 @@
/lib/dart/**/pubspec.lock
/lib/delphi/test/skip/*.request
/lib/delphi/test/skip/*.response
+/lib/delphi/test/serializer/*.dat
/lib/delphi/**/*.identcache
/lib/delphi/**/*.local
/lib/delphi/**/*.dcu
diff --git a/CHANGES.md b/CHANGES.md
index 72c7eb8..e179a63 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -1,5 +1,16 @@
# Apache Thrift Changelog
+## 0.14.0
+
+### Breaking Changes
+
+- [THRIFT-4990](https://issues.apache.org/jira/browse/THRIFT-4990) - Upgrade to .NET Core 3.1 (LTS)
+- [THRIFT-4981](https://issues.apache.org/jira/browse/THRIFT-4981) - Remove deprecated netcore bindings from the code base
+- [THRIFT-5006](https://issues.apache.org/jira/browse/THRIFT-5006) - Implement DEFAULT_MAX_LENGTH at TFramedTransport
+
+### Java
+
+- [THRIFT-5022](https://issues.apache.org/jira/browse/THRIFT-5022) - TIOStreamTransport.isOpen returns true for one-sided transports (see THRIFT-2530).
## 0.13.0
### New Languages
diff --git a/LANGUAGES.md b/LANGUAGES.md
index 923b045..afd7799 100644
--- a/LANGUAGES.md
+++ b/LANGUAGES.md
@@ -315,7 +315,7 @@
<!-- Since -----------------><td>0.2.0</td>
<!-- Build Systems ---------><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td>
<!-- Language Levels -------><td>2.7.12, 3.5.2</td><td>2.7.15, 3.6.8</td>
-<!-- Low-Level Transports --><td><img src="doc/images/cred.png" alt=""/></td><td><img src="doc/images/cred.png" alt=""/></td><td><img src="doc/images/cred.png" alt=""/></td><td><img src="doc/images/cred.png" alt=""/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td>
+<!-- Low-Level Transports --><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cred.png" alt=""/></td><td><img src="doc/images/cred.png" alt=""/></td><td><img src="doc/images/cred.png" alt=""/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td>
<!-- Transport Wrappers ----><td><img src="doc/images/cred.png" alt=""/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td>
<!-- Protocols -------------><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td>
<!-- Servers ---------------><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cgrn.png" alt="Yes"/></td><td><img src="doc/images/cred.png" alt=""/></td><td><img src="doc/images/cred.png" alt=""/></td>
diff --git a/build/appveyor/MSVC-appveyor-install.bat b/build/appveyor/MSVC-appveyor-install.bat
index bc4655a..09b7cc4 100644
--- a/build/appveyor/MSVC-appveyor-install.bat
+++ b/build/appveyor/MSVC-appveyor-install.bat
@@ -56,7 +56,8 @@
tornado ^
twisted || EXIT /B
-cinst -y ghc || EXIT /B
+cinst -y cabal --version 2.4.1.0 || EXIT /B
+cinst -y ghc --version 8.6.5 || EXIT /B
:: Adobe Flex SDK 4.6 for ActionScript
MKDIR "C:\Adobe\Flex\SDK\4.6" || EXIT /B
diff --git a/build/cmake/DefineOptions.cmake b/build/cmake/DefineOptions.cmake
index 778be8d..6a69c6d 100644
--- a/build/cmake/DefineOptions.cmake
+++ b/build/cmake/DefineOptions.cmake
@@ -151,7 +151,7 @@
message(STATUS " Build compiler: ${BUILD_COMPILER}")
message(STATUS " Build libraries: ${BUILD_LIBRARIES}")
message(STATUS " Build tests: ${BUILD_TESTING}")
-MESSAGE_DEP(HAVE_COMPILER "Disabled because BUILD_THRIFT=OFF and no valid THRIFT_COMPILER is given")
+MESSAGE_DEP(HAVE_COMPILER "Disabled because BUILD_COMPILER=OFF and no valid THRIFT_COMPILER is given")
message(STATUS " Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS)
message(STATUS "Language libraries:")
diff --git a/build/docker/README.md b/build/docker/README.md
index 08c1372..8c8be22 100644
--- a/build/docker/README.md
+++ b/build/docker/README.md
@@ -177,7 +177,6 @@
| d | 2.075.1 | 2.087.0 | |
| dart | 2.0.0 | 2.4.0 | |
| delphi | | | Not in CI |
-| dotnet | 2.1.503 | 2.2.300 | |
| erlang | 18.3 | 22.0 | |
| go | 1.10.8 | 1.12.6 | |
| haskell | 7.10.3 | 8.0.2 | |
@@ -185,6 +184,7 @@
| java | 1.8.0_191 | 11.0.3 | |
| js | | | Unsure how to look for version info? |
| lua | | 5.2.4 | Lua 5.3: see THRIFT-4386 |
+| netstd | 3.1 | 3.1 | LTS version |
| nodejs | 6.16.0 | 10.16.0 | |
| ocaml | | 4.05.0 | THRIFT-4517: ocaml 4.02.3 on xenial appears broken |
| perl | 5.22.1 | 5.26.1 | |
diff --git a/build/docker/ubuntu-bionic/Dockerfile b/build/docker/ubuntu-bionic/Dockerfile
index 5f9833f..79d698f 100644
--- a/build/docker/ubuntu-bionic/Dockerfile
+++ b/build/docker/ubuntu-bionic/Dockerfile
@@ -130,7 +130,7 @@
RUN apt-get install -y --no-install-recommends \
`# dotnet core dependencies` \
- dotnet-sdk-2.2
+ dotnet-sdk-3.1
RUN apt-get install -y --no-install-recommends \
`# Erlang dependencies` \
diff --git a/build/docker/ubuntu-disco/Dockerfile b/build/docker/ubuntu-disco/Dockerfile
index b017c4e..95a2c78 100644
--- a/build/docker/ubuntu-disco/Dockerfile
+++ b/build/docker/ubuntu-disco/Dockerfile
@@ -130,7 +130,7 @@
RUN apt-get install -y --no-install-recommends \
`# dotnet core dependencies` \
- dotnet-sdk-2.2
+ dotnet-sdk-3.1
RUN apt-get install -y --no-install-recommends \
`# Erlang dependencies` \
diff --git a/build/docker/ubuntu-xenial/Dockerfile b/build/docker/ubuntu-xenial/Dockerfile
index 8dc6497..8df0887 100644
--- a/build/docker/ubuntu-xenial/Dockerfile
+++ b/build/docker/ubuntu-xenial/Dockerfile
@@ -127,7 +127,7 @@
RUN apt-get install -y --no-install-recommends \
`# dotnet core dependencies` \
- dotnet-sdk-2.1
+ dotnet-sdk-3.1
RUN apt-get install -y --no-install-recommends \
`# Erlang dependencies` \
diff --git a/compiler/cpp/src/thrift/generate/t_cpp_generator.cc b/compiler/cpp/src/thrift/generate/t_cpp_generator.cc
index d66b6e6..896c43f 100644
--- a/compiler/cpp/src/thrift/generate/t_cpp_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_cpp_generator.cc
@@ -504,6 +504,7 @@
* @param ttypedef The type definition
*/
void t_cpp_generator::generate_typedef(t_typedef* ttypedef) {
+ generate_java_doc(f_types_, ttypedef);
f_types_ << indent() << "typedef " << type_name(ttypedef->get_type(), true) << " "
<< ttypedef->get_symbolic() << ";" << endl << endl;
}
@@ -524,6 +525,7 @@
} else {
f << "," << endl;
}
+ generate_java_doc(f, *c_iter);
indent(f) << prefix << (*c_iter)->get_name() << suffix;
if (include_values) {
f << " = " << (*c_iter)->get_value();
@@ -547,6 +549,7 @@
std::string enum_name = tenum->get_name();
if (!gen_pure_enums_) {
enum_name = "type";
+ generate_java_doc(f_types_, tenum);
f_types_ << indent() << "struct " << tenum->get_name() << " {" << endl;
indent_up();
}
@@ -1075,6 +1078,8 @@
out << endl;
+ generate_java_doc(out, tstruct);
+
// Open struct def
out << indent() << "class " << tstruct->get_name() << extends << " {" << endl << indent()
<< " public:" << endl << endl;
@@ -1147,6 +1152,7 @@
// Declare all fields
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+ generate_java_doc(out, *m_iter);
indent(out) << declare_field(*m_iter,
false,
(pointers && !(*m_iter)->get_type()->is_xception()),
@@ -1933,6 +1939,9 @@
if (style == "CobCl" && gen_templates_) {
f_header_ << "template <class Protocol_>" << endl;
}
+
+ generate_java_doc(f_header_, tservice);
+
f_header_ << "class " << service_if_name << extends << " {" << endl << " public:" << endl;
indent_up();
f_header_ << indent() << "virtual ~" << service_if_name << "() {}" << endl;
@@ -2225,6 +2234,7 @@
indent_up();
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+ generate_java_doc(f_header_, *f_iter);
t_struct* arglist = (*f_iter)->get_arglist();
const vector<t_field*>& args = arglist->get_members();
vector<t_field*>::const_iterator a_iter;
@@ -2438,6 +2448,8 @@
}
if (style == "Cob") {
+ generate_java_doc(f_header_, tservice);
+
f_header_ << indent()
<< "::std::shared_ptr< ::apache::thrift::async::TAsyncChannel> getChannel() {" << endl
<< indent() << " return " << _this << "channel_;" << endl << indent() << "}" << endl;
@@ -2449,6 +2461,7 @@
vector<t_function*> functions = tservice->get_functions();
vector<t_function*>::const_iterator f_iter;
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
+ generate_java_doc(f_header_, *f_iter);
indent(f_header_) << function_signature(*f_iter, ifstyle) << ";" << endl;
// TODO(dreiss): Use private inheritance to avoid generating thise in cob-style.
if (style == "Concurrent" && !(*f_iter)->is_oneway()) {
diff --git a/compiler/cpp/src/thrift/generate/t_delphi_generator.cc b/compiler/cpp/src/thrift/generate/t_delphi_generator.cc
index 4a2ebda..cffe305 100644
--- a/compiler/cpp/src/thrift/generate/t_delphi_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_delphi_generator.cc
@@ -152,11 +152,13 @@
void generate_delphi_struct_writer_impl(ostream& out,
std::string cls_prefix,
t_struct* tstruct,
- bool is_exception);
+ bool is_exception,
+ bool is_x_factory);
void generate_delphi_struct_result_writer_impl(ostream& out,
std::string cls_prefix,
t_struct* tstruct,
- bool is_exception);
+ bool is_exception,
+ bool is_x_factory);
void generate_delphi_struct_tostring_impl(ostream& out,
std::string cls_prefix,
@@ -169,7 +171,8 @@
void generate_delphi_struct_reader_impl(ostream& out,
std::string cls_prefix,
t_struct* tstruct,
- bool is_exception);
+ bool is_exception,
+ bool is_x_factory);
void generate_delphi_create_exception_impl(ostream& out,
string cls_prefix,
t_struct* tstruct,
@@ -1532,11 +1535,30 @@
indent_impl(out) << "begin" << endl;
indent_up_impl();
indent_impl(out) << "if F" << exception_factory_name << " = nil" << endl;
- indent_impl(out) << "then F" << exception_factory_name << " := T" << exception_factory_name << "Impl.Create;" << endl;
- indent_impl(out) << endl;
+ indent_impl(out) << "then F" << exception_factory_name << " := T" << exception_factory_name << "Impl.Create;" << endl << endl;
indent_impl(out) << "result := F" << exception_factory_name << ";" << endl;
indent_down_impl();
indent_impl(out) << "end;" << endl << endl;
+ indent_impl(out) << "function " << cls_prefix << cls_nm << ".QueryInterface(const IID: TGUID; out Obj): HRESULT;" << endl;
+ indent_impl(out) << "begin" << endl;
+ indent_up_impl();
+ indent_impl(out) << "if GetInterface(IID, Obj)" << endl;
+ indent_impl(out) << "then result := S_OK" << endl;
+ indent_impl(out) << "else result := E_NOINTERFACE;" << endl;
+ indent_down_impl();
+ indent_impl(out) << "end;" << endl << endl;
+ indent_impl(out) << "function " << cls_prefix << cls_nm << "._AddRef: Integer;" << endl;
+ indent_impl(out) << "begin" << endl;
+ indent_up_impl();
+ indent_impl(out) << "result := -1; // not refcounted" << endl;
+ indent_down_impl();
+ indent_impl(out) << "end;" << endl << endl;
+ indent_impl(out) << "function " << cls_prefix << cls_nm << "._Release: Integer;" << endl;
+ indent_impl(out) << "begin" << endl;
+ indent_up_impl();
+ indent_impl(out) << "result := -1; // not refcounted" << endl;
+ indent_down_impl();
+ indent_impl(out) << "end;" << endl << endl;
}
if (tstruct->is_union()) {
@@ -1586,13 +1608,11 @@
}
}
- if ((!is_exception) || is_x_factory) {
- generate_delphi_struct_reader_impl(out, cls_prefix, tstruct, is_exception);
- if (is_result) {
- generate_delphi_struct_result_writer_impl(out, cls_prefix, tstruct, is_exception);
- } else {
- generate_delphi_struct_writer_impl(out, cls_prefix, tstruct, is_exception);
- }
+ generate_delphi_struct_reader_impl(out, cls_prefix, tstruct, is_exception, is_x_factory);
+ if (is_result) {
+ generate_delphi_struct_result_writer_impl(out, cls_prefix, tstruct, is_exception, is_x_factory);
+ } else {
+ generate_delphi_struct_writer_impl(out, cls_prefix, tstruct, is_exception, is_x_factory);
}
generate_delphi_struct_tostring_impl(out, cls_prefix, tstruct, is_exception, is_x_factory);
@@ -1741,7 +1761,7 @@
}
out << "class(";
if (is_exception && (!is_x_factory)) {
- out << "TException";
+ out << "TException, IInterface, IBase, ISupportsToString";
} else {
out << "TInterfacedObject, IBase, ISupportsToString, " << struct_intf_name;
}
@@ -1801,8 +1821,18 @@
}
}
- indent_down();
+ if (is_exception && (!is_x_factory)) {
+ out << endl;
+ indent_down();
+ indent(out) << "strict protected" << endl;
+ indent_up();
+ indent(out) << "function QueryInterface(const IID: TGUID; out Obj): HRESULT; stdcall;" << endl;
+ indent(out) << "function _AddRef: Integer; stdcall;" << endl;
+ indent(out) << "function _Release: Integer; stdcall;" << endl;
+ out << endl;
+ }
+ indent_down();
indent(out) << "public" << endl;
indent_up();
@@ -1825,12 +1855,10 @@
indent(out) << "function " << exception_factory_name << ": " << struct_intf_name << ";" << endl;
}
- if ((!is_exception) || is_x_factory) {
- out << endl;
- indent(out) << "// IBase" << endl;
- indent(out) << "procedure Read( const iprot: IProtocol);" << endl;
- indent(out) << "procedure Write( const oprot: IProtocol);" << endl;
- }
+ out << endl;
+ indent(out) << "// IBase" << endl;
+ indent(out) << "procedure Read( const iprot: IProtocol);" << endl;
+ indent(out) << "procedure Write( const oprot: IProtocol);" << endl;
if (is_exception && is_x_factory) {
out << endl;
@@ -2163,9 +2191,7 @@
indent_impl(s_service_impl) << "begin" << endl;
indent_up_impl();
indent_impl(s_service_impl) << msgvar << " := iprot_.ReadMessageBegin();" << endl;
- indent_impl(s_service_impl) << "if (" << msgvar << ".Type_ = TMessageType.Exception) then"
- << endl;
- indent_impl(s_service_impl) << "begin" << endl;
+ indent_impl(s_service_impl) << "if (" << msgvar << ".Type_ = TMessageType.Exception) then begin" << endl;
indent_up_impl();
indent_impl(s_service_impl) << appexvar << " := TApplicationException.Read(iprot_);" << endl;
indent_impl(s_service_impl) << "iprot_.ReadMessageEnd();" << endl;
@@ -2178,8 +2204,7 @@
indent_impl(s_service_impl) << "iprot_.ReadMessageEnd();" << endl;
if (!(*f_iter)->get_returntype()->is_void()) {
- indent_impl(s_service_impl) << "if (" << retvar << ".__isset_success) then" << endl;
- indent_impl(s_service_impl) << "begin" << endl;
+ indent_impl(s_service_impl) << "if (" << retvar << ".__isset_success) then begin" << endl;
indent_up_impl();
indent_impl(s_service_impl) << "Result := " << retvar << ".Success;" << endl;
t_type* type = (*f_iter)->get_returntype();
@@ -2195,8 +2220,7 @@
vector<t_field*>::const_iterator x_iter;
for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
indent_impl(s_service_impl) << "if (" << retvar << ".__isset_" << prop_name(*x_iter)
- << ") then" << endl;
- indent_impl(s_service_impl) << "begin" << endl;
+ << ") then begin" << endl;
indent_up_impl();
indent_impl(s_service_impl) << exceptvar << " := " << retvar << "." << prop_name(*x_iter)
<< ".CreateException;" << endl;
@@ -2324,8 +2348,7 @@
indent_impl(s_service_impl) << "msg := iprot.ReadMessageBegin();" << endl;
indent_impl(s_service_impl) << "fn := nil;" << endl;
indent_impl(s_service_impl) << "if not processMap_.TryGetValue(msg.Name, fn)" << endl;
- indent_impl(s_service_impl) << "or not Assigned(fn) then" << endl;
- indent_impl(s_service_impl) << "begin" << endl;
+ indent_impl(s_service_impl) << "or not Assigned(fn) then begin" << endl;
indent_up_impl();
indent_impl(s_service_impl) << "TProtocolUtil.Skip(iprot, TType.Struct);" << endl;
indent_impl(s_service_impl) << "iprot.ReadMessageEnd();" << endl;
@@ -2716,8 +2739,7 @@
indent_impl(out) << obj << " := iprot.ReadListBegin();" << endl;
}
- indent_impl(out) << "for " << counter << " := 0 to " << obj << ".Count - 1 do" << endl;
- indent_impl(out) << "begin" << endl;
+ indent_impl(out) << "for " << counter << " := 0 to " << obj << ".Count - 1 do begin" << endl;
indent_up_impl();
if (ttype->is_map()) {
generate_deserialize_map_element(out, is_xception, (t_map*)ttype, name, local_vars);
@@ -2904,20 +2926,17 @@
string iter = tmp("_iter");
if (ttype->is_map()) {
local_vars << " " << iter << ": " << type_name(((t_map*)ttype)->get_key_type()) << ";" << endl;
- indent_impl(out) << "for " << iter << " in " << prefix << ".Keys do" << endl;
- indent_impl(out) << "begin" << endl;
+ indent_impl(out) << "for " << iter << " in " << prefix << ".Keys do begin" << endl;
indent_up_impl();
} else if (ttype->is_set()) {
local_vars << " " << iter << ": " << type_name(((t_set*)ttype)->get_elem_type()) << ";"
<< endl;
- indent_impl(out) << "for " << iter << " in " << prefix << " do" << endl;
- indent_impl(out) << "begin" << endl;
+ indent_impl(out) << "for " << iter << " in " << prefix << " do begin" << endl;
indent_up_impl();
} else if (ttype->is_list()) {
local_vars << " " << iter << ": " << type_name(((t_list*)ttype)->get_elem_type()) << ";"
<< endl;
- indent_impl(out) << "for " << iter << " in " << prefix << " do" << endl;
- indent_impl(out) << "begin" << endl;
+ indent_impl(out) << "for " << iter << " in " << prefix << " do begin" << endl;
indent_up_impl();
}
@@ -3575,7 +3594,8 @@
void t_delphi_generator::generate_delphi_struct_reader_impl(ostream& out,
string cls_prefix,
t_struct* tstruct,
- bool is_exception) {
+ bool is_exception,
+ bool is_x_factory) {
ostringstream local_vars;
ostringstream code_block;
@@ -3604,32 +3624,28 @@
indent_impl(code_block) << "try" << endl;
indent_up_impl();
- indent_impl(code_block) << "while (true) do" << endl;
- indent_impl(code_block) << "begin" << endl;
+ indent_impl(code_block) << "while (true) do begin" << endl;
indent_up_impl();
indent_impl(code_block) << "field_ := iprot.ReadFieldBegin();" << endl;
- indent_impl(code_block) << "if (field_.Type_ = TType.Stop) then" << endl;
- indent_impl(code_block) << "begin" << endl;
- indent_up_impl();
- indent_impl(code_block) << "Break;" << endl;
- indent_down_impl();
- indent_impl(code_block) << "end;" << endl;
+ indent_impl(code_block) << "if (field_.Type_ = TType.Stop) then Break;" << endl;
bool first = true;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
if (first) {
+ code_block << endl;
indent_impl(code_block) << "case field_.ID of" << endl;
indent_up_impl();
}
first = false;
if (f_iter != fields.begin()) {
- code_block << ";" << endl;
+ code_block << endl;
}
+
indent_impl(code_block) << (*f_iter)->get_key() << ": begin" << endl;
indent_up_impl();
indent_impl(code_block) << "if (field_.Type_ = " << type_to_enum((*f_iter)->get_type())
@@ -3652,12 +3668,13 @@
indent_down_impl();
indent_impl(code_block) << "end;" << endl;
indent_down_impl();
- indent_impl(code_block) << "end";
+ indent_impl(code_block) << "end;";
}
if (!first) {
code_block << endl;
- indent_impl(code_block) << "else begin" << endl;
+ indent_down_impl();
+ indent_impl(code_block) << "else" << endl;
indent_up_impl();
}
@@ -3666,8 +3683,6 @@
if (!first) {
indent_down_impl();
indent_impl(code_block) << "end;" << endl;
- indent_down_impl();
- indent_impl(code_block) << "end;" << endl;
}
indent_impl(code_block) << "iprot.ReadFieldEnd;" << endl;
@@ -3684,8 +3699,13 @@
indent_impl(code_block) << "end;" << endl;
// all required fields have been read?
+ first = true;
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
if ((*f_iter)->get_req() == t_field::T_REQUIRED) {
+ if(first) {
+ code_block << endl;
+ first = false;
+ }
indent_impl(code_block) << "if not _req_isset_" << prop_name(*f_iter, is_exception) << endl;
indent_impl(code_block)
<< "then raise TProtocolExceptionInvalidData.Create("
@@ -3693,13 +3713,17 @@
<< endl;
}
}
-
+
+ if( is_exception && (!is_x_factory)) {
+ code_block << endl;
+ indent_impl(code_block) << "UpdateMessageProperty;" << endl;
+ }
indent_down_impl();
indent_impl(code_block) << "end;" << endl << endl;
string cls_nm;
- cls_nm = type_name(tstruct, true, false, is_exception, is_exception);
+ cls_nm = type_name(tstruct, true, is_exception && (!is_x_factory), is_x_factory, is_x_factory);
indent_impl(out) << "procedure " << cls_prefix << cls_nm << ".Read( const iprot: IProtocol);"
<< endl;
@@ -3715,7 +3739,8 @@
void t_delphi_generator::generate_delphi_struct_result_writer_impl(ostream& out,
string cls_prefix,
t_struct* tstruct,
- bool is_exception) {
+ bool is_exception,
+ bool is_x_factory) {
ostringstream local_vars;
ostringstream code_block;
@@ -3759,7 +3784,7 @@
string cls_nm;
- cls_nm = type_name(tstruct, true, false, is_exception, is_exception);
+ cls_nm = type_name(tstruct, true, is_exception && (!is_x_factory), is_x_factory, is_x_factory);
indent_impl(out) << "procedure " << cls_prefix << cls_nm << ".Write( const oprot: IProtocol);"
<< endl;
@@ -3779,7 +3804,8 @@
void t_delphi_generator::generate_delphi_struct_writer_impl(ostream& out,
string cls_prefix,
t_struct* tstruct,
- bool is_exception) {
+ bool is_exception,
+ bool is_x_factory) {
ostringstream local_vars;
ostringstream code_block;
@@ -3847,7 +3873,7 @@
string cls_nm;
- cls_nm = type_name(tstruct, true, false, is_exception, is_exception);
+ cls_nm = type_name(tstruct, true, is_exception && (!is_x_factory), is_x_factory, is_x_factory);
indent_impl(out) << "procedure " << cls_prefix << cls_nm << ".Write( const oprot: IProtocol);"
<< endl;
diff --git a/compiler/cpp/src/thrift/generate/t_go_generator.cc b/compiler/cpp/src/thrift/generate/t_go_generator.cc
index 33b7547..2093841 100644
--- a/compiler/cpp/src/thrift/generate/t_go_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_go_generator.cc
@@ -311,6 +311,7 @@
std::string camelcase(const std::string& value) const;
void fix_common_initialism(std::string& value, int i) const;
std::string publicize(const std::string& value, bool is_args_or_result = false) const;
+ std::string publicize(const std::string& value, bool is_args_or_result, const std::string& service_name) const;
std::string privatize(const std::string& value) const;
std::string new_prefix(const std::string& value) const;
static std::string variable_name_to_go_name(const std::string& value);
@@ -464,7 +465,7 @@
}
}
-std::string t_go_generator::publicize(const std::string& value, bool is_args_or_result) const {
+std::string t_go_generator::publicize(const std::string& value, bool is_args_or_result, const std::string& service_name) const {
if (value.size() <= 0) {
return value;
}
@@ -506,12 +507,16 @@
// Avoid naming collisions with other services
if (is_args_or_result) {
- prefix += publicize(service_name_);
+ prefix += publicize(service_name);
}
return prefix + value2;
}
+std::string t_go_generator::publicize(const std::string& value, bool is_args_or_result) const {
+ return publicize(value, is_args_or_result, service_name_);
+}
+
std::string t_go_generator::new_prefix(const std::string& value) const {
if (value.size() <= 0) {
return value;
@@ -772,11 +777,16 @@
const vector<t_program*>& includes = program_->get_includes();
string result = "";
string local_namespace = program_->get_namespace("go");
+ std::set<std::string> included;
for (auto include : includes) {
if (!local_namespace.empty() && local_namespace == include->get_namespace("go")) {
continue;
}
+ if (!included.insert(include->get_namespace("go")).second) {
+ continue;
+ }
+
result += render_program_import(include, unused_prot);
}
return result;
@@ -2121,13 +2131,26 @@
* @param tservice The service to generate a remote for.
*/
void t_go_generator::generate_service_remote(t_service* tservice) {
- vector<t_function*> functions = tservice->get_functions();
- t_service* parent = tservice->get_extends();
+ vector<t_function*> functions;
+ std::unordered_map<std::string, std::string> func_to_service;
- // collect inherited functions
+ // collect all functions including inherited functions
+ t_service* parent = tservice;
while (parent != NULL) {
vector<t_function*> p_functions = parent->get_functions();
functions.insert(functions.end(), p_functions.begin(), p_functions.end());
+
+ // We need to maintain a map of functions names to the name of their parent.
+ // This is because functions may come from a parent service, and if we need
+ // to create the arguments struct (e.g. `NewParentServiceNameFuncNameArgs()`)
+ // we need to make sure to specify the correct service name.
+ for (vector<t_function*>::iterator f_iter = p_functions.begin(); f_iter != p_functions.end(); ++f_iter) {
+ auto it = func_to_service.find((*f_iter)->get_name());
+ if (it == func_to_service.end()) {
+ func_to_service.emplace((*f_iter)->get_name(), parent->get_name());
+ }
+ }
+
parent = parent->get_extends();
}
@@ -2340,7 +2363,7 @@
std::vector<t_field*>::size_type num_args = args.size();
string funcName((*f_iter)->get_name());
string pubName(publicize(funcName));
- string argumentsName(publicize(funcName + "_args", true));
+ string argumentsName(publicize(funcName + "_args", true, func_to_service[funcName]));
f_remote << indent() << "case \"" << escape_string(funcName) << "\":" << endl;
indent_up();
f_remote << indent() << "if flag.NArg() - 1 != " << num_args << " {" << endl;
diff --git a/compiler/cpp/src/thrift/generate/t_netstd_generator.cc b/compiler/cpp/src/thrift/generate/t_netstd_generator.cc
index ffe51ab..ba55960 100644
--- a/compiler/cpp/src/thrift/generate/t_netstd_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_netstd_generator.cc
@@ -51,10 +51,11 @@
: t_oop_generator(program)
{
(void)option_string;
-
+ use_pascal_case_properties = false;
union_ = false;
serialize_ = false;
wcf_ = false;
+
wcf_namespace_.clear();
map<string, string>::const_iterator iter;
@@ -75,9 +76,12 @@
wcf_ = true;
wcf_namespace_ = iter->second;
}
- else
+ else if (iter->first.compare("pascal") == 0)
{
- throw "unknown option netstd:" + iter->first;
+ use_pascal_case_properties = true;
+ }
+ else {
+ throw "unknown option netstd:" + iter->first;
}
}
@@ -188,6 +192,7 @@
pverbose("- union ...... %s\n", (is_union_enabled() ? "ON" : "off"));
pverbose("- serialize .. %s\n", (is_serialize_enabled() ? "ON" : "off"));
pverbose("- wcf ........ %s\n", (is_wcf_enabled() ? "ON" : "off"));
+ pverbose("- pascal ..... %s\n", (use_pascal_case_properties ? "ON" : "off"));
}
string t_netstd_generator::normalize_name(string name)
@@ -2674,18 +2679,41 @@
}
}
-string t_netstd_generator::prop_name(t_field* tfield, bool suppress_mapping)
-{
- string name(tfield->get_name());
- if (suppress_mapping)
- {
- name[0] = toupper(name[0]);
+
+string t_netstd_generator::convert_to_pascal_case(const string& str) {
+ string out;
+ bool must_capitalize = true;
+ bool first_character = true;
+ for (auto it = str.begin(); it != str.end(); ++it) {
+ if (std::isalnum(*it)) {
+ if (must_capitalize) {
+ out.append(1, (char)::toupper(*it));
+ must_capitalize = false;
+ } else {
+ out.append(1, *it);
+ }
+ } else {
+ if (first_character) //this is a private variable and should not be PascalCased
+ return str;
+ must_capitalize = true;
}
- else
- {
- name = get_mapped_member_name(name);
- }
- return name;
+ first_character = false;
+ }
+ return out;
+}
+
+
+string t_netstd_generator::prop_name(t_field* tfield, bool suppress_mapping) {
+ string name(tfield->get_name());
+ if (suppress_mapping) {
+ name[0] = toupper(name[0]);
+ if (use_pascal_case_properties)
+ name = t_netstd_generator::convert_to_pascal_case(name);
+ } else {
+ name = get_mapped_member_name(name);
+ }
+
+ return name;
}
string t_netstd_generator::type_name(t_type* ttype)
@@ -3020,4 +3048,5 @@
" wcf: Adds bindings for WCF to generated classes.\n"
" serial: Add serialization support to generated classes.\n"
" union: Use new union typing, which includes a static read function for union types.\n"
+ " pascal: Generate Pascal Case property names according to Microsoft naming convention.\n"
)
diff --git a/compiler/cpp/src/thrift/generate/t_netstd_generator.h b/compiler/cpp/src/thrift/generate/t_netstd_generator.h
index fd9e6c0..1e23f91 100644
--- a/compiler/cpp/src/thrift/generate/t_netstd_generator.h
+++ b/compiler/cpp/src/thrift/generate/t_netstd_generator.h
@@ -134,6 +134,7 @@
string argument_list(t_struct* tstruct);
string type_to_enum(t_type* ttype);
string prop_name(t_field* tfield, bool suppress_mapping = false);
+ string convert_to_pascal_case(const string& str);
string get_enum_class_name(t_type* type);
private:
@@ -145,6 +146,7 @@
bool hashcode_;
bool serialize_;
bool wcf_;
+ bool use_pascal_case_properties;
string wcf_namespace_;
map<string, int> netstd_keywords;
diff --git a/compiler/cpp/src/thrift/generate/t_py_generator.cc b/compiler/cpp/src/thrift/generate/t_py_generator.cc
index 982bca1..e93bbe1 100644
--- a/compiler/cpp/src/thrift/generate/t_py_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_py_generator.cc
@@ -65,6 +65,7 @@
coding_ = "";
gen_dynbaseclass_ = "";
gen_dynbaseclass_exc_ = "";
+ gen_dynbaseclass_frozen_exc_ = "";
gen_dynbaseclass_frozen_ = "";
import_dynbase_ = "";
package_prefix_ = "";
@@ -94,8 +95,11 @@
if( gen_dynbaseclass_exc_.empty()) {
gen_dynbaseclass_exc_ = "TExceptionBase";
}
+ if( gen_dynbaseclass_frozen_exc_.empty()) {
+ gen_dynbaseclass_frozen_exc_ = "TFrozenExceptionBase";
+ }
if( import_dynbase_.empty()) {
- import_dynbase_ = "from thrift.protocol.TBase import TBase, TFrozenBase, TExceptionBase, TTransport\n";
+ import_dynbase_ = "from thrift.protocol.TBase import TBase, TFrozenBase, TExceptionBase, TFrozenExceptionBase, TTransport\n";
}
} else if( iter->first.compare("dynbase") == 0) {
gen_dynbase_ = true;
@@ -104,6 +108,8 @@
gen_dynbaseclass_frozen_ = (iter->second);
} else if( iter->first.compare("dynexc") == 0) {
gen_dynbaseclass_exc_ = (iter->second);
+ } else if( iter->first.compare("dynfrozenexc") == 0) {
+ gen_dynbaseclass_frozen_exc_ = (iter->second);
} else if( iter->first.compare("dynimport") == 0) {
gen_dynbase_ = true;
import_dynbase_ = (iter->second);
@@ -269,7 +275,16 @@
}
static bool is_immutable(t_type* ttype) {
- return ttype->annotations_.find("python.immutable") != ttype->annotations_.end();
+ std::map<std::string, std::string>::iterator it = ttype->annotations_.find("python.immutable");
+
+ if (it == ttype->annotations_.end()) {
+ // Exceptions are immutable by default.
+ return ttype->is_xception();
+ } else if (it->second == "false") {
+ return false;
+ } else {
+ return true;
+ }
}
private:
@@ -288,6 +303,7 @@
std::string gen_dynbaseclass_;
std::string gen_dynbaseclass_frozen_;
std::string gen_dynbaseclass_exc_;
+ std::string gen_dynbaseclass_frozen_exc_;
std::string import_dynbase_;
@@ -742,7 +758,11 @@
out << endl << endl << "class " << tstruct->get_name();
if (is_exception) {
if (gen_dynamic_) {
- out << "(" << gen_dynbaseclass_exc_ << ")";
+ if (is_immutable(tstruct)) {
+ out << "(" << gen_dynbaseclass_frozen_exc_ << ")";
+ } else {
+ out << "(" << gen_dynbaseclass_exc_ << ")";
+ }
} else {
out << "(TException)";
}
@@ -2774,6 +2794,7 @@
" dynbase=CLS Derive generated classes from class CLS instead of TBase.\n"
" dynfrozen=CLS Derive generated immutable classes from class CLS instead of TFrozenBase.\n"
" dynexc=CLS Derive generated exceptions from CLS instead of TExceptionBase.\n"
+ " dynfrozenexc=CLS Derive generated immutable exceptions from CLS instead of TFrozenExceptionBase.\n"
" dynimport='from foo.bar import CLS'\n"
" Add an import line to generated code to find the dynbase class.\n"
" package_prefix='top.package.'\n"
diff --git a/compiler/cpp/src/thrift/generate/t_swift_generator.cc b/compiler/cpp/src/thrift/generate/t_swift_generator.cc
index eb746c1..4a2f87d 100644
--- a/compiler/cpp/src/thrift/generate/t_swift_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_swift_generator.cc
@@ -697,7 +697,9 @@
}
block_open(out);
- for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
+ vector<t_field*> sorted = members;
+ sort(sorted.begin(), sorted.end(), [](t_field *a, t_field *b) { return (a->get_key() < b->get_key()); } );
+ for (m_iter = sorted.begin(); m_iter != sorted.end(); ++m_iter) {
out << endl;
// TODO: Defaults
diff --git a/configure.ac b/configure.ac
index d5d7ecb..64283a8 100755
--- a/configure.ac
+++ b/configure.ac
@@ -505,7 +505,7 @@
if test "$with_netstd" = "yes"; then
AC_PATH_PROG([DOTNETCORE], [dotnet])
if [[ -x "$DOTNETCORE" ]] ; then
- AX_PROG_DOTNETCORE_VERSION( [2.0.0], have_netstd="yes", have_netstd="no")
+ AX_PROG_DOTNETCORE_VERSION( [3.1.0], have_netstd="yes", have_netstd="no")
fi
fi
AM_CONDITIONAL(WITH_DOTNET, [test "$have_netstd" = "yes"])
diff --git a/doc/specs/thrift-compact-protocol.md b/doc/specs/thrift-compact-protocol.md
index 02467dd..6be2a62 100644
--- a/doc/specs/thrift-compact-protocol.md
+++ b/doc/specs/thrift-compact-protocol.md
@@ -97,8 +97,9 @@
### Double encoding
Values of type `double` are first converted to an int64 according to the IEEE 754 floating-point "double format" bit
-layout. Most run-times provide a library to make this conversion. Both the binary protocol as the compact protocol then
-encode the int64 in 8 bytes in big endian order.
+layout. Most run-times provide a library to make this conversion. But while the binary protocol encodes the int64
+in 8 bytes in big endian order, the compact protocol encodes it in little endian order - this is due to an early
+implementation bug that finally became the de-facto standard.
### Boolean encoding
diff --git a/doc/specs/thrift-tconfiguration.md b/doc/specs/thrift-tconfiguration.md
new file mode 100644
index 0000000..e7736cf
--- /dev/null
+++ b/doc/specs/thrift-tconfiguration.md
@@ -0,0 +1,92 @@
+Thrift TConfiguration
+====================================================================
+
+Last Modified: 2019-Dec-03
+
+<!--
+--------------------------------------------------------------------
+
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+
+--------------------------------------------------------------------
+-->
+
+Starting with THRIFT-5021 the need to centralize certain limit settings that are used throughout the whole protocol / transport stack became an obvious need. Previous patches already added some of these limits, but they were not consistently managed and just randomly distributed across the code base.
+
+# Design goals
+
+Following the tradition of similar experience across languages in Thrift, any implementation should meet these design goals:
+
+ * There MUST be a standard CTOR (or equivalent thereof) that provides a default TConfiguration instance.
+ * The default values used SHOULD be implemented as outlined below.
+ * For backwards compatibility, the protocol / transport stack should accept null TConfiguration argument, in which case it should fallback to a default instance automatically. This is to prevent from code-breaking changes as much as possible.
+
+# Implementation
+
+The new TConfiguration class or struct currently holds three settings:
+
+## MaxMessageSize
+
+The MaxMessageSize member defines the maximum size of a (received) message, in bytes. The default value is represented by a constant named DEFAULT_MAX_MESSAGE_SIZE, whose value is 100 * 1024 * 1024 bytes.
+
+## MaxFrameSize
+
+MaxFrameSize limits the size of one frame of data for the TFramedTransport. Since all implementations currently send messages in one frame only if TFramedTransport is used, this value may interfere with MaxMessageSize. In the case of an conflict, the smaller value of the two is used (see remark below). The default value is called DEFAULT_MAX_FRAME_SIZE and has a value of 16384000 bytes.
+
+## RecursionLimit
+
+The RecursionLimit defines, how deep structures may be nested into each other. The default named DEFAULT_RECURSION_DEPTH allows for structures nested up to 64 levels deep.
+
+# Further considerations
+
+## MaxFrameSize vs. MaxMessageSize
+
+The difference between the two options is, that MaxFrameSize exists much longer and it is used only in conjunction with TFramedTransport. In contrast, MaxMessageSize is intended to be a general device to be used with any transport or protocol.
+
+In order to combine both approaches in the most optimal way when using TFramedTransport, it is recommended that the implementation SHOULD update the remaining number of bytes to read based on the received frame size value for the current message.
+
+For calculation purposes it is important to know, that MaxFrameSize excludes the 4 bytes that hold the frame size, while MaxMessageSize is always looking at the whole data. Hence, when updating the remaining read byte count, the known message size should be set to frameSize + sizeof(i32).
+
+## Error handling
+
+If any limit is exceeded, an error should be thrown. Additionally, it may be helpful to check larger memory allocations against the remaining max number of bytes before the allocation attempt takes place.
+
+# Q&A
+
+## Is this a breaking change or not?
+
+There is actually two answers to that question.
+
+1. If done right, it should not be a breaking change vis-á-vis compiling your source code that uses Thrift.
+
+1. It may, however, be a breaking change in the way it limits the accepted overall size of messages or the accepted frame size. This behaviour is by design. If your application hits any of these limits during normal operation, it may require you to instantiate an actual TConfiguration and tweak the settings according to your needs.
+
+## Is splitting the general transport base class into Endpoint and Layered transport base classes necessary?
+
+No, it's not. However, it turned out that this split is a great help when it comes to managing the TConfiguration instance that is passed through the stack. Having two distinct base classes for each of the different transport types not only allows to implement a shared solution for this.
+
+The added benefit is, that a clear distinction between the two transport types makes the Thrift architectural idea much more clear to "newbie" developers.
+
+## I want to contribute an implementation of TConfiguration and I am not sure whether to pick class or struct?
+
+Short answer: Pick whatever is more efficient in the language of your choice.
+
+Technically, remember that the instance is passed down the stack and should therefore be cheap on copying. To ensure this and to make sure all pieces of the protocol / transport stack are really pointing to the same TConfiguration instance, we want to pass the instance **by reference** rather than by value.
+
+For example, in the C# language a class is a suitable choice for this, because classes are naturally reference parameters, while structs are not.
+
diff --git a/lib/cpp/src/thrift/server/TNonblockingServer.cpp b/lib/cpp/src/thrift/server/TNonblockingServer.cpp
index eea0427..26ffa68 100644
--- a/lib/cpp/src/thrift/server/TNonblockingServer.cpp
+++ b/lib/cpp/src/thrift/server/TNonblockingServer.cpp
@@ -492,7 +492,11 @@
case SOCKET_RECV:
// It is an error to be in this state if we already have all the data
- assert(readBufferPos_ < readWant_);
+ if (!(readBufferPos_ < readWant_)) {
+ GlobalOutput.printf("TNonblockingServer: frame size too short");
+ close();
+ return;
+ }
try {
// Read from the socket
diff --git a/lib/cpp/test/processor/Handlers.h b/lib/cpp/test/processor/Handlers.h
index 05d19ed..d72a23c 100644
--- a/lib/cpp/test/processor/Handlers.h
+++ b/lib/cpp/test/processor/Handlers.h
@@ -139,7 +139,7 @@
std::shared_ptr<EventLog> log_;
};
-#ifdef _WIN32
+#ifdef _MSC_VER
#pragma warning( push )
#pragma warning (disable : 4250 ) //inheriting methods via dominance
#endif
@@ -168,7 +168,7 @@
int32_t value_;
};
-#ifdef _WIN32
+#ifdef _MSC_VER
#pragma warning( pop )
#endif
diff --git a/lib/delphi/src/Thrift.Collections.pas b/lib/delphi/src/Thrift.Collections.pas
index 3b56fe2..ad852ac 100644
--- a/lib/delphi/src/Thrift.Collections.pas
+++ b/lib/delphi/src/Thrift.Collections.pas
@@ -65,9 +65,9 @@
end;
TThriftDictionaryImpl<TKey,TValue> = class( TInterfacedObject, IThriftDictionary<TKey,TValue>, IThriftContainer, ISupportsToString)
- private
+ strict private
FDictionaly : TDictionary<TKey,TValue>;
- protected
+ strict protected
function GetEnumerator: TEnumerator<TPair<TKey,TValue>>;
function GetKeys: TDictionary<TKey,TValue>.TKeyCollection;
@@ -142,9 +142,9 @@
end;
TThriftListImpl<T> = class( TInterfacedObject, IThriftList<T>, IThriftContainer, ISupportsToString)
- private
+ strict private
FList : TList<T>;
- protected
+ strict protected
function GetEnumerator: TEnumerator<T>;
function GetCapacity: Integer;
procedure SetCapacity(Value: Integer);
@@ -205,10 +205,10 @@
end;
THashSetImpl<TValue> = class( TInterfacedObject, IHashSet<TValue>, IThriftContainer, ISupportsToString)
- private
+ strict private
FDictionary : IThriftDictionary<TValue,Integer>;
FIsReadOnly: Boolean;
- protected
+ strict protected
function GetEnumerator: TEnumerator<TValue>;
function GetIsReadOnly: Boolean;
function GetCount: Integer;
diff --git a/lib/delphi/src/Thrift.Configuration.pas b/lib/delphi/src/Thrift.Configuration.pas
new file mode 100644
index 0000000..0cb11af
--- /dev/null
+++ b/lib/delphi/src/Thrift.Configuration.pas
@@ -0,0 +1,121 @@
+(*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *)
+
+unit Thrift.Configuration;
+
+interface
+
+uses
+ SysUtils, Generics.Collections, Generics.Defaults;
+
+const
+ DEFAULT_RECURSION_LIMIT = 64;
+ DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024; // 100 MB
+ DEFAULT_MAX_FRAME_SIZE = 16384000; // this value is used consistently across all Thrift libraries
+
+ DEFAULT_THRIFT_TIMEOUT = 5 * 1000; // ms
+
+type
+ IThriftConfiguration = interface
+ ['{ADD75449-1A67-4B78-9B75-502A1E338CFC}']
+ function GetRecursionLimit : Cardinal;
+ procedure SetRecursionLimit( const value : Cardinal);
+ function GetMaxFrameSize : Cardinal;
+ procedure SetMaxFrameSize( const value : Cardinal);
+ function GetMaxMessageSize : Cardinal;
+ procedure SetMaxMessageSize( const value : Cardinal);
+
+ property RecursionLimit : Cardinal read GetRecursionLimit write SetRecursionLimit;
+ property MaxFrameSize : Cardinal read GetMaxFrameSize write SetMaxFrameSize;
+ property MaxMessageSize : Cardinal read GetMaxMessageSize write SetMaxMessageSize;
+ end;
+
+
+ TThriftConfigurationImpl = class( TInterfacedObject, IThriftConfiguration)
+ strict protected
+ FRecursionLimit : Cardinal;
+ FMaxFrameSize : Cardinal;
+ FMaxMessageSize : Cardinal;
+
+ // IThriftConfiguration
+ function GetRecursionLimit : Cardinal;
+ procedure SetRecursionLimit( const value : Cardinal);
+ function GetMaxFrameSize : Cardinal;
+ procedure SetMaxFrameSize( const value : Cardinal);
+ function GetMaxMessageSize : Cardinal;
+ procedure SetMaxMessageSize( const value : Cardinal);
+
+ public
+ constructor Create;
+ end;
+
+
+implementation
+
+
+{ TThriftConfigurationImpl }
+
+
+constructor TThriftConfigurationImpl.Create;
+begin
+ inherited Create;
+
+ FRecursionLimit := DEFAULT_RECURSION_LIMIT;
+ FMaxFrameSize := DEFAULT_MAX_FRAME_SIZE;
+ FMaxMessageSize := DEFAULT_MAX_MESSAGE_SIZE;
+end;
+
+
+function TThriftConfigurationImpl.GetRecursionLimit: Cardinal;
+begin
+ result := FRecursionLimit;
+end;
+
+
+procedure TThriftConfigurationImpl.SetRecursionLimit(const value: Cardinal);
+begin
+ FRecursionLimit := value;
+end;
+
+
+function TThriftConfigurationImpl.GetMaxFrameSize: Cardinal;
+begin
+ result := FMaxFrameSize;
+end;
+
+
+procedure TThriftConfigurationImpl.SetMaxFrameSize(const value: Cardinal);
+begin
+ FMaxFrameSize := value;
+end;
+
+
+function TThriftConfigurationImpl.GetMaxMessageSize: Cardinal;
+begin
+ result := FMaxMessageSize;
+end;
+
+
+procedure TThriftConfigurationImpl.SetMaxMessageSize(const value: Cardinal);
+begin
+ FMaxMessageSize := value;
+end;
+
+
+end.
diff --git a/lib/delphi/src/Thrift.Exception.pas b/lib/delphi/src/Thrift.Exception.pas
index 5d15c36..88b1cfe 100644
--- a/lib/delphi/src/Thrift.Exception.pas
+++ b/lib/delphi/src/Thrift.Exception.pas
@@ -29,6 +29,8 @@
type
// base class for all Thrift exceptions
TException = class( SysUtils.Exception)
+ strict private
+ function GetMessageText : string;
public
function Message : string; // hide inherited property: allow read, but prevent accidental writes
procedure UpdateMessageProperty; // update inherited message property with toString()
@@ -45,17 +47,25 @@
// allow read (exception summary), but prevent accidental writes
// read will return the exception summary
begin
- result := Self.ToString;
+ result := Self.GetMessageText;
end;
+
procedure TException.UpdateMessageProperty;
// Update the inherited Message property to better conform to standard behaviour.
// Nice benefit: The IDE is now able to show the exception message again.
begin
- inherited Message := Self.ToString; // produces a summary text
+ inherited Message := Self.GetMessageText;
end;
+function TException.GetMessageText : string;
+// produces a summary text
+begin
+ result := Self.ToString;
+ if (result <> '') and (result[1] = '(')
+ then result := Copy(result,2,Length(result)-2);
+end;
end.
diff --git a/lib/delphi/src/Thrift.Processor.Multiplex.pas b/lib/delphi/src/Thrift.Processor.Multiplex.pas
index 8cf23db..ba77d94 100644
--- a/lib/delphi/src/Thrift.Processor.Multiplex.pas
+++ b/lib/delphi/src/Thrift.Processor.Multiplex.pas
@@ -62,19 +62,19 @@
TMultiplexedProcessorImpl = class( TInterfacedObject, IMultiplexedProcessor, IProcessor)
- private type
+ strict private type
// Our goal was to work with any protocol. In order to do that, we needed
// to allow them to call readMessageBegin() and get a TMessage in exactly
// the standard format, without the service name prepended to TMessage.name.
TStoredMessageProtocol = class( TProtocolDecorator)
- private
+ strict private
FMessageBegin : TThriftMessage;
public
constructor Create( const protocol : IProtocol; const aMsgBegin : TThriftMessage);
function ReadMessageBegin: TThriftMessage; override;
end;
- private
+ strict private
FServiceProcessorMap : TDictionary<String, IProcessor>;
FDefaultProcessor : IProcessor;
@@ -113,12 +113,6 @@
end;
-function TMultiplexedProcessorImpl.TStoredMessageProtocol.ReadMessageBegin: TThriftMessage;
-begin
- result := FMessageBegin;
-end;
-
-
constructor TMultiplexedProcessorImpl.Create;
begin
inherited Create;
@@ -136,6 +130,13 @@
end;
+function TMultiplexedProcessorImpl.TStoredMessageProtocol.ReadMessageBegin: TThriftMessage;
+begin
+ Reset;
+ result := FMessageBegin;
+end;
+
+
procedure TMultiplexedProcessorImpl.RegisterProcessor( const serviceName : String; const processor : IProcessor; const asDefault : Boolean);
begin
FServiceProcessorMap.Add( serviceName, processor);
diff --git a/lib/delphi/src/Thrift.Protocol.Compact.pas b/lib/delphi/src/Thrift.Protocol.Compact.pas
index 07cab9a..665cfc4 100644
--- a/lib/delphi/src/Thrift.Protocol.Compact.pas
+++ b/lib/delphi/src/Thrift.Protocol.Compact.pas
@@ -28,6 +28,7 @@
SysUtils,
Math,
Generics.Collections,
+ Thrift.Configuration,
Thrift.Transport,
Thrift.Protocol,
Thrift.Utils;
@@ -47,7 +48,7 @@
function GetProtocol( const trans: ITransport): IProtocol;
end;
- private const
+ strict private const
{ TODO
static TStruct ANONYMOUS_STRUCT = new TStruct("");
@@ -61,7 +62,7 @@
TYPE_BITS = Byte( $07); // 0000 0111
TYPE_SHIFT_AMOUNT = Byte( 5);
- private type
+ strict private type
// All of the on-wire type codes.
Types = (
STOP = $00,
@@ -79,7 +80,7 @@
STRUCT = $0C
);
- private const
+ strict private const
ttypeToCompactType : array[TType] of Types = (
Types.STOP, // Stop = 0,
Types(-1), // Void = 1,
@@ -115,7 +116,7 @@
TType.Struct // STRUCT
);
- private
+ strict private
// Used to keep track of the last field for the current and previous structs,
// so we can do the delta stuff.
lastField_ : TStack<Integer>;
@@ -123,19 +124,17 @@
// If we encounter a boolean field begin, save the TField here so it can
// have the value incorporated.
- private booleanField_ : TThriftField;
+ strict private booleanField_ : TThriftField;
// If we Read a field header, and it's a boolean field, save the boolean
// value here so that ReadBool can use it.
- private boolValue_ : ( unused, bool_true, bool_false);
+ strict private boolValue_ : ( unused, bool_true, bool_false);
public
constructor Create(const trans : ITransport);
destructor Destroy; override;
- procedure Reset;
-
- private
+ strict private
procedure WriteByteDirect( const b : Byte); overload;
// Writes a byte without any possibility of all that field header nonsense.
@@ -145,7 +144,7 @@
// TODO: make a permanent buffer like WriteVarint64?
procedure WriteVarint32( n : Cardinal);
- private
+ strict private
// The workhorse of WriteFieldBegin. It has the option of doing a 'type override'
// of the type header. This is used specifically in the boolean field case.
procedure WriteFieldBeginInternal( const field : TThriftField; typeOverride : Byte);
@@ -172,7 +171,7 @@
procedure WriteDouble( const dub: Double); override;
procedure WriteBinary( const b: TBytes); overload; override;
- private
+ private // unit visible stuff
class function DoubleToInt64Bits( const db : Double) : Int64;
class function Int64BitsToDouble( const i64 : Int64) : Double;
@@ -193,6 +192,10 @@
//Convert a Int64 into little-endian bytes in buf starting at off and going until off+7.
class procedure fixedLongToBytes( const n : Int64; var buf : TBytes);
+ strict protected
+ function GetMinSerializedSize( const aType : TType) : Integer; override;
+ procedure Reset; override;
+
public
function ReadMessageBegin: TThriftMessage; override;
procedure ReadMessageEnd(); override;
@@ -266,7 +269,7 @@
//--- TCompactProtocolImpl -------------------------------------------------
-constructor TCompactProtocolImpl.Create(const trans: ITransport);
+constructor TCompactProtocolImpl.Create( const trans : ITransport);
begin
inherited Create( trans);
@@ -291,6 +294,7 @@
procedure TCompactProtocolImpl.Reset;
begin
+ inherited Reset;
lastField_.Clear();
lastFieldId_ := 0;
Init( booleanField_, '', TType.Stop, 0);
@@ -735,6 +739,7 @@
val := getTType( Byte( keyAndValueType and $F));
Init( result, key, val, size);
ASSERT( (result.KeyType = key) and (result.ValueType = val));
+ CheckReadBytesAvailable(result);
end;
@@ -755,6 +760,7 @@
type_ := getTType( size_and_type);
Init( result, type_, size);
+ CheckReadBytesAvailable(result);
end;
@@ -775,6 +781,7 @@
type_ := getTType( size_and_type);
Init( result, type_, size);
+ CheckReadBytesAvailable(result);
end;
@@ -836,6 +843,7 @@
var length : Integer;
begin
length := Integer( ReadVarint32);
+ FTrans.CheckReadBytesAvailable(length);
SetLength( result, length);
if (length > 0)
then Transport.ReadAll( result, 0, length);
@@ -968,6 +976,32 @@
end;
+function TCompactProtocolImpl.GetMinSerializedSize( const aType : TType) : Integer;
+// Return the minimum number of bytes a type will consume on the wire
+begin
+ case aType of
+ TType.Stop: result := 0;
+ TType.Void: result := 0;
+ TType.Bool_: result := SizeOf(Byte);
+ TType.Byte_: result := SizeOf(Byte);
+ TType.Double_: result := 8; // uses fixedLongToBytes() which always writes 8 bytes
+ TType.I16: result := SizeOf(Byte);
+ TType.I32: result := SizeOf(Byte);
+ TType.I64: result := SizeOf(Byte);
+ TType.String_: result := SizeOf(Byte); // string length
+ TType.Struct: result := 0; // empty struct
+ TType.Map: result := SizeOf(Byte); // element count
+ TType.Set_: result := SizeOf(Byte); // element count
+ TType.List: result := SizeOf(Byte); // element count
+ else
+ raise TTransportExceptionBadArgs.Create('Unhandled type code');
+ end;
+end;
+
+
+
+
+
//--- unit tests -------------------------------------------
{$IFDEF Debug}
diff --git a/lib/delphi/src/Thrift.Protocol.JSON.pas b/lib/delphi/src/Thrift.Protocol.JSON.pas
index 30600aa..61cad8b 100644
--- a/lib/delphi/src/Thrift.Protocol.JSON.pas
+++ b/lib/delphi/src/Thrift.Protocol.JSON.pas
@@ -29,6 +29,7 @@
SysUtils,
Math,
Generics.Collections,
+ Thrift.Configuration,
Thrift.Transport,
Thrift.Protocol,
Thrift.Utils;
@@ -52,17 +53,17 @@
function GetProtocol( const trans: ITransport): IProtocol;
end;
- private
+ strict private
class function GetTypeNameForTypeID(typeID : TType) : string;
class function GetTypeIDForTypeName( const name : string) : TType;
- protected
+ strict protected
type
// Base class for tracking JSON contexts that may require
// inserting/Reading additional JSON syntax characters.
// This base context does nothing.
TJSONBaseContext = class
- protected
+ strict protected
FProto : Pointer; // weak IJSONProtocol;
public
constructor Create( const aProto : IJSONProtocol);
@@ -74,7 +75,7 @@
// Context for JSON lists.
// Will insert/Read commas before each item except for the first one.
TJSONListContext = class( TJSONBaseContext)
- private
+ strict private
FFirst : Boolean;
public
constructor Create( const aProto : IJSONProtocol);
@@ -86,7 +87,7 @@
// pair, and commas before each key except the first. In addition, will indicate that numbers
// in the key position need to be escaped in quotes (since JSON keys must be strings).
TJSONPairContext = class( TJSONBaseContext)
- private
+ strict private
FFirst, FColon : Boolean;
public
constructor Create( const aProto : IJSONProtocol);
@@ -97,11 +98,13 @@
// Holds up to one byte from the transport
TLookaheadReader = class
- protected
+ strict protected
FProto : Pointer; // weak IJSONProtocol;
+
+ protected
constructor Create( const aProto : IJSONProtocol);
- private
+ strict private
FHasData : Boolean;
FData : Byte;
@@ -115,7 +118,7 @@
function Peek : Byte;
end;
- protected
+ strict protected
// Stack of nested contexts that we may be in
FContextStack : TStack<TJSONBaseContext>;
@@ -130,17 +133,21 @@
procedure PushContext( const aCtx : TJSONBaseContext);
procedure PopContext;
+ strict protected
+ function GetMinSerializedSize( const aType : TType) : Integer; override;
+ procedure Reset; override;
+
public
// TJSONProtocolImpl Constructor
constructor Create( const aTrans : ITransport);
destructor Destroy; override;
- protected
+ strict protected
// IJSONProtocol
// Read a byte that must match b; otherwise an exception is thrown.
procedure ReadJSONSyntaxChar( b : Byte);
- private
+ strict private
// Convert a byte containing a hex char ('0'-'9' or 'a'-'f') into its corresponding hex value
class function HexVal( ch : Byte) : Byte;
@@ -213,7 +220,7 @@
function ReadBinary: TBytes; override;
- private
+ strict private
// Reading methods.
// Read in a JSON string, unescaping as appropriate.
@@ -292,7 +299,7 @@
function TJSONProtocolImpl.TFactory.GetProtocol( const trans: ITransport): IProtocol;
begin
- result := TJSONProtocolImpl.Create(trans);
+ result := TJSONProtocolImpl.Create( trans);
end;
class function TJSONProtocolImpl.GetTypeNameForTypeID(typeID : TType) : string;
@@ -478,6 +485,13 @@
end;
+procedure TJSONProtocolImpl.Reset;
+begin
+ inherited Reset;
+ ResetContextStack;
+end;
+
+
procedure TJSONProtocolImpl.ResetContextStack;
begin
while FContextStack.Count > 0
@@ -681,6 +695,7 @@
procedure TJSONProtocolImpl.WriteMessageBegin( const aMsg : TThriftMessage);
begin
+ Reset;
ResetContextStack; // THRIFT-1473
WriteJSONArrayStart;
@@ -1051,6 +1066,7 @@
function TJSONProtocolImpl.ReadMessageBegin: TThriftMessage;
begin
+ Reset;
ResetContextStack; // THRIFT-1473
Init( result);
@@ -1121,6 +1137,8 @@
result.ValueType := GetTypeIDForTypeName( str);
result.Count := ReadJSONInteger;
+ CheckReadBytesAvailable(result);
+
ReadJSONObjectStart;
end;
@@ -1141,6 +1159,7 @@
str := SysUtils.TEncoding.UTF8.GetString( ReadJSONString(FALSE));
result.ElementType := GetTypeIDForTypeName( str);
result.Count := ReadJSONInteger;
+ CheckReadBytesAvailable(result);
end;
@@ -1159,6 +1178,7 @@
str := SysUtils.TEncoding.UTF8.GetString( ReadJSONString(FALSE));
result.ElementType := GetTypeIDForTypeName( str);
result.Count := ReadJSONInteger;
+ CheckReadBytesAvailable(result);
end;
@@ -1216,6 +1236,30 @@
end;
+function TJSONProtocolImpl.GetMinSerializedSize( const aType : TType) : Integer;
+// Return the minimum number of bytes a type will consume on the wire
+begin
+ case aType of
+ TType.Stop: result := 0;
+ TType.Void: result := 0;
+ TType.Bool_: result := 1;
+ TType.Byte_: result := 1;
+ TType.Double_: result := 1;
+ TType.I16: result := 1;
+ TType.I32: result := 1;
+ TType.I64: result := 1;
+ TType.String_: result := 2; // empty string
+ TType.Struct: result := 2; // empty struct
+ TType.Map: result := 2; // empty map
+ TType.Set_: result := 2; // empty set
+ TType.List: result := 2; // empty list
+ else
+ raise TTransportExceptionBadArgs.Create('Unhandled type code');
+ end;
+end;
+
+
+
//--- init code ---
procedure InitBytes( var b : TBytes; aData : array of Byte);
diff --git a/lib/delphi/src/Thrift.Protocol.Multiplex.pas b/lib/delphi/src/Thrift.Protocol.Multiplex.pas
index 93a3838..e5e0cd9 100644
--- a/lib/delphi/src/Thrift.Protocol.Multiplex.pas
+++ b/lib/delphi/src/Thrift.Protocol.Multiplex.pas
@@ -54,7 +54,7 @@
{ Used to delimit the service name from the function name }
SEPARATOR = ':';
- private
+ strict private
FServiceName : String;
public
diff --git a/lib/delphi/src/Thrift.Protocol.pas b/lib/delphi/src/Thrift.Protocol.pas
index 609dfc6..d5a7587 100644
--- a/lib/delphi/src/Thrift.Protocol.pas
+++ b/lib/delphi/src/Thrift.Protocol.pas
@@ -31,6 +31,7 @@
Thrift.Stream,
Thrift.Utils,
Thrift.Collections,
+ Thrift.Configuration,
Thrift.Transport;
type
@@ -67,9 +68,6 @@
VALID_MESSAGETYPES = [Low(TMessageType)..High(TMessageType)];
-const
- DEFAULT_RECURSION_LIMIT = 64;
-
type
IProtocol = interface;
@@ -106,30 +104,32 @@
end;
-
IProtocolFactory = interface
['{7CD64A10-4E9F-4E99-93BF-708A31F4A67B}']
function GetProtocol( const trans: ITransport): IProtocol;
end;
- TProtocolException = class( TException)
+ TProtocolException = class abstract( TException)
public
- const // TODO(jensg): change into enum
- UNKNOWN = 0;
- INVALID_DATA = 1;
- NEGATIVE_SIZE = 2;
- SIZE_LIMIT = 3;
- BAD_VERSION = 4;
- NOT_IMPLEMENTED = 5;
- DEPTH_LIMIT = 6;
- protected
+ type TExceptionType = (
+ UNKNOWN = 0,
+ INVALID_DATA = 1,
+ NEGATIVE_SIZE = 2,
+ SIZE_LIMIT = 3,
+ BAD_VERSION = 4,
+ NOT_IMPLEMENTED = 5,
+ DEPTH_LIMIT = 6
+ );
+ strict protected
constructor HiddenCreate(const Msg: string);
+ class function GetType: TExceptionType; virtual; abstract;
public
// purposefully hide inherited constructor
class function Create(const Msg: string): TProtocolException; overload; deprecated 'Use specialized TProtocolException types (or regenerate from IDL)';
class function Create: TProtocolException; overload; deprecated 'Use specialized TProtocolException types (or regenerate from IDL)';
- class function Create( type_: Integer): TProtocolException; overload; deprecated 'Use specialized TProtocolException types (or regenerate from IDL)';
- class function Create( type_: Integer; const msg: string): TProtocolException; overload; deprecated 'Use specialized TProtocolException types (or regenerate from IDL)';
+ class function Create( aType: TExceptionType): TProtocolException; overload; deprecated 'Use specialized TProtocolException types (or regenerate from IDL)';
+ class function Create( aType: TExceptionType; const msg: string): TProtocolException; overload; deprecated 'Use specialized TProtocolException types (or regenerate from IDL)';
+ property Type_: TExceptionType read GetType;
end;
// Needed to remove deprecation warning
@@ -138,13 +138,41 @@
constructor Create(const Msg: string);
end;
- TProtocolExceptionUnknown = class (TProtocolExceptionSpecialized);
- TProtocolExceptionInvalidData = class (TProtocolExceptionSpecialized);
- TProtocolExceptionNegativeSize = class (TProtocolExceptionSpecialized);
- TProtocolExceptionSizeLimit = class (TProtocolExceptionSpecialized);
- TProtocolExceptionBadVersion = class (TProtocolExceptionSpecialized);
- TProtocolExceptionNotImplemented = class (TProtocolExceptionSpecialized);
- TProtocolExceptionDepthLimit = class (TProtocolExceptionSpecialized);
+ TProtocolExceptionUnknown = class (TProtocolExceptionSpecialized)
+ strict protected
+ class function GetType: TProtocolException.TExceptionType; override;
+ end;
+
+ TProtocolExceptionInvalidData = class (TProtocolExceptionSpecialized)
+ strict protected
+ class function GetType: TProtocolException.TExceptionType; override;
+ end;
+
+ TProtocolExceptionNegativeSize = class (TProtocolExceptionSpecialized)
+ strict protected
+ class function GetType: TProtocolException.TExceptionType; override;
+ end;
+
+ TProtocolExceptionSizeLimit = class (TProtocolExceptionSpecialized)
+ strict protected
+ class function GetType: TProtocolException.TExceptionType; override;
+ end;
+
+ TProtocolExceptionBadVersion = class (TProtocolExceptionSpecialized)
+ strict protected
+ class function GetType: TProtocolException.TExceptionType; override;
+ end;
+
+ TProtocolExceptionNotImplemented = class (TProtocolExceptionSpecialized)
+ strict protected
+ class function GetType: TProtocolException.TExceptionType; override;
+ end;
+
+ TProtocolExceptionDepthLimit = class (TProtocolExceptionSpecialized)
+ strict protected
+ class function GetType: TProtocolException.TExceptionType; override;
+ end;
+
TProtocolUtil = class
@@ -158,7 +186,7 @@
end;
TProtocolRecursionTrackerImpl = class abstract( TInterfacedObject, IProtocolRecursionTracker)
- protected
+ strict protected
FProtocol : IProtocol;
public
constructor Create( prot : IProtocol);
@@ -166,7 +194,7 @@
end;
IProtocol = interface
- ['{602A7FFB-0D9E-4CD8-8D7F-E5076660588A}']
+ ['{F0040D99-937F-400D-9932-AF04F665899F}']
function GetTransport: ITransport;
procedure WriteMessageBegin( const msg: TThriftMessage);
procedure WriteMessageEnd;
@@ -213,30 +241,34 @@
function ReadString: string;
function ReadAnsiString: AnsiString;
- procedure SetRecursionLimit( value : Integer);
- function GetRecursionLimit : Integer;
function NextRecursionLevel : IProtocolRecursionTracker;
procedure IncrementRecursionDepth;
procedure DecrementRecursionDepth;
+ function GetMinSerializedSize( const aType : TType) : Integer;
property Transport: ITransport read GetTransport;
- property RecursionLimit : Integer read GetRecursionLimit write SetRecursionLimit;
+ function Configuration : IThriftConfiguration;
end;
TProtocolImpl = class abstract( TInterfacedObject, IProtocol)
- protected
+ strict protected
FTrans : ITransport;
FRecursionLimit : Integer;
FRecursionDepth : Integer;
- procedure SetRecursionLimit( value : Integer);
- function GetRecursionLimit : Integer;
function NextRecursionLevel : IProtocolRecursionTracker;
procedure IncrementRecursionDepth;
procedure DecrementRecursionDepth;
- function GetTransport: ITransport;
- public
+ function GetMinSerializedSize( const aType : TType) : Integer; virtual; abstract;
+ procedure CheckReadBytesAvailable( const value : TThriftList); overload; inline;
+ procedure CheckReadBytesAvailable( const value : TThriftSet); overload; inline;
+ procedure CheckReadBytesAvailable( const value : TThriftMap); overload; inline;
+
+ procedure Reset; virtual;
+ function GetTransport: ITransport;
+ function Configuration : IThriftConfiguration;
+
procedure WriteMessageBegin( const msg: TThriftMessage); virtual; abstract;
procedure WriteMessageEnd; virtual; abstract;
procedure WriteStructBegin( const struc: TThriftStruct); virtual; abstract;
@@ -282,9 +314,10 @@
function ReadString: string; virtual;
function ReadAnsiString: AnsiString; virtual;
- property Transport: ITransport read GetTransport;
+ property Transport: ITransport read GetTransport;
- constructor Create( trans: ITransport );
+ public
+ constructor Create( const aTransport : ITransport);
end;
IBase = interface( ISupportsToString)
@@ -295,33 +328,31 @@
TBinaryProtocolImpl = class( TProtocolImpl )
- protected
+ strict protected
const
VERSION_MASK : Cardinal = $ffff0000;
VERSION_1 : Cardinal = $80010000;
- protected
+ strict protected
FStrictRead : Boolean;
FStrictWrite : Boolean;
+ function GetMinSerializedSize( const aType : TType) : Integer; override;
- private
+ strict private
function ReadAll( const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer ): Integer; inline;
function ReadStringBody( size: Integer): string;
public
-
type
TFactory = class( TInterfacedObject, IProtocolFactory)
- protected
+ strict protected
FStrictRead : Boolean;
FStrictWrite : Boolean;
- public
function GetProtocol( const trans: ITransport): IProtocol;
- constructor Create( AStrictRead, AStrictWrite: Boolean ); overload;
- constructor Create; overload;
+ public
+ constructor Create( const aStrictRead : Boolean = FALSE; const aStrictWrite: Boolean = TRUE); reintroduce;
end;
- constructor Create( const trans: ITransport); overload;
- constructor Create( const trans: ITransport; strictRead: Boolean; strictWrite: Boolean); overload;
+ constructor Create( const trans: ITransport; strictRead: Boolean = FALSE; strictWrite: Boolean = TRUE); reintroduce;
procedure WriteMessageBegin( const msg: TThriftMessage); override;
procedure WriteMessageEnd; override;
@@ -374,9 +405,12 @@
See p.175 of Design Patterns (by Gamma et al.)
}
TProtocolDecorator = class( TProtocolImpl)
- private
+ strict private
FWrappedProtocol : IProtocol;
+ strict protected
+ function GetMinSerializedSize( const aType : TType) : Integer; override;
+
public
// Encloses the specified protocol.
// All operations will be forward to the given protocol. Must be non-null.
@@ -476,13 +510,13 @@
implementation
-function ConvertInt64ToDouble( const n: Int64): Double;
+function ConvertInt64ToDouble( const n: Int64): Double; inline;
begin
ASSERT( SizeOf(n) = SizeOf(Result));
System.Move( n, Result, SizeOf(Result));
end;
-function ConvertDoubleToInt64( const d: Double): Int64;
+function ConvertDoubleToInt64( const d: Double): Int64; inline;
begin
ASSERT( SizeOf(d) = SizeOf(Result));
System.Move( d, Result, SizeOf(Result));
@@ -516,24 +550,14 @@
{ TProtocolImpl }
-constructor TProtocolImpl.Create(trans: ITransport);
+constructor TProtocolImpl.Create( const aTransport : ITransport);
begin
inherited Create;
- FTrans := trans;
- FRecursionLimit := DEFAULT_RECURSION_LIMIT;
+ FTrans := aTransport;
+ FRecursionLimit := aTransport.Configuration.RecursionLimit;
FRecursionDepth := 0;
end;
-procedure TProtocolImpl.SetRecursionLimit( value : Integer);
-begin
- FRecursionLimit := value;
-end;
-
-function TProtocolImpl.GetRecursionLimit : Integer;
-begin
- result := FRecursionLimit;
-end;
-
function TProtocolImpl.NextRecursionLevel : IProtocolRecursionTracker;
begin
result := TProtocolRecursionTrackerImpl.Create(Self);
@@ -556,6 +580,16 @@
Result := FTrans;
end;
+function TProtocolImpl.Configuration : IThriftConfiguration;
+begin
+ Result := FTrans.Configuration;
+end;
+
+procedure TProtocolImpl.Reset;
+begin
+ FTrans.ResetConsumedMessageSize;
+end;
+
function TProtocolImpl.ReadAnsiString: AnsiString;
var
b : TBytes;
@@ -564,8 +598,7 @@
Result := '';
b := ReadBinary;
len := Length( b );
- if len > 0 then
- begin
+ if len > 0 then begin
SetLength( Result, len);
System.Move( b[0], Pointer(Result)^, len );
end;
@@ -583,8 +616,7 @@
begin
len := Length(s);
SetLength( b, len);
- if len > 0 then
- begin
+ if len > 0 then begin
System.Move( Pointer(s)^, b[0], len );
end;
WriteBinary( b );
@@ -598,6 +630,26 @@
WriteBinary( b );
end;
+
+procedure TProtocolImpl.CheckReadBytesAvailable( const value : TThriftList);
+begin
+ FTrans.CheckReadBytesAvailable( value.Count * GetMinSerializedSize(value.ElementType));
+end;
+
+
+procedure TProtocolImpl.CheckReadBytesAvailable( const value : TThriftSet);
+begin
+ FTrans.CheckReadBytesAvailable( value.Count * GetMinSerializedSize(value.ElementType));
+end;
+
+
+procedure TProtocolImpl.CheckReadBytesAvailable( const value : TThriftMap);
+var nPairSize : Integer;
+begin
+ nPairSize := GetMinSerializedSize(value.KeyType) + GetMinSerializedSize(value.ValueType);
+ FTrans.CheckReadBytesAvailable( value.Count * nPairSize);
+end;
+
{ TProtocolUtil }
class procedure TProtocolUtil.Skip( prot: IProtocol; type_: TType);
@@ -662,16 +714,9 @@
{ TBinaryProtocolImpl }
-constructor TBinaryProtocolImpl.Create( const trans: ITransport);
+constructor TBinaryProtocolImpl.Create( const trans: ITransport; strictRead, strictWrite: Boolean);
begin
- //no inherited
- Create( trans, False, True);
-end;
-
-constructor TBinaryProtocolImpl.Create( const trans: ITransport; strictRead,
- strictWrite: Boolean);
-begin
- inherited Create( trans );
+ inherited Create( trans);
FStrictRead := strictRead;
FStrictWrite := strictWrite;
end;
@@ -687,7 +732,8 @@
buf : TBytes;
begin
size := ReadI32;
- SetLength( buf, size );
+ FTrans.CheckReadBytesAvailable( size);
+ SetLength( buf, size);
FTrans.ReadAll( buf, 0, size);
Result := buf;
end;
@@ -759,6 +805,7 @@
begin
result.ElementType := TType(ReadByte);
result.Count := ReadI32;
+ CheckReadBytesAvailable(result);
end;
procedure TBinaryProtocolImpl.ReadListEnd;
@@ -771,6 +818,7 @@
result.KeyType := TType(ReadByte);
result.ValueType := TType(ReadByte);
result.Count := ReadI32;
+ CheckReadBytesAvailable(result);
end;
procedure TBinaryProtocolImpl.ReadMapEnd;
@@ -783,6 +831,7 @@
size : Integer;
version : Integer;
begin
+ Reset;
Init( result);
size := ReadI32;
if (size < 0) then begin
@@ -814,6 +863,7 @@
begin
result.ElementType := TType(ReadByte);
result.Count := ReadI32;
+ CheckReadBytesAvailable(result);
end;
procedure TBinaryProtocolImpl.ReadSetEnd;
@@ -822,10 +872,10 @@
end;
function TBinaryProtocolImpl.ReadStringBody( size: Integer): string;
-var
- buf : TBytes;
+var buf : TBytes;
begin
- SetLength( buf, size );
+ FTrans.CheckReadBytesAvailable( size);
+ SetLength( buf, size);
FTrans.ReadAll( buf, 0, size );
Result := TEncoding.UTF8.GetString( buf);
end;
@@ -940,17 +990,15 @@
end;
procedure TBinaryProtocolImpl.WriteMessageBegin( const msg: TThriftMessage);
-var
- version : Cardinal;
+var version : Cardinal;
begin
- if FStrictWrite then
- begin
+ Reset;
+ if FStrictWrite then begin
version := VERSION_1 or Cardinal( msg.Type_);
WriteI32( Integer( version) );
WriteString( msg.Name);
WriteI32( msg.SeqID);
- end else
- begin
+ end else begin
WriteString( msg.Name);
WriteByte(ShortInt( msg.Type_));
WriteI32( msg.SeqID);
@@ -983,6 +1031,29 @@
end;
+function TBinaryProtocolImpl.GetMinSerializedSize( const aType : TType) : Integer;
+// Return the minimum number of bytes a type will consume on the wire
+begin
+ case aType of
+ TType.Stop: result := 0;
+ TType.Void: result := 0;
+ TType.Bool_: result := SizeOf(Byte);
+ TType.Byte_: result := SizeOf(Byte);
+ TType.Double_: result := SizeOf(Double);
+ TType.I16: result := SizeOf(Int16);
+ TType.I32: result := SizeOf(Int32);
+ TType.I64: result := SizeOf(Int64);
+ TType.String_: result := SizeOf(Int32); // string length
+ TType.Struct: result := 0; // empty struct
+ TType.Map: result := SizeOf(Int32); // element count
+ TType.Set_: result := SizeOf(Int32); // element count
+ TType.List: result := SizeOf(Int32); // element count
+ else
+ raise TTransportExceptionBadArgs.Create('Unhandled type code');
+ end;
+end;
+
+
{ TProtocolException }
constructor TProtocolException.HiddenCreate(const Msg: string);
@@ -1000,23 +1071,24 @@
Result := TProtocolExceptionUnknown.Create('');
end;
-class function TProtocolException.Create(type_: Integer): TProtocolException;
+class function TProtocolException.Create(aType: TExceptionType): TProtocolException;
begin
{$WARN SYMBOL_DEPRECATED OFF}
- Result := Create(type_, '');
+ Result := Create(aType, '');
{$WARN SYMBOL_DEPRECATED DEFAULT}
end;
-class function TProtocolException.Create(type_: Integer; const msg: string): TProtocolException;
+class function TProtocolException.Create(aType: TExceptionType; const msg: string): TProtocolException;
begin
- case type_ of
- INVALID_DATA: Result := TProtocolExceptionInvalidData.Create(msg);
- NEGATIVE_SIZE: Result := TProtocolExceptionNegativeSize.Create(msg);
- SIZE_LIMIT: Result := TProtocolExceptionSizeLimit.Create(msg);
- BAD_VERSION: Result := TProtocolExceptionBadVersion.Create(msg);
- NOT_IMPLEMENTED: Result := TProtocolExceptionNotImplemented.Create(msg);
- DEPTH_LIMIT: Result := TProtocolExceptionDepthLimit.Create(msg);
+ case aType of
+ TExceptionType.INVALID_DATA: Result := TProtocolExceptionInvalidData.Create(msg);
+ TExceptionType.NEGATIVE_SIZE: Result := TProtocolExceptionNegativeSize.Create(msg);
+ TExceptionType.SIZE_LIMIT: Result := TProtocolExceptionSizeLimit.Create(msg);
+ TExceptionType.BAD_VERSION: Result := TProtocolExceptionBadVersion.Create(msg);
+ TExceptionType.NOT_IMPLEMENTED: Result := TProtocolExceptionNotImplemented.Create(msg);
+ TExceptionType.DEPTH_LIMIT: Result := TProtocolExceptionDepthLimit.Create(msg);
else
+ ASSERT( TExceptionType.UNKNOWN = aType);
Result := TProtocolExceptionUnknown.Create(msg);
end;
end;
@@ -1028,21 +1100,52 @@
inherited HiddenCreate(Msg);
end;
+{ specialized TProtocolExceptions }
+
+class function TProtocolExceptionUnknown.GetType: TProtocolException.TExceptionType;
+begin
+ result := TExceptionType.UNKNOWN;
+end;
+
+class function TProtocolExceptionInvalidData.GetType: TProtocolException.TExceptionType;
+begin
+ result := TExceptionType.INVALID_DATA;
+end;
+
+class function TProtocolExceptionNegativeSize.GetType: TProtocolException.TExceptionType;
+begin
+ result := TExceptionType.NEGATIVE_SIZE;
+end;
+
+class function TProtocolExceptionSizeLimit.GetType: TProtocolException.TExceptionType;
+begin
+ result := TExceptionType.SIZE_LIMIT;
+end;
+
+class function TProtocolExceptionBadVersion.GetType: TProtocolException.TExceptionType;
+begin
+ result := TExceptionType.BAD_VERSION;
+end;
+
+class function TProtocolExceptionNotImplemented.GetType: TProtocolException.TExceptionType;
+begin
+ result := TExceptionType.NOT_IMPLEMENTED;
+end;
+
+class function TProtocolExceptionDepthLimit.GetType: TProtocolException.TExceptionType;
+begin
+ result := TExceptionType.DEPTH_LIMIT;
+end;
+
{ TBinaryProtocolImpl.TFactory }
-constructor TBinaryProtocolImpl.TFactory.Create(AStrictRead, AStrictWrite: Boolean);
+constructor TBinaryProtocolImpl.TFactory.Create( const aStrictRead, aStrictWrite: Boolean);
begin
inherited Create;
FStrictRead := AStrictRead;
FStrictWrite := AStrictWrite;
end;
-constructor TBinaryProtocolImpl.TFactory.Create;
-begin
- //no inherited;
- Create( False, True )
-end;
-
function TBinaryProtocolImpl.TFactory.GetProtocol( const trans: ITransport): IProtocol;
begin
Result := TBinaryProtocolImpl.Create( trans, FStrictRead, FStrictWrite);
@@ -1317,6 +1420,12 @@
end;
+function TProtocolDecorator.GetMinSerializedSize( const aType : TType) : Integer;
+begin
+ result := FWrappedProtocol.GetMinSerializedSize(aType);
+end;
+
+
{ Init helper functions }
procedure Init( var rec : TThriftMessage; const AName: string; const AMessageType: TMessageType; const ASeqID: Integer);
@@ -1364,7 +1473,5 @@
-
-
end.
diff --git a/lib/delphi/src/Thrift.Serializer.pas b/lib/delphi/src/Thrift.Serializer.pas
index 5f2905a..cb62603 100644
--- a/lib/delphi/src/Thrift.Serializer.pas
+++ b/lib/delphi/src/Thrift.Serializer.pas
@@ -28,6 +28,7 @@
{$ELSE}
System.Classes, Winapi.Windows, System.SysUtils,
{$ENDIF}
+ Thrift.Configuration,
Thrift.Protocol,
Thrift.Transport,
Thrift.Stream;
@@ -36,18 +37,15 @@
type
// Generic utility for easily serializing objects into a byte array or Stream.
TSerializer = class
- private
+ strict private
FStream : TMemoryStream;
FTransport : ITransport;
FProtocol : IProtocol;
public
- // Create a new TSerializer that uses the TBinaryProtocol by default.
- constructor Create; overload;
-
- // Create a new TSerializer.
- // It will use the TProtocol specified by the factory that is passed in.
- constructor Create( const factory : IProtocolFactory); overload;
+ constructor Create( const aProtFact : IProtocolFactory = nil; // defaults to TBinaryProtocol
+ const aTransFact : ITransportFactory = nil;
+ const aConfig : IThriftConfiguration = nil);
// DTOR
destructor Destroy; override;
@@ -60,18 +58,15 @@
// Generic utility for easily deserializing objects from byte array or Stream.
TDeserializer = class
- private
+ strict private
FStream : TMemoryStream;
FTransport : ITransport;
FProtocol : IProtocol;
public
- // Create a new TDeserializer that uses the TBinaryProtocol by default.
- constructor Create; overload;
-
- // Create a new TDeserializer.
- // It will use the TProtocol specified by the factory that is passed in.
- constructor Create( const factory : IProtocolFactory); overload;
+ constructor Create( const aProtFact : IProtocolFactory = nil; // defaults to TBinaryProtocol
+ const aTransFact : ITransportFactory = nil;
+ const aConfig : IThriftConfiguration = nil);
// DTOR
destructor Destroy; override;
@@ -89,24 +84,27 @@
{ TSerializer }
-constructor TSerializer.Create();
-// Create a new TSerializer that uses the TBinaryProtocol by default.
-begin
- //no inherited;
- Create( TBinaryProtocolImpl.TFactory.Create);
-end;
-
-
-constructor TSerializer.Create( const factory : IProtocolFactory);
-// Create a new TSerializer.
-// It will use the TProtocol specified by the factory that is passed in.
+constructor TSerializer.Create( const aProtFact : IProtocolFactory;
+ const aTransFact : ITransportFactory;
+ const aConfig : IThriftConfiguration);
var adapter : IThriftStream;
+ protfact : IProtocolFactory;
begin
inherited Create;
+
FStream := TMemoryStream.Create;
adapter := TThriftStreamAdapterDelphi.Create( FStream, FALSE);
- FTransport := TStreamTransportImpl.Create( nil, adapter);
- FProtocol := factory.GetProtocol( FTransport);
+
+ FTransport := TStreamTransportImpl.Create( nil, adapter, aConfig);
+ if aTransfact <> nil then FTransport := aTransfact.GetTransport( FTransport);
+
+ if aProtFact <> nil
+ then protfact := aProtFact
+ else protfact := TBinaryProtocolImpl.TFactory.Create;
+ FProtocol := protfact.GetProtocol( FTransport);
+
+ if not FTransport.IsOpen
+ then FTransport.Open;
end;
@@ -131,6 +129,8 @@
try
FStream.Size := 0;
input.Write( FProtocol);
+ FTransport.Flush;
+
SetLength( result, FStream.Size);
iBytes := Length(result);
if iBytes > 0
@@ -150,6 +150,8 @@
try
FStream.Size := 0;
input.Write( FProtocol);
+ FTransport.Flush;
+
aStm.CopyFrom( FStream, COPY_ENTIRE_STREAM);
finally
FStream.Size := 0; // free any allocated memory
@@ -160,24 +162,27 @@
{ TDeserializer }
-constructor TDeserializer.Create();
-// Create a new TDeserializer that uses the TBinaryProtocol by default.
-begin
- //no inherited;
- Create( TBinaryProtocolImpl.TFactory.Create);
-end;
-
-
-constructor TDeserializer.Create( const factory : IProtocolFactory);
-// Create a new TDeserializer.
-// It will use the TProtocol specified by the factory that is passed in.
+constructor TDeserializer.Create( const aProtFact : IProtocolFactory;
+ const aTransFact : ITransportFactory;
+ const aConfig : IThriftConfiguration);
var adapter : IThriftStream;
+ protfact : IProtocolFactory;
begin
inherited Create;
+
FStream := TMemoryStream.Create;
adapter := TThriftStreamAdapterDelphi.Create( FStream, FALSE);
- FTransport := TStreamTransportImpl.Create( adapter, nil);
- FProtocol := factory.GetProtocol( FTransport);
+
+ FTransport := TStreamTransportImpl.Create( adapter, nil, aConfig);
+ if aTransfact <> nil then FTransport := aTransfact.GetTransport( FTransport);
+
+ if aProtFact <> nil
+ then protfact := aProtFact
+ else protfact := TBinaryProtocolImpl.TFactory.Create;
+ FProtocol := protfact.GetProtocol( FTransport);
+
+ if not FTransport.IsOpen
+ then FTransport.Open;
end;
diff --git a/lib/delphi/src/Thrift.Server.pas b/lib/delphi/src/Thrift.Server.pas
index 13c5762..a73e6cb 100644
--- a/lib/delphi/src/Thrift.Server.pas
+++ b/lib/delphi/src/Thrift.Server.pas
@@ -32,7 +32,8 @@
{$ENDIF}
Thrift,
Thrift.Protocol,
- Thrift.Transport;
+ Thrift.Transport,
+ Thrift.Configuration;
type
IServerEvents = interface
@@ -61,7 +62,7 @@
public
type
TLogDelegate = reference to procedure( const str: string);
- protected
+ strict protected
FProcessor : IProcessor;
FServerTransport : IServerTransport;
FInputTransportFactory : ITransportFactory;
@@ -70,6 +71,7 @@
FOutputProtocolFactory : IProtocolFactory;
FLogDelegate : TLogDelegate;
FServerEvents : IServerEvents;
+ FConfiguration : IThriftConfiguration;
class procedure DefaultLogDelegate( const str: string);
@@ -80,52 +82,31 @@
procedure Stop; virtual; abstract;
public
constructor Create(
- const AProcessor :IProcessor;
- const AServerTransport: IServerTransport;
- const AInputTransportFactory : ITransportFactory;
- const AOutputTransportFactory : ITransportFactory;
- const AInputProtocolFactory : IProtocolFactory;
- const AOutputProtocolFactory : IProtocolFactory;
- const ALogDelegate : TLogDelegate
+ const aProcessor :IProcessor;
+ const aServerTransport: IServerTransport;
+ const aInputTransportFactory : ITransportFactory;
+ const aOutputTransportFactory : ITransportFactory;
+ const aInputProtocolFactory : IProtocolFactory;
+ const aOutputProtocolFactory : IProtocolFactory;
+ const aConfig : IThriftConfiguration;
+ const aLogDelegate : TLogDelegate
); overload;
constructor Create(
- const AProcessor :IProcessor;
- const AServerTransport: IServerTransport
- ); overload;
-
- constructor Create(
- const AProcessor :IProcessor;
- const AServerTransport: IServerTransport;
- const ALogDelegate: TLogDelegate
- ); overload;
-
- constructor Create(
- const AProcessor :IProcessor;
- const AServerTransport: IServerTransport;
- const ATransportFactory : ITransportFactory
- ); overload;
-
- constructor Create(
- const AProcessor :IProcessor;
- const AServerTransport: IServerTransport;
- const ATransportFactory : ITransportFactory;
- const AProtocolFactory : IProtocolFactory
+ const aProcessor: IProcessor;
+ const aServerTransport: IServerTransport;
+ const aTransportFactory: ITransportFactory = nil;
+ const aProtocolFactory: IProtocolFactory = nil;
+ const aConfig : IThriftConfiguration = nil;
+ const aLogDel: TServerImpl.TLogDelegate = nil
); overload;
end;
+
TSimpleServer = class( TServerImpl)
private
FStop : Boolean;
public
- constructor Create( const AProcessor: IProcessor; const AServerTransport: IServerTransport); overload;
- constructor Create( const AProcessor: IProcessor; const AServerTransport: IServerTransport;
- ALogDel: TServerImpl.TLogDelegate); overload;
- constructor Create( const AProcessor: IProcessor; const AServerTransport: IServerTransport;
- const ATransportFactory: ITransportFactory); overload;
- constructor Create( const AProcessor: IProcessor; const AServerTransport: IServerTransport;
- const ATransportFactory: ITransportFactory; const AProtocolFactory: IProtocolFactory); overload;
-
procedure Serve; override;
procedure Stop; override;
end;
@@ -135,84 +116,57 @@
{ TServerImpl }
-constructor TServerImpl.Create( const AProcessor: IProcessor;
- const AServerTransport: IServerTransport; const ALogDelegate: TLogDelegate);
-var
- InputFactory, OutputFactory : IProtocolFactory;
- InputTransFactory, OutputTransFactory : ITransportFactory;
-
-begin
- InputFactory := TBinaryProtocolImpl.TFactory.Create;
- OutputFactory := TBinaryProtocolImpl.TFactory.Create;
- InputTransFactory := TTransportFactoryImpl.Create;
- OutputTransFactory := TTransportFactoryImpl.Create;
-
- //no inherited;
- Create(
- AProcessor,
- AServerTransport,
- InputTransFactory,
- OutputTransFactory,
- InputFactory,
- OutputFactory,
- ALogDelegate
- );
-end;
-
-constructor TServerImpl.Create(const AProcessor: IProcessor;
- const AServerTransport: IServerTransport);
-var
- InputFactory, OutputFactory : IProtocolFactory;
- InputTransFactory, OutputTransFactory : ITransportFactory;
-
-begin
- InputFactory := TBinaryProtocolImpl.TFactory.Create;
- OutputFactory := TBinaryProtocolImpl.TFactory.Create;
- InputTransFactory := TTransportFactoryImpl.Create;
- OutputTransFactory := TTransportFactoryImpl.Create;
-
- //no inherited;
- Create(
- AProcessor,
- AServerTransport,
- InputTransFactory,
- OutputTransFactory,
- InputFactory,
- OutputFactory,
- DefaultLogDelegate
- );
-end;
-
-constructor TServerImpl.Create(const AProcessor: IProcessor;
- const AServerTransport: IServerTransport; const ATransportFactory: ITransportFactory);
-var
- InputProtocolFactory : IProtocolFactory;
- OutputProtocolFactory : IProtocolFactory;
-begin
- InputProtocolFactory := TBinaryProtocolImpl.TFactory.Create;
- OutputProtocolFactory := TBinaryProtocolImpl.TFactory.Create;
-
- //no inherited;
- Create( AProcessor, AServerTransport, ATransportFactory, ATransportFactory,
- InputProtocolFactory, OutputProtocolFactory, DefaultLogDelegate);
-end;
-
-constructor TServerImpl.Create(const AProcessor: IProcessor;
- const AServerTransport: IServerTransport;
- const AInputTransportFactory, AOutputTransportFactory: ITransportFactory;
- const AInputProtocolFactory, AOutputProtocolFactory: IProtocolFactory;
- const ALogDelegate : TLogDelegate);
+constructor TServerImpl.Create( const aProcessor: IProcessor;
+ const aServerTransport: IServerTransport;
+ const aInputTransportFactory, aOutputTransportFactory: ITransportFactory;
+ const aInputProtocolFactory, aOutputProtocolFactory: IProtocolFactory;
+ const aConfig : IThriftConfiguration;
+ const aLogDelegate : TLogDelegate);
begin
inherited Create;
- FProcessor := AProcessor;
- FServerTransport := AServerTransport;
- FInputTransportFactory := AInputTransportFactory;
- FOutputTransportFactory := AOutputTransportFactory;
- FInputProtocolFactory := AInputProtocolFactory;
- FOutputProtocolFactory := AOutputProtocolFactory;
- FLogDelegate := ALogDelegate;
+ FProcessor := aProcessor;
+ FServerTransport := aServerTransport;
+
+ if aConfig <> nil
+ then FConfiguration := aConfig
+ else FConfiguration := TThriftConfigurationImpl.Create;
+
+ if aInputTransportFactory <> nil
+ then FInputTransportFactory := aInputTransportFactory
+ else FInputTransportFactory := TTransportFactoryImpl.Create;
+
+ if aOutputTransportFactory <> nil
+ then FOutputTransportFactory := aOutputTransportFactory
+ else FOutputTransportFactory := TTransportFactoryImpl.Create;
+
+ if aInputProtocolFactory <> nil
+ then FInputProtocolFactory := aInputProtocolFactory
+ else FInputProtocolFactory := TBinaryProtocolImpl.TFactory.Create;
+
+ if aOutputProtocolFactory <> nil
+ then FOutputProtocolFactory := aOutputProtocolFactory
+ else FOutputProtocolFactory := TBinaryProtocolImpl.TFactory.Create;
+
+ if Assigned(aLogDelegate)
+ then FLogDelegate := aLogDelegate
+ else FLogDelegate := DefaultLogDelegate;
end;
+
+constructor TServerImpl.Create( const aProcessor: IProcessor;
+ const aServerTransport: IServerTransport;
+ const aTransportFactory: ITransportFactory;
+ const aProtocolFactory: IProtocolFactory;
+ const aConfig : IThriftConfiguration;
+ const aLogDel: TServerImpl.TLogDelegate);
+begin
+ Create( aProcessor, aServerTransport,
+ aTransportFactory, aTransportFactory,
+ aProtocolFactory, aProtocolFactory,
+ aConfig, aLogDel);
+end;
+
+
class procedure TServerImpl.DefaultLogDelegate( const str: string);
begin
try
@@ -223,16 +177,6 @@
end;
end;
-constructor TServerImpl.Create( const AProcessor: IProcessor;
- const AServerTransport: IServerTransport; const ATransportFactory: ITransportFactory;
- const AProtocolFactory: IProtocolFactory);
-begin
- //no inherited;
- Create( AProcessor, AServerTransport,
- ATransportFactory, ATransportFactory,
- AProtocolFactory, AProtocolFactory,
- DefaultLogDelegate);
-end;
function TServerImpl.GetServerEvents : IServerEvents;
@@ -250,55 +194,6 @@
{ TSimpleServer }
-constructor TSimpleServer.Create( const AProcessor: IProcessor;
- const AServerTransport: IServerTransport);
-var
- InputProtocolFactory : IProtocolFactory;
- OutputProtocolFactory : IProtocolFactory;
- InputTransportFactory : ITransportFactory;
- OutputTransportFactory : ITransportFactory;
-begin
- InputProtocolFactory := TBinaryProtocolImpl.TFactory.Create;
- OutputProtocolFactory := TBinaryProtocolImpl.TFactory.Create;
- InputTransportFactory := TTransportFactoryImpl.Create;
- OutputTransportFactory := TTransportFactoryImpl.Create;
-
- inherited Create( AProcessor, AServerTransport, InputTransportFactory,
- OutputTransportFactory, InputProtocolFactory, OutputProtocolFactory, DefaultLogDelegate);
-end;
-
-constructor TSimpleServer.Create( const AProcessor: IProcessor;
- const AServerTransport: IServerTransport; ALogDel: TServerImpl.TLogDelegate);
-var
- InputProtocolFactory : IProtocolFactory;
- OutputProtocolFactory : IProtocolFactory;
- InputTransportFactory : ITransportFactory;
- OutputTransportFactory : ITransportFactory;
-begin
- InputProtocolFactory := TBinaryProtocolImpl.TFactory.Create;
- OutputProtocolFactory := TBinaryProtocolImpl.TFactory.Create;
- InputTransportFactory := TTransportFactoryImpl.Create;
- OutputTransportFactory := TTransportFactoryImpl.Create;
-
- inherited Create( AProcessor, AServerTransport, InputTransportFactory,
- OutputTransportFactory, InputProtocolFactory, OutputProtocolFactory, ALogDel);
-end;
-
-constructor TSimpleServer.Create( const AProcessor: IProcessor;
- const AServerTransport: IServerTransport; const ATransportFactory: ITransportFactory);
-begin
- inherited Create( AProcessor, AServerTransport, ATransportFactory,
- ATransportFactory, TBinaryProtocolImpl.TFactory.Create, TBinaryProtocolImpl.TFactory.Create, DefaultLogDelegate);
-end;
-
-constructor TSimpleServer.Create( const AProcessor: IProcessor;
- const AServerTransport: IServerTransport; const ATransportFactory: ITransportFactory;
- const AProtocolFactory: IProtocolFactory);
-begin
- inherited Create( AProcessor, AServerTransport, ATransportFactory,
- ATransportFactory, AProtocolFactory, AProtocolFactory, DefaultLogDelegate);
-end;
-
procedure TSimpleServer.Serve;
var
client : ITransport;
@@ -372,20 +267,17 @@
end;
except
- on E: TTransportException do
- begin
+ on E: TTransportException do begin
if FStop
then FLogDelegate('TSimpleServer was shutting down, caught ' + E.ToString)
else FLogDelegate( E.ToString);
end;
- on E: Exception do
- begin
+ on E: Exception do begin
FLogDelegate( E.ToString);
end;
end;
- if context <> nil
- then begin
+ if context <> nil then begin
context.CleanupContext;
context := nil;
end;
diff --git a/lib/delphi/src/Thrift.Socket.pas b/lib/delphi/src/Thrift.Socket.pas
index f0cab79..b33f202 100644
--- a/lib/delphi/src/Thrift.Socket.pas
+++ b/lib/delphi/src/Thrift.Socket.pas
@@ -81,7 +81,7 @@
TScopeId = record
public
Value: ULONG;
- private
+ strict private
function GetBitField(Loc: Integer): Integer; inline;
procedure SetBitField(Loc: Integer; const aValue: Integer); inline;
public
@@ -125,7 +125,7 @@
ISmartPointer<T> = reference to function: T;
TSmartPointer<T> = class(TInterfacedObject, ISmartPointer<T>)
- private
+ strict private
FValue: T;
FDestroyer: TSmartPointerDestroyer<T>;
public
@@ -147,7 +147,7 @@
class constructor Create;
class destructor Destroy;
class procedure DefaultLogDelegate(const Str: string);
- protected type
+ strict protected type
IGetAddrInfoWrapper = interface
function Init: Integer;
function GetRes: PAddrInfoW;
diff --git a/lib/delphi/src/Thrift.Stream.pas b/lib/delphi/src/Thrift.Stream.pas
index 3308c53..1668059 100644
--- a/lib/delphi/src/Thrift.Stream.pas
+++ b/lib/delphi/src/Thrift.Stream.pas
@@ -36,9 +36,8 @@
Thrift.Utils;
type
-
IThriftStream = interface
- ['{2A77D916-7446-46C1-8545-0AEC0008DBCA}']
+ ['{3A61A8A6-3639-4B91-A260-EFCA23944F3A}']
procedure Write( const buffer: TBytes; offset: Integer; count: Integer); overload;
procedure Write( const pBuf : Pointer; offset: Integer; count: Integer); overload;
function Read( var buffer: TBytes; offset: Integer; count: Integer): Integer; overload;
@@ -48,12 +47,16 @@
procedure Flush;
function IsOpen: Boolean;
function ToArray: TBytes;
+ function Size : Int64;
+ function Position : Int64;
end;
- TThriftStreamImpl = class( TInterfacedObject, IThriftStream)
- private
+
+ TThriftStreamImpl = class abstract( TInterfacedObject, IThriftStream)
+ strict private
procedure CheckSizeAndOffset( const pBuf : Pointer; const buflen : Integer; offset: Integer; count: Integer); overload;
- protected
+ strict protected
+ // IThriftStream
procedure Write( const buffer: TBytes; offset: Integer; count: Integer); overload; inline;
procedure Write( const pBuf : Pointer; offset: Integer; count: Integer); overload; virtual;
function Read( var buffer: TBytes; offset: Integer; count: Integer): Integer; overload; inline;
@@ -63,13 +66,16 @@
procedure Flush; virtual; abstract;
function IsOpen: Boolean; virtual; abstract;
function ToArray: TBytes; virtual; abstract;
+ function Size : Int64; virtual;
+ function Position : Int64; virtual;
end;
- TThriftStreamAdapterDelphi = class( TThriftStreamImpl )
- private
+ TThriftStreamAdapterDelphi = class( TThriftStreamImpl)
+ strict private
FStream : TStream;
FOwnsStream : Boolean;
- protected
+ strict protected
+ // IThriftStream
procedure Write( const pBuf : Pointer; offset: Integer; count: Integer); override;
function Read( const pBuf : Pointer; const buflen : Integer; offset: Integer; count: Integer): Integer; override;
procedure Open; override;
@@ -77,15 +83,18 @@
procedure Flush; override;
function IsOpen: Boolean; override;
function ToArray: TBytes; override;
+ function Size : Int64; override;
+ function Position : Int64; override;
public
- constructor Create( const AStream: TStream; AOwnsStream : Boolean);
+ constructor Create( const aStream: TStream; aOwnsStream : Boolean);
destructor Destroy; override;
end;
TThriftStreamAdapterCOM = class( TThriftStreamImpl)
- private
+ strict private
FStream : IStream;
- protected
+ strict protected
+ // IThriftStream
procedure Write( const pBuf : Pointer; offset: Integer; count: Integer); override;
function Read( const pBuf : Pointer; const buflen : Integer; offset: Integer; count: Integer): Integer; override;
procedure Open; override;
@@ -93,12 +102,16 @@
procedure Flush; override;
function IsOpen: Boolean; override;
function ToArray: TBytes; override;
+ function Size : Int64; override;
+ function Position : Int64; override;
public
- constructor Create( const AStream: IStream);
+ constructor Create( const aStream: IStream);
end;
implementation
+uses Thrift.Transport;
+
{ TThriftStreamAdapterCOM }
procedure TThriftStreamAdapterCOM.Close;
@@ -106,10 +119,10 @@
FStream := nil;
end;
-constructor TThriftStreamAdapterCOM.Create( const AStream: IStream);
+constructor TThriftStreamAdapterCOM.Create( const aStream: IStream);
begin
inherited Create;
- FStream := AStream;
+ FStream := aStream;
end;
procedure TThriftStreamAdapterCOM.Flush;
@@ -121,6 +134,24 @@
end;
end;
+function TThriftStreamAdapterCOM.Size : Int64;
+var statstg: TStatStg;
+begin
+ FillChar( statstg, SizeOf( statstg), 0);
+ if IsOpen
+ and Succeeded( FStream.Stat( statstg, STATFLAG_NONAME ))
+ then result := statstg.cbSize
+ else result := 0;
+end;
+
+function TThriftStreamAdapterCOM.Position : Int64;
+var newpos : {$IF CompilerVersion >= 29.0} UInt64 {$ELSE} Int64 {$IFEND};
+begin
+ if SUCCEEDED( FStream.Seek( 0, STREAM_SEEK_CUR, newpos))
+ then result := Int64(newpos)
+ else raise TTransportExceptionEndOfFile.Create('Seek() error');
+end;
+
function TThriftStreamAdapterCOM.IsOpen: Boolean;
begin
Result := FStream <> nil;
@@ -151,19 +182,11 @@
function TThriftStreamAdapterCOM.ToArray: TBytes;
var
- statstg: TStatStg;
- len : Integer;
+ len : Int64;
NewPos : {$IF CompilerVersion >= 29.0} UInt64 {$ELSE} Int64 {$IFEND};
cbRead : Integer;
begin
- FillChar( statstg, SizeOf( statstg), 0);
- len := 0;
- if IsOpen then begin
- if Succeeded( FStream.Stat( statstg, STATFLAG_NONAME )) then begin
- len := statstg.cbSize;
- end;
- end;
-
+ len := Self.Size;
SetLength( Result, len );
if len > 0 then begin
@@ -225,8 +248,36 @@
CheckSizeAndOffset( pBuf, offset+count, offset, count);
end;
+function TThriftStreamImpl.Size : Int64;
+begin
+ ASSERT(FALSE);
+ raise ENotImplemented.Create(ClassName+'.Size');
+end;
+
+function TThriftStreamImpl.Position : Int64;
+begin
+ ASSERT(FALSE);
+ raise ENotImplemented.Create(ClassName+'.Position');
+end;
+
+
{ TThriftStreamAdapterDelphi }
+constructor TThriftStreamAdapterDelphi.Create( const aStream: TStream; aOwnsStream: Boolean);
+begin
+ inherited Create;
+ FStream := aStream;
+ FOwnsStream := aOwnsStream;
+end;
+
+destructor TThriftStreamAdapterDelphi.Destroy;
+begin
+ if FOwnsStream
+ then Close;
+
+ inherited;
+end;
+
procedure TThriftStreamAdapterDelphi.Close;
begin
FStream.Free;
@@ -234,26 +285,21 @@
FOwnsStream := False;
end;
-constructor TThriftStreamAdapterDelphi.Create( const AStream: TStream; AOwnsStream: Boolean);
-begin
- inherited Create;
- FStream := AStream;
- FOwnsStream := AOwnsStream;
-end;
-
-destructor TThriftStreamAdapterDelphi.Destroy;
-begin
- if FOwnsStream
- then Close;
-
- inherited;
-end;
-
procedure TThriftStreamAdapterDelphi.Flush;
begin
// nothing to do
end;
+function TThriftStreamAdapterDelphi.Size : Int64;
+begin
+ result := FStream.Size;
+end;
+
+function TThriftStreamAdapterDelphi.Position : Int64;
+begin
+ result := FStream.Position;
+end;
+
function TThriftStreamAdapterDelphi.IsOpen: Boolean;
begin
Result := FStream <> nil;
@@ -285,11 +331,9 @@
OrgPos : Integer;
len : Integer;
begin
- len := 0;
- if FStream <> nil then
- begin
- len := FStream.Size;
- end;
+ if FStream <> nil
+ then len := FStream.Size
+ else len := 0;
SetLength( Result, len );
diff --git a/lib/delphi/src/Thrift.Transport.MsxmlHTTP.pas b/lib/delphi/src/Thrift.Transport.MsxmlHTTP.pas
index c666e7f..bdc65d1 100644
--- a/lib/delphi/src/Thrift.Transport.MsxmlHTTP.pas
+++ b/lib/delphi/src/Thrift.Transport.MsxmlHTTP.pas
@@ -34,14 +34,15 @@
Winapi.ActiveX, Winapi.msxml,
{$ENDIF}
Thrift.Collections,
+ Thrift.Configuration,
Thrift.Transport,
Thrift.Exception,
Thrift.Utils,
Thrift.Stream;
type
- TMsxmlHTTPClientImpl = class( TTransportImpl, IHTTPClient)
- private
+ TMsxmlHTTPClientImpl = class( TEndpointTransportBase, IHTTPClient)
+ strict private
FUri : string;
FInputStream : IThriftStream;
FOutputStream : IThriftStream;
@@ -52,7 +53,7 @@
FCustomHeaders : IThriftDictionary<string,string>;
function CreateRequest: IXMLHTTPRequest;
- protected
+ strict protected
function GetIsOpen: Boolean; override;
procedure Open(); override;
procedure Close(); override;
@@ -80,26 +81,29 @@
property ReadTimeout: Integer read GetReadTimeout write SetReadTimeout;
property CustomHeaders: IThriftDictionary<string,string> read GetCustomHeaders;
public
- constructor Create( const AUri: string);
+ constructor Create( const aUri: string; const aConfig : IThriftConfiguration); reintroduce;
destructor Destroy; override;
end;
implementation
+const
+ XMLHTTP_CONNECTION_TIMEOUT = 60 * 1000;
+ XMLHTTP_SENDRECV_TIMEOUT = 30 * 1000;
{ TMsxmlHTTPClientImpl }
-constructor TMsxmlHTTPClientImpl.Create(const AUri: string);
+constructor TMsxmlHTTPClientImpl.Create( const aUri: string; const aConfig : IThriftConfiguration);
begin
- inherited Create;
- FUri := AUri;
+ inherited Create( aConfig);
+ FUri := aUri;
// defaults according to MSDN
FDnsResolveTimeout := 0; // no timeout
- FConnectionTimeout := 60 * 1000;
- FSendTimeout := 30 * 1000;
- FReadTimeout := 30 * 1000;
+ FConnectionTimeout := XMLHTTP_CONNECTION_TIMEOUT;
+ FSendTimeout := XMLHTTP_SENDRECV_TIMEOUT;
+ FReadTimeout := XMLHTTP_SENDRECV_TIMEOUT;
FCustomHeaders := TThriftDictionaryImpl<string,string>.Create;
FOutputStream := TThriftStreamAdapterDelphi.Create( TMemoryStream.Create, True);
@@ -225,7 +229,7 @@
end;
try
- Result := FInputStream.Read( pBuf, buflen, off, len)
+ Result := FInputStream.Read( pBuf, buflen, off, len);
except
on E: Exception
do raise TTransportExceptionUnknown.Create(E.Message);
@@ -252,6 +256,7 @@
xmlhttp.send( IUnknown( TStreamAdapter.Create( ms, soReference )));
FInputStream := nil;
FInputStream := TThriftStreamAdapterCOM.Create( IUnknown( xmlhttp.responseStream) as IStream);
+ UpdateKnownMessageSize( FInputStream.Size);
finally
ms.Free;
end;
diff --git a/lib/delphi/src/Thrift.Transport.Pipes.pas b/lib/delphi/src/Thrift.Transport.Pipes.pas
index 77a343b..635a841 100644
--- a/lib/delphi/src/Thrift.Transport.Pipes.pas
+++ b/lib/delphi/src/Thrift.Transport.Pipes.pas
@@ -29,6 +29,7 @@
{$ELSE}
Winapi.Windows, System.SysUtils, System.Math, Winapi.AccCtrl, Winapi.AclAPI, System.SyncObjs,
{$ENDIF}
+ Thrift.Configuration,
Thrift.Transport,
Thrift.Utils,
Thrift.Stream;
@@ -64,7 +65,9 @@
public
constructor Create( aEnableOverlapped : Boolean;
const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT;
- const aOpenTimeOut : DWORD = DEFAULT_THRIFT_PIPE_OPEN_TIMEOUT);
+ const aOpenTimeOut : DWORD = DEFAULT_THRIFT_PIPE_OPEN_TIMEOUT
+ ); reintroduce; overload;
+
destructor Destroy; override;
end;
@@ -84,7 +87,8 @@
const aShareMode: DWORD = 0;
const aSecurityAttributes: PSecurityAttributes = nil;
const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT;
- const aOpenTimeOut : DWORD = DEFAULT_THRIFT_PIPE_OPEN_TIMEOUT); overload;
+ const aOpenTimeOut : DWORD = DEFAULT_THRIFT_PIPE_OPEN_TIMEOUT
+ ); reintroduce; overload;
end;
@@ -98,7 +102,9 @@
public
constructor Create( const aPipeHandle : THandle;
const aOwnsHandle, aEnableOverlapped : Boolean;
- const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT); overload;
+ const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT
+ ); reintroduce; overload;
+
destructor Destroy; override;
end;
@@ -112,7 +118,7 @@
TPipeTransportBase = class( TStreamTransportImpl, IPipeTransport)
- public
+ strict protected
// ITransport
function GetIsOpen: Boolean; override;
procedure Open; override;
@@ -123,33 +129,46 @@
TNamedPipeTransportClientEndImpl = class( TPipeTransportBase)
public
// Named pipe constructors
- constructor Create( aPipe : THandle; aOwnsHandle : Boolean;
- const aTimeOut : DWORD); overload;
+ constructor Create( const aPipe : THandle;
+ const aOwnsHandle : Boolean;
+ const aTimeOut : DWORD;
+ const aConfig : IThriftConfiguration = nil
+ ); reintroduce; overload;
+
constructor Create( const aPipeName : string;
const aShareMode: DWORD = 0;
const aSecurityAttributes: PSecurityAttributes = nil;
const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT;
- const aOpenTimeOut : DWORD = DEFAULT_THRIFT_PIPE_OPEN_TIMEOUT); overload;
+ const aOpenTimeOut : DWORD = DEFAULT_THRIFT_PIPE_OPEN_TIMEOUT;
+ const aConfig : IThriftConfiguration = nil
+ ); reintroduce; overload;
end;
TNamedPipeTransportServerEndImpl = class( TNamedPipeTransportClientEndImpl)
strict private
FHandle : THandle;
- public
+ strict protected
// ITransport
procedure Close; override;
- constructor Create( aPipe : THandle; aOwnsHandle : Boolean;
- const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT); reintroduce;
+ public
+ constructor Create( const aPipe : THandle;
+ const aOwnsHandle : Boolean;
+ const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT;
+ const aConfig : IThriftConfiguration = nil
+ ); reintroduce; overload;
+
end;
TAnonymousPipeTransportImpl = class( TPipeTransportBase)
public
// Anonymous pipe constructor
- constructor Create(const aPipeRead, aPipeWrite : THandle;
- aOwnsHandles : Boolean;
- const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT); overload;
+ constructor Create( const aPipeRead, aPipeWrite : THandle;
+ const aOwnsHandles : Boolean;
+ const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT;
+ const aConfig : IThriftConfiguration = nil
+ ); reintroduce; overload;
end;
@@ -179,7 +198,7 @@
procedure InternalClose; virtual; abstract;
function QueryStopServer : Boolean;
public
- constructor Create;
+ constructor Create( const aConfig : IThriftConfiguration);
destructor Destroy; override;
procedure Listen; override;
procedure Close; override;
@@ -199,7 +218,7 @@
FClientAnonWrite : THandle;
FTimeOut: DWORD;
- protected
+ strict protected
function Accept(const fnAccepting: TProc): ITransport; override;
function CreateAnonPipe : Boolean;
@@ -213,7 +232,10 @@
procedure InternalClose; override;
public
- constructor Create(aBufsize : Cardinal = 4096; aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT);
+ constructor Create( const aBufsize : Cardinal = 4096;
+ const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT;
+ const aConfig : IThriftConfiguration = nil
+ ); reintroduce; overload;
end;
@@ -237,9 +259,12 @@
procedure InternalClose; override;
public
- constructor Create( aPipename : string; aBufsize : Cardinal = 4096;
- aMaxConns : Cardinal = PIPE_UNLIMITED_INSTANCES;
- aTimeOut : Cardinal = INFINITE);
+ constructor Create( const aPipename : string;
+ const aBufsize : Cardinal = 4096;
+ const aMaxConns : Cardinal = PIPE_UNLIMITED_INSTANCES;
+ const aTimeOut : Cardinal = INFINITE;
+ const aConfig : IThriftConfiguration = nil
+ ); reintroduce; overload;
end;
@@ -270,15 +295,14 @@
{ TPipeStreamBase }
-constructor TPipeStreamBase.Create( aEnableOverlapped : Boolean;
- const aTimeOut, aOpenTimeOut : DWORD);
+constructor TPipeStreamBase.Create( aEnableOverlapped : Boolean; const aTimeOut, aOpenTimeOut : DWORD);
begin
inherited Create;
- ASSERT( aTimeout > 0); // aOpenTimeout may be 0
FPipe := INVALID_HANDLE_VALUE;
FTimeout := aTimeOut;
FOpenTimeOut := aOpenTimeOut;
FOverlapped := aEnableOverlapped;
+ ASSERT( FTimeout > 0); // FOpenTimeout may be 0
end;
@@ -524,7 +548,7 @@
const aSecurityAttributes: PSecurityAttributes;
const aTimeOut, aOpenTimeOut : DWORD);
begin
- inherited Create( aEnableOverlapped, aTimeout, aOpenTimeOut);
+ inherited Create( aEnableOverlapped, aTimeOut, aOpenTimeOut);
FPipeName := aPipeName;
FShareMode := aShareMode;
@@ -587,7 +611,7 @@
const aOwnsHandle, aEnableOverlapped : Boolean;
const aTimeOut : DWORD);
begin
- inherited Create( aEnableOverlapped, aTimeOut);
+ inherited Create( aEnableOverlapped, aTimeout, aTimeout);
if aOwnsHandle
then FSrcHandle := aPipeHandle
@@ -641,23 +665,27 @@
{ TNamedPipeTransportClientEndImpl }
-constructor TNamedPipeTransportClientEndImpl.Create( const aPipeName : string; const aShareMode: DWORD;
- const aSecurityAttributes: PSecurityAttributes;
- const aTimeOut, aOpenTimeOut : DWORD);
+constructor TNamedPipeTransportClientEndImpl.Create( const aPipeName : string;
+ const aShareMode: DWORD;
+ const aSecurityAttributes: PSecurityAttributes;
+ const aTimeOut, aOpenTimeOut : DWORD;
+ const aConfig : IThriftConfiguration);
// Named pipe constructor
begin
- inherited Create( nil, nil);
+ inherited Create( nil, nil, aConfig);
FInputStream := TNamedPipeStreamImpl.Create( aPipeName, TRUE, aShareMode, aSecurityAttributes, aTimeOut, aOpenTimeOut);
FOutputStream := FInputStream; // true for named pipes
end;
-constructor TNamedPipeTransportClientEndImpl.Create( aPipe : THandle; aOwnsHandle : Boolean;
- const aTimeOut : DWORD);
+constructor TNamedPipeTransportClientEndImpl.Create( const aPipe : THandle;
+ const aOwnsHandle : Boolean;
+ const aTimeOut : DWORD;
+ const aConfig : IThriftConfiguration);
// Named pipe constructor
begin
- inherited Create( nil, nil);
- FInputStream := THandlePipeStreamImpl.Create( aPipe, TRUE, aOwnsHandle, aTimeOut);
+ inherited Create( nil, nil, aConfig);
+ FInputStream := THandlePipeStreamImpl.Create( aPipe, aOwnsHandle, TRUE, aTimeOut);
FOutputStream := FInputStream; // true for named pipes
end;
@@ -665,12 +693,14 @@
{ TNamedPipeTransportServerEndImpl }
-constructor TNamedPipeTransportServerEndImpl.Create( aPipe : THandle; aOwnsHandle : Boolean;
- const aTimeOut : DWORD);
+constructor TNamedPipeTransportServerEndImpl.Create( const aPipe : THandle;
+ const aOwnsHandle : Boolean;
+ const aTimeOut : DWORD;
+ const aConfig : IThriftConfiguration);
// Named pipe constructor
begin
FHandle := DuplicatePipeHandle( aPipe);
- inherited Create( aPipe, aOwnsHandle, aTimeOut);
+ inherited Create( aPipe, aOwnsHandle, aTimeout, aConfig);
end;
@@ -688,23 +718,24 @@
constructor TAnonymousPipeTransportImpl.Create( const aPipeRead, aPipeWrite : THandle;
- aOwnsHandles : Boolean;
- const aTimeOut : DWORD = DEFAULT_THRIFT_TIMEOUT);
+ const aOwnsHandles : Boolean;
+ const aTimeOut : DWORD;
+ const aConfig : IThriftConfiguration);
// Anonymous pipe constructor
begin
- inherited Create( nil, nil);
+ inherited Create( nil, nil, aConfig);
// overlapped is not supported with AnonPipes, see MSDN
- FInputStream := THandlePipeStreamImpl.Create( aPipeRead, aOwnsHandles, FALSE, aTimeOut);
- FOutputStream := THandlePipeStreamImpl.Create( aPipeWrite, aOwnsHandles, FALSE, aTimeOut);
+ FInputStream := THandlePipeStreamImpl.Create( aPipeRead, aOwnsHandles, FALSE, aTimeout);
+ FOutputStream := THandlePipeStreamImpl.Create( aPipeWrite, aOwnsHandles, FALSE, aTimeout);
end;
{ TPipeServerTransportBase }
-constructor TPipeServerTransportBase.Create;
+constructor TPipeServerTransportBase.Create( const aConfig : IThriftConfiguration);
begin
- inherited Create;
+ inherited Create( aConfig);
FStopServer := TEvent.Create(nil,TRUE,FALSE,''); // manual reset
end;
@@ -741,11 +772,12 @@
{ TAnonymousPipeServerTransportImpl }
-
-constructor TAnonymousPipeServerTransportImpl.Create(aBufsize : Cardinal; aTimeOut : DWORD);
+constructor TAnonymousPipeServerTransportImpl.Create( const aBufsize : Cardinal;
+ const aTimeOut : DWORD;
+ const aConfig : IThriftConfiguration);
// Anonymous pipe CTOR
begin
- inherited Create;
+ inherited Create(aConfig);
FBufsize := aBufSize;
FReadHandle := INVALID_HANDLE_VALUE;
FWriteHandle := INVALID_HANDLE_VALUE;
@@ -774,7 +806,7 @@
then raise TTransportExceptionNotOpen.Create('TServerPipe unable to initiate pipe communication');
// create the transport impl
- result := TAnonymousPipeTransportImpl.Create( FReadHandle, FWriteHandle, FALSE, FTimeOut);
+ result := TAnonymousPipeTransportImpl.Create( FReadHandle, FWriteHandle, FALSE, FTimeOut, Configuration);
end;
@@ -852,17 +884,19 @@
{ TNamedPipeServerTransportImpl }
-constructor TNamedPipeServerTransportImpl.Create( aPipename : string; aBufsize, aMaxConns, aTimeOut : Cardinal);
+constructor TNamedPipeServerTransportImpl.Create( const aPipename : string;
+ const aBufsize, aMaxConns, aTimeOut : Cardinal;
+ const aConfig : IThriftConfiguration);
// Named Pipe CTOR
begin
- inherited Create;
- ASSERT( aTimeout > 0);
+ inherited Create( aConfig);
FPipeName := aPipename;
FBufsize := aBufSize;
FMaxConns := Max( 1, Min( PIPE_UNLIMITED_INSTANCES, aMaxConns));
FHandle := INVALID_HANDLE_VALUE;
FTimeout := aTimeOut;
FConnected := FALSE;
+ ASSERT( FTimeout > 0);
if Copy(FPipeName,1,2) <> '\\'
then FPipeName := '\\.\pipe\' + FPipeName; // assume localhost
@@ -931,7 +965,7 @@
hPipe := THandle( InterlockedExchangePointer( Pointer(FHandle), Pointer(INVALID_HANDLE_VALUE)));
try
FConnected := FALSE;
- result := TNamedPipeTransportServerEndImpl.Create( hPipe, TRUE, FTimeout);
+ result := TNamedPipeTransportServerEndImpl.Create( hPipe, TRUE, FTimeout, Configuration);
except
ClosePipeHandle(hPipe);
raise;
diff --git a/lib/delphi/src/Thrift.Transport.WinHTTP.pas b/lib/delphi/src/Thrift.Transport.WinHTTP.pas
index 262e38f..7a1b48f 100644
--- a/lib/delphi/src/Thrift.Transport.WinHTTP.pas
+++ b/lib/delphi/src/Thrift.Transport.WinHTTP.pas
@@ -29,6 +29,7 @@
Math,
Generics.Collections,
Thrift.Collections,
+ Thrift.Configuration,
Thrift.Transport,
Thrift.Exception,
Thrift.Utils,
@@ -36,8 +37,8 @@
Thrift.Stream;
type
- TWinHTTPClientImpl = class( TTransportImpl, IHTTPClient)
- private
+ TWinHTTPClientImpl = class( TEndpointTransportBase, IHTTPClient)
+ strict private
FUri : string;
FInputStream : IThriftStream;
FOutputMemoryStream : TMemoryStream;
@@ -51,14 +52,14 @@
function CreateRequest: IWinHTTPRequest;
function SecureProtocolsAsWinHTTPFlags : Cardinal;
- private
+ strict private
type
TErrorInfo = ( SplitUrl, WinHTTPSession, WinHTTPConnection, WinHTTPRequest, RequestSetup, AutoProxy );
THTTPResponseStream = class( TThriftStreamImpl)
- private
+ strict private
FRequest : IWinHTTPRequest;
- protected
+ strict protected
procedure Write( const pBuf : Pointer; offset: Integer; count: Integer); override;
function Read( const pBuf : Pointer; const buflen : Integer; offset: Integer; count: Integer): Integer; override;
procedure Open; override;
@@ -71,7 +72,7 @@
destructor Destroy; override;
end;
- protected
+ strict protected
function GetIsOpen: Boolean; override;
procedure Open(); override;
procedure Close(); override;
@@ -99,25 +100,29 @@
property ReadTimeout: Integer read GetReadTimeout write SetReadTimeout;
property CustomHeaders: IThriftDictionary<string,string> read GetCustomHeaders;
public
- constructor Create( const AUri: string);
+ constructor Create( const aUri: string; const aConfig : IThriftConfiguration = nil);
destructor Destroy; override;
end;
implementation
+const
+ WINHTTP_CONNECTION_TIMEOUT = 60 * 1000;
+ WINHTTP_SENDRECV_TIMEOUT = 30 * 1000;
+
{ TWinHTTPClientImpl }
-constructor TWinHTTPClientImpl.Create(const AUri: string);
+constructor TWinHTTPClientImpl.Create( const aUri: string; const aConfig : IThriftConfiguration);
begin
- inherited Create;
+ inherited Create( aConfig);
FUri := AUri;
// defaults according to MSDN
FDnsResolveTimeout := 0; // no timeout
- FConnectionTimeout := 60 * 1000;
- FSendTimeout := 30 * 1000;
- FReadTimeout := 30 * 1000;
+ FConnectionTimeout := WINHTTP_CONNECTION_TIMEOUT;
+ FSendTimeout := WINHTTP_SENDRECV_TIMEOUT;
+ FReadTimeout := WINHTTP_SENDRECV_TIMEOUT;
FSecureProtocols := DEFAULT_THRIFT_SECUREPROTOCOLS;
@@ -258,7 +263,7 @@
function TWinHTTPClientImpl.GetIsOpen: Boolean;
begin
- Result := True;
+ Result := Assigned( FOutputMemoryStream);
end;
procedure TWinHTTPClientImpl.Open;
@@ -291,7 +296,8 @@
end;
try
- Result := FInputStream.Read( pBuf, buflen, off, len)
+ Result := FInputStream.Read( pBuf, buflen, off, len);
+ CountConsumedMessageBytes( result);
except
on E: Exception
do raise TTransportExceptionUnknown.Create(E.Message);
@@ -327,7 +333,8 @@
else raise TTransportExceptionInterrupted.Create( sMsg);
end;
- FInputStream := THTTPResponseStream.Create(http);
+ FInputStream := THTTPResponseStream.Create( http);
+ UpdateKnownMessageSize( http.QueryTotalResponseSize);
end;
procedure TWinHTTPClientImpl.Write( const pBuf : Pointer; off, len : Integer);
diff --git a/lib/delphi/src/Thrift.Transport.pas b/lib/delphi/src/Thrift.Transport.pas
index c2071df..af62548 100644
--- a/lib/delphi/src/Thrift.Transport.pas
+++ b/lib/delphi/src/Thrift.Transport.pas
@@ -38,20 +38,28 @@
Thrift.Socket,
{$ENDIF}
{$ENDIF}
+ Thrift.Configuration,
Thrift.Collections,
Thrift.Exception,
Thrift.Utils,
Thrift.WinHTTP,
Thrift.Stream;
+const
+ DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024; // 100 MB
+ DEFAULT_THRIFT_TIMEOUT = 5 * 1000; // ms
+
type
+ IStreamTransport = interface;
+
ITransport = interface
- ['{DB84961E-8BB3-4532-99E1-A8C7AC2300F7}']
+ ['{52F81383-F880-492F-8AA7-A66B85B93D6B}']
function GetIsOpen: Boolean;
property IsOpen: Boolean read GetIsOpen;
function Peek: Boolean;
procedure Open;
procedure Close;
+
function Read(var buf: TBytes; off: Integer; len: Integer): Integer; overload;
function Read(const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer; overload;
function ReadAll(var buf: TBytes; off: Integer; len: Integer): Integer; overload;
@@ -61,15 +69,22 @@
procedure Write( const pBuf : Pointer; off, len : Integer); overload;
procedure Write( const pBuf : Pointer; len : Integer); overload;
procedure Flush;
+
+ function Configuration : IThriftConfiguration;
+ function MaxMessageSize : Integer;
+ procedure ResetConsumedMessageSize( const knownSize : Int64 = -1);
+ procedure CheckReadBytesAvailable( const numBytes : Int64);
+ procedure UpdateKnownMessageSize( const size : Int64);
end;
- TTransportImpl = class( TInterfacedObject, ITransport)
- protected
+ TTransportBase = class abstract( TInterfacedObject)
+ strict protected
function GetIsOpen: Boolean; virtual; abstract;
property IsOpen: Boolean read GetIsOpen;
function Peek: Boolean; virtual;
procedure Open(); virtual; abstract;
procedure Close(); virtual; abstract;
+
function Read(var buf: TBytes; off: Integer; len: Integer): Integer; overload; inline;
function Read(const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer; overload; virtual; abstract;
function ReadAll(var buf: TBytes; off: Integer; len: Integer): Integer; overload; inline;
@@ -79,9 +94,48 @@
procedure Write( const pBuf : Pointer; len : Integer); overload; inline;
procedure Write( const pBuf : Pointer; off, len : Integer); overload; virtual; abstract;
procedure Flush; virtual;
+
+ function Configuration : IThriftConfiguration; virtual; abstract;
+ procedure UpdateKnownMessageSize( const size : Int64); virtual; abstract;
end;
- TTransportException = class( TException)
+ // base class for all endpoint transports, e.g. sockets, pipes or HTTP
+ TEndpointTransportBase = class abstract( TTransportBase, ITransport)
+ strict private
+ FRemainingMessageSize : Int64;
+ FKnownMessageSize : Int64;
+ FConfiguration : IThriftConfiguration;
+ strict protected
+ function Configuration : IThriftConfiguration; override;
+ function MaxMessageSize : Integer;
+ property RemainingMessageSize : Int64 read FRemainingMessageSize;
+ property KnownMessageSize : Int64 read FKnownMessageSize;
+ procedure ResetConsumedMessageSize( const newSize : Int64 = -1); inline;
+ procedure UpdateKnownMessageSize(const size : Int64); override;
+ procedure CheckReadBytesAvailable(const numBytes : Int64); inline;
+ procedure CountConsumedMessageBytes(const numBytes : Int64); inline;
+ public
+ constructor Create( const aConfig : IThriftConfiguration); reintroduce;
+ end;
+
+ // base class for all layered transports, e.g. framed
+ TLayeredTransportBase<T : ITransport> = class abstract( TTransportBase, ITransport)
+ strict private
+ FTransport : T;
+ strict protected
+ property InnerTransport : T read FTransport;
+ function GetUnderlyingTransport: ITransport;
+ function Configuration : IThriftConfiguration; override;
+ procedure UpdateKnownMessageSize( const size : Int64); override;
+ function MaxMessageSize : Integer; inline;
+ procedure ResetConsumedMessageSize( const knownSize : Int64 = -1); inline;
+ procedure CheckReadBytesAvailable( const numBytes : Int64); virtual;
+ public
+ constructor Create( const aTransport: T); reintroduce;
+ property UnderlyingTransport: ITransport read GetUnderlyingTransport;
+ end;
+
+ TTransportException = class abstract( TException)
public
type
TExceptionType = (
@@ -91,16 +145,16 @@
TimedOut,
EndOfFile,
BadArgs,
- Interrupted
+ Interrupted,
+ CorruptedData
);
- private
- function GetType: TExceptionType;
- protected
+ strict protected
constructor HiddenCreate(const Msg: string);
+ class function GetType: TExceptionType; virtual; abstract;
public
- class function Create( AType: TExceptionType): TTransportException; overload; deprecated 'Use specialized TTransportException types (or regenerate from IDL)';
+ class function Create( aType: TExceptionType): TTransportException; overload; deprecated 'Use specialized TTransportException types (or regenerate from IDL)';
class function Create( const msg: string): TTransportException; reintroduce; overload; deprecated 'Use specialized TTransportException types (or regenerate from IDL)';
- class function Create( AType: TExceptionType; const msg: string): TTransportException; overload; deprecated 'Use specialized TTransportException types (or regenerate from IDL)';
+ class function Create( aType: TExceptionType; const msg: string): TTransportException; overload; deprecated 'Use specialized TTransportException types (or regenerate from IDL)';
property Type_: TExceptionType read GetType;
end;
@@ -110,13 +164,45 @@
constructor Create(const Msg: string);
end;
- TTransportExceptionUnknown = class (TTransportExceptionSpecialized);
- TTransportExceptionNotOpen = class (TTransportExceptionSpecialized);
- TTransportExceptionAlreadyOpen = class (TTransportExceptionSpecialized);
- TTransportExceptionTimedOut = class (TTransportExceptionSpecialized);
- TTransportExceptionEndOfFile = class (TTransportExceptionSpecialized);
- TTransportExceptionBadArgs = class (TTransportExceptionSpecialized);
- TTransportExceptionInterrupted = class (TTransportExceptionSpecialized);
+ TTransportExceptionUnknown = class (TTransportExceptionSpecialized)
+ strict protected
+ class function GetType: TTransportException.TExceptionType; override;
+ end;
+
+ TTransportExceptionNotOpen = class (TTransportExceptionSpecialized)
+ strict protected
+ class function GetType: TTransportException.TExceptionType; override;
+ end;
+
+ TTransportExceptionAlreadyOpen = class (TTransportExceptionSpecialized)
+ strict protected
+ class function GetType: TTransportException.TExceptionType; override;
+ end;
+
+ TTransportExceptionTimedOut = class (TTransportExceptionSpecialized)
+ strict protected
+ class function GetType: TTransportException.TExceptionType; override;
+ end;
+
+ TTransportExceptionEndOfFile = class (TTransportExceptionSpecialized)
+ strict protected
+ class function GetType: TTransportException.TExceptionType; override;
+ end;
+
+ TTransportExceptionBadArgs = class (TTransportExceptionSpecialized)
+ strict protected
+ class function GetType: TTransportException.TExceptionType; override;
+ end;
+
+ TTransportExceptionInterrupted = class (TTransportExceptionSpecialized)
+ strict protected
+ class function GetType: TTransportException.TExceptionType; override;
+ end;
+
+ TTransportExceptionCorruptedData = class (TTransportExceptionSpecialized)
+ protected
+ class function GetType: TTransportException.TExceptionType; override;
+ end;
TSecureProtocol = (
SSL_2, SSL_3, TLS_1, // outdated, for compatibilty only
@@ -149,33 +235,41 @@
end;
IServerTransport = interface
- ['{C43B87ED-69EA-47C4-B77C-15E288252900}']
+ ['{FA01363F-6B40-482F-971E-4A085535EFC8}']
procedure Listen;
procedure Close;
function Accept( const fnAccepting: TProc): ITransport;
+ function Configuration : IThriftConfiguration;
end;
TServerTransportImpl = class( TInterfacedObject, IServerTransport)
- protected
+ strict private
+ FConfig : IThriftConfiguration;
+ strict protected
+ function Configuration : IThriftConfiguration;
procedure Listen; virtual; abstract;
procedure Close; virtual; abstract;
- function Accept( const fnAccepting: TProc): ITransport; virtual; abstract;
+ function Accept( const fnAccepting: TProc): ITransport; virtual; abstract;
+ public
+ constructor Create( const aConfig : IThriftConfiguration);
end;
ITransportFactory = interface
['{DD809446-000F-49E1-9BFF-E0D0DC76A9D7}']
- function GetTransport( const ATrans: ITransport): ITransport;
+ function GetTransport( const aTransport: ITransport): ITransport;
end;
- TTransportFactoryImpl = class( TInterfacedObject, ITransportFactory)
- function GetTransport( const ATrans: ITransport): ITransport; virtual;
+ TTransportFactoryImpl = class ( TInterfacedObject, ITransportFactory)
+ strict protected
+ function GetTransport( const aTransport: ITransport): ITransport; virtual;
end;
- TTcpSocketStreamImpl = class( TThriftStreamImpl )
+
+ TTcpSocketStreamImpl = class( TThriftStreamImpl)
{$IFDEF OLD_SOCKETS}
- private type
+ strict private type
TWaitForData = ( wfd_HaveData, wfd_Timeout, wfd_Error);
- private
+ strict private
FTcpClient : TCustomIpClient;
FTimeout : Integer;
function Select( ReadReady, WriteReady, ExceptFlag: PBoolean;
@@ -184,10 +278,10 @@
var wsaError, bytesReady : Integer): TWaitForData;
{$ELSE}
FTcpClient: TSocket;
- protected const
+ strict protected const
SLEEP_TIME = 200;
{$ENDIF}
- protected
+ strict protected
procedure Write( const pBuf : Pointer; offset, count: Integer); override;
function Read( const pBuf : Pointer; const buflen : Integer; offset: Integer; count: Integer): Integer; override;
procedure Open; override;
@@ -198,9 +292,9 @@
function ToArray: TBytes; override;
public
{$IFDEF OLD_SOCKETS}
- constructor Create( const ATcpClient: TCustomIpClient; const aTimeout : Integer = 0);
+ constructor Create( const aTcpClient: TCustomIpClient; const aTimeout : Integer = DEFAULT_THRIFT_TIMEOUT);
{$ELSE}
- constructor Create( const ATcpClient: TSocket; const aTimeout : Longword = 0);
+ constructor Create( const aTcpClient: TSocket; const aTimeout : Longword = DEFAULT_THRIFT_TIMEOUT);
{$ENDIF}
end;
@@ -212,35 +306,37 @@
property OutputStream : IThriftStream read GetOutputStream;
end;
- TStreamTransportImpl = class( TTransportImpl, IStreamTransport)
- protected
+ TStreamTransportImpl = class( TEndpointTransportBase, IStreamTransport)
+ strict protected
FInputStream : IThriftStream;
FOutputStream : IThriftStream;
- protected
+ strict protected
function GetIsOpen: Boolean; override;
function GetInputStream: IThriftStream;
function GetOutputStream: IThriftStream;
- public
- property InputStream : IThriftStream read GetInputStream;
- property OutputStream : IThriftStream read GetOutputStream;
+ strict protected
procedure Open; override;
procedure Close; override;
procedure Flush; override;
function Read( const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer; override;
procedure Write( const pBuf : Pointer; off, len : Integer); override;
- constructor Create( const AInputStream : IThriftStream; const AOutputStream : IThriftStream);
+ public
+ constructor Create( const aInputStream, aOutputStream : IThriftStream; const aConfig : IThriftConfiguration = nil); reintroduce;
destructor Destroy; override;
+
+ property InputStream : IThriftStream read GetInputStream;
+ property OutputStream : IThriftStream read GetOutputStream;
end;
TBufferedStreamImpl = class( TThriftStreamImpl)
- private
+ strict private
FStream : IThriftStream;
FBufSize : Integer;
FReadBuffer : TMemoryStream;
FWriteBuffer : TMemoryStream;
- protected
+ strict protected
procedure Write( const pBuf : Pointer; offset: Integer; count: Integer); override;
function Read( const pBuf : Pointer; const buflen : Integer; offset: Integer; count: Integer): Integer; override;
procedure Open; override;
@@ -248,13 +344,15 @@
procedure Flush; override;
function IsOpen: Boolean; override;
function ToArray: TBytes; override;
+ function Size : Int64; override;
+ function Position : Int64; override;
public
- constructor Create( const AStream: IThriftStream; ABufSize: Integer);
+ constructor Create( const aStream: IThriftStream; const aBufSize : Integer);
destructor Destroy; override;
end;
TServerSocketImpl = class( TServerTransportImpl)
- private
+ strict private
{$IFDEF OLD_SOCKETS}
FServer : TTcpServer;
FPort : Integer;
@@ -264,46 +362,52 @@
{$ENDIF}
FUseBufferedSocket : Boolean;
FOwnsServer : Boolean;
- protected
+
+ strict protected
function Accept( const fnAccepting: TProc) : ITransport; override;
+
public
-{$IFDEF OLD_SOCKETS}
- constructor Create( const AServer: TTcpServer; AClientTimeout: Integer = 0); overload;
- constructor Create( APort: Integer; AClientTimeout: Integer = 0; AUseBufferedSockets: Boolean = FALSE); overload;
-{$ELSE}
- constructor Create( const AServer: TServerSocket; AClientTimeout: Longword = 0); overload;
- constructor Create( APort: Integer; AClientTimeout: Longword = 0; AUseBufferedSockets: Boolean = FALSE); overload;
-{$ENDIF}
+ {$IFDEF OLD_SOCKETS}
+ constructor Create( const aServer: TTcpServer; const aClientTimeout : Integer = DEFAULT_THRIFT_TIMEOUT; const aConfig : IThriftConfiguration = nil); overload;
+ constructor Create( const aPort: Integer; const aClientTimeout: Integer = DEFAULT_THRIFT_TIMEOUT; aUseBufferedSockets: Boolean = FALSE; const aConfig : IThriftConfiguration = nil); overload;
+ {$ELSE}
+ constructor Create( const aServer: TServerSocket; const aClientTimeout: Longword = DEFAULT_THRIFT_TIMEOUT; const aConfig : IThriftConfiguration = nil); overload;
+ constructor Create( const aPort: Integer; const aClientTimeout: Longword = DEFAULT_THRIFT_TIMEOUT; aUseBufferedSockets: Boolean = FALSE; const aConfig : IThriftConfiguration = nil); overload;
+ {$ENDIF}
+
destructor Destroy; override;
procedure Listen; override;
procedure Close; override;
end;
- TBufferedTransportImpl = class( TTransportImpl )
- private
+ TBufferedTransportImpl = class( TLayeredTransportBase<IStreamTransport>)
+ strict private
FInputBuffer : IThriftStream;
FOutputBuffer : IThriftStream;
- FTransport : IStreamTransport;
FBufSize : Integer;
procedure InitBuffers;
- function GetUnderlyingTransport: ITransport;
- protected
+ strict protected
function GetIsOpen: Boolean; override;
procedure Flush; override;
public
+ type
+ TFactory = class( TTransportFactoryImpl )
+ public
+ function GetTransport( const aTransport: ITransport): ITransport; override;
+ end;
+
+ constructor Create( const aTransport : IStreamTransport; const aBufSize: Integer = 1024);
procedure Open(); override;
procedure Close(); override;
function Read( const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer; override;
procedure Write( const pBuf : Pointer; off, len : Integer); override;
- constructor Create( const ATransport : IStreamTransport ); overload;
- constructor Create( const ATransport : IStreamTransport; ABufSize: Integer); overload;
- property UnderlyingTransport: ITransport read GetUnderlyingTransport;
+ procedure CheckReadBytesAvailable( const value : Int64); override;
property IsOpen: Boolean read GetIsOpen;
end;
TSocketImpl = class(TStreamTransportImpl)
- private
+ strict private
{$IFDEF OLD_SOCKETS}
FClient : TCustomIpClient;
{$ELSE}
@@ -319,18 +423,19 @@
{$ENDIF}
procedure InitSocket;
- protected
+ strict protected
function GetIsOpen: Boolean; override;
public
- procedure Open; override;
{$IFDEF OLD_SOCKETS}
- constructor Create( const AClient : TCustomIpClient; aOwnsClient : Boolean; ATimeout: Integer = 0); overload;
- constructor Create( const AHost: string; APort: Integer; ATimeout: Integer = 0); overload;
+ constructor Create( const aClient : TCustomIpClient; const aOwnsClient : Boolean; const aTimeout: Integer = DEFAULT_THRIFT_TIMEOUT; const aConfig : IThriftConfiguration = nil); overload;
+ constructor Create( const aHost: string; const aPort: Integer; const aTimeout: Integer = DEFAULT_THRIFT_TIMEOUT; const aConfig : IThriftConfiguration = nil); overload;
{$ELSE}
- constructor Create(const AClient: TSocket; aOwnsClient: Boolean); overload;
- constructor Create( const AHost: string; APort: Integer; ATimeout: Longword = 0); overload;
+ constructor Create(const aClient: TSocket; const aOwnsClient: Boolean; const aConfig : IThriftConfiguration = nil); overload;
+ constructor Create( const aHost: string; const aPort: Integer; const aTimeout: Longword = DEFAULT_THRIFT_TIMEOUT; const aConfig : IThriftConfiguration = nil); overload;
{$ENDIF}
destructor Destroy; override;
+
+ procedure Open; override;
procedure Close; override;
{$IFDEF OLD_SOCKETS}
property TcpClient: TCustomIpClient read FClient;
@@ -341,93 +446,70 @@
property Port: Integer read FPort;
end;
- TFramedTransportImpl = class( TTransportImpl)
- private const
- FHeaderSize : Integer = 4;
- private class var
- FHeader_Dummy : array of Byte;
- protected
- FTransport : ITransport;
+ TFramedTransportImpl = class( TLayeredTransportBase<ITransport>)
+ strict protected type
+ TFramedHeader = Int32;
+ strict protected
FWriteBuffer : TMemoryStream;
FReadBuffer : TMemoryStream;
procedure InitWriteBuffer;
procedure ReadFrame;
- public
- type
- TFactory = class( TTransportFactoryImpl )
- public
- function GetTransport( const ATrans: ITransport): ITransport; override;
- end;
-
- {$IFDEF HAVE_CLASS_CTOR}
- class constructor Create;
- {$ENDIF}
-
- constructor Create; overload;
- constructor Create( const ATrans: ITransport); overload;
- destructor Destroy; override;
procedure Open(); override;
- function GetIsOpen: Boolean; override;
+ function GetIsOpen: Boolean; override;
procedure Close(); override;
function Read( const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer; override;
procedure Write( const pBuf : Pointer; off, len : Integer); override;
+ procedure CheckReadBytesAvailable( const value : Int64); override;
procedure Flush; override;
+
+ public
+ type
+ TFactory = class( TTransportFactoryImpl )
+ public
+ function GetTransport( const aTransport: ITransport): ITransport; override;
+ end;
+
+ constructor Create( const aTransport: ITransport); overload;
+ destructor Destroy; override;
end;
-{$IFNDEF HAVE_CLASS_CTOR}
-procedure TFramedTransportImpl_Initialize;
-{$ENDIF}
const
- DEFAULT_THRIFT_TIMEOUT = 5 * 1000; // ms
DEFAULT_THRIFT_SECUREPROTOCOLS = [ TSecureProtocol.TLS_1_1, TSecureProtocol.TLS_1_2];
-
-
implementation
-{ TTransportImpl }
-procedure TTransportImpl.Flush;
+{ TTransportBase }
+
+procedure TTransportBase.Flush;
begin
// nothing to do
end;
-function TTransportImpl.Peek: Boolean;
+function TTransportBase.Peek: Boolean;
begin
Result := IsOpen;
end;
-function TTransportImpl.Read(var buf: TBytes; off: Integer; len: Integer): Integer;
+function TTransportBase.Read(var buf: TBytes; off: Integer; len: Integer): Integer;
begin
if Length(buf) > 0
then result := Read( @buf[0], Length(buf), off, len)
else result := 0;
end;
-function TTransportImpl.ReadAll(var buf: TBytes; off: Integer; len: Integer): Integer;
+function TTransportBase.ReadAll(var buf: TBytes; off: Integer; len: Integer): Integer;
begin
if Length(buf) > 0
then result := ReadAll( @buf[0], Length(buf), off, len)
else result := 0;
end;
-procedure TTransportImpl.Write( const buf: TBytes);
-begin
- if Length(buf) > 0
- then Write( @buf[0], 0, Length(buf));
-end;
-
-procedure TTransportImpl.Write( const buf: TBytes; off: Integer; len: Integer);
-begin
- if Length(buf) > 0
- then Write( @buf[0], off, len);
-end;
-
-function TTransportImpl.ReadAll(const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer;
+function TTransportBase.ReadAll(const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer;
var ret : Integer;
begin
result := 0;
@@ -439,41 +521,162 @@
end;
end;
-procedure TTransportImpl.Write( const pBuf : Pointer; len : Integer);
+procedure TTransportBase.Write( const buf: TBytes);
+begin
+ if Length(buf) > 0
+ then Write( @buf[0], 0, Length(buf));
+end;
+
+procedure TTransportBase.Write( const buf: TBytes; off: Integer; len: Integer);
+begin
+ if Length(buf) > 0
+ then Write( @buf[0], off, len);
+end;
+
+procedure TTransportBase.Write( const pBuf : Pointer; len : Integer);
begin
Self.Write( pBuf, 0, len);
end;
-{ TTransportException }
-function TTransportException.GetType: TExceptionType;
+{ TEndpointTransportBase }
+
+constructor TEndpointTransportBase.Create( const aConfig : IThriftConfiguration);
begin
- if Self is TTransportExceptionNotOpen then Result := TExceptionType.NotOpen
- else if Self is TTransportExceptionAlreadyOpen then Result := TExceptionType.AlreadyOpen
- else if Self is TTransportExceptionTimedOut then Result := TExceptionType.TimedOut
- else if Self is TTransportExceptionEndOfFile then Result := TExceptionType.EndOfFile
- else if Self is TTransportExceptionBadArgs then Result := TExceptionType.BadArgs
- else if Self is TTransportExceptionInterrupted then Result := TExceptionType.Interrupted
- else Result := TExceptionType.Unknown;
+ inherited Create;
+
+ if aConfig <> nil
+ then FConfiguration := aConfig
+ else FConfiguration := TThriftConfigurationImpl.Create;
+
+ ResetConsumedMessageSize;
end;
+
+function TEndpointTransportBase.Configuration : IThriftConfiguration;
+begin
+ result := FConfiguration;
+end;
+
+
+function TEndpointTransportBase.MaxMessageSize : Integer;
+begin
+ ASSERT( Configuration <> nil);
+ result := Configuration.MaxMessageSize;
+end;
+
+
+procedure TEndpointTransportBase.ResetConsumedMessageSize( const newSize : Int64);
+// Resets RemainingMessageSize to the configured maximum
+begin
+ // full reset
+ if newSize < 0 then begin
+ FKnownMessageSize := MaxMessageSize;
+ FRemainingMessageSize := MaxMessageSize;
+ Exit;
+ end;
+
+ // update only: message size can shrink, but not grow
+ ASSERT( KnownMessageSize <= MaxMessageSize);
+ if newSize > KnownMessageSize
+ then TTransportExceptionEndOfFile.Create('MaxMessageSize reached');
+
+ FKnownMessageSize := newSize;
+ FRemainingMessageSize := newSize;
+end;
+
+
+procedure TEndpointTransportBase.UpdateKnownMessageSize( const size : Int64);
+// Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport).
+// Will throw if we already consumed too many bytes.
+var consumed : Int64;
+begin
+ consumed := KnownMessageSize - RemainingMessageSize;
+ ResetConsumedMessageSize(size);
+ CountConsumedMessageBytes(consumed);
+end;
+
+
+procedure TEndpointTransportBase.CheckReadBytesAvailable( const numBytes : Int64);
+// Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data
+begin
+ if RemainingMessageSize < numBytes
+ then raise TTransportExceptionEndOfFile.Create('MaxMessageSize reached');
+end;
+
+
+procedure TEndpointTransportBase.CountConsumedMessageBytes( const numBytes : Int64);
+// Consumes numBytes from the RemainingMessageSize.
+begin
+ if (RemainingMessageSize >= numBytes)
+ then Dec( FRemainingMessageSize, numBytes)
+ else begin
+ FRemainingMessageSize := 0;
+ raise TTransportExceptionEndOfFile.Create('MaxMessageSize reached');
+ end;
+end;
+
+{ TLayeredTransportBase }
+
+constructor TLayeredTransportBase<T>.Create( const aTransport: T);
+begin
+ inherited Create;
+ FTransport := aTransport;
+end;
+
+function TLayeredTransportBase<T>.GetUnderlyingTransport: ITransport;
+begin
+ result := InnerTransport;
+end;
+
+function TLayeredTransportBase<T>.Configuration : IThriftConfiguration;
+begin
+ result := InnerTransport.Configuration;
+end;
+
+procedure TLayeredTransportBase<T>.UpdateKnownMessageSize( const size : Int64);
+begin
+ InnerTransport.UpdateKnownMessageSize( size);
+end;
+
+
+function TLayeredTransportBase<T>.MaxMessageSize : Integer;
+begin
+ result := InnerTransport.MaxMessageSize;
+end;
+
+
+procedure TLayeredTransportBase<T>.ResetConsumedMessageSize( const knownSize : Int64 = -1);
+begin
+ InnerTransport.ResetConsumedMessageSize( knownSize);
+end;
+
+
+procedure TLayeredTransportBase<T>.CheckReadBytesAvailable( const numBytes : Int64);
+begin
+ InnerTransport.CheckReadBytesAvailable( numBytes);
+end;
+
+
+
+{ TTransportException }
+
constructor TTransportException.HiddenCreate(const Msg: string);
begin
inherited Create(Msg);
end;
-class function TTransportException.Create(AType: TExceptionType): TTransportException;
+class function TTransportException.Create(aType: TExceptionType): TTransportException;
begin
//no inherited;
{$WARN SYMBOL_DEPRECATED OFF}
- Result := Create(AType, '')
+ Result := Create(aType, '')
{$WARN SYMBOL_DEPRECATED DEFAULT}
end;
-class function TTransportException.Create(AType: TExceptionType;
- const msg: string): TTransportException;
+class function TTransportException.Create(aType: TExceptionType; const msg: string): TTransportException;
begin
- case AType of
+ case aType of
TExceptionType.NotOpen: Result := TTransportExceptionNotOpen.Create(msg);
TExceptionType.AlreadyOpen: Result := TTransportExceptionAlreadyOpen.Create(msg);
TExceptionType.TimedOut: Result := TTransportExceptionTimedOut.Create(msg);
@@ -481,6 +684,7 @@
TExceptionType.BadArgs: Result := TTransportExceptionBadArgs.Create(msg);
TExceptionType.Interrupted: Result := TTransportExceptionInterrupted.Create(msg);
else
+ ASSERT( TExceptionType.Unknown = aType);
Result := TTransportExceptionUnknown.Create(msg);
end;
end;
@@ -497,42 +701,105 @@
inherited HiddenCreate(Msg);
end;
+{ specialized TTransportExceptions }
+
+class function TTransportExceptionUnknown.GetType: TTransportException.TExceptionType;
+begin
+ result := TExceptionType.Unknown;
+end;
+
+class function TTransportExceptionNotOpen.GetType: TTransportException.TExceptionType;
+begin
+ result := TExceptionType.NotOpen;
+end;
+
+class function TTransportExceptionAlreadyOpen.GetType: TTransportException.TExceptionType;
+begin
+ result := TExceptionType.AlreadyOpen;
+end;
+
+class function TTransportExceptionTimedOut.GetType: TTransportException.TExceptionType;
+begin
+ result := TExceptionType.TimedOut;
+end;
+
+class function TTransportExceptionEndOfFile.GetType: TTransportException.TExceptionType;
+begin
+ result := TExceptionType.EndOfFile;
+end;
+
+class function TTransportExceptionBadArgs.GetType: TTransportException.TExceptionType;
+begin
+ result := TExceptionType.BadArgs;
+end;
+
+class function TTransportExceptionInterrupted.GetType: TTransportException.TExceptionType;
+begin
+ result := TExceptionType.Interrupted;
+end;
+
+class function TTransportExceptionCorruptedData.GetType: TTransportException.TExceptionType;
+begin
+ result := TExceptionType.CorruptedData;
+end;
+
{ TTransportFactoryImpl }
-function TTransportFactoryImpl.GetTransport( const ATrans: ITransport): ITransport;
+function TTransportFactoryImpl.GetTransport( const aTransport: ITransport): ITransport;
begin
- Result := ATrans;
+ Result := aTransport;
+end;
+
+
+{ TServerTransportImpl }
+
+constructor TServerTransportImpl.Create( const aConfig : IThriftConfiguration);
+begin
+ inherited Create;
+ if aConfig <> nil
+ then FConfig := aConfig
+ else FConfig := TThriftConfigurationImpl.Create;
+end;
+
+function TServerTransportImpl.Configuration : IThriftConfiguration;
+begin
+ result := FConfig;
end;
{ TServerSocket }
{$IFDEF OLD_SOCKETS}
-constructor TServerSocketImpl.Create( const AServer: TTcpServer; AClientTimeout: Integer);
-begin
- inherited Create;
- FServer := AServer;
- FClientTimeout := AClientTimeout;
-end;
+constructor TServerSocketImpl.Create( const aServer: TTcpServer; const aClientTimeout : Integer; const aConfig : IThriftConfiguration);
{$ELSE}
-constructor TServerSocketImpl.Create( const AServer: TServerSocket; AClientTimeout: Longword);
-begin
- inherited Create;
- FServer := AServer;
- FServer.RecvTimeout := AClientTimeout;
- FServer.SendTimeout := AClientTimeout;
-end;
+constructor TServerSocketImpl.Create( const aServer: TServerSocket; const aClientTimeout: Longword; const aConfig : IThriftConfiguration);
{$ENDIF}
+begin
+ inherited Create( aConfig);
+ FServer := aServer;
+
{$IFDEF OLD_SOCKETS}
-constructor TServerSocketImpl.Create(APort, AClientTimeout: Integer; AUseBufferedSockets: Boolean);
+ FClientTimeout := aClientTimeout;
{$ELSE}
-constructor TServerSocketImpl.Create(APort: Integer; AClientTimeout: Longword; AUseBufferedSockets: Boolean);
+ FServer.RecvTimeout := aClientTimeout;
+ FServer.SendTimeout := aClientTimeout;
+{$ENDIF}
+end;
+
+
+{$IFDEF OLD_SOCKETS}
+constructor TServerSocketImpl.Create( const aPort: Integer; const aClientTimeout: Integer; aUseBufferedSockets: Boolean; const aConfig : IThriftConfiguration);
+{$ELSE}
+constructor TServerSocketImpl.Create( const aPort: Integer; const aClientTimeout: Longword; aUseBufferedSockets: Boolean; const aConfig : IThriftConfiguration);
{$ENDIF}
begin
- inherited Create;
+ inherited Create( aConfig);
+
{$IFDEF OLD_SOCKETS}
- FPort := APort;
- FClientTimeout := AClientTimeout;
+ FPort := aPort;
+ FClientTimeout := aClientTimeout;
+
+ FOwnsServer := True;
FServer := TTcpServer.Create( nil );
FServer.BlockMode := bmBlocking;
{$IF CompilerVersion >= 21.0}
@@ -541,10 +808,11 @@
FServer.LocalPort := IntToStr( FPort);
{$IFEND}
{$ELSE}
- FServer := TServerSocket.Create(APort, AClientTimeout, AClientTimeout);
-{$ENDIF}
- FUseBufferedSocket := AUseBufferedSockets;
FOwnsServer := True;
+ FServer := TServerSocket.Create(aPort, aClientTimeout, aClientTimeout);
+{$ENDIF}
+
+ FUseBufferedSocket := aUseBufferedSockets;
end;
destructor TServerSocketImpl.Destroy;
@@ -588,7 +856,7 @@
Exit;
end;
- trans := TSocketImpl.Create( client, TRUE, FClientTimeout);
+ trans := TSocketImpl.Create( client, TRUE, FClientTimeout, Configuration);
client := nil; // trans owns it now
if FUseBufferedSocket
@@ -607,7 +875,7 @@
client := FServer.Accept;
try
- trans := TSocketImpl.Create(client, True);
+ trans := TSocketImpl.Create(client, TRUE, Configuration);
client := nil;
if FUseBufferedSocket then
@@ -656,37 +924,36 @@
{ TSocket }
{$IFDEF OLD_SOCKETS}
-constructor TSocketImpl.Create( const AClient : TCustomIpClient; aOwnsClient : Boolean; ATimeout: Integer = 0);
-var stream : IThriftStream;
-begin
- FClient := AClient;
- FTimeout := ATimeout;
- FOwnsClient := aOwnsClient;
- stream := TTcpSocketStreamImpl.Create( FClient, FTimeout);
- inherited Create( stream, stream);
-end;
+constructor TSocketImpl.Create( const aClient : TCustomIpClient; const aOwnsClient : Boolean; const aTimeout: Integer; const aConfig : IThriftConfiguration);
{$ELSE}
-constructor TSocketImpl.Create(const AClient: TSocket; aOwnsClient: Boolean);
+constructor TSocketImpl.Create(const aClient: TSocket; const aOwnsClient: Boolean; const aConfig : IThriftConfiguration);
+{$ENDIF}
var stream : IThriftStream;
begin
- FClient := AClient;
- FTimeout := AClient.RecvTimeout;
+ FClient := aClient;
FOwnsClient := aOwnsClient;
- stream := TTcpSocketStreamImpl.Create(FClient, FTimeout);
- inherited Create(stream, stream);
-end;
-{$ENDIF}
{$IFDEF OLD_SOCKETS}
-constructor TSocketImpl.Create(const AHost: string; APort, ATimeout: Integer);
+ FTimeout := aTimeout;
{$ELSE}
-constructor TSocketImpl.Create(const AHost: string; APort: Integer; ATimeout: Longword);
+ FTimeout := aClient.RecvTimeout;
+{$ENDIF}
+
+ stream := TTcpSocketStreamImpl.Create( FClient, FTimeout);
+ inherited Create( stream, stream, aConfig);
+end;
+
+
+{$IFDEF OLD_SOCKETS}
+constructor TSocketImpl.Create(const aHost: string; const aPort, aTimeout: Integer; const aConfig : IThriftConfiguration);
+{$ELSE}
+constructor TSocketImpl.Create(const aHost: string; const aPort : Integer; const aTimeout: Longword; const aConfig : IThriftConfiguration);
{$ENDIF}
begin
- inherited Create(nil,nil);
- FHost := AHost;
- FPort := APort;
- FTimeout := ATimeout;
+ inherited Create(nil,nil, aConfig);
+ FHost := aHost;
+ FPort := aPort;
+ FTimeout := aTimeout;
InitSocket;
end;
@@ -781,11 +1048,11 @@
FWriteBuffer := nil;
end;
-constructor TBufferedStreamImpl.Create( const AStream: IThriftStream; ABufSize: Integer);
+constructor TBufferedStreamImpl.Create( const aStream: IThriftStream; const aBufSize : Integer);
begin
inherited Create;
- FStream := AStream;
- FBufSize := ABufSize;
+ FStream := aStream;
+ FBufSize := aBufSize;
FReadBuffer := TMemoryStream.Create;
FWriteBuffer := TMemoryStream.Create;
end;
@@ -860,14 +1127,13 @@
end;
end;
+
function TBufferedStreamImpl.ToArray: TBytes;
var len : Integer;
begin
- len := 0;
-
- if IsOpen then begin
- len := FReadBuffer.Size;
- end;
+ if IsOpen
+ then len := FReadBuffer.Size
+ else len := 0;
SetLength( Result, len);
@@ -893,13 +1159,26 @@
end;
end;
+
+function TBufferedStreamImpl.Size : Int64;
+begin
+ result := FReadBuffer.Size;
+end;
+
+
+function TBufferedStreamImpl.Position : Int64;
+begin
+ result := FReadBuffer.Position;
+end;
+
+
{ TStreamTransportImpl }
-constructor TStreamTransportImpl.Create( const AInputStream : IThriftStream; const AOutputStream : IThriftStream);
+constructor TStreamTransportImpl.Create( const aInputStream, aOutputStream : IThriftStream; const aConfig : IThriftConfiguration);
begin
- inherited Create;
- FInputStream := AInputStream;
- FOutputStream := AOutputStream;
+ inherited Create( aConfig);
+ FInputStream := aInputStream;
+ FOutputStream := aOutputStream;
end;
destructor TStreamTransportImpl.Destroy;
@@ -941,48 +1220,41 @@
procedure TStreamTransportImpl.Open;
begin
-
+ // nothing to do
end;
function TStreamTransportImpl.Read( const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer;
begin
- if FInputStream = nil then begin
- raise TTransportExceptionNotOpen.Create('Cannot read from null inputstream' );
- end;
+ if FInputStream = nil
+ then raise TTransportExceptionNotOpen.Create('Cannot read from null inputstream' );
Result := FInputStream.Read( pBuf,buflen, off, len );
+ CountConsumedMessageBytes( result);
end;
procedure TStreamTransportImpl.Write( const pBuf : Pointer; off, len : Integer);
begin
- if FOutputStream = nil then begin
- raise TTransportExceptionNotOpen.Create('Cannot write to null outputstream' );
- end;
+ if FOutputStream = nil
+ then raise TTransportExceptionNotOpen.Create('Cannot write to null outputstream' );
FOutputStream.Write( pBuf, off, len );
end;
{ TBufferedTransportImpl }
-constructor TBufferedTransportImpl.Create( const ATransport: IStreamTransport);
+constructor TBufferedTransportImpl.Create( const aTransport : IStreamTransport; const aBufSize: Integer);
begin
- //no inherited;
- Create( ATransport, 1024 );
-end;
-
-constructor TBufferedTransportImpl.Create( const ATransport: IStreamTransport; ABufSize: Integer);
-begin
- inherited Create;
- FTransport := ATransport;
- FBufSize := ABufSize;
+ ASSERT( aTransport <> nil);
+ inherited Create( aTransport);
+ FBufSize := aBufSize;
InitBuffers;
end;
procedure TBufferedTransportImpl.Close;
begin
- FTransport.Close;
+ InnerTransport.Close;
FInputBuffer := nil;
- FOutputBuffer := nil;
+ FOutputBuffer := nil;
end;
procedure TBufferedTransportImpl.Flush;
@@ -994,36 +1266,30 @@
function TBufferedTransportImpl.GetIsOpen: Boolean;
begin
- Result := FTransport.IsOpen;
-end;
-
-function TBufferedTransportImpl.GetUnderlyingTransport: ITransport;
-begin
- Result := FTransport;
+ Result := InnerTransport.IsOpen;
end;
procedure TBufferedTransportImpl.InitBuffers;
begin
- if FTransport.InputStream <> nil then begin
- FInputBuffer := TBufferedStreamImpl.Create( FTransport.InputStream, FBufSize );
+ if InnerTransport.InputStream <> nil then begin
+ FInputBuffer := TBufferedStreamImpl.Create( InnerTransport.InputStream, FBufSize );
end;
- if FTransport.OutputStream <> nil then begin
- FOutputBuffer := TBufferedStreamImpl.Create( FTransport.OutputStream, FBufSize );
+ if InnerTransport.OutputStream <> nil then begin
+ FOutputBuffer := TBufferedStreamImpl.Create( InnerTransport.OutputStream, FBufSize );
end;
end;
procedure TBufferedTransportImpl.Open;
begin
- FTransport.Open;
+ InnerTransport.Open;
InitBuffers; // we need to get the buffers to match FTransport substreams again
end;
function TBufferedTransportImpl.Read( const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer;
begin
- Result := 0;
- if FInputBuffer <> nil then begin
- Result := FInputBuffer.Read( pBuf,buflen, off, len );
- end;
+ if FInputBuffer <> nil
+ then Result := FInputBuffer.Read( pBuf,buflen, off, len )
+ else Result := 0;
end;
procedure TBufferedTransportImpl.Write( const pBuf : Pointer; off, len : Integer);
@@ -1033,65 +1299,69 @@
end;
end;
+procedure TBufferedTransportImpl.CheckReadBytesAvailable( const value : Int64);
+var buffered, need : Int64;
+begin
+ need := value;
+
+ // buffered bytes
+ buffered := FInputBuffer.Size - FInputBuffer.Position;
+ if buffered < need
+ then InnerTransport.CheckReadBytesAvailable( need - buffered);
+end;
+
+
+{ TBufferedTransportImpl.TFactory }
+
+function TBufferedTransportImpl.TFactory.GetTransport( const aTransport: ITransport): ITransport;
+begin
+ Result := TFramedTransportImpl.Create( aTransport);
+end;
+
+
{ TFramedTransportImpl }
-{$IFDEF HAVE_CLASS_CTOR}
-class constructor TFramedTransportImpl.Create;
+constructor TFramedTransportImpl.Create( const aTransport: ITransport);
begin
- SetLength( FHeader_Dummy, FHeaderSize);
- FillChar( FHeader_Dummy[0], Length( FHeader_Dummy) * SizeOf( Byte ), 0);
-end;
-{$ELSE}
-procedure TFramedTransportImpl_Initialize;
-begin
- SetLength( TFramedTransportImpl.FHeader_Dummy, TFramedTransportImpl.FHeaderSize);
- FillChar( TFramedTransportImpl.FHeader_Dummy[0],
- Length( TFramedTransportImpl.FHeader_Dummy) * SizeOf( Byte ), 0);
-end;
-{$ENDIF}
+ ASSERT( aTransport <> nil);
+ inherited Create( aTransport);
-constructor TFramedTransportImpl.Create;
-begin
- inherited Create;
InitWriteBuffer;
end;
-procedure TFramedTransportImpl.Close;
-begin
- FTransport.Close;
-end;
-
-constructor TFramedTransportImpl.Create( const ATrans: ITransport);
-begin
- inherited Create;
- InitWriteBuffer;
- FTransport := ATrans;
-end;
-
destructor TFramedTransportImpl.Destroy;
begin
FWriteBuffer.Free;
+ FWriteBuffer := nil;
FReadBuffer.Free;
+ FReadBuffer := nil;
inherited;
end;
+procedure TFramedTransportImpl.Close;
+begin
+ InnerTransport.Close;
+end;
+
procedure TFramedTransportImpl.Flush;
var
buf : TBytes;
len : Integer;
- data_len : Integer;
-
+ data_len : Int64;
begin
+ if not IsOpen
+ then raise TTransportExceptionNotOpen.Create('not open');
+
len := FWriteBuffer.Size;
SetLength( buf, len);
if len > 0 then begin
System.Move( FWriteBuffer.Memory^, buf[0], len );
end;
- data_len := len - FHeaderSize;
- if (data_len < 0) then begin
- raise TTransportExceptionUnknown.Create('TFramedTransport.Flush: data_len < 0' );
- end;
+ data_len := len - SizeOf(TFramedHeader);
+ if (0 > data_len) or (data_len > Configuration.MaxFrameSize)
+ then raise TTransportExceptionUnknown.Create('TFramedTransport.Flush: invalid frame size ('+IntToStr(data_len)+')')
+ else UpdateKnownMessageSize( len);
InitWriteBuffer;
@@ -1100,13 +1370,13 @@
buf[2] := Byte($FF and (data_len shr 8));
buf[3] := Byte($FF and data_len);
- FTransport.Write( buf, 0, len );
- FTransport.Flush;
+ InnerTransport.Write( buf, 0, len );
+ InnerTransport.Flush;
end;
function TFramedTransportImpl.GetIsOpen: Boolean;
begin
- Result := FTransport.IsOpen;
+ Result := InnerTransport.IsOpen;
end;
type
@@ -1114,16 +1384,17 @@
end;
procedure TFramedTransportImpl.InitWriteBuffer;
+const DUMMY_HEADER : TFramedHeader = 0;
begin
- FWriteBuffer.Free;
+ FreeAndNil( FWriteBuffer);
FWriteBuffer := TMemoryStream.Create;
TAccessMemoryStream(FWriteBuffer).Capacity := 1024;
- FWriteBuffer.Write( Pointer(@FHeader_Dummy[0])^, FHeaderSize);
+ FWriteBuffer.Write( DUMMY_HEADER, SizeOf(DUMMY_HEADER));
end;
procedure TFramedTransportImpl.Open;
begin
- FTransport.Open;
+ InnerTransport.Open;
end;
function TFramedTransportImpl.Read( const pBuf : Pointer; const buflen : Integer; off: Integer; len: Integer): Integer;
@@ -1137,9 +1408,7 @@
if (FReadBuffer <> nil) and (len > 0) then begin
result := FReadBuffer.Read( pTmp^, len);
- if result > 0 then begin
- Exit;
- end;
+ if result > 0 then Exit;
end;
ReadFrame;
@@ -1150,20 +1419,33 @@
procedure TFramedTransportImpl.ReadFrame;
var
- i32rd : TBytes;
+ i32rd : packed array[0..SizeOf(TFramedHeader)-1] of Byte;
size : Integer;
buff : TBytes;
begin
- SetLength( i32rd, FHeaderSize );
- FTransport.ReadAll( i32rd, 0, FHeaderSize);
+ InnerTransport.ReadAll( @i32rd[0], SizeOf(i32rd), 0, SizeOf(i32rd));
size :=
((i32rd[0] and $FF) shl 24) or
((i32rd[1] and $FF) shl 16) or
((i32rd[2] and $FF) shl 8) or
(i32rd[3] and $FF);
+
+ if size < 0 then begin
+ Close();
+ raise TTransportExceptionCorruptedData.Create('Read a negative frame size ('+IntToStr(size)+')');
+ end;
+
+ if Int64(size) > Int64(Configuration.MaxFrameSize) then begin
+ Close();
+ raise TTransportExceptionCorruptedData.Create('Frame size ('+IntToStr(size)+') larger than allowed maximum ('+IntToStr(Configuration.MaxFrameSize)+')');
+ end;
+
+ UpdateKnownMessageSize(size + SizeOf(size));
+
SetLength( buff, size );
- FTransport.ReadAll( buff, 0, size );
- FReadBuffer.Free;
+ InnerTransport.ReadAll( buff, 0, size );
+
+ FreeAndNil( FReadBuffer);
FReadBuffer := TMemoryStream.Create;
if Length(buff) > 0
then FReadBuffer.Write( Pointer(@buff[0])^, size );
@@ -1181,11 +1463,24 @@
end;
end;
+
+procedure TFramedTransportImpl.CheckReadBytesAvailable( const value : Int64);
+var buffered, need : Int64;
+begin
+ need := value;
+
+ // buffered bytes
+ buffered := FReadBuffer.Size - FReadBuffer.Position;
+ if buffered < need
+ then InnerTransport.CheckReadBytesAvailable( need - buffered);
+end;
+
+
{ TFramedTransport.TFactory }
-function TFramedTransportImpl.TFactory.GetTransport( const ATrans: ITransport): ITransport;
+function TFramedTransportImpl.TFactory.GetTransport( const aTransport: ITransport): ITransport;
begin
- Result := TFramedTransportImpl.Create( ATrans );
+ Result := TFramedTransportImpl.Create( aTransport);
end;
{ TTcpSocketStreamImpl }
@@ -1196,17 +1491,17 @@
end;
{$IFDEF OLD_SOCKETS}
-constructor TTcpSocketStreamImpl.Create( const ATcpClient: TCustomIpClient; const aTimeout : Integer);
+constructor TTcpSocketStreamImpl.Create( const aTcpClient: TCustomIpClient; const aTimeout : Integer);
begin
inherited Create;
- FTcpClient := ATcpClient;
+ FTcpClient := aTcpClient;
FTimeout := aTimeout;
end;
{$ELSE}
-constructor TTcpSocketStreamImpl.Create( const ATcpClient: TSocket; const aTimeout : Longword);
+constructor TTcpSocketStreamImpl.Create( const aTcpClient: TSocket; const aTimeout : Longword);
begin
inherited Create;
- FTcpClient := ATcpClient;
+ FTcpClient := aTcpClient;
if aTimeout = 0 then
FTcpClient.RecvTimeout := SLEEP_TIME
else
@@ -1217,9 +1512,10 @@
procedure TTcpSocketStreamImpl.Flush;
begin
-
+ // nothing to do
end;
+
function TTcpSocketStreamImpl.IsOpen: Boolean;
begin
{$IFDEF OLD_SOCKETS}
@@ -1304,7 +1600,7 @@
{$IFDEF LINUX}
result := Libc.select( socket + 1, ReadFdsptr, WriteFdsptr, ExceptFdsptr, Timeptr);
{$ENDIF}
-
+
if result = SOCKET_ERROR
then wsaError := WSAGetLastError;
@@ -1375,7 +1671,7 @@
result := 0;
pTmp := pBuf;
Inc( pTmp, offset);
- while count > 0 do begin
+ while (count > 0) and (result = 0) do begin
while TRUE do begin
wfd := WaitForData( msecs, pTmp, count, wsaError, nBytes);
@@ -1385,10 +1681,7 @@
TWaitForData.wfd_Timeout : begin
if (FTimeout = 0)
then Exit
- else begin
- raise TTransportExceptionTimedOut.Create(SysErrorMessage(Cardinal(wsaError)));
-
- end;
+ else raise TTransportExceptionTimedOut.Create(SysErrorMessage(Cardinal(wsaError)));
end;
else
ASSERT( FALSE);
@@ -1512,12 +1805,4 @@
{$ENDIF}
-{$IF CompilerVersion < 21.0}
-initialization
-begin
- TFramedTransportImpl_Initialize;
-end;
-{$IFEND}
-
-
end.
diff --git a/lib/delphi/src/Thrift.TypeRegistry.pas b/lib/delphi/src/Thrift.TypeRegistry.pas
index c18e97f..3d31c52 100644
--- a/lib/delphi/src/Thrift.TypeRegistry.pas
+++ b/lib/delphi/src/Thrift.TypeRegistry.pas
@@ -29,7 +29,7 @@
TFactoryMethod<T> = function:T;
TypeRegistry = class
- private
+ strict private
class var FTypeInfoToFactoryLookup : TDictionary<Pointer, Pointer>;
public
class constructor Create;
diff --git a/lib/delphi/src/Thrift.Utils.pas b/lib/delphi/src/Thrift.Utils.pas
index ede2656..bc9b460 100644
--- a/lib/delphi/src/Thrift.Utils.pas
+++ b/lib/delphi/src/Thrift.Utils.pas
@@ -97,7 +97,7 @@
THRIFT_MIMETYPE = 'application/x-thrift';
{$IFDEF Win64}
-function InterlockedExchangeAdd64( var Addend : Int64; Value : Int64) : Int64;
+function InterlockedExchangeAdd64( var Addend : Int64; Value : Int64) : Int64;
{$ENDIF}
@@ -289,8 +289,15 @@
var pType : PTypeInfo;
begin
pType := PTypeInfo(TypeInfo(T));
- if Assigned(pType) and (pType^.Kind = tkEnumeration)
- then result := GetEnumName(pType,value)
+ if Assigned(pType)
+ and (pType^.Kind = tkEnumeration)
+ {$IF CompilerVersion >= 23.0} // TODO: Range correct? What we know is that XE does not offer it, but Rio has it
+ and (pType^.TypeData^.MaxValue >= value)
+ and (pType^.TypeData^.MinValue <= value)
+ {$ELSE}
+ and FALSE // THRIFT-5048: pType^.TypeData^ member not supported -> prevent GetEnumName() from reading outside the legal range
+ {$IFEND}
+ then result := GetEnumName( PTypeInfo(pType), value)
else result := IntToStr(Ord(value));
end;
diff --git a/lib/delphi/src/Thrift.WinHTTP.pas b/lib/delphi/src/Thrift.WinHTTP.pas
index 854d7c0..d060066 100644
--- a/lib/delphi/src/Thrift.WinHTTP.pas
+++ b/lib/delphi/src/Thrift.WinHTTP.pas
@@ -205,6 +205,8 @@
// flags for WinHttpOpen():
WINHTTP_FLAG_ASYNC = $10000000; // want async session, requires WinHttpSetStatusCallback() usage
+ WINHTTP_IGNORE_REQUEST_TOTAL_LENGTH = 0;
+
// ports
INTERNET_DEFAULT_PORT = 0; // use the protocol-specific default (80 or 443)
@@ -218,8 +220,16 @@
WINHTTP_FLAG_ESCAPE_DISABLE_QUERY = $00000080; // if escaping enabled escape path part, but do not escape query
// flags for WinHttpSendRequest():
+ WINHTTP_NO_PROXY_NAME = nil;
+ WINHTTP_NO_PROXY_BYPASS = nil;
+ WINHTTP_NO_CLIENT_CERT_CONTEXT = nil;
+ WINHTTP_NO_REFERER = nil;
+ WINHTTP_DEFAULT_ACCEPT_TYPES = nil;
WINHTTP_NO_ADDITIONAL_HEADERS = nil;
WINHTTP_NO_REQUEST_DATA = nil;
+ WINHTTP_HEADER_NAME_BY_INDEX = nil;
+ WINHTTP_NO_OUTPUT_BUFFER = nil;
+ WINHTTP_NO_HEADER_INDEX = nil;
// WinHttpAddRequestHeaders() dwModifiers
WINHTTP_ADDREQ_INDEX_MASK = $0000FFFF;
@@ -247,8 +257,6 @@
INTERNET_SCHEME_HTTP = INTERNET_SCHEME(1);
INTERNET_SCHEME_HTTPS = INTERNET_SCHEME(2);
- WINHTTP_NO_CLIENT_CERT_CONTEXT = nil;
-
// options manifests for WinHttp{Query|Set}Option
WINHTTP_OPTION_CALLBACK = 1;
WINHTTP_OPTION_RESOLVE_TIMEOUT = 2;
@@ -384,6 +392,88 @@
SECURITY_FLAG_STRENGTH_MEDIUM = $40000000;
SECURITY_FLAG_STRENGTH_STRONG = $20000000;
+ // query flags
+ WINHTTP_QUERY_MIME_VERSION = 0;
+ WINHTTP_QUERY_CONTENT_TYPE = 1;
+ WINHTTP_QUERY_CONTENT_TRANSFER_ENCODING = 2;
+ WINHTTP_QUERY_CONTENT_ID = 3;
+ WINHTTP_QUERY_CONTENT_DESCRIPTION = 4;
+ WINHTTP_QUERY_CONTENT_LENGTH = 5;
+ WINHTTP_QUERY_CONTENT_LANGUAGE = 6;
+ WINHTTP_QUERY_ALLOW = 7;
+ WINHTTP_QUERY_PUBLIC = 8;
+ WINHTTP_QUERY_DATE = 9;
+ WINHTTP_QUERY_EXPIRES = 10;
+ WINHTTP_QUERY_LAST_MODIFIED = 11;
+ WINHTTP_QUERY_MESSAGE_ID = 12;
+ WINHTTP_QUERY_URI = 13;
+ WINHTTP_QUERY_DERIVED_FROM = 14;
+ WINHTTP_QUERY_COST = 15;
+ WINHTTP_QUERY_LINK = 16;
+ WINHTTP_QUERY_PRAGMA = 17;
+ WINHTTP_QUERY_VERSION = 18;
+ WINHTTP_QUERY_STATUS_CODE = 19;
+ WINHTTP_QUERY_STATUS_TEXT = 20;
+ WINHTTP_QUERY_RAW_HEADERS = 21;
+ WINHTTP_QUERY_RAW_HEADERS_CRLF = 22;
+ WINHTTP_QUERY_CONNECTION = 23;
+ WINHTTP_QUERY_ACCEPT = 24;
+ WINHTTP_QUERY_ACCEPT_CHARSET = 25;
+ WINHTTP_QUERY_ACCEPT_ENCODING = 26;
+ WINHTTP_QUERY_ACCEPT_LANGUAGE = 27;
+ WINHTTP_QUERY_AUTHORIZATION = 28;
+ WINHTTP_QUERY_CONTENT_ENCODING = 29;
+ WINHTTP_QUERY_FORWARDED = 30;
+ WINHTTP_QUERY_FROM = 31;
+ WINHTTP_QUERY_IF_MODIFIED_SINCE = 32;
+ WINHTTP_QUERY_LOCATION = 33;
+ WINHTTP_QUERY_ORIG_URI = 34;
+ WINHTTP_QUERY_REFERER = 35;
+ WINHTTP_QUERY_RETRY_AFTER = 36;
+ WINHTTP_QUERY_SERVER = 37;
+ WINHTTP_QUERY_TITLE = 38;
+ WINHTTP_QUERY_USER_AGENT = 39;
+ WINHTTP_QUERY_WWW_AUTHENTICATE = 40;
+ WINHTTP_QUERY_PROXY_AUTHENTICATE = 41;
+ WINHTTP_QUERY_ACCEPT_RANGES = 42;
+ WINHTTP_QUERY_SET_COOKIE = 43;
+ WINHTTP_QUERY_COOKIE = 44;
+ WINHTTP_QUERY_REQUEST_METHOD = 45;
+ WINHTTP_QUERY_REFRESH = 46;
+ WINHTTP_QUERY_CONTENT_DISPOSITION = 47;
+ WINHTTP_QUERY_AGE = 48;
+ WINHTTP_QUERY_CACHE_CONTROL = 49;
+ WINHTTP_QUERY_CONTENT_BASE = 50;
+ WINHTTP_QUERY_CONTENT_LOCATION = 51;
+ WINHTTP_QUERY_CONTENT_MD5 = 52;
+ WINHTTP_QUERY_CONTENT_RANGE = 53;
+ WINHTTP_QUERY_ETAG = 54;
+ WINHTTP_QUERY_HOST = 55;
+ WINHTTP_QUERY_IF_MATCH = 56;
+ WINHTTP_QUERY_IF_NONE_MATCH = 57;
+ WINHTTP_QUERY_IF_RANGE = 58;
+ WINHTTP_QUERY_IF_UNMODIFIED_SINCE = 59;
+ WINHTTP_QUERY_MAX_FORWARDS = 60;
+ WINHTTP_QUERY_PROXY_AUTHORIZATION = 61;
+ WINHTTP_QUERY_RANGE = 62;
+ WINHTTP_QUERY_TRANSFER_ENCODING = 63;
+ WINHTTP_QUERY_UPGRADE = 64;
+ WINHTTP_QUERY_VARY = 65;
+ WINHTTP_QUERY_VIA = 66;
+ WINHTTP_QUERY_WARNING = 67;
+ WINHTTP_QUERY_EXPECT = 68;
+ WINHTTP_QUERY_PROXY_CONNECTION = 69;
+ WINHTTP_QUERY_UNLESS_MODIFIED_SINCE = 70;
+ WINHTTP_QUERY_PROXY_SUPPORT = 75;
+ WINHTTP_QUERY_AUTHENTICATION_INFO = 76;
+ WINHTTP_QUERY_PASSPORT_URLS = 77;
+ WINHTTP_QUERY_PASSPORT_CONFIG = 78;
+ WINHTTP_QUERY_MAX = 78;
+ WINHTTP_QUERY_CUSTOM = 65535;
+ WINHTTP_QUERY_FLAG_REQUEST_HEADERS = $80000000;
+ WINHTTP_QUERY_FLAG_SYSTEMTIME = $40000000;
+ WINHTTP_QUERY_FLAG_NUMBER = $20000000;
+
// Secure connection error status flags
WINHTTP_CALLBACK_STATUS_FLAG_CERT_REV_FAILED = $00000001;
WINHTTP_CALLBACK_STATUS_FLAG_INVALID_CERT = $00000002;
@@ -486,7 +576,7 @@
IWinHTTPConnection = interface;
IWinHTTPRequest = interface
- ['{F65952F2-2F3B-47DC-B524-F1694E6D2AD7}']
+ ['{7A8E7255-5440-4621-A8A8-1E9FFAA6D6FA}']
function Handle : HINTERNET;
function Connection : IWinHTTPConnection;
function AddRequestHeader( const aHeader : string; const addflag : DWORD = WINHTTP_ADDREQ_FLAG_ADD) : Boolean;
@@ -498,6 +588,8 @@
function FlushAndReceiveResponse : Boolean;
function ReadData( const dwRead : DWORD) : TBytes; overload;
function ReadData( const pBuf : Pointer; const dwRead : DWORD) : DWORD; overload;
+ function QueryDataAvailable : DWORD;
+ function QueryTotalResponseSize : DWORD;
end;
IWinHTTPConnection = interface
@@ -616,6 +708,8 @@
function FlushAndReceiveResponse : Boolean;
function ReadData( const dwRead : DWORD) : TBytes; overload;
function ReadData( const pBuf : Pointer; const dwRead : DWORD) : DWORD; overload;
+ function QueryDataAvailable : DWORD;
+ function QueryTotalResponseSize : DWORD;
public
constructor Create( const aConnection : IWinHTTPConnection;
@@ -1111,6 +1205,32 @@
end;
+function TWinHTTPRequestImpl.QueryDataAvailable : DWORD;
+begin
+ if not WinHttpQueryDataAvailable( FHandle, result)
+ then result := 0;
+end;
+
+
+function TWinHTTPRequestImpl.QueryTotalResponseSize : DWORD;
+var dwBytes, dwError, dwIndex : DWORD;
+begin
+ dwBytes := SizeOf( result);
+ dwIndex := DWORD( WINHTTP_NO_HEADER_INDEX);
+ if not WinHttpQueryHeaders( FHandle,
+ WINHTTP_QUERY_CONTENT_LENGTH or WINHTTP_QUERY_FLAG_NUMBER,
+ WINHTTP_HEADER_NAME_BY_INDEX,
+ @result, dwBytes,
+ dwIndex)
+ then begin
+ dwError := GetLastError;
+ if dwError <> ERROR_WINHTTP_HEADER_NOT_FOUND then ASSERT(FALSE); // anything else would be an real error
+ result := MAXINT; // we don't know
+ end;
+end;
+
+
+
{ TWinHTTPUrlImpl }
constructor TWinHTTPUrlImpl.Create(const aUri: UnicodeString);
diff --git a/lib/delphi/src/Thrift.pas b/lib/delphi/src/Thrift.pas
index e127380..1926b11 100644
--- a/lib/delphi/src/Thrift.pas
+++ b/lib/delphi/src/Thrift.pas
@@ -23,6 +23,7 @@
uses
SysUtils,
+ Thrift.Utils,
Thrift.Exception,
Thrift.Protocol;
@@ -34,7 +35,7 @@
TApplicationExceptionSpecializedClass = class of TApplicationExceptionSpecialized;
- TApplicationException = class( TException)
+ TApplicationException = class( TException, IBase, ISupportsToString)
public
type
{$SCOPEDENUMS ON}
@@ -52,10 +53,18 @@
UnsupportedClientType
);
{$SCOPEDENUMS OFF}
- private
- function GetType: TExceptionType;
- protected
+ strict private
+ FExceptionType : TExceptionType;
+
+ strict protected
+ function QueryInterface(const IID: TGUID; out Obj): HResult; stdcall;
+ function _AddRef: Integer; stdcall;
+ function _Release: Integer; stdcall;
+
+ strict protected
constructor HiddenCreate(const Msg: string);
+ class function GetSpecializedExceptionType(AType: TExceptionType): TApplicationExceptionSpecializedClass;
+
public
// purposefully hide inherited constructor
class function Create(const Msg: string): TApplicationException; overload; deprecated 'Use specialized TApplicationException types (or regenerate from IDL)';
@@ -63,7 +72,10 @@
class function Create( AType: TExceptionType): TApplicationException; overload; deprecated 'Use specialized TApplicationException types (or regenerate from IDL)';
class function Create( AType: TExceptionType; const msg: string): TApplicationException; overload; deprecated 'Use specialized TApplicationException types (or regenerate from IDL)';
- class function GetSpecializedExceptionType(AType: TExceptionType): TApplicationExceptionSpecializedClass;
+ function Type_: TExceptionType; virtual;
+
+ procedure IBase_Read( const iprot: IProtocol);
+ procedure IBase.Read = IBase_Read;
class function Read( const iprot: IProtocol): TApplicationException;
procedure Write( const oprot: IProtocol );
@@ -71,42 +83,74 @@
// Needed to remove deprecation warning
TApplicationExceptionSpecialized = class abstract (TApplicationException)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; virtual; abstract;
public
constructor Create(const Msg: string);
+ function Type_: TApplicationException.TExceptionType; override;
end;
- TApplicationExceptionUnknown = class (TApplicationExceptionSpecialized);
- TApplicationExceptionUnknownMethod = class (TApplicationExceptionSpecialized);
- TApplicationExceptionInvalidMessageType = class (TApplicationExceptionSpecialized);
- TApplicationExceptionWrongMethodName = class (TApplicationExceptionSpecialized);
- TApplicationExceptionBadSequenceID = class (TApplicationExceptionSpecialized);
- TApplicationExceptionMissingResult = class (TApplicationExceptionSpecialized);
- TApplicationExceptionInternalError = class (TApplicationExceptionSpecialized);
- TApplicationExceptionProtocolError = class (TApplicationExceptionSpecialized);
- TApplicationExceptionInvalidTransform = class (TApplicationExceptionSpecialized);
- TApplicationExceptionInvalidProtocol = class (TApplicationExceptionSpecialized);
- TApplicationExceptionUnsupportedClientType = class (TApplicationExceptionSpecialized);
+ TApplicationExceptionUnknown = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
+ TApplicationExceptionUnknownMethod = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
+ TApplicationExceptionInvalidMessageType = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
+ TApplicationExceptionWrongMethodName = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
+ TApplicationExceptionBadSequenceID = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
+ TApplicationExceptionMissingResult = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
+ TApplicationExceptionInternalError = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
+ TApplicationExceptionProtocolError = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
+ TApplicationExceptionInvalidTransform = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
+ TApplicationExceptionInvalidProtocol = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
+ TApplicationExceptionUnsupportedClientType = class (TApplicationExceptionSpecialized)
+ strict protected
+ class function GetType: TApplicationException.TExceptionType; override;
+ end;
+
implementation
{ TApplicationException }
-function TApplicationException.GetType: TExceptionType;
-begin
- if Self is TApplicationExceptionUnknownMethod then Result := TExceptionType.UnknownMethod
- else if Self is TApplicationExceptionInvalidMessageType then Result := TExceptionType.InvalidMessageType
- else if Self is TApplicationExceptionWrongMethodName then Result := TExceptionType.WrongMethodName
- else if Self is TApplicationExceptionBadSequenceID then Result := TExceptionType.BadSequenceID
- else if Self is TApplicationExceptionMissingResult then Result := TExceptionType.MissingResult
- else if Self is TApplicationExceptionInternalError then Result := TExceptionType.InternalError
- else if Self is TApplicationExceptionProtocolError then Result := TExceptionType.ProtocolError
- else if Self is TApplicationExceptionInvalidTransform then Result := TExceptionType.InvalidTransform
- else if Self is TApplicationExceptionInvalidProtocol then Result := TExceptionType.InvalidProtocol
- else if Self is TApplicationExceptionUnsupportedClientType then Result := TExceptionType.UnsupportedClientType
- else Result := TExceptionType.Unknown;
-end;
-
constructor TApplicationException.HiddenCreate(const Msg: string);
begin
inherited Create(Msg);
@@ -134,6 +178,31 @@
Result := GetSpecializedExceptionType(AType).Create(msg);
end;
+
+function TApplicationException.QueryInterface(const IID: TGUID; out Obj): HResult;
+begin
+ if GetInterface(IID, Obj)
+ then result := S_OK
+ else result := E_NOINTERFACE;
+end;
+
+function TApplicationException._AddRef: Integer;
+begin
+ result := -1; // not refcounted
+end;
+
+function TApplicationException._Release: Integer;
+begin
+ result := -1; // not refcounted
+end;
+
+
+function TApplicationException.Type_: TExceptionType;
+begin
+ result := FExceptionType;
+end;
+
+
class function TApplicationException.GetSpecializedExceptionType(AType: TExceptionType): TApplicationExceptionSpecializedClass;
begin
case AType of
@@ -148,56 +217,66 @@
TExceptionType.InvalidProtocol: Result := TApplicationExceptionInvalidProtocol;
TExceptionType.UnsupportedClientType: Result := TApplicationExceptionUnsupportedClientType;
else
+ ASSERT( TExceptionType.Unknown = aType);
Result := TApplicationExceptionUnknown;
end;
end;
-class function TApplicationException.Read( const iprot: IProtocol): TApplicationException;
+
+procedure TApplicationException.IBase_Read( const iprot: IProtocol);
var
field : TThriftField;
- msg : string;
- typ : TExceptionType;
struc : TThriftStruct;
begin
- msg := '';
- typ := TExceptionType.Unknown;
struc := iprot.ReadStructBegin;
while ( True ) do
begin
field := iprot.ReadFieldBegin;
- if ( field.Type_ = TType.Stop) then
- begin
+ if ( field.Type_ = TType.Stop) then begin
Break;
end;
case field.Id of
1 : begin
- if ( field.Type_ = TType.String_) then
- begin
- msg := iprot.ReadString;
- end else
- begin
+ if ( field.Type_ = TType.String_) then begin
+ Exception(Self).Message := iprot.ReadString;
+ end else begin
TProtocolUtil.Skip( iprot, field.Type_ );
end;
end;
2 : begin
- if ( field.Type_ = TType.I32) then
- begin
- typ := TExceptionType( iprot.ReadI32 );
- end else
- begin
+ if ( field.Type_ = TType.I32) then begin
+ FExceptionType := TExceptionType( iprot.ReadI32 );
+ end else begin
TProtocolUtil.Skip( iprot, field.Type_ );
end;
- end else
- begin
+ end else begin
TProtocolUtil.Skip( iprot, field.Type_);
end;
end;
iprot.ReadFieldEnd;
end;
iprot.ReadStructEnd;
- Result := GetSpecializedExceptionType(typ).Create(msg);
+end;
+
+
+class function TApplicationException.Read( const iprot: IProtocol): TApplicationException;
+var instance : TApplicationException;
+ base : IBase;
+begin
+ instance := TApplicationException.CreateFmt('',[]);
+ try
+ if Supports( instance, IBase, base) then try
+ base.Read(iprot);
+ finally
+ base := nil; // clear ref before free
+ end;
+
+ result := GetSpecializedExceptionType(instance.Type_).Create( Exception(instance).Message);
+ finally
+ instance.Free;
+ end;
end;
procedure TApplicationException.Write( const oprot: IProtocol);
@@ -209,8 +288,7 @@
Init(field);
oprot.WriteStructBegin( struc );
- if Message <> '' then
- begin
+ if Message <> '' then begin
field.Name := 'message';
field.Type_ := TType.String_;
field.Id := 1;
@@ -223,7 +301,7 @@
field.Type_ := TType.I32;
field.Id := 2;
oprot.WriteFieldBegin(field);
- oprot.WriteI32(Integer(GetType));
+ oprot.WriteI32(Integer(Type_));
oprot.WriteFieldEnd();
oprot.WriteFieldStop();
oprot.WriteStructEnd();
@@ -236,4 +314,68 @@
inherited HiddenCreate(Msg);
end;
+function TApplicationExceptionSpecialized.Type_: TApplicationException.TExceptionType;
+begin
+ result := GetType;
+end;
+
+
+{ specialized TApplicationExceptions }
+
+class function TApplicationExceptionUnknownMethod.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.UnknownMethod;
+end;
+
+class function TApplicationExceptionInvalidMessageType.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.InvalidMessageType;
+end;
+
+class function TApplicationExceptionWrongMethodName.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.WrongMethodName;
+end;
+
+class function TApplicationExceptionBadSequenceID.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.BadSequenceID;
+end;
+
+class function TApplicationExceptionMissingResult.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.MissingResult;
+end;
+
+class function TApplicationExceptionInternalError.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.InternalError;
+end;
+
+class function TApplicationExceptionProtocolError.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.ProtocolError;
+end;
+
+class function TApplicationExceptionInvalidTransform.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.InvalidTransform;
+end;
+
+class function TApplicationExceptionInvalidProtocol.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.InvalidProtocol;
+end;
+
+class function TApplicationExceptionUnsupportedClientType.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.UnsupportedClientType;
+end;
+
+class function TApplicationExceptionUnknown.GetType : TApplicationException.TExceptionType;
+begin
+ result := TExceptionType.Unknown;
+end;
+
+
end.
diff --git a/lib/delphi/test/Performance/PerfTests.pas b/lib/delphi/test/Performance/PerfTests.pas
index 2c820b1..e485212 100644
--- a/lib/delphi/test/Performance/PerfTests.pas
+++ b/lib/delphi/test/Performance/PerfTests.pas
@@ -21,6 +21,7 @@
uses
Windows, Classes, SysUtils,
Thrift.Collections,
+ Thrift.Configuration,
Thrift.Test,
Thrift.Protocol,
Thrift.Protocol.JSON,
@@ -34,9 +35,10 @@
type
TPerformanceTests = class
strict private
- Testdata : ICrazyNesting;
- MemBuffer : TMemoryStream;
- Transport : ITransport;
+ FTestdata : ICrazyNesting;
+ FMemBuffer : TMemoryStream;
+ FTransport : ITransport;
+ FConfig : IThriftConfiguration;
procedure ProtocolPeformanceTest;
procedure RunTest( const ptyp : TKnownProtocol; const layered : TLayeredTransport);
@@ -74,7 +76,7 @@
var layered : TLayeredTransport;
begin
Console.WriteLine('Setting up for ProtocolPeformanceTest ...');
- Testdata := TestDataFactory.CreateCrazyNesting();
+ FTestdata := TestDataFactory.CreateCrazyNesting();
for layered := Low(TLayeredTransport) to High(TLayeredTransport) do begin
RunTest( TKnownProtocol.prot_Binary, layered);
@@ -91,10 +93,12 @@
begin
QueryPerformanceFrequency( freq);
+ FConfig := TThriftConfigurationImpl.Create;
+
proto := GenericProtocolFactory( ptyp, layered, TRUE);
QueryPerformanceCounter( start);
- Testdata.Write(proto);
- Transport.Flush;
+ FTestdata.Write(proto);
+ FTransport.Flush;
QueryPerformanceCounter( stop);
Console.WriteLine( Format('RunTest(%s): write = %d msec', [
GetProtocolTransportName(ptyp,layered),
@@ -121,24 +125,24 @@
begin
// read happens after write here, so let's take over the written bytes
newBuf := TMemoryStream.Create;
- if not forWrite then newBuf.CopyFrom( MemBuffer, COPY_ENTIRE_STREAM);
- MemBuffer := newBuf;
- MemBuffer.Position := 0;
+ if not forWrite then newBuf.CopyFrom( FMemBuffer, COPY_ENTIRE_STREAM);
+ FMemBuffer := newBuf;
+ FMemBuffer.Position := 0;
// layered transports anyone?
stream := TThriftStreamAdapterDelphi.Create( newBuf, TRUE);
if forWrite
- then trans := TStreamTransportImpl.Create( nil, stream)
- else trans := TStreamTransportImpl.Create( stream, nil);
+ then trans := TStreamTransportImpl.Create( nil, stream, FConfig)
+ else trans := TStreamTransportImpl.Create( stream, nil, FConfig);
case layered of
- trns_Framed : Transport := TFramedTransportImpl.Create( trans);
- trns_Buffered : Transport := TBufferedTransportImpl.Create( trans);
+ trns_Framed : FTransport := TFramedTransportImpl.Create( trans);
+ trns_Buffered : FTransport := TBufferedTransportImpl.Create( trans);
else
- Transport := trans;
+ FTransport := trans;
end;
- if not Transport.IsOpen
- then Transport.Open;
+ if not FTransport.IsOpen
+ then FTransport.Open;
case ptyp of
prot_Binary : result := TBinaryProtocolImpl.Create(trans);
diff --git a/lib/delphi/test/TestClient.pas b/lib/delphi/test/TestClient.pas
index e59c327..1579bd5 100644
--- a/lib/delphi/test/TestClient.pas
+++ b/lib/delphi/test/TestClient.pas
@@ -53,6 +53,8 @@
Thrift.Test,
Thrift.WinHTTP,
Thrift.Utils,
+
+ Thrift.Configuration,
Thrift.Collections;
type
@@ -93,7 +95,7 @@
Normal, // Fairly small array of usual size (256 bytes)
ByteArrayTest, // THRIFT-4454 Large writes/reads may cause range check errors in debug mode
PipeWriteLimit, // THRIFT-4372 Pipe write operations across a network are limited to 65,535 bytes per write.
- TwentyMB // that's quite a bit of data
+ FifteenMB // quite a bit of data, but still below the default max frame size
);
private
@@ -122,7 +124,7 @@
procedure InitializeProtocolTransportStack;
procedure ShutdownProtocolTransportStack;
- function InitializeHttpTransport( const aTimeoutSetting : Integer) : IHTTPClient;
+ function InitializeHttpTransport( const aTimeoutSetting : Integer; const aConfig : IThriftConfiguration = nil) : IHTTPClient;
procedure JSONProtocolReadWriteTest;
function PrepareBinaryData( aRandomDist : Boolean; aSize : TTestSize) : TBytes;
@@ -1024,7 +1026,7 @@
Normal : SetLength( result, $100);
ByteArrayTest : SetLength( result, SizeOf(TByteArray) + 128);
PipeWriteLimit : SetLength( result, 65535 + 128);
- TwentyMB : SetLength( result, 20 * 1024 * 1024);
+ FifteenMB : SetLength( result, 15 * 1024 * 1024);
else
raise EArgumentException.Create('aSize');
end;
@@ -1068,6 +1070,7 @@
var prot : IProtocol;
stm : TStringStream;
list : TThriftList;
+ config : IThriftConfiguration;
binary, binRead, emptyBinary : TBytes;
i,iErr : Integer;
const
@@ -1089,6 +1092,8 @@
try
StartTestGroup( 'JsonProtocolTest', test_Unknown);
+ config := TThriftConfigurationImpl.Create;
+
// prepare binary data
binary := PrepareBinaryData( FALSE, Normal);
SetLength( emptyBinary, 0); // empty binary data block
@@ -1096,7 +1101,7 @@
// output setup
prot := TJSONProtocolImpl.Create(
TStreamTransportImpl.Create(
- nil, TThriftStreamAdapterDelphi.Create( stm, FALSE)));
+ nil, TThriftStreamAdapterDelphi.Create( stm, FALSE), config));
// write
Init( list, TType.String_, 9);
@@ -1119,7 +1124,7 @@
stm.Position := 0;
prot := TJSONProtocolImpl.Create(
TStreamTransportImpl.Create(
- TThriftStreamAdapterDelphi.Create( stm, FALSE), nil));
+ TThriftStreamAdapterDelphi.Create( stm, FALSE), nil, config));
// read and compare
list := prot.ReadListBegin;
@@ -1161,7 +1166,7 @@
stm.Position := 0;
prot := TJSONProtocolImpl.Create(
TStreamTransportImpl.Create(
- TThriftStreamAdapterDelphi.Create( stm, FALSE), nil));
+ TThriftStreamAdapterDelphi.Create( stm, FALSE), nil, config));
Expect( prot.ReadString = SOLIDUS_EXCPECTED, 'Solidus encoding');
@@ -1172,12 +1177,12 @@
stm.Size := 0;
prot := TJSONProtocolImpl.Create(
TStreamTransportImpl.Create(
- nil, TThriftStreamAdapterDelphi.Create( stm, FALSE)));
+ nil, TThriftStreamAdapterDelphi.Create( stm, FALSE), config));
prot.WriteString( G_CLEF_AND_CYRILLIC_TEXT);
stm.Position := 0;
prot := TJSONProtocolImpl.Create(
TStreamTransportImpl.Create(
- TThriftStreamAdapterDelphi.Create( stm, FALSE), nil));
+ TThriftStreamAdapterDelphi.Create( stm, FALSE), nil, config));
Expect( prot.ReadString = G_CLEF_AND_CYRILLIC_TEXT, 'Writing JSON with chars > 8 bit');
// Widechars should work with hex-encoding too. Do they?
@@ -1187,7 +1192,7 @@
stm.Position := 0;
prot := TJSONProtocolImpl.Create(
TStreamTransportImpl.Create(
- TThriftStreamAdapterDelphi.Create( stm, FALSE), nil));
+ TThriftStreamAdapterDelphi.Create( stm, FALSE), nil, config));
Expect( prot.ReadString = G_CLEF_AND_CYRILLIC_TEXT, 'Reading JSON with chars > 8 bit');
@@ -1330,7 +1335,7 @@
end;
-function TClientThread.InitializeHttpTransport( const aTimeoutSetting : Integer) : IHTTPClient;
+function TClientThread.InitializeHttpTransport( const aTimeoutSetting : Integer; const aConfig : IThriftConfiguration) : IHTTPClient;
var sUrl : string;
comps : URL_COMPONENTS;
dwChars : DWORD;
@@ -1367,8 +1372,8 @@
Console.WriteLine('Target URL: '+sUrl);
case FSetup.endpoint of
- trns_MsxmlHttp : result := TMsxmlHTTPClientImpl.Create( sUrl);
- trns_WinHttp : result := TWinHTTPClientImpl.Create( sUrl);
+ trns_MsxmlHttp : result := TMsxmlHTTPClientImpl.Create( sUrl, aConfig);
+ trns_WinHttp : result := TWinHTTPClientImpl.Create( sUrl, aConfig);
else
raise Exception.Create(ENDPOINT_TRANSPORTS[FSetup.endpoint]+' unhandled case');
end;
@@ -1396,7 +1401,7 @@
case FSetup.endpoint of
trns_Sockets: begin
Console.WriteLine('Using sockets ('+FSetup.host+' port '+IntToStr(FSetup.port)+')');
- streamtrans := TSocketImpl.Create( FSetup.host, FSetup.port );
+ streamtrans := TSocketImpl.Create( FSetup.host, FSetup.port);
FTransport := streamtrans;
end;
@@ -1417,7 +1422,7 @@
end;
trns_AnonPipes: begin
- streamtrans := TAnonymousPipeTransportImpl.Create( FSetup.hAnonRead, FSetup.hAnonWrite, FALSE);
+ streamtrans := TAnonymousPipeTransportImpl.Create( FSetup.hAnonRead, FSetup.hAnonWrite, FALSE, PIPE_TIMEOUT);
FTransport := streamtrans;
end;
diff --git a/lib/delphi/test/TestServer.pas b/lib/delphi/test/TestServer.pas
index 2a80d52..bbc798b 100644
--- a/lib/delphi/test/TestServer.pas
+++ b/lib/delphi/test/TestServer.pas
@@ -36,6 +36,7 @@
Thrift.Protocol.JSON,
Thrift.Protocol.Compact,
Thrift.Collections,
+ Thrift.Configuration,
Thrift.Utils,
Thrift.Test,
Thrift,
@@ -157,13 +158,11 @@
procedure TTestServer.TTestHandlerImpl.testException(const arg: string);
begin
Console.WriteLine('testException(' + arg + ')');
- if ( arg = 'Xception') then
- begin
+ if ( arg = 'Xception') then begin
raise TXception.Create( 1001, arg);
end;
- if (arg = 'TException') then
- begin
+ if (arg = 'TException') then begin
raise TException.Create('TException');
end;
@@ -585,7 +584,7 @@
trns_Sockets : begin
Console.WriteLine('- sockets (port '+IntToStr(port)+')');
if (trns_Buffered in layered) then Console.WriteLine('- buffered');
- servertrans := TServerSocketImpl.Create( Port, 0, (trns_Buffered in layered));
+ servertrans := TServerSocketImpl.Create( Port, DEFAULT_THRIFT_TIMEOUT, (trns_Buffered in layered));
end;
trns_MsxmlHttp,
@@ -595,7 +594,7 @@
trns_NamedPipes : begin
Console.WriteLine('- named pipe ('+sPipeName+')');
- namedpipe := TNamedPipeServerTransportImpl.Create( sPipeName, 4096, PIPE_UNLIMITED_INSTANCES);
+ namedpipe := TNamedPipeServerTransportImpl.Create( sPipeName, 4096, PIPE_UNLIMITED_INSTANCES, INFINITE);
servertrans := namedpipe;
end;
@@ -616,7 +615,7 @@
if (trns_Framed in layered) then begin
Console.WriteLine('- framed transport');
- TransportFactory := TFramedTransportImpl.TFactory.Create
+ TransportFactory := TFramedTransportImpl.TFactory.Create;
end
else begin
TransportFactory := TTransportFactoryImpl.Create;
diff --git a/lib/delphi/test/client.dpr b/lib/delphi/test/client.dpr
index 83727f6..d4875b8 100644
--- a/lib/delphi/test/client.dpr
+++ b/lib/delphi/test/client.dpr
@@ -31,6 +31,7 @@
Thrift in '..\src\Thrift.pas',
Thrift.Transport in '..\src\Thrift.Transport.pas',
Thrift.Socket in '..\src\Thrift.Socket.pas',
+ Thrift.Configuration in '..\src\Thrift.Configuration.pas',
Thrift.Exception in '..\src\Thrift.Exception.pas',
Thrift.Transport.Pipes in '..\src\Thrift.Transport.Pipes.pas',
Thrift.Transport.WinHTTP in '..\src\Thrift.Transport.WinHTTP.pas',
diff --git a/lib/delphi/test/multiplexed/Multiplex.Client.Main.pas b/lib/delphi/test/multiplexed/Multiplex.Client.Main.pas
index 35fdf6f..4b6a0a2 100644
--- a/lib/delphi/test/multiplexed/Multiplex.Client.Main.pas
+++ b/lib/delphi/test/multiplexed/Multiplex.Client.Main.pas
@@ -35,6 +35,7 @@
Thrift.Transport,
Thrift.Stream,
Thrift.Collections,
+ Thrift.Configuration,
Benchmark, // in gen-delphi folder
Aggr, // in gen-delphi folder
Multiplex.Test.Common;
@@ -93,8 +94,10 @@
procedure TTestClient.Setup;
var trans : ITransport;
+ config : IThriftConfiguration;
begin
- trans := TSocketImpl.Create( 'localhost', 9090);
+ config := TThriftConfigurationImpl.Create;
+ trans := TSocketImpl.Create( 'localhost', 9090, DEFAULT_THRIFT_TIMEOUT, config);
trans := TFramedTransportImpl.Create( trans);
trans.Open;
FProtocol := TBinaryProtocolImpl.Create( trans, TRUE, TRUE);
diff --git a/lib/delphi/test/multiplexed/Multiplex.Server.Main.pas b/lib/delphi/test/multiplexed/Multiplex.Server.Main.pas
index 3860f5a..a23ff37 100644
--- a/lib/delphi/test/multiplexed/Multiplex.Server.Main.pas
+++ b/lib/delphi/test/multiplexed/Multiplex.Server.Main.pas
@@ -35,6 +35,7 @@
Thrift.Protocol.Multiplex,
Thrift.Processor.Multiplex,
Thrift.Collections,
+ Thrift.Configuration,
Thrift.Utils,
Thrift,
Benchmark, // in gen-delphi folder
@@ -156,11 +157,14 @@
aggrProcessor : IProcessor;
multiplex : IMultiplexedProcessor;
ServerEngine : IServer;
+ config : IThriftConfiguration;
begin
try
+ config := TThriftConfigurationImpl.Create;
+
// create protocol factory, default to BinaryProtocol
ProtocolFactory := TBinaryProtocolImpl.TFactory.Create( TRUE, TRUE);
- servertrans := TServerSocketImpl.Create( 9090, 0, FALSE);
+ servertrans := TServerSocketImpl.Create( 9090, DEFAULT_THRIFT_TIMEOUT, FALSE, config);
TransportFactory := TFramedTransportImpl.TFactory.Create;
benchHandler := TBenchmarkServiceImpl.Create;
diff --git a/lib/delphi/test/multiplexed/Multiplex.Test.Client.dpr b/lib/delphi/test/multiplexed/Multiplex.Test.Client.dpr
index a57e93a..19f8f6a 100644
--- a/lib/delphi/test/multiplexed/Multiplex.Test.Client.dpr
+++ b/lib/delphi/test/multiplexed/Multiplex.Test.Client.dpr
@@ -33,6 +33,7 @@
Thrift.Protocol in '..\..\src\Thrift.Protocol.pas',
Thrift.Protocol.Multiplex in '..\..\src\Thrift.Protocol.Multiplex.pas',
Thrift.Collections in '..\..\src\Thrift.Collections.pas',
+ Thrift.Configuration in '..\..\src\Thrift.Configuration.pas',
Thrift.Server in '..\..\src\Thrift.Server.pas',
Thrift.Stream in '..\..\src\Thrift.Stream.pas',
Thrift.TypeRegistry in '..\..\src\Thrift.TypeRegistry.pas',
diff --git a/lib/delphi/test/multiplexed/Multiplex.Test.Server.dpr b/lib/delphi/test/multiplexed/Multiplex.Test.Server.dpr
index 81ed3dd..307a9c2 100644
--- a/lib/delphi/test/multiplexed/Multiplex.Test.Server.dpr
+++ b/lib/delphi/test/multiplexed/Multiplex.Test.Server.dpr
@@ -33,6 +33,7 @@
Thrift.Protocol in '..\..\src\Thrift.Protocol.pas',
Thrift.Protocol.Multiplex in '..\..\src\Thrift.Protocol.Multiplex.pas',
Thrift.Processor.Multiplex in '..\..\src\Thrift.Processor.Multiplex.pas',
+ Thrift.Configuration in '..\..\src\Thrift.Configuration.pas',
Thrift.Collections in '..\..\src\Thrift.Collections.pas',
Thrift.Server in '..\..\src\Thrift.Server.pas',
Thrift.Utils in '..\..\src\Thrift.Utils.pas',
diff --git a/lib/delphi/test/serializer/TestSerializer.Tests.pas b/lib/delphi/test/serializer/TestSerializer.Tests.pas
new file mode 100644
index 0000000..83d67b1
--- /dev/null
+++ b/lib/delphi/test/serializer/TestSerializer.Tests.pas
@@ -0,0 +1,381 @@
+unit TestSerializer.Tests;
+(*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *)
+
+interface
+
+uses
+ Classes,
+ Windows,
+ SysUtils,
+ Generics.Collections,
+ Thrift,
+ Thrift.Exception,
+ Thrift.Socket,
+ Thrift.Transport,
+ Thrift.Protocol,
+ Thrift.Protocol.JSON,
+ Thrift.Protocol.Compact,
+ Thrift.Collections,
+ Thrift.Configuration,
+ Thrift.Server,
+ Thrift.Utils,
+ Thrift.Serializer,
+ Thrift.Stream,
+ Thrift.WinHTTP,
+ Thrift.TypeRegistry,
+ System_,
+ DebugProtoTest,
+ TestSerializer.Data;
+
+
+type
+ TFactoryPair = record
+ prot : IProtocolFactory;
+ trans : ITransportFactory;
+ end;
+
+ TTestSerializer = class //extends TestCase {
+ private type
+ TMethod = (
+ mt_Bytes,
+ mt_Stream
+ );
+
+ private
+ FProtocols : TList< TFactoryPair>;
+ procedure AddFactoryCombination( const aProto : IProtocolFactory; const aTrans : ITransportFactory);
+ class function UserFriendlyName( const factory : TFactoryPair) : string; overload;
+ class function UserFriendlyName( const method : TMethod) : string; overload;
+
+ class function Serialize(const input : IBase; const factory : TFactoryPair) : TBytes; overload;
+ class procedure Serialize(const input : IBase; const factory : TFactoryPair; const aStream : TStream); overload;
+
+ class procedure Deserialize( const input : TBytes; const target : IBase; const factory : TFactoryPair); overload;
+ class procedure Deserialize( const input : TStream; const target : IBase; const factory : TFactoryPair); overload;
+
+ class procedure ValidateReadToEnd( const input : TBytes; const serial : TDeserializer); overload;
+ class procedure ValidateReadToEnd( const input : TStream; const serial : TDeserializer); overload;
+
+ procedure Test_Serializer_Deserializer;
+ procedure Test_OneOfEach( const method : TMethod; const factory : TFactoryPair; const stream : TFileStream);
+ procedure Test_CompactStruct( const method : TMethod; const factory : TFactoryPair; const stream : TFileStream);
+
+ public
+ constructor Create;
+ destructor Destroy; override;
+
+ procedure RunTests;
+ end;
+
+
+implementation
+
+
+{ TTestSerializer }
+
+constructor TTestSerializer.Create;
+begin
+ inherited Create;
+ FProtocols := TList< TFactoryPair>.Create;
+
+ AddFactoryCombination( TBinaryProtocolImpl.TFactory.Create, nil);
+ AddFactoryCombination( TCompactProtocolImpl.TFactory.Create, nil);
+ AddFactoryCombination( TJSONProtocolImpl.TFactory.Create, nil);
+
+ AddFactoryCombination( TBinaryProtocolImpl.TFactory.Create, TFramedTransportImpl.TFactory.Create);
+ AddFactoryCombination( TCompactProtocolImpl.TFactory.Create, TFramedTransportImpl.TFactory.Create);
+ AddFactoryCombination( TJSONProtocolImpl.TFactory.Create, TFramedTransportImpl.TFactory.Create);
+
+ AddFactoryCombination( TBinaryProtocolImpl.TFactory.Create, TBufferedTransportImpl.TFactory.Create);
+ AddFactoryCombination( TCompactProtocolImpl.TFactory.Create, TBufferedTransportImpl.TFactory.Create);
+ AddFactoryCombination( TJSONProtocolImpl.TFactory.Create, TBufferedTransportImpl.TFactory.Create);
+end;
+
+
+destructor TTestSerializer.Destroy;
+begin
+ try
+ FreeAndNil( FProtocols);
+ finally
+ inherited Destroy;
+ end;
+end;
+
+
+procedure TTestSerializer.AddFactoryCombination( const aProto : IProtocolFactory; const aTrans : ITransportFactory);
+var rec : TFactoryPair;
+begin
+ rec.prot := aProto;
+ rec.trans := aTrans;
+ FProtocols.Add( rec);
+end;
+
+
+procedure TTestSerializer.Test_OneOfEach( const method : TMethod; const factory : TFactoryPair; const stream : TFileStream);
+var tested, correct : IOneOfEach;
+ bytes : TBytes;
+ i : Integer;
+begin
+ // write
+ tested := Fixtures.CreateOneOfEach;
+ case method of
+ mt_Bytes: bytes := Serialize( tested, factory);
+ mt_Stream: begin
+ stream.Size := 0;
+ Serialize( tested, factory, stream);
+ end
+ else
+ ASSERT( FALSE);
+ end;
+
+ // init + read
+ tested := TOneOfEachImpl.Create;
+ case method of
+ mt_Bytes: Deserialize( bytes, tested, factory);
+ mt_Stream: begin
+ stream.Position := 0;
+ Deserialize( stream, tested, factory);
+ end
+ else
+ ASSERT( FALSE);
+ end;
+
+ // check
+ correct := Fixtures.CreateOneOfEach;
+ ASSERT( tested.Im_true = correct.Im_true);
+ ASSERT( tested.Im_false = correct.Im_false);
+ ASSERT( tested.A_bite = correct.A_bite);
+ ASSERT( tested.Integer16 = correct.Integer16);
+ ASSERT( tested.Integer32 = correct.Integer32);
+ ASSERT( tested.Integer64 = correct.Integer64);
+ ASSERT( Abs( tested.Double_precision - correct.Double_precision) < 1E-12);
+ ASSERT( tested.Some_characters = correct.Some_characters);
+ ASSERT( tested.Zomg_unicode = correct.Zomg_unicode);
+ ASSERT( tested.What_who = correct.What_who);
+
+ ASSERT( Length(tested.Base64) = Length(correct.Base64));
+ ASSERT( CompareMem( @tested.Base64[0], @correct.Base64[0], Length(correct.Base64)));
+
+ ASSERT( tested.Byte_list.Count = correct.Byte_list.Count);
+ for i := 0 to tested.Byte_list.Count-1
+ do ASSERT( tested.Byte_list[i] = correct.Byte_list[i]);
+
+ ASSERT( tested.I16_list.Count = correct.I16_list.Count);
+ for i := 0 to tested.I16_list.Count-1
+ do ASSERT( tested.I16_list[i] = correct.I16_list[i]);
+
+ ASSERT( tested.I64_list.Count = correct.I64_list.Count);
+ for i := 0 to tested.I64_list.Count-1
+ do ASSERT( tested.I64_list[i] = correct.I64_list[i]);
+end;
+
+
+procedure TTestSerializer.Test_CompactStruct( const method : TMethod; const factory : TFactoryPair; const stream : TFileStream);
+var tested, correct : ICompactProtoTestStruct;
+ bytes : TBytes;
+begin
+ // write
+ tested := Fixtures.CreateCompactProtoTestStruct;
+ case method of
+ mt_Bytes: bytes := Serialize( tested, factory);
+ mt_Stream: begin
+ stream.Size := 0;
+ Serialize( tested, factory, stream);
+ end
+ else
+ ASSERT( FALSE);
+ end;
+
+ // init + read
+ correct := TCompactProtoTestStructImpl.Create;
+ case method of
+ mt_Bytes: Deserialize( bytes, tested, factory);
+ mt_Stream: begin
+ stream.Position := 0;
+ Deserialize( stream, tested, factory);
+ end
+ else
+ ASSERT( FALSE);
+ end;
+
+ // check
+ correct := Fixtures.CreateCompactProtoTestStruct;
+ ASSERT( correct.Field500 = tested.Field500);
+ ASSERT( correct.Field5000 = tested.Field5000);
+ ASSERT( correct.Field20000 = tested.Field20000);
+end;
+
+
+procedure TTestSerializer.Test_Serializer_Deserializer;
+var factory : TFactoryPair;
+ stream : TFileStream;
+ method : TMethod;
+begin
+ stream := TFileStream.Create( 'TestSerializer.dat', fmCreate);
+ try
+ for method in [Low(TMethod)..High(TMethod)] do begin
+ Writeln( UserFriendlyName(method));
+
+ for factory in FProtocols do begin
+ Writeln('- '+UserFriendlyName(factory));
+
+ Test_OneOfEach( method, factory, stream);
+ Test_CompactStruct( method, factory, stream);
+ end;
+
+ Writeln;
+ end;
+
+ finally
+ stream.Free;
+ end;
+end;
+
+
+class function TTestSerializer.UserFriendlyName( const factory : TFactoryPair) : string;
+begin
+ result := Copy( (factory.prot as TObject).ClassName, 2, MAXINT);
+
+ if factory.trans <> nil
+ then result := Copy( (factory.trans as TObject).ClassName, 2, MAXINT) +' '+ result;
+
+ result := StringReplace( result, 'Impl', '', [rfReplaceAll]);
+ result := StringReplace( result, 'Transport.TFactory', '', [rfReplaceAll]);
+ result := StringReplace( result, 'Protocol.TFactory', '', [rfReplaceAll]);
+end;
+
+
+class function TTestSerializer.UserFriendlyName( const method : TMethod) : string;
+begin
+ result := EnumUtils<TMethod>.ToString(Ord(method));
+ result := StringReplace( result, 'mt_', '', [rfReplaceAll]);
+end;
+
+
+procedure TTestSerializer.RunTests;
+begin
+ try
+ Test_Serializer_Deserializer;
+ except
+ on e:Exception do begin
+ Writeln( e.ClassName+': '+ e.Message);
+ Write('Hit ENTER to close ... '); Readln;
+ end;
+ end;
+end;
+
+
+class function TTestSerializer.Serialize(const input : IBase; const factory : TFactoryPair) : TBytes;
+var serial : TSerializer;
+ config : IThriftConfiguration;
+begin
+ config := TThriftConfigurationImpl.Create;
+ config.MaxMessageSize := 0; // we don't read anything here
+
+ serial := TSerializer.Create( factory.prot, factory.trans, config);
+ try
+ result := serial.Serialize( input);
+ finally
+ serial.Free;
+ end;
+end;
+
+
+class procedure TTestSerializer.Serialize(const input : IBase; const factory : TFactoryPair; const aStream : TStream);
+var serial : TSerializer;
+ config : IThriftConfiguration;
+begin
+ config := TThriftConfigurationImpl.Create;
+ config.MaxMessageSize := 0; // we don't read anything here
+
+ serial := TSerializer.Create( factory.prot, factory.trans, config);
+ try
+ serial.Serialize( input, aStream);
+ finally
+ serial.Free;
+ end;
+end;
+
+
+class procedure TTestSerializer.Deserialize( const input : TBytes; const target : IBase; const factory : TFactoryPair);
+var serial : TDeserializer;
+ config : IThriftConfiguration;
+begin
+ config := TThriftConfigurationImpl.Create;
+ config.MaxMessageSize := Length(input);
+
+ serial := TDeserializer.Create( factory.prot, factory.trans, config);
+ try
+ serial.Deserialize( input, target);
+ ValidateReadToEnd( input, serial);
+ finally
+ serial.Free;
+ end;
+end;
+
+
+class procedure TTestSerializer.Deserialize( const input : TStream; const target : IBase; const factory : TFactoryPair);
+var serial : TDeserializer;
+ config : IThriftConfiguration;
+begin
+ config := TThriftConfigurationImpl.Create;
+ config.MaxMessageSize := input.Size;
+
+ serial := TDeserializer.Create( factory.prot, factory.trans, config);
+ try
+ serial.Deserialize( input, target);
+ ValidateReadToEnd( input, serial);
+ finally
+ serial.Free;
+ end;
+end;
+
+
+class procedure TTestSerializer.ValidateReadToEnd( const input : TBytes; const serial : TDeserializer);
+// we should not have any more byte to read
+var dummy : IBase;
+begin
+ try
+ dummy := TOneOfEachImpl.Create;
+ serial.Deserialize( input, dummy);
+ raise EInOutError.Create('Expected exception not thrown?');
+ except
+ on e:TTransportExceptionEndOfFile do {expected};
+ on e:Exception do raise; // unexpected
+ end;
+end;
+
+
+class procedure TTestSerializer.ValidateReadToEnd( const input : TStream; const serial : TDeserializer);
+// we should not have any more byte to read
+var dummy : IBase;
+begin
+ try
+ input.Position := 0;
+ dummy := TOneOfEachImpl.Create;
+ serial.Deserialize( input, dummy);
+ raise EInOutError.Create('Expected exception not thrown?');
+ except
+ on e:TTransportExceptionEndOfFile do {expected};
+ on e:Exception do raise; // unexpected
+ end;
+end;
+
+end.
diff --git a/lib/delphi/test/serializer/TestSerializer.dpr b/lib/delphi/test/serializer/TestSerializer.dpr
index 56d0d15..0620014 100644
--- a/lib/delphi/test/serializer/TestSerializer.dpr
+++ b/lib/delphi/test/serializer/TestSerializer.dpr
@@ -22,7 +22,10 @@
{$APPTYPE CONSOLE}
uses
- Classes, Windows, SysUtils, Generics.Collections,
+ Classes,
+ Windows,
+ SysUtils,
+ Generics.Collections,
Thrift in '..\..\src\Thrift.pas',
Thrift.Exception in '..\..\src\Thrift.Exception.pas',
Thrift.Socket in '..\..\src\Thrift.Socket.pas',
@@ -31,6 +34,7 @@
Thrift.Protocol.JSON in '..\..\src\Thrift.Protocol.JSON.pas',
Thrift.Protocol.Compact in '..\..\src\Thrift.Protocol.Compact.pas',
Thrift.Collections in '..\..\src\Thrift.Collections.pas',
+ Thrift.Configuration in '..\..\src\Thrift.Configuration.pas',
Thrift.Server in '..\..\src\Thrift.Server.pas',
Thrift.Utils in '..\..\src\Thrift.Utils.pas',
Thrift.Serializer in '..\..\src\Thrift.Serializer.pas',
@@ -39,236 +43,8 @@
Thrift.TypeRegistry in '..\..\src\Thrift.TypeRegistry.pas',
System_,
DebugProtoTest,
- TestSerializer.Data;
-
-
-
-type
- TTestSerializer = class //extends TestCase {
- private type
- TMethod = (
- mt_Bytes,
- mt_Stream
- );
-
- private
- FProtocols : TList< IProtocolFactory>;
-
- class function Serialize(const input : IBase; const factory : IProtocolFactory) : TBytes; overload;
- class procedure Serialize(const input : IBase; const factory : IProtocolFactory; const aStream : TStream); overload;
- class procedure Deserialize( const input : TBytes; const target : IBase; const factory : IProtocolFactory); overload;
- class procedure Deserialize( const input : TStream; const target : IBase; const factory : IProtocolFactory); overload;
-
- procedure Test_Serializer_Deserializer;
- procedure Test_OneOfEach( const method : TMethod; const factory : IProtocolFactory; const stream : TFileStream);
- procedure Test_CompactStruct( const method : TMethod; const factory : IProtocolFactory; const stream : TFileStream);
-
- public
- constructor Create;
- destructor Destroy; override;
-
- procedure RunTests;
- end;
-
-
-
-{ TTestSerializer }
-
-constructor TTestSerializer.Create;
-begin
- inherited Create;
- FProtocols := TList< IProtocolFactory>.Create;
- FProtocols.Add( TBinaryProtocolImpl.TFactory.Create);
- FProtocols.Add( TCompactProtocolImpl.TFactory.Create);
- FProtocols.Add( TJSONProtocolImpl.TFactory.Create);
-end;
-
-
-destructor TTestSerializer.Destroy;
-begin
- try
- FreeAndNil( FProtocols);
- finally
- inherited Destroy;
- end;
-end;
-
-
-procedure TTestSerializer.Test_OneOfEach( const method : TMethod; const factory : IProtocolFactory; const stream : TFileStream);
-var tested, correct : IOneOfEach;
- bytes : TBytes;
- i : Integer;
-begin
- // write
- tested := Fixtures.CreateOneOfEach;
- case method of
- mt_Bytes: bytes := Serialize( tested, factory);
- mt_Stream: begin
- stream.Size := 0;
- Serialize( tested, factory, stream);
- end
- else
- ASSERT( FALSE);
- end;
-
- // init + read
- tested := TOneOfEachImpl.Create;
- case method of
- mt_Bytes: Deserialize( bytes, tested, factory);
- mt_Stream: begin
- stream.Position := 0;
- Deserialize( stream, tested, factory);
- end
- else
- ASSERT( FALSE);
- end;
-
- // check
- correct := Fixtures.CreateOneOfEach;
- ASSERT( tested.Im_true = correct.Im_true);
- ASSERT( tested.Im_false = correct.Im_false);
- ASSERT( tested.A_bite = correct.A_bite);
- ASSERT( tested.Integer16 = correct.Integer16);
- ASSERT( tested.Integer32 = correct.Integer32);
- ASSERT( tested.Integer64 = correct.Integer64);
- ASSERT( Abs( tested.Double_precision - correct.Double_precision) < 1E-12);
- ASSERT( tested.Some_characters = correct.Some_characters);
- ASSERT( tested.Zomg_unicode = correct.Zomg_unicode);
- ASSERT( tested.What_who = correct.What_who);
-
- ASSERT( Length(tested.Base64) = Length(correct.Base64));
- ASSERT( CompareMem( @tested.Base64[0], @correct.Base64[0], Length(correct.Base64)));
-
- ASSERT( tested.Byte_list.Count = correct.Byte_list.Count);
- for i := 0 to tested.Byte_list.Count-1
- do ASSERT( tested.Byte_list[i] = correct.Byte_list[i]);
-
- ASSERT( tested.I16_list.Count = correct.I16_list.Count);
- for i := 0 to tested.I16_list.Count-1
- do ASSERT( tested.I16_list[i] = correct.I16_list[i]);
-
- ASSERT( tested.I64_list.Count = correct.I64_list.Count);
- for i := 0 to tested.I64_list.Count-1
- do ASSERT( tested.I64_list[i] = correct.I64_list[i]);
-end;
-
-
-procedure TTestSerializer.Test_CompactStruct( const method : TMethod; const factory : IProtocolFactory; const stream : TFileStream);
-var tested, correct : ICompactProtoTestStruct;
- bytes : TBytes;
-begin
- // write
- tested := Fixtures.CreateCompactProtoTestStruct;
- case method of
- mt_Bytes: bytes := Serialize( tested, factory);
- mt_Stream: begin
- stream.Size := 0;
- Serialize( tested, factory, stream);
- end
- else
- ASSERT( FALSE);
- end;
-
- // init + read
- correct := TCompactProtoTestStructImpl.Create;
- case method of
- mt_Bytes: Deserialize( bytes, tested, factory);
- mt_Stream: begin
- stream.Position := 0;
- Deserialize( stream, tested, factory);
- end
- else
- ASSERT( FALSE);
- end;
-
- // check
- correct := Fixtures.CreateCompactProtoTestStruct;
- ASSERT( correct.Field500 = tested.Field500);
- ASSERT( correct.Field5000 = tested.Field5000);
- ASSERT( correct.Field20000 = tested.Field20000);
-end;
-
-
-procedure TTestSerializer.Test_Serializer_Deserializer;
-var factory : IProtocolFactory;
- stream : TFileStream;
- method : TMethod;
-begin
- stream := TFileStream.Create( 'TestSerializer.dat', fmCreate);
- try
-
- for method in [Low(TMethod)..High(TMethod)] do begin
- for factory in FProtocols do begin
-
- Test_OneOfEach( method, factory, stream);
- Test_CompactStruct( method, factory, stream);
- end;
- end;
-
- finally
- stream.Free;
- end;
-end;
-
-
-procedure TTestSerializer.RunTests;
-begin
- try
- Test_Serializer_Deserializer;
- except
- on e:Exception do begin
- Writeln( e.Message);
- Write('Hit ENTER to close ... '); Readln;
- end;
- end;
-end;
-
-
-class function TTestSerializer.Serialize(const input : IBase; const factory : IProtocolFactory) : TBytes;
-var serial : TSerializer;
-begin
- serial := TSerializer.Create( factory);
- try
- result := serial.Serialize( input);
- finally
- serial.Free;
- end;
-end;
-
-
-class procedure TTestSerializer.Serialize(const input : IBase; const factory : IProtocolFactory; const aStream : TStream);
-var serial : TSerializer;
-begin
- serial := TSerializer.Create( factory);
- try
- serial.Serialize( input, aStream);
- finally
- serial.Free;
- end;
-end;
-
-
-class procedure TTestSerializer.Deserialize( const input : TBytes; const target : IBase; const factory : IProtocolFactory);
-var serial : TDeserializer;
-begin
- serial := TDeserializer.Create( factory);
- try
- serial.Deserialize( input, target);
- finally
- serial.Free;
- end;
-end;
-
-class procedure TTestSerializer.Deserialize( const input : TStream; const target : IBase; const factory : IProtocolFactory);
-var serial : TDeserializer;
-begin
- serial := TDeserializer.Create( factory);
- try
- serial.Deserialize( input, target);
- finally
- serial.Free;
- end;
-end;
+ TestSerializer.Tests in 'TestSerializer.Tests.pas',
+ TestSerializer.Data in 'TestSerializer.Data.pas';
var test : TTestSerializer;
diff --git a/lib/delphi/test/server.dpr b/lib/delphi/test/server.dpr
index 9731dd4..954d0b6 100644
--- a/lib/delphi/test/server.dpr
+++ b/lib/delphi/test/server.dpr
@@ -37,6 +37,7 @@
Thrift.Protocol.Multiplex in '..\src\Thrift.Protocol.Multiplex.pas',
Thrift.Processor.Multiplex in '..\src\Thrift.Processor.Multiplex.pas',
Thrift.Collections in '..\src\Thrift.Collections.pas',
+ Thrift.Configuration in '..\src\Thrift.Configuration.pas',
Thrift.Server in '..\src\Thrift.Server.pas',
Thrift.TypeRegistry in '..\src\Thrift.TypeRegistry.pas',
Thrift.Utils in '..\src\Thrift.Utils.pas',
diff --git a/lib/delphi/test/skip/idl/skiptest_version_1.thrift b/lib/delphi/test/skip/idl/skiptest_version_1.thrift
index 8353c5e..4221177 100644
--- a/lib/delphi/test/skip/idl/skiptest_version_1.thrift
+++ b/lib/delphi/test/skip/idl/skiptest_version_1.thrift
@@ -24,12 +24,14 @@
const i32 SKIPTESTSERVICE_VERSION = 1
-struct Pong {
- 1 : optional i32 version1
+enum PingPongEnum {
+ PingOne = 0,
+ PongOne = 1,
}
struct Ping {
1 : optional i32 version1
+ 100 : PingPongEnum EnumTest
}
exception PongFailed {
@@ -38,7 +40,7 @@
service SkipTestService {
- void PingPong( 1: Ping pong) throws (444: PongFailed pof);
+ Ping PingPong( 1: Ping ping) throws (444: PongFailed pof);
}
diff --git a/lib/delphi/test/skip/idl/skiptest_version_2.thrift b/lib/delphi/test/skip/idl/skiptest_version_2.thrift
index f3352d3..3ea69f7 100644
--- a/lib/delphi/test/skip/idl/skiptest_version_2.thrift
+++ b/lib/delphi/test/skip/idl/skiptest_version_2.thrift
@@ -24,9 +24,17 @@
const i32 SKIPTESTSERVICE_VERSION = 2
+enum PingPongEnum {
+ PingOne = 0,
+ PongOne = 1,
+ PingTwo = 2,
+ PongTwo = 3,
+}
+
struct Pong {
1 : optional i32 version1
2 : optional i16 version2
+ 100 : PingPongEnum EnumTest
}
struct Ping {
@@ -40,6 +48,7 @@
16 : optional string strVal
17 : optional Pong structVal
18 : optional map< list< Pong>, set< string>> mapVal
+ 100 : PingPongEnum EnumTest
}
exception PingFailed {
diff --git a/lib/delphi/test/skip/skiptest_version1.dpr b/lib/delphi/test/skip/skiptest_version1.dpr
index 0bfe96f..f7cde2f 100644
--- a/lib/delphi/test/skip/skiptest_version1.dpr
+++ b/lib/delphi/test/skip/skiptest_version1.dpr
@@ -30,7 +30,9 @@
Thrift.Transport in '..\..\src\Thrift.Transport.pas',
Thrift.Protocol in '..\..\src\Thrift.Protocol.pas',
Thrift.Protocol.JSON in '..\..\src\Thrift.Protocol.JSON.pas',
+ Thrift.Protocol.Compact in '..\..\src\Thrift.Protocol.Compact.pas',
Thrift.Collections in '..\..\src\Thrift.Collections.pas',
+ Thrift.Configuration in '..\..\src\Thrift.Configuration.pas',
Thrift.Server in '..\..\src\Thrift.Server.pas',
Thrift.Utils in '..\..\src\Thrift.Utils.pas',
Thrift.WinHTTP in '..\..\src\Thrift.WinHTTP.pas',
@@ -46,6 +48,7 @@
begin
result := TPingImpl.Create;
result.Version1 := Tskiptest_version_1Constants.SKIPTESTSERVICE_VERSION;
+ result.EnumTest := TPingPongEnum.PingOne;
end;
@@ -53,14 +56,16 @@
TDummyServer = class( TInterfacedObject, TSkipTestService.Iface)
protected
// TSkipTestService.Iface
- procedure PingPong(const ping: IPing);
+ function PingPong(const ping: IPing): IPing;
end;
-procedure TDummyServer.PingPong(const ping: IPing);
+function TDummyServer.PingPong(const ping: IPing): IPing;
// TSkipTestService.Iface
begin
Writeln('- performing request from version '+IntToStr(ping.Version1)+' client');
+ Writeln( ping.ToString);
+ result := CreatePing;
end;
@@ -70,8 +75,8 @@
begin
adapt := TThriftStreamAdapterDelphi.Create( stm, FALSE);
if aForInput
- then trans := TStreamTransportImpl.Create( adapt, nil)
- else trans := TStreamTransportImpl.Create( nil, adapt);
+ then trans := TStreamTransportImpl.Create( adapt, nil, TThriftConfigurationImpl.Create)
+ else trans := TStreamTransportImpl.Create( nil, adapt, TThriftConfigurationImpl.Create);
result := protfact.GetProtocol( trans);
end;
@@ -108,6 +113,7 @@
procedure ReadResponse( protfact : IProtocolFactory; fname : string);
var stm : TFileStream;
+ ping : IPing;
proto : IProtocol;
client : TSkipTestService.TClient; // we need access to send/recv_pingpong()
cliRef : IUnknown; // holds the refcount
@@ -115,11 +121,11 @@
Writeln('- reading response');
stm := TFileStream.Create( fname+RESPONSE_EXT, fmOpenRead);
try
- // save request data
+ // load request data
proto := CreateProtocol( protfact, stm, TRUE);
client := TSkipTestService.TClient.Create( proto, nil);
cliRef := client as IUnknown;
- client.recv_PingPong;
+ ping := client.recv_PingPong;
finally
client := nil; // not Free!
@@ -163,12 +169,14 @@
procedure Test( protfact : IProtocolFactory; fname : string);
begin
// try to read an existing request
+ Writeln('Reading data file '+fname);
if FileExists( fname + REQUEST_EXT) then begin
ProcessFile( protfact, fname);
ReadResponse( protfact, fname);
end;
// create a new request and try to process
+ Writeln('Writing data file '+fname);
CreateRequest( protfact, fname);
ProcessFile( protfact, fname);
ReadResponse( protfact, fname);
@@ -176,8 +184,9 @@
const
- FILE_BINARY = 'pingpong.bin';
- FILE_JSON = 'pingpong.json';
+ FILE_BINARY = 'pingpong.bin';
+ FILE_JSON = 'pingpong.json';
+ FILE_COMPACT = 'pingpong.compact';
begin
try
Writeln( 'Delphi SkipTest '+IntToStr(Tskiptest_version_1Constants.SKIPTESTSERVICE_VERSION)+' using '+Thrift.Version);
@@ -191,6 +200,10 @@
Test( TJSONProtocolImpl.TFactory.Create, FILE_JSON);
Writeln;
+ Writeln('Compact protocol');
+ Test( TCompactProtocolImpl.TFactory.Create, FILE_COMPACT);
+
+ Writeln;
Writeln('Test completed without errors.');
Writeln;
Write('Press ENTER to close ...'); Readln;
diff --git a/lib/delphi/test/skip/skiptest_version2.dpr b/lib/delphi/test/skip/skiptest_version2.dpr
index 7893748..478ea7c 100644
--- a/lib/delphi/test/skip/skiptest_version2.dpr
+++ b/lib/delphi/test/skip/skiptest_version2.dpr
@@ -30,7 +30,9 @@
Thrift.Transport in '..\..\src\Thrift.Transport.pas',
Thrift.Protocol in '..\..\src\Thrift.Protocol.pas',
Thrift.Protocol.JSON in '..\..\src\Thrift.Protocol.JSON.pas',
+ Thrift.Protocol.Compact in '..\..\src\Thrift.Protocol.Compact.pas',
Thrift.Collections in '..\..\src\Thrift.Collections.pas',
+ Thrift.Configuration in '..\..\src\Thrift.Configuration.pas',
Thrift.Server in '..\..\src\Thrift.Server.pas',
Thrift.Utils in '..\..\src\Thrift.Utils.pas',
Thrift.WinHTTP in '..\..\src\Thrift.WinHTTP.pas',
@@ -41,12 +43,15 @@
REQUEST_EXT = '.request';
RESPONSE_EXT = '.response';
+
function CreatePing : IPing;
var list : IThriftList<IPong>;
set_ : IHashSet<string>;
begin
result := TPingImpl.Create;
result.Version1 := Tskiptest_version_2Constants.SKIPTESTSERVICE_VERSION;
+ result.EnumTest := TPingPongEnum.PingTwo;
+
result.BoolVal := TRUE;
result.ByteVal := 2;
result.DbVal := 3;
@@ -58,6 +63,7 @@
result.StructVal := TPongImpl.Create;
result.StructVal.Version1 := -1;
result.StructVal.Version2 := -2;
+ result.StructVal.EnumTest := TPingPongEnum.PongTwo;
list := TThriftListImpl<IPong>.Create;
list.Add( result.StructVal);
@@ -86,6 +92,7 @@
// TSkipTestService.Iface
begin
Writeln('- performing request from version '+IntToStr(ping.Version1)+' client');
+ Writeln( ping.ToString);
result := CreatePing;
end;
@@ -96,8 +103,8 @@
begin
adapt := TThriftStreamAdapterDelphi.Create( stm, FALSE);
if aForInput
- then trans := TStreamTransportImpl.Create( adapt, nil)
- else trans := TStreamTransportImpl.Create( nil, adapt);
+ then trans := TStreamTransportImpl.Create( adapt, nil, TThriftConfigurationImpl.Create)
+ else trans := TStreamTransportImpl.Create( nil, adapt, TThriftConfigurationImpl.Create);
result := protfact.GetProtocol( trans);
end;
@@ -142,7 +149,7 @@
Writeln('- reading response');
stm := TFileStream.Create( fname+RESPONSE_EXT, fmOpenRead);
try
- // save request data
+ // load request data
proto := CreateProtocol( protfact, stm, TRUE);
client := TSkipTestService.TClient.Create( proto, nil);
cliRef := client as IUnknown;
@@ -190,12 +197,16 @@
procedure Test( protfact : IProtocolFactory; fname : string);
begin
// try to read an existing request
+ Writeln;
+ Writeln('Reading data file '+fname);
if FileExists( fname + REQUEST_EXT) then begin
ProcessFile( protfact, fname);
ReadResponse( protfact, fname);
end;
// create a new request and try to process
+ Writeln;
+ Writeln('Writing data file '+fname);
CreateRequest( protfact, fname);
ProcessFile( protfact, fname);
ReadResponse( protfact, fname);
@@ -203,8 +214,9 @@
const
- FILE_BINARY = 'pingpong.bin';
- FILE_JSON = 'pingpong.json';
+ FILE_BINARY = 'pingpong.bin';
+ FILE_JSON = 'pingpong.json';
+ FILE_COMPACT = 'pingpong.compact';
begin
try
Writeln( 'Delphi SkipTest '+IntToStr(Tskiptest_version_2Constants.SKIPTESTSERVICE_VERSION)+' using '+Thrift.Version);
@@ -218,6 +230,10 @@
Test( TJSONProtocolImpl.TFactory.Create, FILE_JSON);
Writeln;
+ Writeln('Compact protocol');
+ Test( TCompactProtocolImpl.TFactory.Create, FILE_COMPACT);
+
+ Writeln;
Writeln('Test completed without errors.');
Writeln;
Write('Press ENTER to close ...'); Readln;
diff --git a/lib/delphi/test/typeregistry/TestTypeRegistry.dpr b/lib/delphi/test/typeregistry/TestTypeRegistry.dpr
index fd5e3dd..31c0fb2 100644
--- a/lib/delphi/test/typeregistry/TestTypeRegistry.dpr
+++ b/lib/delphi/test/typeregistry/TestTypeRegistry.dpr
@@ -30,6 +30,7 @@
Thrift.Protocol in '..\..\src\Thrift.Protocol.pas',
Thrift.Protocol.JSON in '..\..\src\Thrift.Protocol.JSON.pas',
Thrift.Collections in '..\..\src\Thrift.Collections.pas',
+ Thrift.Configuration in '..\..\src\Thrift.Configuration.pas',
Thrift.Server in '..\..\src\Thrift.Server.pas',
Thrift.Utils in '..\..\src\Thrift.Utils.pas',
Thrift.Serializer in '..\..\src\Thrift.Serializer.pas',
diff --git a/lib/go/test/DuplicateImportsTest.thrift b/lib/go/test/DuplicateImportsTest.thrift
new file mode 100644
index 0000000..ffe1caf
--- /dev/null
+++ b/lib/go/test/DuplicateImportsTest.thrift
@@ -0,0 +1,5 @@
+include "common/a.thrift"
+include "common/b.thrift"
+
+typedef a.A A
+typedef b.B B
diff --git a/lib/go/test/Makefile.am b/lib/go/test/Makefile.am
index 244ddff..f5de26e 100644
--- a/lib/go/test/Makefile.am
+++ b/lib/go/test/Makefile.am
@@ -45,7 +45,8 @@
ConflictNamespaceTestC.thrift \
ConflictNamespaceTestD.thrift \
ConflictNamespaceTestSuperThing.thrift \
- ConflictNamespaceServiceTest.thrift
+ ConflictNamespaceServiceTest.thrift \
+ DuplicateImportsTest.thrift
mkdir -p gopath/src
grep -v list.*map.*list.*map $(THRIFTTEST) | grep -v 'set<Insanity>' > ThriftTest.thrift
$(THRIFT) $(THRIFTARGS) -r IncludesTest.thrift
@@ -71,6 +72,7 @@
$(THRIFT) $(THRIFTARGS) ConflictNamespaceTestD.thrift
$(THRIFT) $(THRIFTARGS) ConflictNamespaceTestSuperThing.thrift
$(THRIFT) $(THRIFTARGS) ConflictNamespaceServiceTest.thrift
+ $(THRIFT) $(THRIFTARGS) -r DuplicateImportsTest.thrift
GOPATH=`pwd`/gopath $(GO) get github.com/golang/mock/gomock || true
sed -i 's/\"context\"/\"golang.org\/x\/net\/context\"/g' gopath/src/github.com/golang/mock/gomock/controller.go || true
GOPATH=`pwd`/gopath $(GO) get github.com/golang/mock/gomock
@@ -93,7 +95,9 @@
ignoreinitialismstest \
unionbinarytest \
conflictnamespacetestsuperthing \
- conflict/context/conflict_service-remote
+ conflict/context/conflict_service-remote \
+ servicestest/container_test-remote \
+ duplicateimportstest
GOPATH=`pwd`/gopath $(GO) test thrift tests dontexportrwtest
clean-local:
@@ -127,5 +131,5 @@
ConflictNamespaceTestB.thrift \
ConflictNamespaceTestC.thrift \
ConflictNamespaceTestD.thrift \
- ConflictNamespaceTestSuperThing.thrift
+ ConflictNamespaceTestSuperThing.thrift \
ConflictNamespaceServiceTest.thrift
diff --git a/lib/go/test/ServicesTest.thrift b/lib/go/test/ServicesTest.thrift
index 882b03a..666197f 100644
--- a/lib/go/test/ServicesTest.thrift
+++ b/lib/go/test/ServicesTest.thrift
@@ -107,5 +107,12 @@
struct_a struct_a_func_2ex_1int_1s(1: i64 i, 2: string s) throws(1: moderate_disaster err1, 2:total_disaster err2)
struct_a struct_a_func_1struct_a(1: struct_a st)
+}
+service container_test_parent {
+ void parent_only_func(1: set<i32> s)
+}
+
+service container_test extends container_test_parent {
+ void child_only_func(1: set<i32> s)
}
diff --git a/lib/go/test/common/a.thrift b/lib/go/test/common/a.thrift
new file mode 100644
index 0000000..37e0e1c
--- /dev/null
+++ b/lib/go/test/common/a.thrift
@@ -0,0 +1,5 @@
+namespace go common
+
+struct A {
+ 1: optional string a
+}
diff --git a/lib/go/test/common/b.thrift b/lib/go/test/common/b.thrift
new file mode 100644
index 0000000..19930e7
--- /dev/null
+++ b/lib/go/test/common/b.thrift
@@ -0,0 +1,5 @@
+namespace go common
+
+struct B {
+ 1: optional string b
+}
diff --git a/lib/go/test/tests/thrifttest_driver.go b/lib/go/test/tests/thrifttest_driver.go
index de54cbc..4fc5baa 100644
--- a/lib/go/test/tests/thrifttest_driver.go
+++ b/lib/go/test/tests/thrifttest_driver.go
@@ -213,24 +213,25 @@
2: {thrifttest.Numberz_SIX: crazyEmpty},
}
if r, err := client.TestInsanity(defaultCtx, crazy); !reflect.DeepEqual(r, insanity) || err != nil {
- t.Fatal("TestInsanity failed")
+ t.Fatal("TestInsanity failed:", err)
}
if err := client.TestException(defaultCtx, "TException"); err == nil {
- t.Fatal("TestException TException failed")
+ t.Fatal("TestException TException failed:", err)
}
- if err, ok := client.TestException(defaultCtx, "Xception").(*thrifttest.Xception); ok == false || err == nil {
- t.Fatal("TestException Xception failed")
- } else if err.ErrorCode != 1001 || err.Message != "Xception" {
- t.Fatal("TestException Xception failed")
+ err := client.TestException(defaultCtx, "Xception")
+ if e, ok := err.(*thrifttest.Xception); ok == false || e == nil {
+ t.Fatal("TestException Xception failed:", err)
+ } else if e.ErrorCode != 1001 || e.Message != "Xception" {
+ t.Fatal("TestException Xception failed:", e)
}
if err := client.TestException(defaultCtx, "no Exception"); err != nil {
- t.Fatal("TestException no Exception failed")
+ t.Fatal("TestException no Exception failed:", err)
}
if err := client.TestOneway(defaultCtx, 0); err != nil {
- t.Fatal("TestOneway failed")
+ t.Fatal("TestOneway failed:", err)
}
}
diff --git a/lib/go/thrift/debug_protocol.go b/lib/go/thrift/debug_protocol.go
index 57943e0..c33fba8 100644
--- a/lib/go/thrift/debug_protocol.go
+++ b/lib/go/thrift/debug_protocol.go
@@ -21,23 +21,40 @@
import (
"context"
- "log"
+ "fmt"
)
type TDebugProtocol struct {
Delegate TProtocol
LogPrefix string
+ Logger Logger
}
type TDebugProtocolFactory struct {
Underlying TProtocolFactory
LogPrefix string
+ Logger Logger
}
+// NewTDebugProtocolFactory creates a TDebugProtocolFactory.
+//
+// Deprecated: Please use NewTDebugProtocolFactoryWithLogger or the struct
+// itself instead. This version will use the default logger from standard
+// library.
func NewTDebugProtocolFactory(underlying TProtocolFactory, logPrefix string) *TDebugProtocolFactory {
return &TDebugProtocolFactory{
Underlying: underlying,
LogPrefix: logPrefix,
+ Logger: StdLogger(nil),
+ }
+}
+
+// NewTDebugProtocolFactoryWithLogger creates a TDebugProtocolFactory.
+func NewTDebugProtocolFactoryWithLogger(underlying TProtocolFactory, logPrefix string, logger Logger) *TDebugProtocolFactory {
+ return &TDebugProtocolFactory{
+ Underlying: underlying,
+ LogPrefix: logPrefix,
+ Logger: logger,
}
}
@@ -45,223 +62,228 @@
return &TDebugProtocol{
Delegate: t.Underlying.GetProtocol(trans),
LogPrefix: t.LogPrefix,
+ Logger: fallbackLogger(t.Logger),
}
}
+func (tdp *TDebugProtocol) logf(format string, v ...interface{}) {
+ fallbackLogger(tdp.Logger)(fmt.Sprintf(format, v...))
+}
+
func (tdp *TDebugProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
err := tdp.Delegate.WriteMessageBegin(name, typeId, seqid)
- log.Printf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err)
+ tdp.logf("%sWriteMessageBegin(name=%#v, typeId=%#v, seqid=%#v) => %#v", tdp.LogPrefix, name, typeId, seqid, err)
return err
}
func (tdp *TDebugProtocol) WriteMessageEnd() error {
err := tdp.Delegate.WriteMessageEnd()
- log.Printf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err)
+ tdp.logf("%sWriteMessageEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteStructBegin(name string) error {
err := tdp.Delegate.WriteStructBegin(name)
- log.Printf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err)
+ tdp.logf("%sWriteStructBegin(name=%#v) => %#v", tdp.LogPrefix, name, err)
return err
}
func (tdp *TDebugProtocol) WriteStructEnd() error {
err := tdp.Delegate.WriteStructEnd()
- log.Printf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err)
+ tdp.logf("%sWriteStructEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldBegin(name string, typeId TType, id int16) error {
err := tdp.Delegate.WriteFieldBegin(name, typeId, id)
- log.Printf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err)
+ tdp.logf("%sWriteFieldBegin(name=%#v, typeId=%#v, id%#v) => %#v", tdp.LogPrefix, name, typeId, id, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldEnd() error {
err := tdp.Delegate.WriteFieldEnd()
- log.Printf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err)
+ tdp.logf("%sWriteFieldEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteFieldStop() error {
err := tdp.Delegate.WriteFieldStop()
- log.Printf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err)
+ tdp.logf("%sWriteFieldStop() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteMapBegin(keyType TType, valueType TType, size int) error {
err := tdp.Delegate.WriteMapBegin(keyType, valueType, size)
- log.Printf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err)
+ tdp.logf("%sWriteMapBegin(keyType=%#v, valueType=%#v, size=%#v) => %#v", tdp.LogPrefix, keyType, valueType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteMapEnd() error {
err := tdp.Delegate.WriteMapEnd()
- log.Printf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err)
+ tdp.logf("%sWriteMapEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteListBegin(elemType TType, size int) error {
err := tdp.Delegate.WriteListBegin(elemType, size)
- log.Printf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
+ tdp.logf("%sWriteListBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteListEnd() error {
err := tdp.Delegate.WriteListEnd()
- log.Printf("%sWriteListEnd() => %#v", tdp.LogPrefix, err)
+ tdp.logf("%sWriteListEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteSetBegin(elemType TType, size int) error {
err := tdp.Delegate.WriteSetBegin(elemType, size)
- log.Printf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
+ tdp.logf("%sWriteSetBegin(elemType=%#v, size=%#v) => %#v", tdp.LogPrefix, elemType, size, err)
return err
}
func (tdp *TDebugProtocol) WriteSetEnd() error {
err := tdp.Delegate.WriteSetEnd()
- log.Printf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err)
+ tdp.logf("%sWriteSetEnd() => %#v", tdp.LogPrefix, err)
return err
}
func (tdp *TDebugProtocol) WriteBool(value bool) error {
err := tdp.Delegate.WriteBool(value)
- log.Printf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err)
+ tdp.logf("%sWriteBool(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteByte(value int8) error {
err := tdp.Delegate.WriteByte(value)
- log.Printf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err)
+ tdp.logf("%sWriteByte(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI16(value int16) error {
err := tdp.Delegate.WriteI16(value)
- log.Printf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err)
+ tdp.logf("%sWriteI16(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI32(value int32) error {
err := tdp.Delegate.WriteI32(value)
- log.Printf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err)
+ tdp.logf("%sWriteI32(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteI64(value int64) error {
err := tdp.Delegate.WriteI64(value)
- log.Printf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err)
+ tdp.logf("%sWriteI64(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteDouble(value float64) error {
err := tdp.Delegate.WriteDouble(value)
- log.Printf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err)
+ tdp.logf("%sWriteDouble(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteString(value string) error {
err := tdp.Delegate.WriteString(value)
- log.Printf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err)
+ tdp.logf("%sWriteString(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) WriteBinary(value []byte) error {
err := tdp.Delegate.WriteBinary(value)
- log.Printf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err)
+ tdp.logf("%sWriteBinary(value=%#v) => %#v", tdp.LogPrefix, value, err)
return err
}
func (tdp *TDebugProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
name, typeId, seqid, err = tdp.Delegate.ReadMessageBegin()
- log.Printf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err)
+ tdp.logf("%sReadMessageBegin() (name=%#v, typeId=%#v, seqid=%#v, err=%#v)", tdp.LogPrefix, name, typeId, seqid, err)
return
}
func (tdp *TDebugProtocol) ReadMessageEnd() (err error) {
err = tdp.Delegate.ReadMessageEnd()
- log.Printf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err)
+ tdp.logf("%sReadMessageEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadStructBegin() (name string, err error) {
name, err = tdp.Delegate.ReadStructBegin()
- log.Printf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err)
+ tdp.logf("%sReadStructBegin() (name%#v, err=%#v)", tdp.LogPrefix, name, err)
return
}
func (tdp *TDebugProtocol) ReadStructEnd() (err error) {
err = tdp.Delegate.ReadStructEnd()
- log.Printf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err)
+ tdp.logf("%sReadStructEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadFieldBegin() (name string, typeId TType, id int16, err error) {
name, typeId, id, err = tdp.Delegate.ReadFieldBegin()
- log.Printf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err)
+ tdp.logf("%sReadFieldBegin() (name=%#v, typeId=%#v, id=%#v, err=%#v)", tdp.LogPrefix, name, typeId, id, err)
return
}
func (tdp *TDebugProtocol) ReadFieldEnd() (err error) {
err = tdp.Delegate.ReadFieldEnd()
- log.Printf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err)
+ tdp.logf("%sReadFieldEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadMapBegin() (keyType TType, valueType TType, size int, err error) {
keyType, valueType, size, err = tdp.Delegate.ReadMapBegin()
- log.Printf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err)
+ tdp.logf("%sReadMapBegin() (keyType=%#v, valueType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, keyType, valueType, size, err)
return
}
func (tdp *TDebugProtocol) ReadMapEnd() (err error) {
err = tdp.Delegate.ReadMapEnd()
- log.Printf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err)
+ tdp.logf("%sReadMapEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadListBegin() (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadListBegin()
- log.Printf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
+ tdp.logf("%sReadListBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
return
}
func (tdp *TDebugProtocol) ReadListEnd() (err error) {
err = tdp.Delegate.ReadListEnd()
- log.Printf("%sReadListEnd() err=%#v", tdp.LogPrefix, err)
+ tdp.logf("%sReadListEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadSetBegin() (elemType TType, size int, err error) {
elemType, size, err = tdp.Delegate.ReadSetBegin()
- log.Printf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
+ tdp.logf("%sReadSetBegin() (elemType=%#v, size=%#v, err=%#v)", tdp.LogPrefix, elemType, size, err)
return
}
func (tdp *TDebugProtocol) ReadSetEnd() (err error) {
err = tdp.Delegate.ReadSetEnd()
- log.Printf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err)
+ tdp.logf("%sReadSetEnd() err=%#v", tdp.LogPrefix, err)
return
}
func (tdp *TDebugProtocol) ReadBool() (value bool, err error) {
value, err = tdp.Delegate.ReadBool()
- log.Printf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
+ tdp.logf("%sReadBool() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadByte() (value int8, err error) {
value, err = tdp.Delegate.ReadByte()
- log.Printf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
+ tdp.logf("%sReadByte() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI16() (value int16, err error) {
value, err = tdp.Delegate.ReadI16()
- log.Printf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
+ tdp.logf("%sReadI16() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI32() (value int32, err error) {
value, err = tdp.Delegate.ReadI32()
- log.Printf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
+ tdp.logf("%sReadI32() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadI64() (value int64, err error) {
value, err = tdp.Delegate.ReadI64()
- log.Printf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
+ tdp.logf("%sReadI64() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadDouble() (value float64, err error) {
value, err = tdp.Delegate.ReadDouble()
- log.Printf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
+ tdp.logf("%sReadDouble() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadString() (value string, err error) {
value, err = tdp.Delegate.ReadString()
- log.Printf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
+ tdp.logf("%sReadString() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) ReadBinary() (value []byte, err error) {
value, err = tdp.Delegate.ReadBinary()
- log.Printf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
+ tdp.logf("%sReadBinary() (value=%#v, err=%#v)", tdp.LogPrefix, value, err)
return
}
func (tdp *TDebugProtocol) Skip(fieldType TType) (err error) {
err = tdp.Delegate.Skip(fieldType)
- log.Printf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err)
+ tdp.logf("%sSkip(fieldType=%#v) (err=%#v)", tdp.LogPrefix, fieldType, err)
return
}
func (tdp *TDebugProtocol) Flush(ctx context.Context) (err error) {
err = tdp.Delegate.Flush(ctx)
- log.Printf("%sFlush() (err=%#v)", tdp.LogPrefix, err)
+ tdp.logf("%sFlush() (err=%#v)", tdp.LogPrefix, err)
return
}
diff --git a/lib/go/thrift/header_protocol.go b/lib/go/thrift/header_protocol.go
index 46205b2..99deaf7 100644
--- a/lib/go/thrift/header_protocol.go
+++ b/lib/go/thrift/header_protocol.go
@@ -303,3 +303,17 @@
func (p *THeaderProtocol) Skip(fieldType TType) error {
return p.protocol.Skip(fieldType)
}
+
+// GetResponseHeadersFromClient is a helper function to get the read THeaderMap
+// from the last response received from the given client.
+//
+// If the last response was not sent over THeader protocol,
+// a nil map will be returned.
+func GetResponseHeadersFromClient(c TClient) THeaderMap {
+ if sc, ok := c.(*TStandardClient); ok {
+ if hp, ok := sc.iprot.(*THeaderProtocol); ok {
+ return hp.transport.readHeaders
+ }
+ }
+ return nil
+}
diff --git a/lib/go/thrift/http_client.go b/lib/go/thrift/http_client.go
index 5c82bf5..d093eeb 100644
--- a/lib/go/thrift/http_client.go
+++ b/lib/go/thrift/http_client.go
@@ -218,7 +218,7 @@
}
const maxSize = ^uint64(0)
- return maxSize // the thruth is, we just don't know unless framed is used
+ return maxSize // the truth is, we just don't know unless framed is used
}
// Deprecated: Use NewTHttpClientTransportFactory instead.
diff --git a/lib/go/thrift/iostream_transport.go b/lib/go/thrift/iostream_transport.go
index fea93bc..0b1775d 100644
--- a/lib/go/thrift/iostream_transport.go
+++ b/lib/go/thrift/iostream_transport.go
@@ -210,5 +210,5 @@
func (p *StreamTransport) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
- return maxSize // the thruth is, we just don't know unless framed is used
+ return maxSize // the truth is, we just don't know unless framed is used
}
diff --git a/lib/go/thrift/logger.go b/lib/go/thrift/logger.go
new file mode 100644
index 0000000..c42aac9
--- /dev/null
+++ b/lib/go/thrift/logger.go
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "log"
+ "os"
+ "testing"
+)
+
+// Logger is a simple wrapper of a logging function.
+//
+// In reality the users might actually use different logging libraries, and they
+// are not always compatible with each other.
+//
+// Logger is meant to be a simple common ground that it's easy to wrap whatever
+// logging library they use into.
+//
+// See https://issues.apache.org/jira/browse/THRIFT-4985 for the design
+// discussion behind it.
+type Logger func(msg string)
+
+// NopLogger is a Logger implementation that does nothing.
+func NopLogger(msg string) {}
+
+// StdLogger wraps stdlib log package into a Logger.
+//
+// If logger passed in is nil, it will fallback to use stderr and default flags.
+func StdLogger(logger *log.Logger) Logger {
+ if logger == nil {
+ logger = log.New(os.Stderr, "", log.LstdFlags)
+ }
+ return func(msg string) {
+ logger.Print(msg)
+ }
+}
+
+// TestLogger is a Logger implementation can be used in test codes.
+//
+// It fails the test when being called.
+func TestLogger(tb testing.TB) Logger {
+ return func(msg string) {
+ tb.Errorf("logger called with msg: %q", msg)
+ }
+}
+
+func fallbackLogger(logger Logger) Logger {
+ if logger == nil {
+ return StdLogger(nil)
+ }
+ return logger
+}
diff --git a/lib/go/thrift/response_helper.go b/lib/go/thrift/response_helper.go
new file mode 100644
index 0000000..d884c6a
--- /dev/null
+++ b/lib/go/thrift/response_helper.go
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "context"
+)
+
+// See https://godoc.org/context#WithValue on why do we need the unexported typedefs.
+type responseHelperKey struct{}
+
+// TResponseHelper defines a object with a set of helper functions that can be
+// retrieved from the context object passed into server handler functions.
+//
+// Use GetResponseHelper to retrieve the injected TResponseHelper implementation
+// from the context object.
+//
+// The zero value of TResponseHelper is valid with all helper functions being
+// no-op.
+type TResponseHelper struct {
+ // THeader related functions
+ *THeaderResponseHelper
+}
+
+// THeaderResponseHelper defines THeader related TResponseHelper functions.
+//
+// The zero value of *THeaderResponseHelper is valid with all helper functions
+// being no-op.
+type THeaderResponseHelper struct {
+ proto *THeaderProtocol
+}
+
+// NewTHeaderResponseHelper creates a new THeaderResponseHelper from the
+// underlying TProtocol.
+func NewTHeaderResponseHelper(proto TProtocol) *THeaderResponseHelper {
+ if hp, ok := proto.(*THeaderProtocol); ok {
+ return &THeaderResponseHelper{
+ proto: hp,
+ }
+ }
+ return nil
+}
+
+// SetHeader sets a response header.
+//
+// It's no-op if the underlying protocol/transport does not support THeader.
+func (h *THeaderResponseHelper) SetHeader(key, value string) {
+ if h != nil && h.proto != nil {
+ h.proto.SetWriteHeader(key, value)
+ }
+}
+
+// ClearHeaders clears all the response headers previously set.
+//
+// It's no-op if the underlying protocol/transport does not support THeader.
+func (h *THeaderResponseHelper) ClearHeaders() {
+ if h != nil && h.proto != nil {
+ h.proto.ClearWriteHeaders()
+ }
+}
+
+// GetResponseHelper retrieves the TResponseHelper implementation injected into
+// the context object.
+//
+// If no helper was found in the context object, a nop helper with ok == false
+// will be returned.
+func GetResponseHelper(ctx context.Context) (helper TResponseHelper, ok bool) {
+ if v := ctx.Value(responseHelperKey{}); v != nil {
+ helper, ok = v.(TResponseHelper)
+ }
+ return
+}
+
+// SetResponseHelper injects TResponseHelper into the context object.
+func SetResponseHelper(ctx context.Context, helper TResponseHelper) context.Context {
+ return context.WithValue(ctx, responseHelperKey{}, helper)
+}
diff --git a/lib/go/thrift/response_helper_test.go b/lib/go/thrift/response_helper_test.go
new file mode 100644
index 0000000..69f76d3
--- /dev/null
+++ b/lib/go/thrift/response_helper_test.go
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package thrift
+
+import (
+ "context"
+ "testing"
+)
+
+func TestResponseHelperContext(t *testing.T) {
+ ctx := context.Background()
+
+ t.Run(
+ "empty-noop",
+ func(t *testing.T) {
+ helper, ok := GetResponseHelper(ctx)
+ if ok {
+ t.Error("GetResponseHelper expected ok == false")
+ }
+ // Just make sure those function calls does not panic
+ helper.SetHeader("foo", "bar")
+ helper.ClearHeaders()
+ },
+ )
+
+ t.Run(
+ "set-get",
+ func(t *testing.T) {
+ trans := NewTHeaderTransport(NewTMemoryBuffer())
+ proto := NewTHeaderProtocol(trans)
+ ctx = SetResponseHelper(
+ ctx,
+ TResponseHelper{
+ THeaderResponseHelper: NewTHeaderResponseHelper(proto),
+ },
+ )
+ helper, ok := GetResponseHelper(ctx)
+ if !ok {
+ t.Error("GetResponseHelper expected ok == true")
+ }
+ if helper.THeaderResponseHelper == nil {
+ t.Error("GetResponseHelper expected THeaderResponseHelper to be non-nil")
+ }
+ },
+ )
+}
+
+func TestHeaderHelper(t *testing.T) {
+ t.Run(
+ "THeaderProtocol",
+ func(t *testing.T) {
+ trans := NewTHeaderTransport(NewTMemoryBuffer())
+ proto := NewTHeaderProtocol(trans)
+ helper := NewTHeaderResponseHelper(proto)
+
+ const (
+ key = "key"
+ value = "value"
+ )
+ helper.SetHeader(key, value)
+ if len(trans.writeHeaders) != 1 {
+ t.Errorf(
+ "Expected THeaderTransport.writeHeaders to be with size of 1, got %+v",
+ trans.writeHeaders,
+ )
+ }
+ actual := trans.writeHeaders[key]
+ if actual != value {
+ t.Errorf(
+ "Expected THeaderTransport.writeHeaders to have %q:%q, got %+v",
+ key,
+ value,
+ trans.writeHeaders,
+ )
+ }
+ helper.ClearHeaders()
+ if len(trans.writeHeaders) != 0 {
+ t.Errorf(
+ "Expected THeaderTransport.writeHeaders to be empty after ClearHeaders call, got %+v",
+ trans.writeHeaders,
+ )
+ }
+ },
+ )
+
+ t.Run(
+ "other-protocol",
+ func(t *testing.T) {
+ trans := NewTMemoryBuffer()
+ proto := NewTCompactProtocol(trans)
+ helper := NewTHeaderResponseHelper(proto)
+
+ // We only need to make sure that functions in helper
+ // don't panic here.
+ helper.SetHeader("foo", "bar")
+ helper.ClearHeaders()
+ },
+ )
+
+ t.Run(
+ "zero-value",
+ func(t *testing.T) {
+ var helper *THeaderResponseHelper
+
+ // We only need to make sure that functions in helper
+ // don't panic here.
+ helper.SetHeader("foo", "bar")
+ helper.ClearHeaders()
+ },
+ )
+}
+
+func TestTResponseHelperZeroValue(t *testing.T) {
+ var helper THeaderResponseHelper
+
+ // We only need to make sure that functions in helper
+ // don't panic here.
+ helper.SetHeader("foo", "bar")
+ helper.ClearHeaders()
+}
diff --git a/lib/go/thrift/simple_server.go b/lib/go/thrift/simple_server.go
index f8efbed..5a9c9c9 100644
--- a/lib/go/thrift/simple_server.go
+++ b/lib/go/thrift/simple_server.go
@@ -20,8 +20,8 @@
package thrift
import (
- "log"
- "runtime/debug"
+ "fmt"
+ "io"
"sync"
"sync/atomic"
)
@@ -45,6 +45,8 @@
// Headers to auto forward in THeaderProtocol
forwardHeaders []string
+
+ logger Logger
}
func NewTSimpleServer2(processor TProcessor, serverTransport TServerTransport) *TSimpleServer {
@@ -148,6 +150,14 @@
p.forwardHeaders = keys
}
+// SetLogger sets the logger used by this TSimpleServer.
+//
+// If no logger was set before Serve is called, a default logger using standard
+// log library will be used.
+func (p *TSimpleServer) SetLogger(logger Logger) {
+ p.logger = logger
+}
+
func (p *TSimpleServer) innerAccept() (int32, error) {
client, err := p.serverTransport.Accept()
p.mu.Lock()
@@ -164,7 +174,7 @@
go func() {
defer p.wg.Done()
if err := p.processRequests(client); err != nil {
- log.Println("error processing request:", err)
+ p.logger(fmt.Sprintf("error processing request: %v", err))
}
}()
}
@@ -184,6 +194,8 @@
}
func (p *TSimpleServer) Serve() error {
+ p.logger = fallbackLogger(p.logger)
+
err := p.Listen()
if err != nil {
return err
@@ -204,7 +216,27 @@
return nil
}
-func (p *TSimpleServer) processRequests(client TTransport) error {
+// If err is actually EOF, return nil, otherwise return err as-is.
+func treatEOFErrorsAsNil(err error) error {
+ if err == nil {
+ return nil
+ }
+ // err could be io.EOF wrapped with TProtocolException,
+ // so that err == io.EOF doesn't necessarily work in some cases.
+ if err.Error() == io.EOF.Error() {
+ return nil
+ }
+ if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE {
+ return nil
+ }
+ return err
+}
+
+func (p *TSimpleServer) processRequests(client TTransport) (err error) {
+ defer func() {
+ err = treatEOFErrorsAsNil(err)
+ }()
+
processor := p.processorFactory.GetProcessor(client)
inputTransport, err := p.inputTransportFactory.GetTransport(client)
if err != nil {
@@ -229,12 +261,6 @@
outputProtocol = p.outputProtocolFactory.GetProtocol(outputTransport)
}
- defer func() {
- if e := recover(); e != nil {
- log.Printf("panic in processor: %s: %s", e, debug.Stack())
- }
- }()
-
if inputTransport != nil {
defer inputTransport.Close()
}
@@ -246,7 +272,12 @@
return nil
}
- ctx := defaultCtx
+ ctx := SetResponseHelper(
+ defaultCtx,
+ TResponseHelper{
+ THeaderResponseHelper: NewTHeaderResponseHelper(outputProtocol),
+ },
+ )
if headerProtocol != nil {
// We need to call ReadFrame here, otherwise we won't
// get any headers on the AddReadTHeaderToContext call.
@@ -257,14 +288,12 @@
if err := headerProtocol.ReadFrame(); err != nil {
return err
}
- ctx = AddReadTHeaderToContext(defaultCtx, headerProtocol.GetReadHeaders())
+ ctx = AddReadTHeaderToContext(ctx, headerProtocol.GetReadHeaders())
ctx = SetWriteHeaderList(ctx, p.forwardHeaders)
}
ok, err := processor.Process(ctx, inputProtocol, outputProtocol)
- if err, ok := err.(TTransportException); ok && err.TypeId() == END_OF_FILE {
- return nil
- } else if err != nil {
+ if _, ok := err.(TTransportException); ok && err != nil {
return err
}
if err, ok := err.(TApplicationException); ok && err.TypeId() == UNKNOWN_METHOD {
diff --git a/lib/go/thrift/ssl_socket.go b/lib/go/thrift/ssl_socket.go
index ba63377..45bf38a 100644
--- a/lib/go/thrift/ssl_socket.go
+++ b/lib/go/thrift/ssl_socket.go
@@ -172,5 +172,5 @@
func (p *TSSLSocket) RemainingBytes() (num_bytes uint64) {
const maxSize = ^uint64(0)
- return maxSize // the thruth is, we just don't know unless framed is used
+ return maxSize // the truth is, we just don't know unless framed is used
}
diff --git a/lib/go/thrift/zlib_transport.go b/lib/go/thrift/zlib_transport.go
index f3d4267..e7efdfb 100644
--- a/lib/go/thrift/zlib_transport.go
+++ b/lib/go/thrift/zlib_transport.go
@@ -23,7 +23,6 @@
"compress/zlib"
"context"
"io"
- "log"
)
// TZlibTransportFactory is a factory for TZlibTransport instances
@@ -67,7 +66,6 @@
func NewTZlibTransport(trans TTransport, level int) (*TZlibTransport, error) {
w, err := zlib.NewWriterLevel(trans, level)
if err != nil {
- log.Println(err)
return nil, err
}
diff --git a/lib/hs/CMakeLists.txt b/lib/hs/CMakeLists.txt
index 1a5b8fd..c477c9b 100644
--- a/lib/hs/CMakeLists.txt
+++ b/lib/hs/CMakeLists.txt
@@ -37,7 +37,7 @@
)
if(BUILD_TESTING)
- list(APPEND haskell_soruces
+ list(APPEND haskell_sources
test/Spec.hs
test/BinarySpec.hs
test/CompactSpec.hs
diff --git a/lib/java/src/org/apache/thrift/server/TSaslNonblockingServer.java b/lib/java/src/org/apache/thrift/server/TSaslNonblockingServer.java
new file mode 100644
index 0000000..89dbb78
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/server/TSaslNonblockingServer.java
@@ -0,0 +1,480 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.server;
+
+import java.io.IOException;
+import java.nio.channels.SelectionKey;
+import java.nio.channels.Selector;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+import javax.security.auth.callback.CallbackHandler;
+
+import org.apache.thrift.TProcessor;
+import org.apache.thrift.transport.TNonblockingServerSocket;
+import org.apache.thrift.transport.TNonblockingServerTransport;
+import org.apache.thrift.transport.TNonblockingTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.sasl.NonblockingSaslHandler;
+import org.apache.thrift.transport.sasl.NonblockingSaslHandler.Phase;
+import org.apache.thrift.transport.sasl.TBaseSaslProcessorFactory;
+import org.apache.thrift.transport.sasl.TSaslProcessorFactory;
+import org.apache.thrift.transport.sasl.TSaslServerFactory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * TServer with sasl support, using asynchronous execution and nonblocking io.
+ */
+public class TSaslNonblockingServer extends TServer {
+ private static final Logger LOGGER = LoggerFactory.getLogger(TSaslNonblockingServer.class);
+
+ private static final int DEFAULT_NETWORK_THREADS = 1;
+ private static final int DEFAULT_AUTHENTICATION_THREADS = 1;
+ private static final int DEFAULT_PROCESSING_THREADS = Runtime.getRuntime().availableProcessors();
+
+ private final AcceptorThread acceptor;
+ private final NetworkThreadPool networkThreadPool;
+ private final ExecutorService authenticationExecutor;
+ private final ExecutorService processingExecutor;
+ private final TSaslServerFactory saslServerFactory;
+ private final TSaslProcessorFactory saslProcessorFactory;
+
+ public TSaslNonblockingServer(Args args) throws IOException {
+ super(args);
+ acceptor = new AcceptorThread((TNonblockingServerSocket) serverTransport_);
+ networkThreadPool = new NetworkThreadPool(args.networkThreads);
+ authenticationExecutor = Executors.newFixedThreadPool(args.saslThreads);
+ processingExecutor = Executors.newFixedThreadPool(args.processingThreads);
+ saslServerFactory = args.saslServerFactory;
+ saslProcessorFactory = args.saslProcessorFactory;
+ }
+
+ @Override
+ public void serve() {
+ if (eventHandler_ != null) {
+ eventHandler_.preServe();
+ }
+ networkThreadPool.start();
+ acceptor.start();
+ setServing(true);
+ }
+
+ /**
+ * Trigger a graceful shutdown, but it does not block to wait for the shutdown to finish.
+ */
+ @Override
+ public void stop() {
+ if (!stopped_) {
+ setServing(false);
+ stopped_ = true;
+ acceptor.wakeup();
+ networkThreadPool.wakeupAll();
+ authenticationExecutor.shutdownNow();
+ processingExecutor.shutdownNow();
+ }
+ }
+
+ /**
+ * Gracefully shut down the server and block until all threads are stopped.
+ *
+ * @throws InterruptedException if is interrupted while waiting for shutdown.
+ */
+ public void shutdown() throws InterruptedException {
+ stop();
+ acceptor.join();
+ for (NetworkThread networkThread : networkThreadPool.networkThreads) {
+ networkThread.join();
+ }
+ while (!authenticationExecutor.isTerminated()) {
+ authenticationExecutor.awaitTermination(10, TimeUnit.SECONDS);
+ }
+ while (!processingExecutor.isTerminated()) {
+ processingExecutor.awaitTermination(10, TimeUnit.SECONDS);
+ }
+ }
+
+ private class AcceptorThread extends Thread {
+
+ private final TNonblockingServerTransport serverTransport;
+ private final Selector acceptSelector;
+
+ private AcceptorThread(TNonblockingServerSocket serverTransport) throws IOException {
+ super("acceptor-thread");
+ this.serverTransport = serverTransport;
+ acceptSelector = Selector.open();
+ serverTransport.registerSelector(acceptSelector);
+ }
+
+ @Override
+ public void run() {
+ try {
+ serverTransport.listen();
+ while (!stopped_) {
+ select();
+ acceptNewConnection();
+ }
+ } catch (TTransportException e) {
+ // Failed to listen.
+ LOGGER.error("Failed to listen on server socket, error " + e.getType(), e);
+ } catch (Throwable e) {
+ // Unexpected errors.
+ LOGGER.error("Unexpected error in acceptor thread.", e);
+ } finally {
+ TSaslNonblockingServer.this.stop();
+ close();
+ }
+ }
+
+ void wakeup() {
+ acceptSelector.wakeup();
+ }
+
+ private void acceptNewConnection() {
+ Iterator<SelectionKey> selectedKeyItr = acceptSelector.selectedKeys().iterator();
+ while (!stopped_ && selectedKeyItr.hasNext()) {
+ SelectionKey selected = selectedKeyItr.next();
+ selectedKeyItr.remove();
+ if (selected.isAcceptable()) {
+ try {
+ while (true) {
+ // Accept all available connections from the backlog.
+ TNonblockingTransport connection = serverTransport.accept();
+ if (connection == null) {
+ break;
+ }
+ if (!networkThreadPool.acceptNewConnection(connection)) {
+ LOGGER.error("Network thread does not accept: " + connection);
+ connection.close();
+ }
+ }
+ } catch (TTransportException e) {
+ LOGGER.warn("Failed to accept incoming connection.", e);
+ }
+ } else {
+ LOGGER.error("Not acceptable selection: " + selected.channel());
+ }
+ }
+ }
+
+ private void select() {
+ try {
+ acceptSelector.select();
+ } catch (IOException e) {
+ LOGGER.error("Failed to select on the server socket.", e);
+ }
+ }
+
+ private void close() {
+ LOGGER.info("Closing acceptor thread.");
+ serverTransport.close();
+ try {
+ acceptSelector.close();
+ } catch (IOException e) {
+ LOGGER.error("Failed to close accept selector.", e);
+ }
+ }
+ }
+
+ private class NetworkThread extends Thread {
+ private final BlockingQueue<TNonblockingTransport> incomingConnections = new LinkedBlockingQueue<>();
+ private final BlockingQueue<NonblockingSaslHandler> stateTransitions = new LinkedBlockingQueue<>();
+ private final Selector ioSelector;
+
+ NetworkThread(String name) throws IOException {
+ super(name);
+ ioSelector = Selector.open();
+ }
+
+ @Override
+ public void run() {
+ try {
+ while (!stopped_) {
+ handleIncomingConnections();
+ handleStateChanges();
+ select();
+ handleIO();
+ }
+ } catch (Throwable e) {
+ LOGGER.error("Unreoverable error in " + getName(), e);
+ } finally {
+ close();
+ }
+ }
+
+ private void handleStateChanges() {
+ while (true) {
+ NonblockingSaslHandler statemachine = stateTransitions.poll();
+ if (statemachine == null) {
+ return;
+ }
+ tryRunNextPhase(statemachine);
+ }
+ }
+
+ private void select() {
+ try {
+ ioSelector.select();
+ } catch (IOException e) {
+ LOGGER.error("Failed to select in " + getName(), e);
+ }
+ }
+
+ private void handleIO() {
+ Iterator<SelectionKey> selectedKeyItr = ioSelector.selectedKeys().iterator();
+ while (!stopped_ && selectedKeyItr.hasNext()) {
+ SelectionKey selected = selectedKeyItr.next();
+ selectedKeyItr.remove();
+ if (!selected.isValid()) {
+ closeChannel(selected);
+ }
+ NonblockingSaslHandler saslHandler = (NonblockingSaslHandler) selected.attachment();
+ if (selected.isReadable()) {
+ saslHandler.handleRead();
+ } else if (selected.isWritable()) {
+ saslHandler.handleWrite();
+ } else {
+ LOGGER.error("Invalid intrest op " + selected.interestOps());
+ closeChannel(selected);
+ continue;
+ }
+ if (saslHandler.isCurrentPhaseDone()) {
+ tryRunNextPhase(saslHandler);
+ }
+ }
+ }
+
+ // The following methods are modifying the registered channel set on the selector, which itself
+ // is not thread safe. Thus we need a lock to protect it from race condition.
+
+ private synchronized void handleIncomingConnections() {
+ while (true) {
+ TNonblockingTransport connection = incomingConnections.poll();
+ if (connection == null) {
+ return;
+ }
+ if (!connection.isOpen()) {
+ LOGGER.warn("Incoming connection is already closed");
+ continue;
+ }
+ try {
+ SelectionKey selectionKey = connection.registerSelector(ioSelector, SelectionKey.OP_READ);
+ if (selectionKey.isValid()) {
+ NonblockingSaslHandler saslHandler = new NonblockingSaslHandler(selectionKey, connection,
+ saslServerFactory, saslProcessorFactory, inputProtocolFactory_, outputProtocolFactory_,
+ eventHandler_);
+ selectionKey.attach(saslHandler);
+ }
+ } catch (IOException e) {
+ LOGGER.error("Failed to register connection for the selector, close it.", e);
+ connection.close();
+ }
+ }
+ }
+
+ private synchronized void close() {
+ LOGGER.warn("Closing " + getName());
+ while (true) {
+ TNonblockingTransport incomingConnection = incomingConnections.poll();
+ if (incomingConnection == null) {
+ break;
+ }
+ incomingConnection.close();
+ }
+ Set<SelectionKey> registered = ioSelector.keys();
+ for (SelectionKey selection : registered) {
+ closeChannel(selection);
+ }
+ try {
+ ioSelector.close();
+ } catch (IOException e) {
+ LOGGER.error("Failed to close io selector " + getName(), e);
+ }
+ }
+
+ private synchronized void closeChannel(SelectionKey selectionKey) {
+ if (selectionKey.attachment() == null) {
+ try {
+ selectionKey.channel().close();
+ } catch (IOException e) {
+ LOGGER.error("Failed to close channel.", e);
+ } finally {
+ selectionKey.cancel();
+ }
+ } else {
+ NonblockingSaslHandler saslHandler = (NonblockingSaslHandler) selectionKey.attachment();
+ saslHandler.close();
+ }
+ }
+
+ private void tryRunNextPhase(NonblockingSaslHandler saslHandler) {
+ Phase nextPhase = saslHandler.getNextPhase();
+ saslHandler.stepToNextPhase();
+ switch (nextPhase) {
+ case EVALUATING_SASL_RESPONSE:
+ authenticationExecutor.submit(new Computation(saslHandler));
+ break;
+ case PROCESSING:
+ processingExecutor.submit(new Computation(saslHandler));
+ break;
+ case CLOSING:
+ saslHandler.runCurrentPhase();
+ break;
+ default: // waiting for next io event for the current state machine
+ }
+ }
+
+ public boolean accept(TNonblockingTransport connection) {
+ if (stopped_) {
+ return false;
+ }
+ if (incomingConnections.offer(connection)) {
+ wakeup();
+ return true;
+ }
+ return false;
+ }
+
+ private void wakeup() {
+ ioSelector.wakeup();
+ }
+
+ private class Computation implements Runnable {
+
+ private final NonblockingSaslHandler statemachine;
+
+ private Computation(NonblockingSaslHandler statemachine) {
+ this.statemachine = statemachine;
+ }
+
+ @Override
+ public void run() {
+ try {
+ while (!statemachine.isCurrentPhaseDone()) {
+ statemachine.runCurrentPhase();
+ }
+ stateTransitions.add(statemachine);
+ wakeup();
+ } catch (Throwable e) {
+ LOGGER.error("Damn it!", e);
+ }
+ }
+ }
+ }
+
+ private class NetworkThreadPool {
+ private final List<NetworkThread> networkThreads;
+ private int accepted = 0;
+
+ NetworkThreadPool(int size) throws IOException {
+ networkThreads = new ArrayList<>(size);
+ int digits = (int) Math.log10(size) + 1;
+ String threadNamePattern = "network-thread-%0" + digits + "d";
+ for (int i = 0; i < size; i++) {
+ networkThreads.add(new NetworkThread(String.format(threadNamePattern, i)));
+ }
+ }
+
+ /**
+ * Round robin new connection among all the network threads.
+ *
+ * @param connection incoming connection.
+ * @return true if the incoming connection is accepted by network thread pool.
+ */
+ boolean acceptNewConnection(TNonblockingTransport connection) {
+ return networkThreads.get((accepted ++) % networkThreads.size()).accept(connection);
+ }
+
+ public void start() {
+ for (NetworkThread thread : networkThreads) {
+ thread.start();
+ }
+ }
+
+ void wakeupAll() {
+ for (NetworkThread networkThread : networkThreads) {
+ networkThread.wakeup();
+ }
+ }
+ }
+
+ public static class Args extends AbstractServerArgs<Args> {
+
+ private int networkThreads = DEFAULT_NETWORK_THREADS;
+ private int saslThreads = DEFAULT_AUTHENTICATION_THREADS;
+ private int processingThreads = DEFAULT_PROCESSING_THREADS;
+ private TSaslServerFactory saslServerFactory = new TSaslServerFactory();
+ private TSaslProcessorFactory saslProcessorFactory;
+
+ public Args(TNonblockingServerTransport transport) {
+ super(transport);
+ }
+
+ public Args networkThreads(int networkThreads) {
+ this.networkThreads = networkThreads <= 0 ? DEFAULT_NETWORK_THREADS : networkThreads;
+ return this;
+ }
+
+ public Args saslThreads(int authenticationThreads) {
+ this.saslThreads = authenticationThreads <= 0 ? DEFAULT_AUTHENTICATION_THREADS : authenticationThreads;
+ return this;
+ }
+
+ public Args processingThreads(int processingThreads) {
+ this.processingThreads = processingThreads <= 0 ? DEFAULT_PROCESSING_THREADS : processingThreads;
+ return this;
+ }
+
+ public Args processor(TProcessor processor) {
+ saslProcessorFactory = new TBaseSaslProcessorFactory(processor);
+ return this;
+ }
+
+ public Args saslProcessorFactory(TSaslProcessorFactory saslProcessorFactory) {
+ if (saslProcessorFactory == null) {
+ throw new NullPointerException("Processor factory cannot be null");
+ }
+ this.saslProcessorFactory = saslProcessorFactory;
+ return this;
+ }
+
+ public Args addSaslMechanism(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh) {
+ saslServerFactory.addSaslMechanism(mechanism, protocol, serverName, props, cbh);
+ return this;
+ }
+
+ public Args saslServerFactory(TSaslServerFactory saslServerFactory) {
+ if (saslServerFactory == null) {
+ throw new NullPointerException("saslServerFactory cannot be null");
+ }
+ this.saslServerFactory = saslServerFactory;
+ return this;
+ }
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/server/TServerEventHandler.java b/lib/java/src/org/apache/thrift/server/TServerEventHandler.java
index f069b9b..3bd7959 100644
--- a/lib/java/src/org/apache/thrift/server/TServerEventHandler.java
+++ b/lib/java/src/org/apache/thrift/server/TServerEventHandler.java
@@ -28,6 +28,10 @@
* about. Your subclass can also store local data that you may care about,
* such as additional "arguments" to these methods (stored in the object
* instance's state).
+ *
+ * TODO: It seems this is a custom code entry point created for some resource management purpose in hive.
+ * But when looking into hive code, we see that the argments of TProtocol and TTransport are never used.
+ * We probably should remove these arguments from all the methods.
*/
public interface TServerEventHandler {
@@ -56,4 +60,4 @@
void processContext(ServerContext serverContext,
TTransport inputTransport, TTransport outputTransport);
-}
\ No newline at end of file
+}
diff --git a/lib/java/src/org/apache/thrift/transport/TEOFException.java b/lib/java/src/org/apache/thrift/transport/TEOFException.java
new file mode 100644
index 0000000..b5ae6ef
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/TEOFException.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+/**
+ * End of file, especially, the underlying socket is closed.
+ */
+public class TEOFException extends TTransportException {
+
+ public TEOFException(String message) {
+ super(TTransportException.END_OF_FILE, message);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/TIOStreamTransport.java b/lib/java/src/org/apache/thrift/transport/TIOStreamTransport.java
index 2d31f39..d97d506 100644
--- a/lib/java/src/org/apache/thrift/transport/TIOStreamTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TIOStreamTransport.java
@@ -27,8 +27,8 @@
import java.io.OutputStream;
/**
- * This is the most commonly used base transport. It takes an InputStream
- * and an OutputStream and uses those to perform all transport operations.
+ * This is the most commonly used base transport. It takes an InputStream or
+ * an OutputStream or both and uses it/them to perform transport operations.
* This allows for compatibility with all the nice constructs Java already
* has to provide a variety of types of streams.
*
@@ -50,7 +50,7 @@
protected TIOStreamTransport() {}
/**
- * Input stream constructor.
+ * Input stream constructor, constructs an input only transport.
*
* @param is Input stream to read from
*/
@@ -59,9 +59,9 @@
}
/**
- * Output stream constructor.
+ * Output stream constructor, constructs an output only transport.
*
- * @param os Output stream to read from
+ * @param os Output stream to write to
*/
public TIOStreamTransport(OutputStream os) {
outputStream_ = os;
@@ -83,7 +83,7 @@
* @return false after close is called.
*/
public boolean isOpen() {
- return inputStream_ != null && outputStream_ != null;
+ return inputStream_ != null || outputStream_ != null;
}
/**
@@ -95,20 +95,23 @@
* Closes both the input and output streams.
*/
public void close() {
- if (inputStream_ != null) {
- try {
- inputStream_.close();
- } catch (IOException iox) {
- LOGGER.warn("Error closing input stream.", iox);
+ try {
+ if (inputStream_ != null) {
+ try {
+ inputStream_.close();
+ } catch (IOException iox) {
+ LOGGER.warn("Error closing input stream.", iox);
+ }
}
+ if (outputStream_ != null) {
+ try {
+ outputStream_.close();
+ } catch (IOException iox) {
+ LOGGER.warn("Error closing output stream.", iox);
+ }
+ }
+ } finally {
inputStream_ = null;
- }
- if (outputStream_ != null) {
- try {
- outputStream_.close();
- } catch (IOException iox) {
- LOGGER.warn("Error closing output stream.", iox);
- }
outputStream_ = null;
}
}
diff --git a/lib/java/src/org/apache/thrift/transport/TMemoryTransport.java b/lib/java/src/org/apache/thrift/transport/TMemoryTransport.java
new file mode 100644
index 0000000..f41bc09
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/TMemoryTransport.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.TByteArrayOutputStream;
+
+/**
+ * In memory transport with separate buffers for input and output.
+ */
+public class TMemoryTransport extends TTransport {
+
+ private final ByteBuffer inputBuffer;
+ private final TByteArrayOutputStream outputBuffer;
+
+ public TMemoryTransport(byte[] input) {
+ inputBuffer = ByteBuffer.wrap(input);
+ outputBuffer = new TByteArrayOutputStream(1024);
+ }
+
+ @Override
+ public boolean isOpen() {
+ return true;
+ }
+
+ /**
+ * Opening on an in memory transport should have no effect.
+ */
+ @Override
+ public void open() {
+ // Do nothing.
+ }
+
+ @Override
+ public void close() {
+ // Do nothing.
+ }
+
+ @Override
+ public int read(byte[] buf, int off, int len) throws TTransportException {
+ int remaining = inputBuffer.remaining();
+ if (remaining < len) {
+ throw new TTransportException(TTransportException.END_OF_FILE,
+ "There's only " + remaining + "bytes, but it asks for " + len);
+ }
+ inputBuffer.get(buf, off, len);
+ return len;
+ }
+
+ @Override
+ public void write(byte[] buf, int off, int len) throws TTransportException {
+ outputBuffer.write(buf, off, len);
+ }
+
+ /**
+ * Get all the bytes written by thrift output protocol.
+ *
+ * @return a byte array.
+ */
+ public TByteArrayOutputStream getOutput() {
+ return outputBuffer;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java b/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java
index df37cb0..1631892 100644
--- a/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java
+++ b/lib/java/src/org/apache/thrift/transport/TNonblockingServerSocket.java
@@ -108,7 +108,8 @@
}
}
- protected TNonblockingSocket acceptImpl() throws TTransportException {
+ @Override
+ public TNonblockingSocket accept() throws TTransportException {
if (serverSocket_ == null) {
throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket.");
}
@@ -160,4 +161,9 @@
return serverSocket_.getLocalPort();
}
+ // Expose it for test purpose.
+ ServerSocketChannel getServerSocketChannel() {
+ return serverSocketChannel;
+ }
+
}
diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java b/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java
index ba45b09..daac0d5 100644
--- a/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TNonblockingServerTransport.java
@@ -28,4 +28,12 @@
public abstract class TNonblockingServerTransport extends TServerTransport {
public abstract void registerSelector(Selector selector);
+
+ /**
+ *
+ * @return an incoming connection or null if there is none.
+ * @throws TTransportException
+ */
+ @Override
+ public abstract TNonblockingTransport accept() throws TTransportException;
}
diff --git a/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java
index f86a48b..37a66d6 100644
--- a/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java
+++ b/lib/java/src/org/apache/thrift/transport/TNonblockingSocket.java
@@ -207,4 +207,9 @@
return socketChannel_.finishConnect();
}
+ @Override
+ public String toString() {
+ return "[remote: " + socketChannel_.socket().getRemoteSocketAddress() +
+ ", local: " + socketChannel_.socket().getLocalAddress() + "]" ;
+ }
}
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
index 4b1ca0a..5fc7cff 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslClientTransport.java
@@ -27,6 +27,7 @@
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
+import org.apache.thrift.transport.sasl.NegotiationStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
index 39b81ca..31f309e 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslServerTransport.java
@@ -31,6 +31,8 @@
import javax.security.sasl.SaslException;
import javax.security.sasl.SaslServer;
+import org.apache.thrift.transport.sasl.NegotiationStatus;
+import org.apache.thrift.transport.sasl.TSaslServerDefinition;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -50,29 +52,9 @@
private Map<String, TSaslServerDefinition> serverDefinitionMap = new HashMap<String, TSaslServerDefinition>();
/**
- * Contains all the parameters used to define a SASL server implementation.
- */
- private static class TSaslServerDefinition {
- public String mechanism;
- public String protocol;
- public String serverName;
- public Map<String, String> props;
- public CallbackHandler cbh;
-
- public TSaslServerDefinition(String mechanism, String protocol, String serverName,
- Map<String, String> props, CallbackHandler cbh) {
- this.mechanism = mechanism;
- this.protocol = protocol;
- this.serverName = serverName;
- this.props = props;
- this.cbh = cbh;
- }
- }
-
- /**
* Uses the given underlying transport. Assumes that addServerDefinition is
* called later.
- *
+ *
* @param transport
* Transport underlying this one.
*/
@@ -84,7 +66,7 @@
* Creates a <code>SaslServer</code> using the given SASL-specific parameters.
* See the Java documentation for <code>Sasl.createSaslServer</code> for the
* details of the parameters.
- *
+ *
* @param transport
* The underlying Thrift transport.
*/
diff --git a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
index 4a453b6..d1a3d31 100644
--- a/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TSaslTransport.java
@@ -20,8 +20,6 @@
package org.apache.thrift.transport;
import java.nio.charset.StandardCharsets;
-import java.util.HashMap;
-import java.util.Map;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
@@ -30,6 +28,7 @@
import org.apache.thrift.EncodingUtils;
import org.apache.thrift.TByteArrayOutputStream;
+import org.apache.thrift.transport.sasl.NegotiationStatus;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -52,39 +51,6 @@
}
/**
- * Status bytes used during the initial Thrift SASL handshake.
- */
- protected static enum NegotiationStatus {
- START((byte)0x01),
- OK((byte)0x02),
- BAD((byte)0x03),
- ERROR((byte)0x04),
- COMPLETE((byte)0x05);
-
- private final byte value;
-
- private static final Map<Byte, NegotiationStatus> reverseMap =
- new HashMap<Byte, NegotiationStatus>();
- static {
- for (NegotiationStatus s : NegotiationStatus.class.getEnumConstants()) {
- reverseMap.put(s.getValue(), s);
- }
- }
-
- private NegotiationStatus(byte val) {
- this.value = val;
- }
-
- public byte getValue() {
- return value;
- }
-
- public static NegotiationStatus byValue(byte val) {
- return reverseMap.get(val);
- }
- }
-
- /**
* Transport underlying this one.
*/
protected TTransport underlyingTransport;
@@ -392,7 +358,7 @@
try {
sasl.dispose();
} catch (SaslException e) {
- // Not much we can do here.
+ LOGGER.warn("Failed to dispose sasl participant.", e);
}
}
@@ -427,9 +393,7 @@
} catch (TTransportException transportException) {
// If there is no-data or no-sasl header in the stream, log the failure, and rethrow.
if (transportException.getType() == TTransportException.END_OF_FILE) {
- if (LOGGER.isDebugEnabled()) {
- LOGGER.debug("No data or no sasl data in the stream during negotiation");
- }
+ LOGGER.debug("No data or no sasl data in the stream during negotiation");
}
throw transportException;
}
diff --git a/lib/java/src/org/apache/thrift/transport/TServerSocket.java b/lib/java/src/org/apache/thrift/transport/TServerSocket.java
index 79f7b7f..eb302fd 100644
--- a/lib/java/src/org/apache/thrift/transport/TServerSocket.java
+++ b/lib/java/src/org/apache/thrift/transport/TServerSocket.java
@@ -121,18 +121,23 @@
}
}
- protected TSocket acceptImpl() throws TTransportException {
+ @Override
+ public TSocket accept() throws TTransportException {
if (serverSocket_ == null) {
throw new TTransportException(TTransportException.NOT_OPEN, "No underlying server socket.");
}
+ Socket result;
try {
- Socket result = serverSocket_.accept();
- TSocket result2 = new TSocket(result);
- result2.setTimeout(clientTimeout_);
- return result2;
- } catch (IOException iox) {
- throw new TTransportException(iox);
+ result = serverSocket_.accept();
+ } catch (Exception e) {
+ throw new TTransportException(e);
}
+ if (result == null) {
+ throw new TTransportException("Blocking server's accept() may not return NULL");
+ }
+ TSocket socket = new TSocket(result);
+ socket.setTimeout(clientTimeout_);
+ return socket;
}
public void close() {
diff --git a/lib/java/src/org/apache/thrift/transport/TServerTransport.java b/lib/java/src/org/apache/thrift/transport/TServerTransport.java
index 424e4fa..55ef0c4 100644
--- a/lib/java/src/org/apache/thrift/transport/TServerTransport.java
+++ b/lib/java/src/org/apache/thrift/transport/TServerTransport.java
@@ -56,18 +56,18 @@
public abstract void listen() throws TTransportException;
- public final TTransport accept() throws TTransportException {
- TTransport transport = acceptImpl();
- if (transport == null) {
- throw new TTransportException("accept() may not return NULL");
- }
- return transport;
- }
+ /**
+ * Accept incoming connection on the server socket. When there is no incoming connection available:
+ * either it should block infinitely in a blocking implementation, either it should return null in
+ * a nonblocking implementation.
+ *
+ * @return new connection
+ * @throws TTransportException if IO error.
+ */
+ public abstract TTransport accept() throws TTransportException;
public abstract void close();
- protected abstract TTransport acceptImpl() throws TTransportException;
-
/**
* Optional method implementation. This signals to the server transport
* that it should break out of any accept() or listen() that it is currently
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/DataFrameHeaderReader.java b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameHeaderReader.java
new file mode 100644
index 0000000..2900df9
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameHeaderReader.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+/**
+ * The header for data frame, it only contains a 4-byte payload size.
+ */
+public class DataFrameHeaderReader extends FixedSizeHeaderReader {
+ public static final int PAYLOAD_LENGTH_BYTES = 4;
+
+ private int payloadSize;
+
+ @Override
+ protected int headerSize() {
+ return PAYLOAD_LENGTH_BYTES;
+ }
+
+ @Override
+ protected void onComplete() throws TInvalidSaslFrameException {
+ payloadSize = byteBuffer.getInt(0);
+ if (payloadSize < 0) {
+ throw new TInvalidSaslFrameException("Payload size is negative: " + payloadSize);
+ }
+ }
+
+ @Override
+ public int payloadSize() {
+ return payloadSize;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/DataFrameReader.java b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameReader.java
new file mode 100644
index 0000000..e6900bb
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameReader.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+/**
+ * Frames for thrift (serialized) messages.
+ */
+public class DataFrameReader extends FrameReader<DataFrameHeaderReader> {
+
+ public DataFrameReader() {
+ super(new DataFrameHeaderReader());
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/DataFrameWriter.java b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameWriter.java
new file mode 100644
index 0000000..a2dd15a
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/DataFrameWriter.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.EncodingUtils;
+import org.apache.thrift.utils.StringUtils;
+
+import static org.apache.thrift.transport.sasl.DataFrameHeaderReader.PAYLOAD_LENGTH_BYTES;
+
+/**
+ * Write frames of thrift messages. It expects an empty/null header to be provided with a payload
+ * to be written out. Non empty headers are considered as error.
+ */
+public class DataFrameWriter extends FrameWriter {
+
+ @Override
+ public void withOnlyPayload(byte[] payload, int offset, int length) {
+ if (!isComplete()) {
+ throw new IllegalStateException("Previsous write is not yet complete, with " +
+ frameBytes.remaining() + " bytes left.");
+ }
+ frameBytes = buildFrameWithPayload(payload, offset, length);
+ }
+
+ @Override
+ protected ByteBuffer buildFrame(byte[] header, int headerOffset, int headerLength,
+ byte[] payload, int payloadOffset, int payloadLength) {
+ if (header != null && headerLength > 0) {
+ throw new IllegalArgumentException("Extra header [" + StringUtils.bytesToHexString(header) +
+ "] offset " + payloadOffset + " length " + payloadLength);
+ }
+ return buildFrameWithPayload(payload, payloadOffset, payloadLength);
+ }
+
+ private ByteBuffer buildFrameWithPayload(byte[] payload, int offset, int length) {
+ byte[] bytes = new byte[PAYLOAD_LENGTH_BYTES + length];
+ EncodingUtils.encodeBigEndian(length, bytes, 0);
+ System.arraycopy(payload, offset, bytes, PAYLOAD_LENGTH_BYTES, length);
+ return ByteBuffer.wrap(bytes);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/FixedSizeHeaderReader.java b/lib/java/src/org/apache/thrift/transport/sasl/FixedSizeHeaderReader.java
new file mode 100644
index 0000000..1cbc0ac
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/FixedSizeHeaderReader.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.utils.StringUtils;
+
+import java.nio.ByteBuffer;
+
+/**
+ * Headers' size should be predefined.
+ */
+public abstract class FixedSizeHeaderReader implements FrameHeaderReader {
+
+ protected final ByteBuffer byteBuffer = ByteBuffer.allocate(headerSize());
+
+ @Override
+ public boolean isComplete() {
+ return !byteBuffer.hasRemaining();
+ }
+
+ @Override
+ public void clear() {
+ byteBuffer.clear();
+ }
+
+ @Override
+ public byte[] toBytes() {
+ if (!isComplete()) {
+ throw new IllegalStateException("Header is not yet complete " + StringUtils.bytesToHexString(byteBuffer.array(), 0, byteBuffer.position()));
+ }
+ return byteBuffer.array();
+ }
+
+ @Override
+ public boolean read(TTransport transport) throws TTransportException {
+ FrameReader.readAvailable(transport, byteBuffer);
+ if (byteBuffer.hasRemaining()) {
+ return false;
+ }
+ onComplete();
+ return true;
+ }
+
+ /**
+ * @return Size of the header.
+ */
+ protected abstract int headerSize();
+
+ /**
+ * Actions (e.g. validation) to carry out when the header is complete.
+ *
+ * @throws TTransportException
+ */
+ protected abstract void onComplete() throws TTransportException;
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/FrameHeaderReader.java b/lib/java/src/org/apache/thrift/transport/sasl/FrameHeaderReader.java
new file mode 100644
index 0000000..f7c6593
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/FrameHeaderReader.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
+
+/**
+ * Read headers for a frame. For each frame, the header contains payload size and other metadata.
+ */
+public interface FrameHeaderReader {
+
+ /**
+ * As the thrift sasl specification states, all sasl messages (both for negotiatiing and for
+ * sending data) should have a header to indicate the size of the payload.
+ *
+ * @return size of the payload.
+ */
+ int payloadSize();
+
+ /**
+ *
+ * @return The received bytes for the header.
+ * @throws IllegalStateException if isComplete returns false.
+ */
+ byte[] toBytes();
+
+ /**
+ * @return true if this header has all its fields set.
+ */
+ boolean isComplete();
+
+ /**
+ * Clear the header and make it available to read a new header.
+ */
+ void clear();
+
+ /**
+ * (Nonblocking) Read fields from underlying transport layer.
+ *
+ * @param transport underlying transport.
+ * @return true if header is complete after read.
+ * @throws TSaslNegotiationException if fail to read a valid header of a sasl negotiation message.
+ * @throws TTransportException if io error.
+ */
+ boolean read(TTransport transport) throws TSaslNegotiationException, TTransportException;
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/FrameReader.java b/lib/java/src/org/apache/thrift/transport/sasl/FrameReader.java
new file mode 100644
index 0000000..acb4b73
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/FrameReader.java
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TEOFException;
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
+
+import java.nio.ByteBuffer;
+
+/**
+ * Read frames from a transport. Each frame has a header and a payload. A header will indicate
+ * the size of the payload and other informations about how to decode payload.
+ * Implementations should subclass it by providing a header reader implementation.
+ *
+ * @param <T> Header type.
+ */
+public abstract class FrameReader<T extends FrameHeaderReader> {
+ private final T header;
+ private ByteBuffer payload;
+
+ protected FrameReader(T header) {
+ this.header = header;
+ }
+
+ /**
+ * (Nonblocking) Read available bytes out of the transport without blocking to wait for incoming
+ * data.
+ *
+ * @param transport TTransport
+ * @return true if current frame is complete after read.
+ * @throws TSaslNegotiationException if fail to read back a valid sasl negotiation message.
+ * @throws TTransportException if io error.
+ */
+ public boolean read(TTransport transport) throws TSaslNegotiationException, TTransportException {
+ if (!header.isComplete()) {
+ if (readHeader(transport)) {
+ payload = ByteBuffer.allocate(header.payloadSize());
+ } else {
+ return false;
+ }
+ }
+ if (header.payloadSize() == 0) {
+ return true;
+ }
+ return readPayload(transport);
+ }
+
+ /**
+ * (Nonblocking) Try to read available header bytes from transport.
+ *
+ * @return true if header is complete after read.
+ * @throws TSaslNegotiationException if fail to read back a validd sasl negotiation header.
+ * @throws TTransportException if io error.
+ */
+ private boolean readHeader(TTransport transport) throws TSaslNegotiationException, TTransportException {
+ return header.read(transport);
+ }
+
+ /**
+ * (Nonblocking) Try to read available
+ *
+ * @param transport underlying transport.
+ * @return true if payload is complete after read.
+ * @throws TTransportException if io error.
+ */
+ private boolean readPayload(TTransport transport) throws TTransportException {
+ readAvailable(transport, payload);
+ return payload.hasRemaining();
+ }
+
+ /**
+ *
+ * @return header of the frame
+ */
+ public T getHeader() {
+ return header;
+ }
+
+ /**
+ *
+ * @return number of bytes of the header
+ */
+ public int getHeaderSize() {
+ return header.toBytes().length;
+ }
+
+ /**
+ *
+ * @return byte array of the payload
+ */
+ public byte[] getPayload() {
+ return payload.array();
+ }
+
+ /**
+ *
+ * @return size of the payload
+ */
+ public int getPayloadSize() {
+ return header.payloadSize();
+ }
+
+ /**
+ *
+ * @return true if the reader has fully read a frame
+ */
+ public boolean isComplete() {
+ return !(payload == null || payload.hasRemaining());
+ }
+
+ /**
+ * Reset the state of the reader so that it can be reused to read a new frame.
+ */
+ public void clear() {
+ header.clear();
+ payload = null;
+ }
+
+ /**
+ * Read immediately available bytes from the transport into the byte buffer.
+ *
+ * @param transport TTransport
+ * @param recipient ByteBuffer
+ * @return number of bytes read out of the transport
+ * @throws TTransportException if io error
+ */
+ static int readAvailable(TTransport transport, ByteBuffer recipient) throws TTransportException {
+ if (!recipient.hasRemaining()) {
+ throw new IllegalStateException("Trying to fill a full recipient with " + recipient.limit()
+ + " bytes");
+ }
+ int currentPosition = recipient.position();
+ byte[] bytes = recipient.array();
+ int offset = recipient.arrayOffset() + currentPosition;
+ int expectedLength = recipient.remaining();
+ int got = transport.read(bytes, offset, expectedLength);
+ if (got < 0) {
+ throw new TEOFException("Transport is closed, while trying to read " + expectedLength +
+ " bytes");
+ }
+ recipient.position(currentPosition + got);
+ return got;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/FrameWriter.java b/lib/java/src/org/apache/thrift/transport/sasl/FrameWriter.java
new file mode 100644
index 0000000..5f48121
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/FrameWriter.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.transport.TNonblockingTransport;
+
+/**
+ * Write frame (header and payload) to transport in a nonblocking way.
+ */
+public abstract class FrameWriter {
+
+ protected ByteBuffer frameBytes;
+
+ /**
+ * Provide (maybe empty) header and payload to the frame. This can be called only when isComplete
+ * returns true (last frame has been written out).
+ *
+ * @param header Some extra header bytes (without the 4 bytes for payload length), which will be
+ * the start of the frame. It can be empty, depending on the message format
+ * @param payload Payload as a byte array
+ * @throws IllegalStateException if it is called when isComplete returns false
+ * @throws IllegalArgumentException if header or payload is invalid
+ */
+ public void withHeaderAndPayload(byte[] header, byte[] payload) {
+ if (payload == null) {
+ payload = new byte[0];
+ }
+ if (header == null) {
+ withOnlyPayload(payload);
+ } else {
+ withHeaderAndPayload(header, 0, header.length, payload, 0, payload.length);
+ }
+ }
+
+ /**
+ * Provide extra header and payload to the frame.
+ *
+ * @param header byte array containing the extra header
+ * @param headerOffset starting offset of the header portition
+ * @param headerLength length of the extra header
+ * @param payload byte array containing the payload
+ * @param payloadOffset starting offset of the payload portion
+ * @param payloadLength length of the payload
+ * @throws IllegalStateException if preivous frame is not yet complete (isComplete returns fals)
+ * @throws IllegalArgumentException if header or payload is invalid
+ */
+ public void withHeaderAndPayload(byte[] header, int headerOffset, int headerLength,
+ byte[] payload, int payloadOffset, int payloadLength) {
+ if (!isComplete()) {
+ throw new IllegalStateException("Previsous write is not yet complete, with " +
+ frameBytes.remaining() + " bytes left.");
+ }
+ frameBytes = buildFrame(header, headerOffset, headerLength, payload, payloadOffset, payloadLength);
+ }
+
+ /**
+ * Provide only payload to the frame. Throws UnsupportedOperationException if the frame expects
+ * a header.
+ *
+ * @param payload payload as a byte array
+ */
+ public void withOnlyPayload(byte[] payload) {
+ withOnlyPayload(payload, 0, payload.length);
+ }
+
+ /**
+ * Provide only payload to the frame. Throws UnsupportedOperationException if the frame expects
+ * a header.
+ *
+ * @param payload The underlying byte array as a recipient of the payload
+ * @param offset The offset in the byte array starting from where the payload is located
+ * @param length The length of the payload
+ */
+ public abstract void withOnlyPayload(byte[] payload, int offset, int length);
+
+ protected abstract ByteBuffer buildFrame(byte[] header, int headerOffset, int headerLength,
+ byte[] payload, int payloadOffset, int payloadeLength);
+
+ /**
+ * Nonblocking write to the underlying transport.
+ *
+ * @throws IOException
+ */
+ public void write(TNonblockingTransport transport) throws IOException {
+ transport.write(frameBytes);
+ }
+
+ /**
+ *
+ * @return true when no more data needs to be written out
+ */
+ public boolean isComplete() {
+ return frameBytes == null || !frameBytes.hasRemaining();
+ }
+
+ /**
+ * Release the byte buffer.
+ */
+ public void clear() {
+ frameBytes = null;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/NegotiationStatus.java b/lib/java/src/org/apache/thrift/transport/sasl/NegotiationStatus.java
new file mode 100644
index 0000000..ad704a0
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/NegotiationStatus.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.PROTOCOL_ERROR;
+
+/**
+ * Status bytes used during the initial Thrift SASL handshake.
+ */
+public enum NegotiationStatus {
+ START((byte)0x01),
+ OK((byte)0x02),
+ BAD((byte)0x03),
+ ERROR((byte)0x04),
+ COMPLETE((byte)0x05);
+
+ private static final Map<Byte, NegotiationStatus> reverseMap = new HashMap<>();
+
+ static {
+ for (NegotiationStatus s : NegotiationStatus.values()) {
+ reverseMap.put(s.getValue(), s);
+ }
+ }
+
+ private final byte value;
+
+ NegotiationStatus(byte val) {
+ this.value = val;
+ }
+
+ public byte getValue() {
+ return value;
+ }
+
+ public static NegotiationStatus byValue(byte val) throws TSaslNegotiationException {
+ if (!reverseMap.containsKey(val)) {
+ throw new TSaslNegotiationException(PROTOCOL_ERROR, "Invalid status " + val);
+ }
+ return reverseMap.get(val);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java b/lib/java/src/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
new file mode 100644
index 0000000..4557146
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/NonblockingSaslHandler.java
@@ -0,0 +1,528 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.io.IOException;
+import java.nio.channels.SelectionKey;
+import java.nio.charset.StandardCharsets;
+
+import javax.security.sasl.SaslServer;
+
+import org.apache.thrift.TByteArrayOutputStream;
+import org.apache.thrift.TProcessor;
+import org.apache.thrift.protocol.TProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.server.ServerContext;
+import org.apache.thrift.server.TServerEventHandler;
+import org.apache.thrift.transport.TMemoryTransport;
+import org.apache.thrift.transport.TNonblockingTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.thrift.transport.sasl.NegotiationStatus.COMPLETE;
+import static org.apache.thrift.transport.sasl.NegotiationStatus.OK;
+
+/**
+ * State machine managing one sasl connection in a nonblocking way.
+ */
+public class NonblockingSaslHandler {
+ private static final Logger LOGGER = LoggerFactory.getLogger(NonblockingSaslHandler.class);
+
+ private static final int INTEREST_NONE = 0;
+ private static final int INTEREST_READ = SelectionKey.OP_READ;
+ private static final int INTEREST_WRITE = SelectionKey.OP_WRITE;
+
+ // Tracking the current running phase
+ private Phase currentPhase = Phase.INITIIALIIZING;
+ // Tracking the next phase on the next invocation of the state machine.
+ // It should be the same as current phase if current phase is not yet finished.
+ // Otherwise, if it is different from current phase, the statemachine is in a transition state:
+ // current phase is done, and next phase is not yet started.
+ private Phase nextPhase = currentPhase;
+
+ // Underlying nonblocking transport
+ private SelectionKey selectionKey;
+ private TNonblockingTransport underlyingTransport;
+
+ // APIs for intercepting event / customizing behaviors:
+ // Factories (decorating the base implementations) & EventHandler (intercepting)
+ private TSaslServerFactory saslServerFactory;
+ private TSaslProcessorFactory processorFactory;
+ private TProtocolFactory inputProtocolFactory;
+ private TProtocolFactory outputProtocolFactory;
+ private TServerEventHandler eventHandler;
+ private ServerContext serverContext;
+ // It turns out the event handler implementation in hive sometimes creates a null ServerContext.
+ // In order to know whether TServerEventHandler#createContext is called we use such a flag.
+ private boolean serverContextCreated = false;
+
+ // Wrapper around sasl server
+ private ServerSaslPeer saslPeer;
+
+ // Sasl negotiation io
+ private SaslNegotiationFrameReader saslResponse;
+ private SaslNegotiationFrameWriter saslChallenge;
+ // IO for request from and response to the socket
+ private DataFrameReader requestReader;
+ private DataFrameWriter responseWriter;
+ // If sasl is negotiated for integrity/confidentiality protection
+ private boolean dataProtected;
+
+ public NonblockingSaslHandler(SelectionKey selectionKey, TNonblockingTransport underlyingTransport,
+ TSaslServerFactory saslServerFactory, TSaslProcessorFactory processorFactory,
+ TProtocolFactory inputProtocolFactory, TProtocolFactory outputProtocolFactory,
+ TServerEventHandler eventHandler) {
+ this.selectionKey = selectionKey;
+ this.underlyingTransport = underlyingTransport;
+ this.saslServerFactory = saslServerFactory;
+ this.processorFactory = processorFactory;
+ this.inputProtocolFactory = inputProtocolFactory;
+ this.outputProtocolFactory = outputProtocolFactory;
+ this.eventHandler = eventHandler;
+
+ saslResponse = new SaslNegotiationFrameReader();
+ saslChallenge = new SaslNegotiationFrameWriter();
+ requestReader = new DataFrameReader();
+ responseWriter = new DataFrameWriter();
+ }
+
+ /**
+ * Get current phase of the state machine.
+ *
+ * @return current phase.
+ */
+ public Phase getCurrentPhase() {
+ return currentPhase;
+ }
+
+ /**
+ * Get next phase of the state machine.
+ * It is different from current phase iff current phase is done (and next phase not yet started).
+ *
+ * @return next phase.
+ */
+ public Phase getNextPhase() {
+ return nextPhase;
+ }
+
+ /**
+ *
+ * @return underlying nonblocking socket
+ */
+ public TNonblockingTransport getUnderlyingTransport() {
+ return underlyingTransport;
+ }
+
+ /**
+ *
+ * @return SaslServer instance
+ */
+ public SaslServer getSaslServer() {
+ return saslPeer.getSaslServer();
+ }
+
+ /**
+ *
+ * @return true if current phase is done.
+ */
+ public boolean isCurrentPhaseDone() {
+ return currentPhase != nextPhase;
+ }
+
+ /**
+ * Run state machine.
+ *
+ * @throws IllegalStateException if current state is already done.
+ */
+ public void runCurrentPhase() {
+ currentPhase.runStateMachine(this);
+ }
+
+ /**
+ * When current phase is intrested in read selection, calling this will run the current phase and
+ * its following phases if the following ones are interested to read, until there is nothing
+ * available in the underlying transport.
+ *
+ * @throws IllegalStateException if is called in an irrelevant phase.
+ */
+ public void handleRead() {
+ handleOps(INTEREST_READ);
+ }
+
+ /**
+ * Similiar to handleRead. But it is for write ops.
+ *
+ * @throws IllegalStateException if it is called in an irrelevant phase.
+ */
+ public void handleWrite() {
+ handleOps(INTEREST_WRITE);
+ }
+
+ private void handleOps(int interestOps) {
+ if (currentPhase.selectionInterest != interestOps) {
+ throw new IllegalStateException("Current phase " + currentPhase + " but got interest " +
+ interestOps);
+ }
+ runCurrentPhase();
+ if (isCurrentPhaseDone() && nextPhase.selectionInterest == interestOps) {
+ stepToNextPhase();
+ handleOps(interestOps);
+ }
+ }
+
+ /**
+ * When current phase is finished, it's expected to call this method first before running the
+ * state machine again.
+ * By calling this, "next phase" is marked as started (and not done), thus is ready to run.
+ *
+ * @throws IllegalArgumentException if current phase is not yet done.
+ */
+ public void stepToNextPhase() {
+ if (!isCurrentPhaseDone()) {
+ throw new IllegalArgumentException("Not yet done with current phase: " + currentPhase);
+ }
+ LOGGER.debug("Switch phase {} to {}", currentPhase, nextPhase);
+ switch (nextPhase) {
+ case INITIIALIIZING:
+ throw new IllegalStateException("INITIALIZING cannot be the next phase of " + currentPhase);
+ default:
+ }
+ // If next phase's interest is not the same as current, nor the same as the selection key,
+ // we need to change interest on the selector.
+ if (!(nextPhase.selectionInterest == currentPhase.selectionInterest ||
+ nextPhase.selectionInterest == selectionKey.interestOps())) {
+ changeSelectionInterest(nextPhase.selectionInterest);
+ }
+ currentPhase = nextPhase;
+ }
+
+ private void changeSelectionInterest(int selectionInterest) {
+ selectionKey.interestOps(selectionInterest);
+ }
+
+ // sasl negotiaion failure handling
+ private void failSaslNegotiation(TSaslNegotiationException e) {
+ LOGGER.error("Sasl negotiation failed", e);
+ String errorMsg = e.getDetails();
+ saslChallenge.withHeaderAndPayload(new byte[]{e.getErrorType().code.getValue()},
+ errorMsg.getBytes(StandardCharsets.UTF_8));
+ nextPhase = Phase.WRITING_FAILURE_MESSAGE;
+ }
+
+ private void fail(Exception e) {
+ LOGGER.error("Failed io in " + currentPhase, e);
+ nextPhase = Phase.CLOSING;
+ }
+
+ private void failIO(TTransportException e) {
+ StringBuilder errorMsg = new StringBuilder("IO failure ")
+ .append(e.getType())
+ .append(" in ")
+ .append(currentPhase);
+ if (e.getMessage() != null) {
+ errorMsg.append(": ").append(e.getMessage());
+ }
+ LOGGER.error(errorMsg.toString(), e);
+ nextPhase = Phase.CLOSING;
+ }
+
+ // Read handlings
+
+ private void handleInitializing() {
+ try {
+ saslResponse.read(underlyingTransport);
+ if (saslResponse.isComplete()) {
+ SaslNegotiationHeaderReader startHeader = saslResponse.getHeader();
+ if (startHeader.getStatus() != NegotiationStatus.START) {
+ throw new TInvalidSaslFrameException("Expecting START status but got " + startHeader.getStatus());
+ }
+ String mechanism = new String(saslResponse.getPayload(), StandardCharsets.UTF_8);
+ saslPeer = saslServerFactory.getSaslPeer(mechanism);
+ saslResponse.clear();
+ nextPhase = Phase.READING_SASL_RESPONSE;
+ }
+ } catch (TSaslNegotiationException e) {
+ failSaslNegotiation(e);
+ } catch (TTransportException e) {
+ failIO(e);
+ }
+ }
+
+ private void handleReadingSaslResponse() {
+ try {
+ saslResponse.read(underlyingTransport);
+ if (saslResponse.isComplete()) {
+ nextPhase = Phase.EVALUATING_SASL_RESPONSE;
+ }
+ } catch (TSaslNegotiationException e) {
+ failSaslNegotiation(e);
+ } catch (TTransportException e) {
+ failIO(e);
+ }
+ }
+
+ private void handleReadingRequest() {
+ try {
+ requestReader.read(underlyingTransport);
+ if (requestReader.isComplete()) {
+ nextPhase = Phase.PROCESSING;
+ }
+ } catch (TTransportException e) {
+ failIO(e);
+ }
+ }
+
+ // Computation executions
+
+ private void executeEvaluatingSaslResponse() {
+ if (!(saslResponse.getHeader().getStatus() == OK || saslResponse.getHeader().getStatus() == COMPLETE)) {
+ String error = "Expect status OK or COMPLETE, but got " + saslResponse.getHeader().getStatus();
+ failSaslNegotiation(new TSaslNegotiationException(ErrorType.PROTOCOL_ERROR, error));
+ return;
+ }
+ try {
+ byte[] response = saslResponse.getPayload();
+ saslResponse.clear();
+ byte[] newChallenge = saslPeer.evaluate(response);
+ if (saslPeer.isAuthenticated()) {
+ dataProtected = saslPeer.isDataProtected();
+ saslChallenge.withHeaderAndPayload(new byte[]{COMPLETE.getValue()}, newChallenge);
+ nextPhase = Phase.WRITING_SUCCESS_MESSAGE;
+ } else {
+ saslChallenge.withHeaderAndPayload(new byte[]{OK.getValue()}, newChallenge);
+ nextPhase = Phase.WRITING_SASL_CHALLENGE;
+ }
+ } catch (TSaslNegotiationException e) {
+ failSaslNegotiation(e);
+ }
+ }
+
+ private void executeProcessing() {
+ try {
+ byte[] inputPayload = requestReader.getPayload();
+ requestReader.clear();
+ byte[] rawInput = dataProtected ? saslPeer.unwrap(inputPayload) : inputPayload;
+ TMemoryTransport memoryTransport = new TMemoryTransport(rawInput);
+ TProtocol requestProtocol = inputProtocolFactory.getProtocol(memoryTransport);
+ TProtocol responseProtocol = outputProtocolFactory.getProtocol(memoryTransport);
+
+ if (eventHandler != null) {
+ if (!serverContextCreated) {
+ serverContext = eventHandler.createContext(requestProtocol, responseProtocol);
+ serverContextCreated = true;
+ }
+ eventHandler.processContext(serverContext, memoryTransport, memoryTransport);
+ }
+
+ TProcessor processor = processorFactory.getProcessor(this);
+ processor.process(requestProtocol, responseProtocol);
+ TByteArrayOutputStream rawOutput = memoryTransport.getOutput();
+ if (rawOutput.len() == 0) {
+ // This is a oneway request, no response to send back. Waiting for next incoming request.
+ nextPhase = Phase.READING_REQUEST;
+ return;
+ }
+ if (dataProtected) {
+ byte[] outputPayload = saslPeer.wrap(rawOutput.get(), 0, rawOutput.len());
+ responseWriter.withOnlyPayload(outputPayload);
+ } else {
+ responseWriter.withOnlyPayload(rawOutput.get(), 0 ,rawOutput.len());
+ }
+ nextPhase = Phase.WRITING_RESPONSE;
+ } catch (TTransportException e) {
+ failIO(e);
+ } catch (Exception e) {
+ fail(e);
+ }
+ }
+
+ // Write handlings
+
+ private void handleWritingSaslChallenge() {
+ try {
+ saslChallenge.write(underlyingTransport);
+ if (saslChallenge.isComplete()) {
+ saslChallenge.clear();
+ nextPhase = Phase.READING_SASL_RESPONSE;
+ }
+ } catch (IOException e) {
+ fail(e);
+ }
+ }
+
+ private void handleWritingSuccessMessage() {
+ try {
+ saslChallenge.write(underlyingTransport);
+ if (saslChallenge.isComplete()) {
+ LOGGER.debug("Authentication is done.");
+ saslChallenge = null;
+ saslResponse = null;
+ nextPhase = Phase.READING_REQUEST;
+ }
+ } catch (IOException e) {
+ fail(e);
+ }
+ }
+
+ private void handleWritingFailureMessage() {
+ try {
+ saslChallenge.write(underlyingTransport);
+ if (saslChallenge.isComplete()) {
+ nextPhase = Phase.CLOSING;
+ }
+ } catch (IOException e) {
+ fail(e);
+ }
+ }
+
+ private void handleWritingResponse() {
+ try {
+ responseWriter.write(underlyingTransport);
+ if (responseWriter.isComplete()) {
+ responseWriter.clear();
+ nextPhase = Phase.READING_REQUEST;
+ }
+ } catch (IOException e) {
+ fail(e);
+ }
+ }
+
+ /**
+ * Release all the resources managed by this state machine (connection, selection and sasl server).
+ * To avoid being blocked, this should be invoked in the network thread that manages the selector.
+ */
+ public void close() {
+ underlyingTransport.close();
+ selectionKey.cancel();
+ if (saslPeer != null) {
+ saslPeer.dispose();
+ }
+ if (serverContextCreated) {
+ eventHandler.deleteContext(serverContext,
+ inputProtocolFactory.getProtocol(underlyingTransport),
+ outputProtocolFactory.getProtocol(underlyingTransport));
+ }
+ nextPhase = Phase.CLOSED;
+ currentPhase = Phase.CLOSED;
+ LOGGER.trace("Connection closed: {}", underlyingTransport);
+ }
+
+ public enum Phase {
+ INITIIALIIZING(INTEREST_READ) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleInitializing();
+ }
+ },
+ READING_SASL_RESPONSE(INTEREST_READ) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleReadingSaslResponse();
+ }
+ },
+ EVALUATING_SASL_RESPONSE(INTEREST_NONE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.executeEvaluatingSaslResponse();
+ }
+ },
+ WRITING_SASL_CHALLENGE(INTEREST_WRITE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleWritingSaslChallenge();
+ }
+ },
+ WRITING_SUCCESS_MESSAGE(INTEREST_WRITE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleWritingSuccessMessage();
+ }
+ },
+ WRITING_FAILURE_MESSAGE(INTEREST_WRITE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleWritingFailureMessage();
+ }
+ },
+ READING_REQUEST(INTEREST_READ) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleReadingRequest();
+ }
+ },
+ PROCESSING(INTEREST_NONE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.executeProcessing();
+ }
+ },
+ WRITING_RESPONSE(INTEREST_WRITE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.handleWritingResponse();
+ }
+ },
+ CLOSING(INTEREST_NONE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ statemachine.close();
+ }
+ },
+ CLOSED(INTEREST_NONE) {
+ @Override
+ void unsafeRun(NonblockingSaslHandler statemachine) {
+ // Do nothing.
+ }
+ }
+ ;
+
+ // The interest on the selection key during the phase
+ private int selectionInterest;
+
+ Phase(int selectionInterest) {
+ this.selectionInterest = selectionInterest;
+ }
+
+ /**
+ * Provide the execution to run for the state machine in current phase. The execution should
+ * return the next phase after running on the state machine.
+ *
+ * @param statemachine The state machine to run.
+ * @throws IllegalArgumentException if the state machine's current phase is different.
+ * @throws IllegalStateException if the state machine' current phase is already done.
+ */
+ void runStateMachine(NonblockingSaslHandler statemachine) {
+ if (statemachine.currentPhase != this) {
+ throw new IllegalArgumentException("State machine is " + statemachine.currentPhase +
+ " but is expected to be " + this);
+ }
+ if (statemachine.isCurrentPhaseDone()) {
+ throw new IllegalStateException("State machine should step into " + statemachine.nextPhase);
+ }
+ unsafeRun(statemachine);
+ }
+
+ // Run the state machine without checkiing its own phase
+ // It should not be called direcly by users.
+ abstract void unsafeRun(NonblockingSaslHandler statemachine);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameReader.java b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameReader.java
new file mode 100644
index 0000000..01c1728
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameReader.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+/**
+ * Read frames for sasl negotiatiions.
+ */
+public class SaslNegotiationFrameReader extends FrameReader<SaslNegotiationHeaderReader> {
+
+ public SaslNegotiationFrameReader() {
+ super(new SaslNegotiationHeaderReader());
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameWriter.java b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameWriter.java
new file mode 100644
index 0000000..1e9ad15
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationFrameWriter.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.EncodingUtils;
+import org.apache.thrift.utils.StringUtils;
+
+import static org.apache.thrift.transport.sasl.SaslNegotiationHeaderReader.PAYLOAD_LENGTH_BYTES;
+import static org.apache.thrift.transport.sasl.SaslNegotiationHeaderReader.STATUS_BYTES;
+
+/**
+ * Writer for sasl negotiation frames. It expect a status byte as header with a payload to be
+ * written out (any header whose size is not equal to 1 would be considered as error).
+ */
+public class SaslNegotiationFrameWriter extends FrameWriter {
+
+ public static final int HEADER_BYTES = STATUS_BYTES + PAYLOAD_LENGTH_BYTES;
+
+ @Override
+ public void withOnlyPayload(byte[] payload, int offset, int length) {
+ throw new UnsupportedOperationException("Status byte is expected for sasl frame header.");
+ }
+
+ @Override
+ protected ByteBuffer buildFrame(byte[] header, int headerOffset, int headerLength,
+ byte[] payload, int payloadOffset, int payloadLength) {
+ if (header == null || headerLength != STATUS_BYTES) {
+ throw new IllegalArgumentException("Header " + StringUtils.bytesToHexString(header) +
+ " does not have expected length " + STATUS_BYTES);
+ }
+ byte[] bytes = new byte[HEADER_BYTES + payloadLength];
+ System.arraycopy(header, headerOffset, bytes, 0, STATUS_BYTES);
+ EncodingUtils.encodeBigEndian(payloadLength, bytes, STATUS_BYTES);
+ System.arraycopy(payload, payloadOffset, bytes, HEADER_BYTES, payloadLength);
+ return ByteBuffer.wrap(bytes);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationHeaderReader.java b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationHeaderReader.java
new file mode 100644
index 0000000..2d76ddb
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/SaslNegotiationHeaderReader.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.PROTOCOL_ERROR;
+
+/**
+ * Header for sasl negotiation frames. It contains status byte of negotiation and a 4-byte integer
+ * (payload size).
+ */
+public class SaslNegotiationHeaderReader extends FixedSizeHeaderReader {
+ public static final int STATUS_BYTES = 1;
+ public static final int PAYLOAD_LENGTH_BYTES = 4;
+
+ private NegotiationStatus negotiationStatus;
+ private int payloadSize;
+
+ @Override
+ protected int headerSize() {
+ return STATUS_BYTES + PAYLOAD_LENGTH_BYTES;
+ }
+
+ @Override
+ protected void onComplete() throws TSaslNegotiationException {
+ negotiationStatus = NegotiationStatus.byValue(byteBuffer.get(0));
+ payloadSize = byteBuffer.getInt(1);
+ if (payloadSize < 0) {
+ throw new TSaslNegotiationException(PROTOCOL_ERROR, "Payload size is negative: " + payloadSize);
+ }
+ }
+
+ @Override
+ public int payloadSize() {
+ return payloadSize;
+ }
+
+ public NegotiationStatus getStatus() {
+ return negotiationStatus;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/SaslPeer.java b/lib/java/src/org/apache/thrift/transport/sasl/SaslPeer.java
new file mode 100644
index 0000000..8f81380
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/SaslPeer.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TTransportException;
+
+/**
+ * A peer in a sasl negotiation.
+ */
+public interface SaslPeer {
+
+ /**
+ * Evaluate and validate the negotiation message (response/challenge) received from peer.
+ *
+ * @param negotiationMessage response/challenge received from peer.
+ * @return new response/challenge to send to peer, can be null if authentication becomes success.
+ * @throws TSaslNegotiationException if sasl authentication fails.
+ */
+ byte[] evaluate(byte[] negotiationMessage) throws TSaslNegotiationException;
+
+ /**
+ * @return true if authentication is done.
+ */
+ boolean isAuthenticated();
+
+ /**
+ * This method can only be called when the negotiation is complete (isAuthenticated returns true).
+ * Otherwise it will throw IllegalStateExceptiion.
+ *
+ * @return if the qop requires some integrity/confidential protection.
+ * @throws IllegalStateException if negotiation is not yet complete.
+ */
+ boolean isDataProtected();
+
+ /**
+ * Wrap raw bytes to protect it.
+ *
+ * @param data raw bytes.
+ * @param offset the start position of the content to wrap.
+ * @param length the length of the content to wrap.
+ * @return bytes with protection to send to peer.
+ * @throws TTransportException if failure.
+ */
+ byte[] wrap(byte[] data, int offset, int length) throws TTransportException;
+
+ /**
+ * Wrap the whole byte array.
+ *
+ * @param data raw bytes.
+ * @return wrapped bytes.
+ * @throws TTransportException if failure.
+ */
+ default byte[] wrap(byte[] data) throws TTransportException {
+ return wrap(data, 0, data.length);
+ }
+
+ /**
+ * Unwrap protected data to raw bytes.
+ *
+ * @param data protected data received from peer.
+ * @param offset the start position of the content to unwrap.
+ * @param length the length of the content to unwrap.
+ * @return raw bytes.
+ * @throws TTransportException if failed.
+ */
+ byte[] unwrap(byte[] data, int offset, int length) throws TTransportException;
+
+ /**
+ * Unwrap the whole byte array.
+ *
+ * @param data wrapped bytes.
+ * @return raw bytes.
+ * @throws TTransportException if failure.
+ */
+ default byte[] unwrap(byte[] data) throws TTransportException {
+ return unwrap(data, 0, data.length);
+ }
+
+ /**
+ * Close this peer and release resources.
+ */
+ void dispose();
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/ServerSaslPeer.java b/lib/java/src/org/apache/thrift/transport/sasl/ServerSaslPeer.java
new file mode 100644
index 0000000..31992e5
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/ServerSaslPeer.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import org.apache.thrift.transport.TTransportException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.AUTHENTICATION_FAILURE;
+
+/**
+ * Server side sasl peer, a wrapper around SaslServer to provide some handy methods.
+ */
+public class ServerSaslPeer implements SaslPeer {
+ private static final Logger LOGGER = LoggerFactory.getLogger(ServerSaslPeer.class);
+
+ private static final String QOP_AUTH_INT = "auth-int";
+ private static final String QOP_AUTH_CONF = "auth-conf";
+
+ private final SaslServer saslServer;
+
+ public ServerSaslPeer(SaslServer saslServer) {
+ this.saslServer = saslServer;
+ }
+
+ @Override
+ public byte[] evaluate(byte[] negotiationMessage) throws TSaslNegotiationException {
+ try {
+ return saslServer.evaluateResponse(negotiationMessage);
+ } catch (SaslException e) {
+ throw new TSaslNegotiationException(AUTHENTICATION_FAILURE,
+ "Authentication failed with " + saslServer.getMechanismName(), e);
+ }
+ }
+
+ @Override
+ public boolean isAuthenticated() {
+ return saslServer.isComplete();
+ }
+
+ @Override
+ public boolean isDataProtected() {
+ Object qop = saslServer.getNegotiatedProperty(Sasl.QOP);
+ if (qop == null) {
+ return false;
+ }
+ for (String word : qop.toString().split("\\s*,\\s*")) {
+ String lowerCaseWord = word.toLowerCase();
+ if (QOP_AUTH_INT.equals(lowerCaseWord) || QOP_AUTH_CONF.equals(lowerCaseWord)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public byte[] wrap(byte[] data, int offset, int length) throws TTransportException {
+ try {
+ return saslServer.wrap(data, offset, length);
+ } catch (SaslException e) {
+ throw new TTransportException("Failed to wrap data", e);
+ }
+ }
+
+ @Override
+ public byte[] unwrap(byte[] data, int offset, int length) throws TTransportException {
+ try {
+ return saslServer.unwrap(data, offset, length);
+ } catch (SaslException e) {
+ throw new TTransportException(TTransportException.CORRUPTED_DATA, "Failed to unwrap data", e);
+ }
+ }
+
+ @Override
+ public void dispose() {
+ try {
+ saslServer.dispose();
+ } catch (Exception e) {
+ LOGGER.warn("Failed to close sasl server " + saslServer.getMechanismName(), e);
+ }
+ }
+
+ SaslServer getSaslServer() {
+ return saslServer;
+ }
+
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TBaseSaslProcessorFactory.java b/lib/java/src/org/apache/thrift/transport/sasl/TBaseSaslProcessorFactory.java
new file mode 100644
index 0000000..c08884c
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TBaseSaslProcessorFactory.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.TProcessor;
+
+public class TBaseSaslProcessorFactory implements TSaslProcessorFactory {
+
+ private final TProcessor processor;
+
+ public TBaseSaslProcessorFactory(TProcessor processor) {
+ this.processor = processor;
+ }
+
+ @Override
+ public TProcessor getProcessor(NonblockingSaslHandler saslHandler) {
+ return processor;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TInvalidSaslFrameException.java b/lib/java/src/org/apache/thrift/transport/sasl/TInvalidSaslFrameException.java
new file mode 100644
index 0000000..ff57ea5
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TInvalidSaslFrameException.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+/**
+ * Got an invalid frame that does not respect the thrift sasl protocol.
+ */
+public class TInvalidSaslFrameException extends TSaslNegotiationException {
+
+ public TInvalidSaslFrameException(String message) {
+ super(ErrorType.PROTOCOL_ERROR, message);
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TSaslNegotiationException.java b/lib/java/src/org/apache/thrift/transport/sasl/TSaslNegotiationException.java
new file mode 100644
index 0000000..9b1fa06
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TSaslNegotiationException.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TTransportException;
+
+/**
+ * Exception for sasl negotiation errors.
+ */
+public class TSaslNegotiationException extends TTransportException {
+
+ private final ErrorType error;
+
+ public TSaslNegotiationException(ErrorType error, String summary) {
+ super(summary);
+ this.error = error;
+ }
+
+ public TSaslNegotiationException(ErrorType error, String summary, Throwable cause) {
+ super(summary, cause);
+ this.error = error;
+ }
+
+ public ErrorType getErrorType() {
+ return error;
+ }
+
+ /**
+ * @return Errory type plus the message.
+ */
+ public String getSummary() {
+ return error.name() + ": " + getMessage();
+ }
+
+ /**
+ * @return Summary and eventually the cause's message.
+ */
+ public String getDetails() {
+ return getCause() == null ? getSummary() : getSummary() + "\nReason: " + getCause().getMessage();
+ }
+
+ public enum ErrorType {
+ // Unexpected system internal error during negotiation (e.g. sasl initialization failure)
+ INTERNAL_ERROR(NegotiationStatus.ERROR),
+ // Cannot read correct sasl frames from the connection => Send "ERROR" status byte to peer
+ PROTOCOL_ERROR(NegotiationStatus.ERROR),
+ // Peer is using unsupported sasl mechanisms => Send "BAD" status byte to peer
+ MECHANISME_MISMATCH(NegotiationStatus.BAD),
+ // Sasl authentication failure => Send "BAD" status byte to peer
+ AUTHENTICATION_FAILURE(NegotiationStatus.BAD),
+ ;
+
+ public final NegotiationStatus code;
+
+ ErrorType(NegotiationStatus code) {
+ this.code = code;
+ }
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TSaslProcessorFactory.java b/lib/java/src/org/apache/thrift/transport/sasl/TSaslProcessorFactory.java
new file mode 100644
index 0000000..877d049
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TSaslProcessorFactory.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.TException;
+import org.apache.thrift.TProcessor;
+
+/**
+ * Get processor for a given state machine, so that users can customize the behavior of a TProcessor
+ * by interacting with the state machine.
+ */
+public interface TSaslProcessorFactory {
+
+ TProcessor getProcessor(NonblockingSaslHandler saslHandler) throws TException;
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerDefinition.java b/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerDefinition.java
new file mode 100644
index 0000000..5486641
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerDefinition.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import javax.security.auth.callback.CallbackHandler;
+import java.util.Map;
+
+/**
+ * Contains all the parameters used to define a SASL server implementation.
+ */
+public class TSaslServerDefinition {
+ public final String mechanism;
+ public final String protocol;
+ public final String serverName;
+ public final Map<String, String> props;
+ public final CallbackHandler cbh;
+
+ public TSaslServerDefinition(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh) {
+ this.mechanism = mechanism;
+ this.protocol = protocol;
+ this.serverName = serverName;
+ this.props = props;
+ this.cbh = cbh;
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerFactory.java b/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerFactory.java
new file mode 100644
index 0000000..06cf534
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/transport/sasl/TSaslServerFactory.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import javax.security.auth.callback.CallbackHandler;
+import javax.security.sasl.Sasl;
+import javax.security.sasl.SaslException;
+import javax.security.sasl.SaslServer;
+
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.MECHANISME_MISMATCH;
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.PROTOCOL_ERROR;
+
+/**
+ * Factory to create sasl server. Users can extend this class to customize the SaslServer creation.
+ */
+public class TSaslServerFactory {
+
+ private final Map<String, TSaslServerDefinition> saslMechanisms;
+
+ public TSaslServerFactory() {
+ this.saslMechanisms = new HashMap<>();
+ }
+
+ public void addSaslMechanism(String mechanism, String protocol, String serverName,
+ Map<String, String> props, CallbackHandler cbh) {
+ TSaslServerDefinition definition = new TSaslServerDefinition(mechanism, protocol, serverName,
+ props, cbh);
+ saslMechanisms.put(definition.mechanism, definition);
+ }
+
+ public ServerSaslPeer getSaslPeer(String mechanism) throws TSaslNegotiationException {
+ if (!saslMechanisms.containsKey(mechanism)) {
+ throw new TSaslNegotiationException(MECHANISME_MISMATCH, "Unsupported mechanism " + mechanism);
+ }
+ TSaslServerDefinition saslDef = saslMechanisms.get(mechanism);
+ try {
+ SaslServer saslServer = Sasl.createSaslServer(saslDef.mechanism, saslDef.protocol,
+ saslDef.serverName, saslDef.props, saslDef.cbh);
+ return new ServerSaslPeer(saslServer);
+ } catch (SaslException e) {
+ throw new TSaslNegotiationException(PROTOCOL_ERROR, "Fail to create sasl server " + mechanism, e);
+ }
+ }
+}
diff --git a/lib/java/src/org/apache/thrift/utils/StringUtils.java b/lib/java/src/org/apache/thrift/utils/StringUtils.java
new file mode 100644
index 0000000..15183a3
--- /dev/null
+++ b/lib/java/src/org/apache/thrift/utils/StringUtils.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.utils;
+
+public final class StringUtils {
+
+ private StringUtils() {
+ // Utility class.
+ }
+
+ private static final char[] HEX_CHARS = {'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'};
+
+ /**
+ * Stringify a byte array to the hex representation for each byte.
+ *
+ * @param bytes
+ * @return hex string.
+ */
+ public static String bytesToHexString(byte[] bytes) {
+ if (bytes == null) {
+ return null;
+ }
+ return bytesToHexString(bytes, 0, bytes.length);
+ }
+
+ /**
+ * Stringify a portion of the byte array.
+ *
+ * @param bytes byte array.
+ * @param offset portion start.
+ * @param length portion length.
+ * @return hex string.
+ */
+ public static String bytesToHexString(byte[] bytes, int offset, int length) {
+ if (length < 0) {
+ throw new IllegalArgumentException("Negative length " + length);
+ }
+ if (offset < 0) {
+ throw new IndexOutOfBoundsException("Negative start offset " + offset);
+ }
+ char[] chars = new char[length * 2];
+ for (int i = 0; i < length; i++) {
+ int unsignedInt = bytes[i + offset] & 0xFF;
+ chars[2 * i] = HEX_CHARS[unsignedInt >>> 4];
+ chars[2 * i + 1] = HEX_CHARS[unsignedInt & 0x0F];
+ }
+ return new String(chars);
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/server/TestSaslNonblockingServer.java b/lib/java/test/org/apache/thrift/server/TestSaslNonblockingServer.java
new file mode 100644
index 0000000..d0a6746
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/server/TestSaslNonblockingServer.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.server;
+
+import org.apache.thrift.TProcessor;
+import org.apache.thrift.protocol.TBinaryProtocol;
+import org.apache.thrift.protocol.TProtocolFactory;
+import org.apache.thrift.transport.TNonblockingServerSocket;
+import org.apache.thrift.transport.TNonblockingServerTransport;
+import org.apache.thrift.transport.TSaslClientTransport;
+import org.apache.thrift.transport.TSocket;
+import org.apache.thrift.transport.TTransportException;
+import org.apache.thrift.transport.TTransportFactory;
+import org.apache.thrift.transport.TestTSaslTransports;
+import org.apache.thrift.transport.TestTSaslTransports.TestSaslCallbackHandler;
+import org.apache.thrift.transport.sasl.TSaslNegotiationException;
+import thrift.test.ThriftTest;
+
+import static org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType.AUTHENTICATION_FAILURE;
+
+public class TestSaslNonblockingServer extends TestTSaslTransports.TestTSaslTransportsWithServer {
+
+ private TSaslNonblockingServer server;
+
+ @Override
+ public void startServer(TProcessor processor, TProtocolFactory protoFactory, TTransportFactory factory)
+ throws Exception {
+ TNonblockingServerTransport serverSocket = new TNonblockingServerSocket(
+ new TNonblockingServerSocket.NonblockingAbstractServerSocketArgs().port(PORT));
+ TSaslNonblockingServer.Args args = new TSaslNonblockingServer.Args(serverSocket)
+ .processor(processor)
+ .transportFactory(factory)
+ .protocolFactory(protoFactory)
+ .addSaslMechanism(TestTSaslTransports.WRAPPED_MECHANISM, TestTSaslTransports.SERVICE,
+ TestTSaslTransports.HOST, TestTSaslTransports.WRAPPED_PROPS,
+ new TestSaslCallbackHandler(TestTSaslTransports.PASSWORD));
+ server = new TSaslNonblockingServer(args);
+ server.serve();
+ }
+
+ @Override
+ public void stopServer() throws Exception {
+ server.shutdown();
+ }
+
+ @Override
+ public void testIt() throws Exception {
+ super.testIt();
+ }
+
+ public void testBadPassword() throws Exception {
+ TProtocolFactory protocolFactory = new TBinaryProtocol.Factory();
+ TProcessor processor = new ThriftTest.Processor<>(new TestHandler());
+ startServer(processor, protocolFactory);
+
+ TSocket socket = new TSocket(HOST, PORT);
+ socket.setTimeout(SOCKET_TIMEOUT);
+ TSaslClientTransport client = new TSaslClientTransport(TestTSaslTransports.WRAPPED_MECHANISM,
+ TestTSaslTransports.PRINCIPAL, TestTSaslTransports.SERVICE, TestTSaslTransports.HOST,
+ TestTSaslTransports.WRAPPED_PROPS, new TestSaslCallbackHandler("bad_password"), socket);
+ try {
+ client.open();
+ fail("Client should fail with sasl negotiation.");
+ } catch (TTransportException error) {
+ TSaslNegotiationException serverSideError = new TSaslNegotiationException(AUTHENTICATION_FAILURE,
+ "Authentication failed with " + TestTSaslTransports.WRAPPED_MECHANISM);
+ assertTrue("Server should return error message \"" + serverSideError.getSummary() + "\"",
+ error.getMessage().contains(serverSideError.getSummary()));
+ } finally {
+ stopServer();
+ client.close();
+ }
+ }
+
+ @Override
+ public void testTransportFactory() {
+ // This test is irrelevant here, so skipped.
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/TestNonblockingServerSocket.java b/lib/java/test/org/apache/thrift/transport/TestNonblockingServerSocket.java
new file mode 100644
index 0000000..6b28dfd
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/TestNonblockingServerSocket.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.channels.ServerSocketChannel;
+
+public class TestNonblockingServerSocket {
+
+ @Test
+ public void testSocketChannelBlockingMode() throws TTransportException {
+ try (TNonblockingServerSocket nonblockingServer = new TNonblockingServerSocket(0)){
+ ServerSocketChannel socketChannel = nonblockingServer.getServerSocketChannel();
+ Assert.assertFalse("Socket channel should be nonblocking", socketChannel.isBlocking());
+ }
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/TestTIOStreamTransport.java b/lib/java/test/org/apache/thrift/transport/TestTIOStreamTransport.java
new file mode 100644
index 0000000..5965446
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/TestTIOStreamTransport.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.InputStream;
+import java.io.OutputStream;
+
+import junit.framework.TestCase;
+
+public class TestTIOStreamTransport extends TestCase {
+
+ // THRIFT-5022
+ public void testOpenClose_2streams() throws TTransportException {
+ byte[] dummy = {20}; // So the input stream isn't EOF immediately.
+ InputStream input = new ByteArrayInputStream(dummy);
+ OutputStream output = new ByteArrayOutputStream();
+ TTransport transport = new TIOStreamTransport(input, output);
+ runOpenClose(transport);
+ }
+
+ // THRIFT-5022
+ public void testOpenClose_1input() throws TTransportException {
+ byte[] dummy = {20};
+ InputStream input = new ByteArrayInputStream(dummy);
+ TTransport transport = new TIOStreamTransport(input);
+ runOpenClose(transport);
+ }
+
+ // THRIFT-5022
+ public void testIOpenClose_1output() throws TTransportException {
+ OutputStream output = new ByteArrayOutputStream();
+ TTransport transport = new TIOStreamTransport(output);
+ runOpenClose(transport);
+ }
+
+ private void runOpenClose(TTransport transport) throws TTransportException {
+ transport.open();
+ boolean b1 = transport.isOpen();
+ assertTrue(b1);
+ transport.close();
+ boolean b2 = transport.isOpen();
+ assertFalse(b2);
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/TestTMemoryTransport.java b/lib/java/test/org/apache/thrift/transport/TestTMemoryTransport.java
new file mode 100644
index 0000000..2e20ffe
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/TestTMemoryTransport.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport;
+
+import org.apache.thrift.TByteArrayOutputStream;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+import java.util.Random;
+
+public class TestTMemoryTransport {
+
+ @Test
+ public void testReadBatches() throws TTransportException {
+ byte[] inputBytes = {0x10, 0x7A, (byte) 0xBF, (byte) 0xFE, 0x53, (byte) 0x82, (byte) 0xFF};
+ TMemoryTransport transport = new TMemoryTransport(inputBytes);
+ byte[] read = new byte[inputBytes.length];
+ int firstBatch = new Random().nextInt(inputBytes.length);
+ int secondBatch = inputBytes.length - firstBatch;
+ transport.read(read, 0, firstBatch);
+ transport.read(read, firstBatch, secondBatch);
+ boolean equal = true;
+ for (int i = 0; i < inputBytes.length; i++) {
+ equal = equal && inputBytes[i] == read[i];
+ }
+ Assert.assertEquals(ByteBuffer.wrap(inputBytes), ByteBuffer.wrap(read));
+ }
+
+ @Test (expected = TTransportException.class)
+ public void testReadMoreThanRemaining() throws TTransportException {
+ TMemoryTransport transport = new TMemoryTransport(new byte[] {0x00, 0x32});
+ byte[] read = new byte[3];
+ transport.read(read, 0, 3);
+ }
+
+ @Test
+ public void testWrite() throws TTransportException {
+ TMemoryTransport transport = new TMemoryTransport(new byte[0]);
+ byte[] output1 = {0x72, 0x56, 0x29, (byte) 0xAF, (byte) 0x9B};
+ transport.write(output1);
+ byte[] output2 = {(byte) 0x83, 0x10, 0x00};
+ transport.write(output2, 0, 2);
+ byte[] expected = {0x72, 0x56, 0x29, (byte) 0xAF, (byte) 0x9B, (byte) 0x83, 0x10};
+ TByteArrayOutputStream outputByteArray = transport.getOutput();
+ Assert.assertEquals(ByteBuffer.wrap(expected), ByteBuffer.wrap(outputByteArray.get(), 0, outputByteArray.len()));
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
index 36a06e9..6eb38e7 100644
--- a/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
+++ b/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
@@ -53,17 +53,17 @@
private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class);
- private static final String HOST = "localhost";
- private static final String SERVICE = "thrift-test";
- private static final String PRINCIPAL = "thrift-test-principal";
- private static final String PASSWORD = "super secret password";
- private static final String REALM = "thrift-test-realm";
+ public static final String HOST = "localhost";
+ public static final String SERVICE = "thrift-test";
+ public static final String PRINCIPAL = "thrift-test-principal";
+ public static final String PASSWORD = "super secret password";
+ public static final String REALM = "thrift-test-realm";
- private static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
- private static final Map<String, String> UNWRAPPED_PROPS = null;
+ public static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
+ public static final Map<String, String> UNWRAPPED_PROPS = null;
- private static final String WRAPPED_MECHANISM = "DIGEST-MD5";
- private static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();
+ public static final String WRAPPED_MECHANISM = "DIGEST-MD5";
+ public static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();
static {
WRAPPED_PROPS.put(Sasl.QOP, "auth-int");
@@ -80,7 +80,7 @@
+ "'We hold these truths to be self-evident, that all men are created equal.'";
- private static class TestSaslCallbackHandler implements CallbackHandler {
+ public static class TestSaslCallbackHandler implements CallbackHandler {
private final String password;
public TestSaslCallbackHandler(String password) {
@@ -265,7 +265,7 @@
new TestTSaslTransportsWithServer().testIt();
}
- private static class TestTSaslTransportsWithServer extends ServerTestBase {
+ public static class TestTSaslTransportsWithServer extends ServerTestBase {
private Thread serverThread;
private TServer server;
diff --git a/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameReader.java b/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameReader.java
new file mode 100644
index 0000000..9ae0e1e
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameReader.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TMemoryInputTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+
+public class TestDataFrameReader {
+
+ @Test
+ public void testRead() throws TTransportException {
+ // Prepare data
+ int payloadSize = 23;
+ ByteBuffer buffer = ByteBuffer.allocate(DataFrameHeaderReader.PAYLOAD_LENGTH_BYTES + payloadSize);
+ buffer.putInt(payloadSize);
+ for (int i = 0; i < payloadSize; i++) {
+ buffer.put((byte) i);
+ }
+ buffer.rewind();
+
+ TMemoryInputTransport transport = new TMemoryInputTransport();
+ DataFrameReader dataFrameReader = new DataFrameReader();
+ // No bytes received.
+ dataFrameReader.read(transport);
+ Assert.assertFalse("No bytes received", dataFrameReader.isComplete());
+ Assert.assertFalse("No bytes received", dataFrameReader.getHeader().isComplete());
+ // Payload size (header) and part of the payload are received.
+ transport.reset(buffer.array(), 0, 6);
+ dataFrameReader.read(transport);
+ Assert.assertFalse("Only header is complete", dataFrameReader.isComplete());
+ Assert.assertTrue("Header should be complete", dataFrameReader.getHeader().isComplete());
+ Assert.assertEquals("Payload size should be " + payloadSize, payloadSize, dataFrameReader.getHeader().payloadSize());
+ // Read the rest of payload.
+ transport.reset(buffer.array(), 6, 21);
+ dataFrameReader.read(transport);
+ Assert.assertTrue("Reader should be complete", dataFrameReader.isComplete());
+ buffer.position(DataFrameHeaderReader.PAYLOAD_LENGTH_BYTES);
+ Assert.assertEquals("Payload should be the same as from the transport", buffer, ByteBuffer.wrap(dataFrameReader.getPayload()));
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameWriter.java b/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameWriter.java
new file mode 100644
index 0000000..d242593
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/sasl/TestDataFrameWriter.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.EncodingUtils;
+import org.apache.thrift.transport.TNonblockingTransport;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.apache.thrift.transport.sasl.DataFrameHeaderReader.PAYLOAD_LENGTH_BYTES;
+
+public class TestDataFrameWriter {
+
+ private static final byte[] BYTES = new byte[]{0x32, 0x2A, (byte) 0xE1, 0x18, (byte) 0x90, 0x75};
+
+ @Test
+ public void testProvideEntireByteArrayAsPayload() {
+ DataFrameWriter frameWriter = new DataFrameWriter();
+ frameWriter.withOnlyPayload(BYTES);
+ byte[] expectedBytes = new byte[BYTES.length + PAYLOAD_LENGTH_BYTES];
+ EncodingUtils.encodeBigEndian(BYTES.length, expectedBytes);
+ System.arraycopy(BYTES, 0, expectedBytes, PAYLOAD_LENGTH_BYTES, BYTES.length);
+ Assert.assertEquals(ByteBuffer.wrap(expectedBytes), frameWriter.frameBytes);
+ }
+
+ @Test
+ public void testProvideByteArrayPortionAsPayload() {
+ DataFrameWriter frameWriter = new DataFrameWriter();
+ int portionOffset = 2;
+ int portionLength = 3;
+ frameWriter.withOnlyPayload(BYTES, portionOffset, portionLength);
+ byte[] expectedBytes = new byte[portionLength + PAYLOAD_LENGTH_BYTES];
+ EncodingUtils.encodeBigEndian(portionLength, expectedBytes);
+ System.arraycopy(BYTES, portionOffset, expectedBytes, PAYLOAD_LENGTH_BYTES, portionLength);
+ Assert.assertEquals(ByteBuffer.wrap(expectedBytes), frameWriter.frameBytes);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testProvideHeaderAndPayload() {
+ DataFrameWriter frameWriter = new DataFrameWriter();
+ frameWriter.withHeaderAndPayload(new byte[1], new byte[1]);
+ }
+
+ @Test(expected = IllegalStateException.class)
+ public void testProvidePayloadToIncompleteFrame() {
+ DataFrameWriter frameWriter = new DataFrameWriter();
+ frameWriter.withOnlyPayload(BYTES);
+ frameWriter.withOnlyPayload(new byte[1]);
+ }
+
+ @Test
+ public void testWrite() throws IOException {
+ DataFrameWriter frameWriter = new DataFrameWriter();
+ frameWriter.withOnlyPayload(BYTES);
+ // Slow socket which writes one byte per call.
+ TNonblockingTransport transport = Mockito.mock(TNonblockingTransport.class);
+ SlowWriting slowWriting = new SlowWriting();
+ Mockito.when(transport.write(frameWriter.frameBytes)).thenAnswer(slowWriting);
+ frameWriter.write(transport);
+ while (slowWriting.written < frameWriter.frameBytes.limit()) {
+ Assert.assertFalse("Frame writer should not be complete", frameWriter.isComplete());
+ frameWriter.write(transport);
+ }
+ Assert.assertTrue("Frame writer should be complete", frameWriter.isComplete());
+ }
+
+ private static class SlowWriting implements Answer<Integer> {
+ int written = 0;
+
+ @Override
+ public Integer answer(InvocationOnMock invocation) throws Throwable {
+ ByteBuffer bytes = (ByteBuffer) invocation.getArguments()[0];
+ bytes.get();
+ written ++;
+ return 1;
+ }
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameReader.java b/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameReader.java
new file mode 100644
index 0000000..f2abbe6
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameReader.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import org.apache.thrift.transport.TMemoryInputTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.nio.ByteBuffer;
+
+public class TestSaslNegotiationFrameReader {
+
+ @Test
+ public void testRead() throws TTransportException {
+ TMemoryInputTransport transport = new TMemoryInputTransport();
+ SaslNegotiationFrameReader negotiationReader = new SaslNegotiationFrameReader();
+ // No bytes received
+ negotiationReader.read(transport);
+ Assert.assertFalse("No bytes received", negotiationReader.isComplete());
+ Assert.assertFalse("No bytes received", negotiationReader.getHeader().isComplete());
+ // Read header
+ ByteBuffer buffer = ByteBuffer.allocate(5);
+ buffer.put(0, NegotiationStatus.OK.getValue());
+ buffer.putInt(1, 10);
+ transport.reset(buffer.array());
+ negotiationReader.read(transport);
+ Assert.assertFalse("Only header is complete", negotiationReader.isComplete());
+ Assert.assertTrue("Header should be complete", negotiationReader.getHeader().isComplete());
+ Assert.assertEquals("Payload size should be 10", 10, negotiationReader.getHeader().payloadSize());
+ // Read payload
+ transport.reset(new byte[20]);
+ negotiationReader.read(transport);
+ Assert.assertTrue("Reader should be complete", negotiationReader.isComplete());
+ Assert.assertEquals("Payload length should be 10", 10, negotiationReader.getPayload().length);
+ }
+
+ @Test (expected = TSaslNegotiationException.class)
+ public void testReadInvalidNegotiationStatus() throws TTransportException {
+ byte[] bytes = new byte[5];
+ // Invalid status byte.
+ bytes[0] = -1;
+ TMemoryInputTransport transport = new TMemoryInputTransport(bytes);
+ SaslNegotiationFrameReader negotiationReader = new SaslNegotiationFrameReader();
+ negotiationReader.read(transport);
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameWriter.java b/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameWriter.java
new file mode 100644
index 0000000..ce7ff29
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/transport/sasl/TestSaslNegotiationFrameWriter.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.transport.sasl;
+
+import java.nio.ByteBuffer;
+
+import org.apache.thrift.EncodingUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.apache.thrift.transport.sasl.SaslNegotiationFrameWriter.HEADER_BYTES;
+
+public class TestSaslNegotiationFrameWriter {
+
+ private static final byte[] PAYLOAD = {0x11, 0x08, 0x3F, 0x58, 0x73, 0x22, 0x00, (byte) 0xFF};
+
+ @Test
+ public void testWithHeaderAndPayload() {
+ SaslNegotiationFrameWriter frameWriter = new SaslNegotiationFrameWriter();
+ frameWriter.withHeaderAndPayload(new byte[] {NegotiationStatus.OK.getValue()}, PAYLOAD);
+ byte[] expectedBytes = new byte[HEADER_BYTES + PAYLOAD.length];
+ expectedBytes[0] = NegotiationStatus.OK.getValue();
+ EncodingUtils.encodeBigEndian(PAYLOAD.length, expectedBytes, 1);
+ System.arraycopy(PAYLOAD, 0, expectedBytes, HEADER_BYTES, PAYLOAD.length);
+ Assert.assertEquals(ByteBuffer.wrap(expectedBytes), frameWriter.frameBytes);
+ }
+
+ @Test(expected = IllegalArgumentException.class)
+ public void testWithInvalidHeaderLength() {
+ SaslNegotiationFrameWriter frameWriter = new SaslNegotiationFrameWriter();
+ frameWriter.withHeaderAndPayload(new byte[5], 0, 2, PAYLOAD, 0, 1);
+ }
+
+ @Test(expected = UnsupportedOperationException.class)
+ public void testWithOnlyPayload() {
+ SaslNegotiationFrameWriter frameWriter = new SaslNegotiationFrameWriter();
+ frameWriter.withOnlyPayload(new byte[0]);
+ }
+}
diff --git a/lib/java/test/org/apache/thrift/utils/TestStringUtils.java b/lib/java/test/org/apache/thrift/utils/TestStringUtils.java
new file mode 100644
index 0000000..3a8cf39
--- /dev/null
+++ b/lib/java/test/org/apache/thrift/utils/TestStringUtils.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.thrift.utils;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class TestStringUtils {
+
+ @Test
+ public void testToHexString() {
+ byte[] bytes = {0x00, 0x1A, (byte) 0xEF, (byte) 0xAB, (byte) 0x92};
+ Assert.assertEquals("001AEFAB92", StringUtils.bytesToHexString(bytes));
+ Assert.assertEquals("EFAB92", StringUtils.bytesToHexString(bytes, 2, 3));
+ Assert.assertNull(StringUtils.bytesToHexString(null));
+ }
+}
diff --git a/lib/js/package.json b/lib/js/package.json
index f578e4b..b75019d 100644
--- a/lib/js/package.json
+++ b/lib/js/package.json
@@ -2,6 +2,7 @@
"name": "thrift",
"version": "0.14.0",
"description": "Thrift is a software framework for scalable cross-language services development.",
+ "main": "./src/thrift",
"author": {
"name": "Apache Thrift Developers",
"email": "dev@thrift.apache.org"
diff --git a/lib/lua/TCompactProtocol.lua b/lib/lua/TCompactProtocol.lua
index 877595a..7b75967 100644
--- a/lib/lua/TCompactProtocol.lua
+++ b/lib/lua/TCompactProtocol.lua
@@ -124,8 +124,8 @@
end
function TCompactProtocol:writeStructEnd()
- self.lastFieldIndex = self.lastFieldIndex - 1
self.lastFieldId = self.lastField[self.lastFieldIndex]
+ self.lastFieldIndex = self.lastFieldIndex - 1
end
function TCompactProtocol:writeFieldBegin(name, ttype, id)
diff --git a/lib/netstd/README.md b/lib/netstd/README.md
index 88ba73a..d554e38 100644
--- a/lib/netstd/README.md
+++ b/lib/netstd/README.md
@@ -11,7 +11,7 @@
- Build with scripts
## How to build on Unix/Linux
-- Ensure you have .NET SDK >= 2.0 installed, or use the [Ubuntu docker image](../../build/docker/README.md)
+- Ensure you have .NET Core SDK 3.1 (LTS) installed, or use the [Ubuntu docker image](../../build/docker/README.md)
- Follow common automake build practice: `./ bootstrap && ./ configure && make`
## Known issues
diff --git a/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs b/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs
index b1f8418..b8df515 100644
--- a/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs
+++ b/lib/netstd/Tests/Thrift.IntegrationTests/Protocols/ProtocolsOperationsTests.cs
@@ -31,6 +31,7 @@
public class ProtocolsOperationsTests
{
private readonly CompareLogic _compareLogic = new CompareLogic();
+ private static readonly TConfiguration Configuration = null; // or new TConfiguration() if needed
[DataTestMethod]
[DataRow(typeof(TBinaryProtocol), TMessageType.Call)]
@@ -494,7 +495,7 @@
private static Tuple<Stream, TProtocol> GetProtocolInstance(Type protocolType)
{
var memoryStream = new MemoryStream();
- var streamClientTransport = new TStreamTransport(memoryStream, memoryStream);
+ var streamClientTransport = new TStreamTransport(memoryStream, memoryStream,Configuration);
var protocol = (TProtocol) Activator.CreateInstance(protocolType, streamClientTransport);
return new Tuple<Stream, TProtocol>(memoryStream, protocol);
}
diff --git a/lib/netstd/Tests/Thrift.IntegrationTests/Thrift.IntegrationTests.csproj b/lib/netstd/Tests/Thrift.IntegrationTests/Thrift.IntegrationTests.csproj
index 381d8aa..7c5639b 100644
--- a/lib/netstd/Tests/Thrift.IntegrationTests/Thrift.IntegrationTests.csproj
+++ b/lib/netstd/Tests/Thrift.IntegrationTests/Thrift.IntegrationTests.csproj
@@ -19,7 +19,7 @@
-->
<PropertyGroup>
- <TargetFramework>netcoreapp2.0</TargetFramework>
+ <TargetFramework>netcoreapp3.1</TargetFramework>
<AssemblyName>Thrift.IntegrationTests</AssemblyName>
<PackageId>Thrift.IntegrationTests</PackageId>
<OutputType>Exe</OutputType>
@@ -32,11 +32,11 @@
</PropertyGroup>
<ItemGroup>
- <PackageReference Include="CompareNETObjects" Version="4.58.0" />
- <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
- <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
- <PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
- <PackageReference Include="System.ServiceModel.Primitives" Version="4.5.3" />
+ <PackageReference Include="CompareNETObjects" Version="4.64.0" />
+ <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.4.0" />
+ <PackageReference Include="MSTest.TestAdapter" Version="2.0.0" />
+ <PackageReference Include="MSTest.TestFramework" Version="2.0.0" />
+ <PackageReference Include="System.ServiceModel.Primitives" Version="4.7.0" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\Thrift\Thrift.csproj" />
diff --git a/lib/netstd/Tests/Thrift.PublicInterfaces.Compile.Tests/Thrift.PublicInterfaces.Compile.Tests.csproj b/lib/netstd/Tests/Thrift.PublicInterfaces.Compile.Tests/Thrift.PublicInterfaces.Compile.Tests.csproj
index 58f61a2..d2db348 100644
--- a/lib/netstd/Tests/Thrift.PublicInterfaces.Compile.Tests/Thrift.PublicInterfaces.Compile.Tests.csproj
+++ b/lib/netstd/Tests/Thrift.PublicInterfaces.Compile.Tests/Thrift.PublicInterfaces.Compile.Tests.csproj
@@ -19,7 +19,7 @@
-->
<PropertyGroup>
- <TargetFramework>netcoreapp2.0</TargetFramework>
+ <TargetFramework>netcoreapp3.1</TargetFramework>
<AssemblyName>Thrift.PublicInterfaces.Compile.Tests</AssemblyName>
<PackageId>Thrift.PublicInterfaces.Compile.Tests</PackageId>
<GenerateAssemblyConfigurationAttribute>false</GenerateAssemblyConfigurationAttribute>
@@ -33,7 +33,7 @@
</ItemGroup>
<ItemGroup>
- <PackageReference Include="System.ServiceModel.Primitives" Version="4.5.3" />
+ <PackageReference Include="System.ServiceModel.Primitives" Version="4.7.0" />
</ItemGroup>
<Target Name="PreBuild" BeforeTargets="_GenerateRestoreProjectSpec;Restore;Compile">
diff --git a/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs b/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs
index 970ce7e..4054a29 100644
--- a/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs
+++ b/lib/netstd/Tests/Thrift.Tests/Protocols/TJsonProtocolTests.cs
@@ -21,7 +21,6 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.VisualStudio.TestTools.UnitTesting;
-using NSubstitute;
using Thrift.Protocol;
using Thrift.Protocol.Entities;
using Thrift.Transport;
@@ -36,7 +35,7 @@
[TestMethod]
public void TJSONProtocol_Can_Create_Instance_Test()
{
- var httpClientTransport = Substitute.For<THttpTransport>(new Uri("http://localhost"), null, null);
+ var httpClientTransport = new THttpTransport( new Uri("http://localhost"), null, null, null);
var result = new TJSONProtocolWrapper(httpClientTransport);
@@ -45,7 +44,7 @@
Assert.IsNotNull(result.WrappedReader);
Assert.IsNotNull(result.Transport);
Assert.IsTrue(result.WrappedRecursionDepth == 0);
- Assert.IsTrue(result.WrappedRecursionLimit == TProtocol.DefaultRecursionDepth);
+ Assert.IsTrue(result.WrappedRecursionLimit == TConfiguration.DEFAULT_RECURSION_DEPTH);
Assert.IsTrue(result.Transport.Equals(httpClientTransport));
Assert.IsTrue(result.WrappedContext.GetType().Name.Equals("JSONBaseContext", StringComparison.OrdinalIgnoreCase));
diff --git a/lib/netstd/Tests/Thrift.Tests/Thrift.Tests.csproj b/lib/netstd/Tests/Thrift.Tests/Thrift.Tests.csproj
index 434424d..20fdfe4 100644
--- a/lib/netstd/Tests/Thrift.Tests/Thrift.Tests.csproj
+++ b/lib/netstd/Tests/Thrift.Tests/Thrift.Tests.csproj
@@ -18,14 +18,14 @@
under the License.
-->
<PropertyGroup>
- <TargetFramework>netcoreapp2.0</TargetFramework>
+ <TargetFramework>netcoreapp3.1</TargetFramework>
</PropertyGroup>
<ItemGroup>
- <PackageReference Include="CompareNETObjects" Version="4.58.0" />
- <PackageReference Include="Microsoft.NET.Test.Sdk" Version="15.9.0" />
- <PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
- <PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
- <PackageReference Include="NSubstitute" Version="4.0.0" />
+ <PackageReference Include="CompareNETObjects" Version="4.64.0" />
+ <PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.4.0" />
+ <PackageReference Include="MSTest.TestAdapter" Version="2.0.0" />
+ <PackageReference Include="MSTest.TestFramework" Version="2.0.0" />
+ <PackageReference Include="NSubstitute" Version="4.2.1" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\..\Thrift\Thrift.csproj" />
diff --git a/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs b/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs
index 3f30d4a..a00c5c1 100644
--- a/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TBinaryProtocol.cs
@@ -16,6 +16,7 @@
// under the License.
using System;
+using System.Buffers.Binary;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
@@ -215,9 +216,7 @@
{
return;
}
-
- PreAllocatedBuffer[0] = (byte)(0xff & (i16 >> 8));
- PreAllocatedBuffer[1] = (byte)(0xff & i16);
+ BinaryPrimitives.WriteInt16BigEndian(PreAllocatedBuffer, i16);
await Trans.WriteAsync(PreAllocatedBuffer, 0, 2, cancellationToken);
}
@@ -229,10 +228,7 @@
return;
}
- PreAllocatedBuffer[0] = (byte)(0xff & (i32 >> 24));
- PreAllocatedBuffer[1] = (byte)(0xff & (i32 >> 16));
- PreAllocatedBuffer[2] = (byte)(0xff & (i32 >> 8));
- PreAllocatedBuffer[3] = (byte)(0xff & i32);
+ BinaryPrimitives.WriteInt32BigEndian(PreAllocatedBuffer, i32);
await Trans.WriteAsync(PreAllocatedBuffer, 0, 4, cancellationToken);
}
@@ -245,14 +241,7 @@
return;
}
- PreAllocatedBuffer[0] = (byte)(0xff & (i64 >> 56));
- PreAllocatedBuffer[1] = (byte)(0xff & (i64 >> 48));
- PreAllocatedBuffer[2] = (byte)(0xff & (i64 >> 40));
- PreAllocatedBuffer[3] = (byte)(0xff & (i64 >> 32));
- PreAllocatedBuffer[4] = (byte)(0xff & (i64 >> 24));
- PreAllocatedBuffer[5] = (byte)(0xff & (i64 >> 16));
- PreAllocatedBuffer[6] = (byte)(0xff & (i64 >> 8));
- PreAllocatedBuffer[7] = (byte)(0xff & i64);
+ BinaryPrimitives.WriteInt64BigEndian(PreAllocatedBuffer, i64);
await Trans.WriteAsync(PreAllocatedBuffer, 0, 8, cancellationToken);
}
@@ -381,7 +370,7 @@
ValueType = (TType) await ReadByteAsync(cancellationToken),
Count = await ReadI32Async(cancellationToken)
};
-
+ CheckReadBytesAvailable(map);
return map;
}
@@ -405,7 +394,7 @@
ElementType = (TType) await ReadByteAsync(cancellationToken),
Count = await ReadI32Async(cancellationToken)
};
-
+ CheckReadBytesAvailable(list);
return list;
}
@@ -429,7 +418,7 @@
ElementType = (TType) await ReadByteAsync(cancellationToken),
Count = await ReadI32Async(cancellationToken)
};
-
+ CheckReadBytesAvailable(set);
return set;
}
@@ -470,7 +459,7 @@
}
await Trans.ReadAllAsync(PreAllocatedBuffer, 0, 2, cancellationToken);
- var result = (short) (((PreAllocatedBuffer[0] & 0xff) << 8) | PreAllocatedBuffer[1] & 0xff);
+ var result = BinaryPrimitives.ReadInt16BigEndian(PreAllocatedBuffer);
return result;
}
@@ -483,34 +472,11 @@
await Trans.ReadAllAsync(PreAllocatedBuffer, 0, 4, cancellationToken);
- var result =
- ((PreAllocatedBuffer[0] & 0xff) << 24) |
- ((PreAllocatedBuffer[1] & 0xff) << 16) |
- ((PreAllocatedBuffer[2] & 0xff) << 8) |
- PreAllocatedBuffer[3] & 0xff;
+ var result = BinaryPrimitives.ReadInt32BigEndian(PreAllocatedBuffer);
return result;
}
-#pragma warning disable 675
-
- protected internal long ReadI64FromPreAllocatedBuffer()
- {
- var result =
- ((long) (PreAllocatedBuffer[0] & 0xff) << 56) |
- ((long) (PreAllocatedBuffer[1] & 0xff) << 48) |
- ((long) (PreAllocatedBuffer[2] & 0xff) << 40) |
- ((long) (PreAllocatedBuffer[3] & 0xff) << 32) |
- ((long) (PreAllocatedBuffer[4] & 0xff) << 24) |
- ((long) (PreAllocatedBuffer[5] & 0xff) << 16) |
- ((long) (PreAllocatedBuffer[6] & 0xff) << 8) |
- PreAllocatedBuffer[7] & 0xff;
-
- return result;
- }
-
-#pragma warning restore 675
-
public override async ValueTask<long> ReadI64Async(CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
@@ -519,7 +485,7 @@
}
await Trans.ReadAllAsync(PreAllocatedBuffer, 0, 8, cancellationToken);
- return ReadI64FromPreAllocatedBuffer();
+ return BinaryPrimitives.ReadInt64BigEndian(PreAllocatedBuffer);
}
public override async ValueTask<double> ReadDoubleAsync(CancellationToken cancellationToken)
@@ -541,6 +507,7 @@
}
var size = await ReadI32Async(cancellationToken);
+ Transport.CheckReadBytesAvailable(size);
var buf = new byte[size];
await Trans.ReadAllAsync(buf, 0, size, cancellationToken);
return buf;
@@ -570,11 +537,34 @@
return Encoding.UTF8.GetString(PreAllocatedBuffer, 0, size);
}
+ Transport.CheckReadBytesAvailable(size);
var buf = new byte[size];
await Trans.ReadAllAsync(buf, 0, size, cancellationToken);
return Encoding.UTF8.GetString(buf, 0, buf.Length);
}
+ // Return the minimum number of bytes a type will consume on the wire
+ public override int GetMinSerializedSize(TType type)
+ {
+ switch (type)
+ {
+ case TType.Stop: return 0;
+ case TType.Void: return 0;
+ case TType.Bool: return sizeof(byte);
+ case TType.Byte: return sizeof(byte);
+ case TType.Double: return sizeof(double);
+ case TType.I16: return sizeof(short);
+ case TType.I32: return sizeof(int);
+ case TType.I64: return sizeof(long);
+ case TType.String: return sizeof(int); // string length
+ case TType.Struct: return 0; // empty struct
+ case TType.Map: return sizeof(int); // element count
+ case TType.Set: return sizeof(int); // element count
+ case TType.List: return sizeof(int); // element count
+ default: throw new TTransportException(TTransportException.ExceptionType.Unknown, "unrecognized type code");
+ }
+ }
+
public class Factory : TProtocolFactory
{
protected bool StrictRead;
diff --git a/lib/netstd/Thrift/Protocol/TCompactProtocol.cs b/lib/netstd/Thrift/Protocol/TCompactProtocol.cs
index c26633a..a8a46f2 100644
--- a/lib/netstd/Thrift/Protocol/TCompactProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TCompactProtocol.cs
@@ -16,6 +16,7 @@
// under the License.
using System;
+using System.Buffers.Binary;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
@@ -399,8 +400,7 @@
{
return;
}
-
- FixedLongToBytes(BitConverter.DoubleToInt64Bits(d), PreAllocatedBuffer, 0);
+ BinaryPrimitives.WriteInt64LittleEndian(PreAllocatedBuffer, BitConverter.DoubleToInt64Bits(d));
await Trans.WriteAsync(PreAllocatedBuffer, 0, 8, cancellationToken);
}
@@ -590,7 +590,9 @@
var size = (int) await ReadVarInt32Async(cancellationToken);
var keyAndValueType = size == 0 ? (byte) 0 : (byte) await ReadByteAsync(cancellationToken);
- return new TMap(GetTType((byte) (keyAndValueType >> 4)), GetTType((byte) (keyAndValueType & 0xf)), size);
+ var map = new TMap(GetTType((byte) (keyAndValueType >> 4)), GetTType((byte) (keyAndValueType & 0xf)), size);
+ CheckReadBytesAvailable(map);
+ return map;
}
public override async Task ReadMapEndAsync(CancellationToken cancellationToken)
@@ -683,8 +685,8 @@
}
await Trans.ReadAllAsync(PreAllocatedBuffer, 0, 8, cancellationToken);
-
- return BitConverter.Int64BitsToDouble(BytesToLong(PreAllocatedBuffer));
+
+ return BitConverter.Int64BitsToDouble(BinaryPrimitives.ReadInt64LittleEndian(PreAllocatedBuffer));
}
public override async ValueTask<string> ReadStringAsync(CancellationToken cancellationToken)
@@ -703,6 +705,7 @@
return Encoding.UTF8.GetString(PreAllocatedBuffer, 0, length);
}
+ Transport.CheckReadBytesAvailable(length);
var buf = new byte[length];
await Trans.ReadAllAsync(buf, 0, length, cancellationToken);
return Encoding.UTF8.GetString(buf, 0, length);
@@ -718,6 +721,7 @@
}
// read data
+ Transport.CheckReadBytesAvailable(length);
var buf = new byte[length];
await Trans.ReadAllAsync(buf, 0, length, cancellationToken);
return buf;
@@ -745,7 +749,9 @@
}
var type = GetTType(sizeAndType);
- return new TList(type, size);
+ var list = new TList(type, size);
+ CheckReadBytesAvailable(list);
+ return list;
}
public override async Task ReadListEndAsync(CancellationToken cancellationToken)
@@ -838,25 +844,6 @@
return (long) (n >> 1) ^ -(long) (n & 1);
}
- private static long BytesToLong(byte[] bytes)
- {
- /*
- Note that it's important that the mask bytes are long literals,
- otherwise they'll default to ints, and when you shift an int left 56 bits,
- you just get a messed up int.
- */
-
- return
- ((bytes[7] & 0xffL) << 56) |
- ((bytes[6] & 0xffL) << 48) |
- ((bytes[5] & 0xffL) << 40) |
- ((bytes[4] & 0xffL) << 32) |
- ((bytes[3] & 0xffL) << 24) |
- ((bytes[2] & 0xffL) << 16) |
- ((bytes[1] & 0xffL) << 8) |
- (bytes[0] & 0xffL);
- }
-
private static TType GetTType(byte type)
{
// Given a TCompactProtocol.Types constant, convert it to its corresponding TType value.
@@ -875,17 +862,26 @@
return (uint) (n << 1) ^ (uint) (n >> 31);
}
- private static void FixedLongToBytes(long n, byte[] buf, int off)
+ // Return the minimum number of bytes a type will consume on the wire
+ public override int GetMinSerializedSize(TType type)
{
- // Convert a long into little-endian bytes in buf starting at off and going until off+7.
- buf[off + 0] = (byte) (n & 0xff);
- buf[off + 1] = (byte) ((n >> 8) & 0xff);
- buf[off + 2] = (byte) ((n >> 16) & 0xff);
- buf[off + 3] = (byte) ((n >> 24) & 0xff);
- buf[off + 4] = (byte) ((n >> 32) & 0xff);
- buf[off + 5] = (byte) ((n >> 40) & 0xff);
- buf[off + 6] = (byte) ((n >> 48) & 0xff);
- buf[off + 7] = (byte) ((n >> 56) & 0xff);
+ switch (type)
+ {
+ case TType.Stop: return 0;
+ case TType.Void: return 0;
+ case TType.Bool: return sizeof(byte);
+ case TType.Double: return 8; // uses fixedLongToBytes() which always writes 8 bytes
+ case TType.Byte: return sizeof(byte);
+ case TType.I16: return sizeof(byte); // zigzag
+ case TType.I32: return sizeof(byte); // zigzag
+ case TType.I64: return sizeof(byte); // zigzag
+ case TType.String: return sizeof(byte); // string length
+ case TType.Struct: return 0; // empty struct
+ case TType.Map: return sizeof(byte); // element count
+ case TType.Set: return sizeof(byte); // element count
+ case TType.List: return sizeof(byte); // element count
+ default: throw new TTransportException(TTransportException.ExceptionType.Unknown, "unrecognized type code");
+ }
}
public class Factory : TProtocolFactory
diff --git a/lib/netstd/Thrift/Protocol/TJSONProtocol.cs b/lib/netstd/Thrift/Protocol/TJSONProtocol.cs
index 464bd62..7bc7130 100644
--- a/lib/netstd/Thrift/Protocol/TJSONProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TJSONProtocol.cs
@@ -703,6 +703,7 @@
map.KeyType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
map.ValueType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
map.Count = (int) await ReadJsonIntegerAsync(cancellationToken);
+ CheckReadBytesAvailable(map);
await ReadJsonObjectStartAsync(cancellationToken);
return map;
}
@@ -719,6 +720,7 @@
await ReadJsonArrayStartAsync(cancellationToken);
list.ElementType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
list.Count = (int) await ReadJsonIntegerAsync(cancellationToken);
+ CheckReadBytesAvailable(list);
return list;
}
@@ -733,6 +735,7 @@
await ReadJsonArrayStartAsync(cancellationToken);
set.ElementType = TJSONProtocolHelper.GetTypeIdForTypeName(await ReadJsonStringAsync(false, cancellationToken));
set.Count = (int) await ReadJsonIntegerAsync(cancellationToken);
+ CheckReadBytesAvailable(set);
return set;
}
@@ -782,6 +785,28 @@
return await ReadJsonBase64Async(cancellationToken);
}
+ // Return the minimum number of bytes a type will consume on the wire
+ public override int GetMinSerializedSize(TType type)
+ {
+ switch (type)
+ {
+ case TType.Stop: return 0;
+ case TType.Void: return 0;
+ case TType.Bool: return 1; // written as int
+ case TType.Byte: return 1;
+ case TType.Double: return 1;
+ case TType.I16: return 1;
+ case TType.I32: return 1;
+ case TType.I64: return 1;
+ case TType.String: return 2; // empty string
+ case TType.Struct: return 2; // empty struct
+ case TType.Map: return 2; // empty map
+ case TType.Set: return 2; // empty set
+ case TType.List: return 2; // empty list
+ default: throw new TTransportException(TTransportException.ExceptionType.Unknown, "unrecognized type code");
+ }
+ }
+
/// <summary>
/// Factory for JSON protocol objects
/// </summary>
diff --git a/lib/netstd/Thrift/Protocol/TProtocol.cs b/lib/netstd/Thrift/Protocol/TProtocol.cs
index 75edb11..5275c9c 100644
--- a/lib/netstd/Thrift/Protocol/TProtocol.cs
+++ b/lib/netstd/Thrift/Protocol/TProtocol.cs
@@ -27,7 +27,6 @@
// ReSharper disable once InconsistentNaming
public abstract class TProtocol : IDisposable
{
- public const int DefaultRecursionDepth = 64;
private bool _isDisposed;
protected int RecursionDepth;
@@ -36,7 +35,7 @@
protected TProtocol(TTransport trans)
{
Trans = trans;
- RecursionLimit = DefaultRecursionDepth;
+ RecursionLimit = trans.Configuration.RecursionLimit;
RecursionDepth = 0;
}
@@ -78,6 +77,27 @@
_isDisposed = true;
}
+
+ protected void CheckReadBytesAvailable(TSet set)
+ {
+ Transport.CheckReadBytesAvailable(set.Count * GetMinSerializedSize(set.ElementType));
+ }
+
+ protected void CheckReadBytesAvailable(TList list)
+ {
+ Transport.CheckReadBytesAvailable(list.Count * GetMinSerializedSize(list.ElementType));
+ }
+
+ protected void CheckReadBytesAvailable(TMap map)
+ {
+ var elmSize = GetMinSerializedSize(map.KeyType) + GetMinSerializedSize(map.ValueType);
+ Transport.CheckReadBytesAvailable(map.Count * elmSize);
+ }
+
+ // Returns the minimum amount of bytes needed to store the smallest possible instance of TType.
+ public abstract int GetMinSerializedSize(TType type);
+
+
public virtual async Task WriteMessageBeginAsync(TMessage message)
{
await WriteMessageBeginAsync(message, CancellationToken.None);
diff --git a/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs b/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs
index 845c827..b032e83 100644
--- a/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs
+++ b/lib/netstd/Thrift/Protocol/TProtocolDecorator.cs
@@ -243,5 +243,13 @@
{
return await _wrappedProtocol.ReadBinaryAsync(cancellationToken);
}
+
+ // Returns the minimum amount of bytes needed to store the smallest possible instance of TType.
+ public override int GetMinSerializedSize(TType type)
+ {
+ return _wrappedProtocol.GetMinSerializedSize(type);
+ }
+
+
}
}
diff --git a/lib/netstd/Thrift/Server/TSimpleAsyncServer.cs b/lib/netstd/Thrift/Server/TSimpleAsyncServer.cs
index bdaa348..45e5513 100644
--- a/lib/netstd/Thrift/Server/TSimpleAsyncServer.cs
+++ b/lib/netstd/Thrift/Server/TSimpleAsyncServer.cs
@@ -66,7 +66,8 @@
outputTransportFactory,
inputProtocolFactory,
outputProtocolFactory,
- loggerFactory.CreateLogger<TSimpleAsyncServer>())
+ loggerFactory.CreateLogger<TSimpleAsyncServer>(),
+ clientWaitingDelay)
{
}
diff --git a/lib/netstd/Thrift/TConfiguration.cs b/lib/netstd/Thrift/TConfiguration.cs
new file mode 100644
index 0000000..c8dde10
--- /dev/null
+++ b/lib/netstd/Thrift/TConfiguration.cs
@@ -0,0 +1,19 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Thrift
+{
+ public class TConfiguration
+ {
+ public const int DEFAULT_MAX_MESSAGE_SIZE = 100 * 1024 * 1024;
+ public const int DEFAULT_MAX_FRAME_SIZE = 16384000; // this value is used consistently across all Thrift libraries
+ public const int DEFAULT_RECURSION_DEPTH = 64;
+
+ public int MaxMessageSize { get; set; } = DEFAULT_MAX_MESSAGE_SIZE;
+ public int MaxFrameSize { get; set; } = DEFAULT_MAX_FRAME_SIZE;
+ public int RecursionLimit { get; set; } = DEFAULT_RECURSION_DEPTH;
+
+ // TODO(JensG): add connection and i/o timeouts
+ }
+}
diff --git a/lib/netstd/Thrift/Thrift.csproj b/lib/netstd/Thrift/Thrift.csproj
index 5d8a9c3..e40db00 100644
--- a/lib/netstd/Thrift/Thrift.csproj
+++ b/lib/netstd/Thrift/Thrift.csproj
@@ -44,16 +44,16 @@
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Http.Abstractions" Version="2.2.0" />
- <PackageReference Include="Microsoft.Extensions.Logging" Version="2.2.0" />
- <PackageReference Include="Microsoft.Extensions.Logging.Console" Version="2.2.0" />
- <PackageReference Include="Microsoft.Extensions.Logging.Debug" Version="2.2.0" />
+ <PackageReference Include="Microsoft.Extensions.Logging" Version="3.1.0" />
+ <PackageReference Include="Microsoft.Extensions.Logging.Console" Version="3.1.0" />
+ <PackageReference Include="Microsoft.Extensions.Logging.Debug" Version="3.1.0" />
<PackageReference Include="System.IO.Pipes" Version="[4.3,)" />
<PackageReference Include="System.IO.Pipes.AccessControl" Version="4.5.1" />
- <PackageReference Include="System.Net.Http.WinHttpHandler" Version="4.5.2" />
+ <PackageReference Include="System.Net.Http.WinHttpHandler" Version="4.7.0" />
<PackageReference Include="System.Net.NameResolution" Version="[4.3,)" />
<PackageReference Include="System.Net.Requests" Version="[4.3,)" />
<PackageReference Include="System.Net.Security" Version="4.3.2" />
- <PackageReference Include="System.Threading.Tasks.Extensions" Version="4.5.2" />
+ <PackageReference Include="System.Threading.Tasks.Extensions" Version="4.5.3" />
</ItemGroup>
</Project>
diff --git a/lib/netstd/Thrift/Transport/Client/THttpTransport.cs b/lib/netstd/Thrift/Transport/Client/THttpTransport.cs
index c84df83..bbd94fa 100644
--- a/lib/netstd/Thrift/Transport/Client/THttpTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/THttpTransport.cs
@@ -28,7 +28,7 @@
namespace Thrift.Transport.Client
{
// ReSharper disable once InconsistentNaming
- public class THttpTransport : TTransport
+ public class THttpTransport : TEndpointTransport
{
private readonly X509Certificate[] _certificates;
private readonly Uri _uri;
@@ -39,13 +39,14 @@
private MemoryStream _outputStream = new MemoryStream();
private bool _isDisposed;
- public THttpTransport(Uri uri, IDictionary<string, string> customRequestHeaders = null, string userAgent = null)
- : this(uri, Enumerable.Empty<X509Certificate>(), customRequestHeaders, userAgent)
+ public THttpTransport(Uri uri, TConfiguration config, IDictionary<string, string> customRequestHeaders = null, string userAgent = null)
+ : this(uri, config, Enumerable.Empty<X509Certificate>(), customRequestHeaders, userAgent)
{
}
- public THttpTransport(Uri uri, IEnumerable<X509Certificate> certificates,
+ public THttpTransport(Uri uri, TConfiguration config, IEnumerable<X509Certificate> certificates,
IDictionary<string, string> customRequestHeaders, string userAgent = null)
+ : base(config)
{
_uri = uri;
_certificates = (certificates ?? Enumerable.Empty<X509Certificate>()).ToArray();
@@ -99,24 +100,22 @@
public override async ValueTask<int> ReadAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
- {
return await Task.FromCanceled<int>(cancellationToken);
- }
if (_inputStream == null)
- {
throw new TTransportException(TTransportException.ExceptionType.NotOpen, "No request has been sent");
- }
+
+ CheckReadBytesAvailable(length);
try
{
var ret = await _inputStream.ReadAsync(buffer, offset, length, cancellationToken);
-
if (ret == -1)
{
throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "No more data available");
}
+ CountConsumedMessageBytes(ret);
return ret;
}
catch (IOException iox)
@@ -201,9 +200,11 @@
finally
{
_outputStream = new MemoryStream();
+ ResetConsumedMessageSize();
}
}
+
// IDisposable
protected override void Dispose(bool disposing)
{
diff --git a/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs b/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
index 25895c2..290e50c 100644
--- a/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TMemoryBufferTransport.cs
@@ -24,26 +24,24 @@
namespace Thrift.Transport.Client
{
// ReSharper disable once InconsistentNaming
- public class TMemoryBufferTransport : TTransport
+ public class TMemoryBufferTransport : TEndpointTransport
{
private bool IsDisposed;
private byte[] Bytes;
private int _bytesUsed;
- public TMemoryBufferTransport()
+ public TMemoryBufferTransport(TConfiguration config, int initialCapacity = 2048)
+ : base(config)
{
- Bytes = new byte[2048]; // default size
+ Bytes = new byte[initialCapacity];
}
- public TMemoryBufferTransport(int initialCapacity)
- {
- Bytes = new byte[initialCapacity]; // default size
- }
-
- public TMemoryBufferTransport(byte[] buf)
+ public TMemoryBufferTransport(byte[] buf, TConfiguration config)
+ :base(config)
{
Bytes = (byte[])buf.Clone();
_bytesUsed = Bytes.Length;
+ UpdateKnownMessageSize(_bytesUsed);
}
public int Position { get; set; }
@@ -117,6 +115,9 @@
if ((0 > newPos) || (newPos > _bytesUsed))
throw new ArgumentException(nameof(origin));
Position = newPos;
+
+ ResetConsumedMessageSize();
+ CountConsumedMessageBytes(Position);
}
public override ValueTask<int> ReadAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken)
@@ -124,6 +125,7 @@
var count = Math.Min(Length - Position, length);
Buffer.BlockCopy(Bytes, Position, buffer, offset, count);
Position += count;
+ CountConsumedMessageBytes(count);
return new ValueTask<int>(count);
}
@@ -147,6 +149,7 @@
{
await Task.FromCanceled(cancellationToken);
}
+ ResetConsumedMessageSize();
}
public byte[] GetBuffer()
@@ -162,7 +165,6 @@
return true;
}
-
// IDisposable
protected override void Dispose(bool disposing)
{
diff --git a/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs b/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs
index 7dfe013..f7f10b7 100644
--- a/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TNamedPipeTransport.cs
@@ -23,17 +23,18 @@
namespace Thrift.Transport.Client
{
// ReSharper disable once InconsistentNaming
- public class TNamedPipeTransport : TTransport
+ public class TNamedPipeTransport : TEndpointTransport
{
private NamedPipeClientStream PipeStream;
- private int ConnectTimeout;
+ private readonly int ConnectTimeout;
- public TNamedPipeTransport(string pipe, int timeout = Timeout.Infinite)
- : this(".", pipe, timeout)
+ public TNamedPipeTransport(string pipe, TConfiguration config, int timeout = Timeout.Infinite)
+ : this(".", pipe, config, timeout)
{
}
- public TNamedPipeTransport(string server, string pipe, int timeout = Timeout.Infinite)
+ public TNamedPipeTransport(string server, string pipe, TConfiguration config, int timeout = Timeout.Infinite)
+ : base(config)
{
var serverName = string.IsNullOrWhiteSpace(server) ? server : ".";
ConnectTimeout = (timeout > 0) ? timeout : Timeout.Infinite;
@@ -51,6 +52,7 @@
}
await PipeStream.ConnectAsync( ConnectTimeout, cancellationToken);
+ ResetConsumedMessageSize();
}
public override void Close()
@@ -69,7 +71,10 @@
throw new TTransportException(TTransportException.ExceptionType.NotOpen);
}
- return await PipeStream.ReadAsync(buffer, offset, length, cancellationToken);
+ CheckReadBytesAvailable(length);
+ var numRead = await PipeStream.ReadAsync(buffer, offset, length, cancellationToken);
+ CountConsumedMessageBytes(numRead);
+ return numRead;
}
public override async Task WriteAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken)
@@ -98,11 +103,16 @@
{
await Task.FromCanceled(cancellationToken);
}
+ ResetConsumedMessageSize();
}
+
protected override void Dispose(bool disposing)
{
- PipeStream.Dispose();
+ if(disposing)
+ {
+ PipeStream?.Dispose();
+ }
}
}
}
diff --git a/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs b/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs
index 00da045..d559154 100644
--- a/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TSocketTransport.cs
@@ -30,18 +30,15 @@
private bool _isDisposed;
- public TSocketTransport(TcpClient client)
+ public TSocketTransport(TcpClient client, TConfiguration config)
+ : base(config)
{
TcpClient = client ?? throw new ArgumentNullException(nameof(client));
SetInputOutputStream();
}
- public TSocketTransport(IPAddress host, int port)
- : this(host, port, 0)
- {
- }
-
- public TSocketTransport(IPAddress host, int port, int timeout)
+ public TSocketTransport(IPAddress host, int port, TConfiguration config, int timeout = 0)
+ : base(config)
{
Host = host;
Port = port;
@@ -52,7 +49,8 @@
SetInputOutputStream();
}
- public TSocketTransport(string host, int port, int timeout = 0)
+ public TSocketTransport(string host, int port, TConfiguration config, int timeout = 0)
+ : base(config)
{
try
{
@@ -84,7 +82,7 @@
}
}
- public TcpClient TcpClient { get; private set; }
+ public TcpClient TcpClient { get; private set; }
public IPAddress Host { get; }
public int Port { get; }
@@ -159,4 +157,4 @@
_isDisposed = true;
}
}
-}
\ No newline at end of file
+}
diff --git a/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs b/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs
index d8574d6..e04b3b3 100644
--- a/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TStreamTransport.cs
@@ -22,15 +22,17 @@
namespace Thrift.Transport.Client
{
// ReSharper disable once InconsistentNaming
- public class TStreamTransport : TTransport
+ public class TStreamTransport : TEndpointTransport
{
private bool _isDisposed;
- protected TStreamTransport()
+ protected TStreamTransport(TConfiguration config)
+ :base(config)
{
}
- public TStreamTransport(Stream inputStream, Stream outputStream)
+ public TStreamTransport(Stream inputStream, Stream outputStream, TConfiguration config)
+ : base(config)
{
InputStream = inputStream;
OutputStream = outputStream;
@@ -38,7 +40,14 @@
protected Stream OutputStream { get; set; }
- protected Stream InputStream { get; set; }
+ private Stream _InputStream = null;
+ protected Stream InputStream {
+ get => _InputStream;
+ set {
+ _InputStream = value;
+ ResetConsumedMessageSize();
+ }
+ }
public override bool IsOpen => true;
@@ -90,8 +99,10 @@
public override async Task FlushAsync(CancellationToken cancellationToken)
{
await OutputStream.FlushAsync(cancellationToken);
+ ResetConsumedMessageSize();
}
+
// IDisposable
protected override void Dispose(bool disposing)
{
diff --git a/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs b/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs
index 9295bb0..0980526 100644
--- a/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs
+++ b/lib/netstd/Thrift/Transport/Client/TTlsSocketTransport.cs
@@ -42,10 +42,12 @@
private SslStream _secureStream;
private int _timeout;
- public TTlsSocketTransport(TcpClient client, X509Certificate2 certificate, bool isServer = false,
+ public TTlsSocketTransport(TcpClient client, TConfiguration config,
+ X509Certificate2 certificate, bool isServer = false,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
+ : base(config)
{
_client = client;
_certificate = certificate;
@@ -67,11 +69,12 @@
}
}
- public TTlsSocketTransport(IPAddress host, int port, string certificatePath,
+ public TTlsSocketTransport(IPAddress host, int port, TConfiguration config,
+ string certificatePath,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
- : this(host, port, 0,
+ : this(host, port, config, 0,
new X509Certificate2(certificatePath),
certValidator,
localCertificateSelectionCallback,
@@ -79,12 +82,12 @@
{
}
- public TTlsSocketTransport(IPAddress host, int port,
+ public TTlsSocketTransport(IPAddress host, int port, TConfiguration config,
X509Certificate2 certificate = null,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
- : this(host, port, 0,
+ : this(host, port, config, 0,
certificate,
certValidator,
localCertificateSelectionCallback,
@@ -92,11 +95,12 @@
{
}
- public TTlsSocketTransport(IPAddress host, int port, int timeout,
+ public TTlsSocketTransport(IPAddress host, int port, TConfiguration config, int timeout,
X509Certificate2 certificate,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
+ : base(config)
{
_host = host;
_port = port;
@@ -109,11 +113,12 @@
InitSocket();
}
- public TTlsSocketTransport(string host, int port, int timeout,
+ public TTlsSocketTransport(string host, int port, TConfiguration config, int timeout,
X509Certificate2 certificate,
RemoteCertificateValidationCallback certValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
+ : base(config)
{
try
{
diff --git a/lib/netstd/Thrift/Transport/TBufferedTransport.cs b/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
similarity index 89%
rename from lib/netstd/Thrift/Transport/TBufferedTransport.cs
rename to lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
index e4fdd3a..dee52dd 100644
--- a/lib/netstd/Thrift/Transport/TBufferedTransport.cs
+++ b/lib/netstd/Thrift/Transport/Layered/TBufferedTransport.cs
@@ -24,12 +24,11 @@
namespace Thrift.Transport
{
// ReSharper disable once InconsistentNaming
- public class TBufferedTransport : TTransport
+ public class TBufferedTransport : TLayeredTransport
{
private readonly int DesiredBufferSize;
- private readonly Client.TMemoryBufferTransport ReadBuffer = new Client.TMemoryBufferTransport(1024);
- private readonly Client.TMemoryBufferTransport WriteBuffer = new Client.TMemoryBufferTransport(1024);
- private readonly TTransport InnerTransport;
+ private readonly Client.TMemoryBufferTransport ReadBuffer;
+ private readonly Client.TMemoryBufferTransport WriteBuffer;
private bool IsDisposed;
public class Factory : TTransportFactory
@@ -42,19 +41,20 @@
//TODO: should support only specified input transport?
public TBufferedTransport(TTransport transport, int bufSize = 1024)
+ : base(transport)
{
if (bufSize <= 0)
{
throw new ArgumentOutOfRangeException(nameof(bufSize), "Buffer size must be a positive number.");
}
- InnerTransport = transport ?? throw new ArgumentNullException(nameof(transport));
DesiredBufferSize = bufSize;
- if (DesiredBufferSize != ReadBuffer.Capacity)
- ReadBuffer.Capacity = DesiredBufferSize;
- if (DesiredBufferSize != WriteBuffer.Capacity)
- WriteBuffer.Capacity = DesiredBufferSize;
+ WriteBuffer = new Client.TMemoryBufferTransport(InnerTransport.Configuration, bufSize);
+ ReadBuffer = new Client.TMemoryBufferTransport(InnerTransport.Configuration, bufSize);
+
+ Debug.Assert(DesiredBufferSize == ReadBuffer.Capacity);
+ Debug.Assert(DesiredBufferSize == WriteBuffer.Capacity);
}
public TTransport UnderlyingTransport
@@ -172,6 +172,17 @@
await InnerTransport.FlushAsync(cancellationToken);
}
+ public override void CheckReadBytesAvailable(long numBytes)
+ {
+ var buffered = ReadBuffer.Length - ReadBuffer.Position;
+ if (buffered < numBytes)
+ {
+ numBytes -= buffered;
+ InnerTransport.CheckReadBytesAvailable(numBytes);
+ }
+ }
+
+
private void CheckNotDisposed()
{
if (IsDisposed)
diff --git a/lib/netstd/Thrift/Transport/TFramedTransport.cs b/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
similarity index 84%
rename from lib/netstd/Thrift/Transport/TFramedTransport.cs
rename to lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
index de6df72..be1513f 100644
--- a/lib/netstd/Thrift/Transport/TFramedTransport.cs
+++ b/lib/netstd/Thrift/Transport/Layered/TFramedTransport.cs
@@ -16,6 +16,7 @@
// under the License.
using System;
+using System.Buffers.Binary;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
@@ -23,13 +24,12 @@
namespace Thrift.Transport
{
// ReSharper disable once InconsistentNaming
- public class TFramedTransport : TTransport
+ public class TFramedTransport : TLayeredTransport
{
private const int HeaderSize = 4;
private readonly byte[] HeaderBuf = new byte[HeaderSize];
- private readonly Client.TMemoryBufferTransport ReadBuffer = new Client.TMemoryBufferTransport();
- private readonly Client.TMemoryBufferTransport WriteBuffer = new Client.TMemoryBufferTransport();
- private readonly TTransport InnerTransport;
+ private readonly Client.TMemoryBufferTransport ReadBuffer;
+ private readonly Client.TMemoryBufferTransport WriteBuffer;
private bool IsDisposed;
@@ -42,9 +42,10 @@
}
public TFramedTransport(TTransport transport)
+ : base(transport)
{
- InnerTransport = transport ?? throw new ArgumentNullException(nameof(transport));
-
+ ReadBuffer = new Client.TMemoryBufferTransport(Configuration);
+ WriteBuffer = new Client.TMemoryBufferTransport(Configuration);
InitWriteBuffer();
}
@@ -86,7 +87,11 @@
private async ValueTask ReadFrameAsync(CancellationToken cancellationToken)
{
await InnerTransport.ReadAllAsync(HeaderBuf, 0, HeaderSize, cancellationToken);
- var size = DecodeFrameSize(HeaderBuf);
+ int size = BinaryPrimitives.ReadInt32BigEndian(HeaderBuf);
+
+ if ((0 > size) || (size > Configuration.MaxFrameSize)) // size must be in the range 0 to allowed max
+ throw new TTransportException(TTransportException.ExceptionType.Unknown, $"Maximum frame size exceeded ({size} bytes)");
+ UpdateKnownMessageSize(size + HeaderSize);
ReadBuffer.SetLength(size);
ReadBuffer.Seek(0, SeekOrigin.Begin);
@@ -133,7 +138,7 @@
}
// Inject message header into the reserved buffer space
- EncodeFrameSize(dataLen, bufSegment.Array);
+ BinaryPrimitives.WriteInt32BigEndian(bufSegment.Array, dataLen);
// Send the entire message at once
await InnerTransport.WriteAsync(bufSegment.Array, 0, bufSegment.Count, cancellationToken);
@@ -150,24 +155,16 @@
WriteBuffer.Seek(0, SeekOrigin.End);
}
- private static void EncodeFrameSize(int frameSize, byte[] buf)
+ public override void CheckReadBytesAvailable(long numBytes)
{
- buf[0] = (byte) (0xff & (frameSize >> 24));
- buf[1] = (byte) (0xff & (frameSize >> 16));
- buf[2] = (byte) (0xff & (frameSize >> 8));
- buf[3] = (byte) (0xff & (frameSize));
+ var buffered = ReadBuffer.Length - ReadBuffer.Position;
+ if (buffered < numBytes)
+ {
+ numBytes -= buffered;
+ InnerTransport.CheckReadBytesAvailable(numBytes);
+ }
}
- private static int DecodeFrameSize(byte[] buf)
- {
- return
- ((buf[0] & 0xff) << 24) |
- ((buf[1] & 0xff) << 16) |
- ((buf[2] & 0xff) << 8) |
- (buf[3] & 0xff);
- }
-
-
private void CheckNotDisposed()
{
if (IsDisposed)
diff --git a/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs b/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs
new file mode 100644
index 0000000..2137ae4
--- /dev/null
+++ b/lib/netstd/Thrift/Transport/Layered/TLayeredTransport.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Thrift.Transport
+{
+ public abstract class TLayeredTransport : TTransport
+ {
+ public readonly TTransport InnerTransport;
+
+ public override TConfiguration Configuration { get => InnerTransport.Configuration; }
+
+ public TLayeredTransport(TTransport transport)
+ {
+ InnerTransport = transport ?? throw new ArgumentNullException(nameof(transport));
+ }
+
+ public override void UpdateKnownMessageSize(long size)
+ {
+ InnerTransport.UpdateKnownMessageSize(size);
+ }
+
+ public override void CheckReadBytesAvailable(long numBytes)
+ {
+ InnerTransport.CheckReadBytesAvailable(numBytes);
+ }
+ }
+}
diff --git a/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs b/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs
index 056300c..7271f50 100644
--- a/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs
+++ b/lib/netstd/Thrift/Transport/Server/THttpServerTransport.cs
@@ -42,24 +42,31 @@
protected TTransportFactory OutputTransportFactory;
protected ITAsyncProcessor Processor;
+ protected TConfiguration Configuration;
- public THttpServerTransport(ITAsyncProcessor processor, RequestDelegate next = null, ILoggerFactory loggerFactory = null)
- : this(processor, new TBinaryProtocol.Factory(), null, next, loggerFactory)
+ public THttpServerTransport(
+ ITAsyncProcessor processor,
+ TConfiguration config,
+ RequestDelegate next = null,
+ ILoggerFactory loggerFactory = null)
+ : this(processor, config, new TBinaryProtocol.Factory(), null, next, loggerFactory)
{
}
public THttpServerTransport(
- ITAsyncProcessor processor,
+ ITAsyncProcessor processor,
+ TConfiguration config,
TProtocolFactory protocolFactory,
TTransportFactory transFactory = null,
RequestDelegate next = null,
ILoggerFactory loggerFactory = null)
- : this(processor, protocolFactory, protocolFactory, transFactory, transFactory, next, loggerFactory)
+ : this(processor, config, protocolFactory, protocolFactory, transFactory, transFactory, next, loggerFactory)
{
}
public THttpServerTransport(
- ITAsyncProcessor processor,
+ ITAsyncProcessor processor,
+ TConfiguration config,
TProtocolFactory inputProtocolFactory,
TProtocolFactory outputProtocolFactory,
TTransportFactory inputTransFactory = null,
@@ -70,6 +77,8 @@
// loggerFactory == null is not illegal anymore
Processor = processor ?? throw new ArgumentNullException(nameof(processor));
+ Configuration = config; // may be null
+
InputProtocolFactory = inputProtocolFactory ?? throw new ArgumentNullException(nameof(inputProtocolFactory));
OutputProtocolFactory = outputProtocolFactory ?? throw new ArgumentNullException(nameof(outputProtocolFactory));
@@ -88,7 +97,7 @@
public async Task ProcessRequestAsync(HttpContext context, CancellationToken cancellationToken)
{
- var transport = new TStreamTransport(context.Request.Body, context.Response.Body);
+ var transport = new TStreamTransport(context.Request.Body, context.Response.Body, Configuration);
try
{
diff --git a/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs b/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs
index 77b8251..a8b64c4 100644
--- a/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs
+++ b/lib/netstd/Thrift/Transport/Server/TNamedPipeServerTransport.cs
@@ -38,7 +38,8 @@
private volatile bool _isPending = true;
private NamedPipeServerStream _stream = null;
- public TNamedPipeServerTransport(string pipeAddress)
+ public TNamedPipeServerTransport(string pipeAddress, TConfiguration config)
+ : base(config)
{
_pipeAddress = pipeAddress;
}
@@ -92,10 +93,16 @@
try
{
var handle = CreatePipeNative(_pipeAddress, inbuf, outbuf);
- if( (handle != null) && (!handle.IsInvalid))
+ if ((handle != null) && (!handle.IsInvalid))
+ {
_stream = new NamedPipeServerStream(PipeDirection.InOut, _asyncMode, false, handle);
+ handle = null; // we don't own it any longer
+ }
else
+ {
+ handle?.Dispose();
_stream = new NamedPipeServerStream(_pipeAddress, direction, maxconn, mode, options, inbuf, outbuf/*, pipesec*/);
+ }
}
catch (NotImplementedException) // Mono still does not support async, fallback to sync
{
@@ -218,7 +225,7 @@
await _stream.WaitForConnectionAsync(cancellationToken);
- var trans = new ServerTransport(_stream);
+ var trans = new ServerTransport(_stream, Configuration);
_stream = null; // pass ownership to ServerTransport
//_isPending = false;
@@ -237,11 +244,12 @@
}
}
- private class ServerTransport : TTransport
+ private class ServerTransport : TEndpointTransport
{
private readonly NamedPipeServerStream PipeStream;
- public ServerTransport(NamedPipeServerStream stream)
+ public ServerTransport(NamedPipeServerStream stream, TConfiguration config)
+ : base(config)
{
PipeStream = stream;
}
@@ -268,7 +276,10 @@
throw new TTransportException(TTransportException.ExceptionType.NotOpen);
}
- return await PipeStream.ReadAsync(buffer, offset, length, cancellationToken);
+ CheckReadBytesAvailable(length);
+ var numBytes = await PipeStream.ReadAsync(buffer, offset, length, cancellationToken);
+ CountConsumedMessageBytes(numBytes);
+ return numBytes;
}
public override async Task WriteAsync(byte[] buffer, int offset, int length, CancellationToken cancellationToken)
@@ -297,11 +308,16 @@
{
await Task.FromCanceled(cancellationToken);
}
+
+ ResetConsumedMessageSize();
}
protected override void Dispose(bool disposing)
{
- PipeStream?.Dispose();
+ if (disposing)
+ {
+ PipeStream?.Dispose();
+ }
}
}
}
diff --git a/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs b/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs
index 86d82e3..6656b64 100644
--- a/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs
+++ b/lib/netstd/Thrift/Transport/Server/TServerSocketTransport.cs
@@ -31,14 +31,15 @@
private readonly int _clientTimeout;
private TcpListener _server;
- public TServerSocketTransport(TcpListener listener, int clientTimeout = 0)
+ public TServerSocketTransport(TcpListener listener, TConfiguration config, int clientTimeout = 0)
+ : base(config)
{
_server = listener;
_clientTimeout = clientTimeout;
}
- public TServerSocketTransport(int port, int clientTimeout = 0)
- : this(null, clientTimeout)
+ public TServerSocketTransport(int port, TConfiguration config, int clientTimeout = 0)
+ : this(null, config, clientTimeout)
{
try
{
@@ -93,7 +94,7 @@
try
{
- tSocketTransport = new TSocketTransport(tcpClient)
+ tSocketTransport = new TSocketTransport(tcpClient,Configuration)
{
Timeout = _clientTimeout
};
diff --git a/lib/netstd/Thrift/Transport/TServerTransport.cs b/lib/netstd/Thrift/Transport/Server/TServerTransport.cs
similarity index 88%
rename from lib/netstd/Thrift/Transport/TServerTransport.cs
rename to lib/netstd/Thrift/Transport/Server/TServerTransport.cs
index 74c54cd..31f578d 100644
--- a/lib/netstd/Thrift/Transport/TServerTransport.cs
+++ b/lib/netstd/Thrift/Transport/Server/TServerTransport.cs
@@ -23,6 +23,13 @@
// ReSharper disable once InconsistentNaming
public abstract class TServerTransport
{
+ public readonly TConfiguration Configuration;
+
+ public TServerTransport(TConfiguration config)
+ {
+ Configuration = config ?? new TConfiguration();
+ }
+
public abstract void Listen();
public abstract void Close();
public abstract bool IsClientPending();
@@ -34,7 +41,7 @@
protected abstract ValueTask<TTransport> AcceptImplementationAsync(CancellationToken cancellationToken);
- public async ValueTask<TTransport> AcceptAsync()
+ public async ValueTask<TTransport> AcceptAsync()
{
return await AcceptAsync(CancellationToken.None);
}
@@ -51,4 +58,4 @@
return transport;
}
}
-}
\ No newline at end of file
+}
diff --git a/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs b/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs
index 1286805..9f74562 100644
--- a/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs
+++ b/lib/netstd/Thrift/Transport/Server/TTlsServerSocketTransport.cs
@@ -39,10 +39,12 @@
public TTlsServerSocketTransport(
TcpListener listener,
+ TConfiguration config,
X509Certificate2 certificate,
RemoteCertificateValidationCallback clientCertValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
+ : base(config)
{
if (!certificate.HasPrivateKey)
{
@@ -59,11 +61,12 @@
public TTlsServerSocketTransport(
int port,
+ TConfiguration config,
X509Certificate2 certificate,
RemoteCertificateValidationCallback clientCertValidator = null,
LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
SslProtocols sslProtocols = SslProtocols.Tls12)
- : this(null, certificate, clientCertValidator, localCertificateSelectionCallback)
+ : this(null, config, certificate, clientCertValidator, localCertificateSelectionCallback, sslProtocols)
{
try
{
@@ -117,7 +120,9 @@
client.SendTimeout = client.ReceiveTimeout = _clientTimeout;
//wrap the client in an SSL Socket passing in the SSL cert
- var tTlsSocket = new TTlsSocketTransport(client, _serverCertificate, true, _clientCertValidator,
+ var tTlsSocket = new TTlsSocketTransport(
+ client, Configuration,
+ _serverCertificate, true, _clientCertValidator,
_localCertificateSelectionCallback, _sslProtocols);
await tTlsSocket.SetupTlsAsync();
diff --git a/lib/netstd/Thrift/Transport/TEndpointTransport.cs b/lib/netstd/Thrift/Transport/TEndpointTransport.cs
new file mode 100644
index 0000000..fa2ac6b
--- /dev/null
+++ b/lib/netstd/Thrift/Transport/TEndpointTransport.cs
@@ -0,0 +1,87 @@
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.Text;
+
+namespace Thrift.Transport
+{
+
+ abstract public class TEndpointTransport : TTransport
+ {
+ protected long MaxMessageSize { get => Configuration.MaxMessageSize; }
+ protected long KnownMessageSize { get; private set; }
+ protected long RemainingMessageSize { get; private set; }
+
+ private readonly TConfiguration _configuration;
+ public override TConfiguration Configuration { get => _configuration; }
+
+ public TEndpointTransport( TConfiguration config)
+ {
+ _configuration = config ?? new TConfiguration();
+ Debug.Assert(Configuration != null);
+
+ ResetConsumedMessageSize();
+ }
+
+ /// <summary>
+ /// Resets RemainingMessageSize to the configured maximum
+ /// </summary>
+ protected void ResetConsumedMessageSize(long newSize = -1)
+ {
+ // full reset
+ if (newSize < 0)
+ {
+ KnownMessageSize = MaxMessageSize;
+ RemainingMessageSize = MaxMessageSize;
+ return;
+ }
+
+ // update only: message size can shrink, but not grow
+ Debug.Assert(KnownMessageSize <= MaxMessageSize);
+ if (newSize > KnownMessageSize)
+ throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached");
+
+ KnownMessageSize = newSize;
+ RemainingMessageSize = newSize;
+ }
+
+ /// <summary>
+ /// Updates RemainingMessageSize to reflect then known real message size (e.g. framed transport).
+ /// Will throw if we already consumed too many bytes or if the new size is larger than allowed.
+ /// </summary>
+ /// <param name="size"></param>
+ public override void UpdateKnownMessageSize(long size)
+ {
+ var consumed = KnownMessageSize - RemainingMessageSize;
+ ResetConsumedMessageSize(size);
+ CountConsumedMessageBytes(consumed);
+ }
+
+ /// <summary>
+ /// Throws if there are not enough bytes in the input stream to satisfy a read of numBytes bytes of data
+ /// </summary>
+ /// <param name="numBytes"></param>
+ public override void CheckReadBytesAvailable(long numBytes)
+ {
+ if (RemainingMessageSize < numBytes)
+ throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached");
+ }
+
+ /// <summary>
+ /// Consumes numBytes from the RemainingMessageSize.
+ /// </summary>
+ /// <param name="numBytes"></param>
+ protected void CountConsumedMessageBytes(long numBytes)
+ {
+ if (RemainingMessageSize >= numBytes)
+ {
+ RemainingMessageSize -= numBytes;
+ }
+ else
+ {
+ RemainingMessageSize = 0;
+ throw new TTransportException(TTransportException.ExceptionType.EndOfFile, "MaxMessageSize reached");
+ }
+ }
+ }
+}
diff --git a/lib/netstd/Thrift/Transport/TTransport.cs b/lib/netstd/Thrift/Transport/TTransport.cs
index 7998012..dedd51d 100644
--- a/lib/netstd/Thrift/Transport/TTransport.cs
+++ b/lib/netstd/Thrift/Transport/TTransport.cs
@@ -30,8 +30,11 @@
//TODO: think how to avoid peek byte
private readonly byte[] _peekBuffer = new byte[1];
private bool _hasPeekByte;
- public abstract bool IsOpen { get; }
+ public abstract bool IsOpen { get; }
+ public abstract TConfiguration Configuration { get; }
+ public abstract void UpdateKnownMessageSize(long size);
+ public abstract void CheckReadBytesAvailable(long numBytes);
public void Dispose()
{
Dispose(true);
diff --git a/lib/nodejs/lib/thrift/ws_transport.js b/lib/nodejs/lib/thrift/ws_transport.js
index 3513b84..4cf62b9 100644
--- a/lib/nodejs/lib/thrift/ws_transport.js
+++ b/lib/nodejs/lib/thrift/ws_transport.js
@@ -83,8 +83,8 @@
//If the user made calls before the connection was fully
//open, send them now
this.send_pending.forEach(function(elem) {
- this.socket.send(elem.buf);
- this.callbacks.push((function() {
+ self.socket.send(elem.buf);
+ self.callbacks.push((function() {
var clientCallback = elem.cb;
return function(msg) {
self.setRecvBuffer(msg);
diff --git a/lib/php/test/Fixtures.php b/lib/php/test/Fixtures.php
index 996f4af..fd57d83 100644
--- a/lib/php/test/Fixtures.php
+++ b/lib/php/test/Fixtures.php
@@ -66,32 +66,32 @@
self::$testArgs['testStruct'] =
new Xtruct(
- array(
+ array(
'string_thing' => 'worked',
'byte_thing' => 0x01,
'i32_thing' => pow(2, 30),
'i64_thing' => self::$testArgs['testI64']
)
- );
+ );
self::$testArgs['testNestNested'] =
new Xtruct(
- array(
+ array(
'string_thing' => 'worked',
'byte_thing' => 0x01,
'i32_thing' => pow(2, 30),
'i64_thing' => self::$testArgs['testI64']
)
- );
+ );
self::$testArgs['testNest'] =
new Xtruct2(
- array(
+ array(
'byte_thing' => 0x01,
'struct_thing' => self::$testArgs['testNestNested'],
'i32_thing' => pow(2, 15)
)
- );
+ );
self::$testArgs['testMap'] =
array(
@@ -138,23 +138,23 @@
$xtruct1 =
new Xtruct(
- array(
+ array(
'string_thing' => 'Goodbye4',
'byte_thing' => 4,
'i32_thing' => 4,
'i64_thing' => 4
)
- );
+ );
$xtruct2 =
new Xtruct(
- array(
+ array(
'string_thing' => 'Hello2',
'byte_thing' => 2,
'i32_thing' => 2,
'i64_thing' => 2
)
- );
+ );
$userMap =
array(
@@ -164,21 +164,21 @@
$insanity2 =
new Insanity(
- array(
+ array(
'userMap' => $userMap,
'xtructs' => array($xtruct1, $xtruct2)
)
- );
+ );
$insanity3 = $insanity2;
$insanity6 =
new Insanity(
- array(
+ array(
'userMap' => null,
'xtructs' => null
)
- );
+ );
self::$testArgs['testInsanityExpectedResult'] =
array(
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index c390cbb..ef655ea 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -17,8 +17,6 @@
# under the License.
#
-import sys
-
class TType(object):
STOP = 0
@@ -90,15 +88,6 @@
class TException(Exception):
"""Base class for all thrift exceptions."""
- # BaseException.message is deprecated in Python v[2.6,3.0)
- if (2, 6, 0) <= sys.version_info < (3, 0):
- def _get_message(self):
- return self._message
-
- def _set_message(self, message):
- self._message = message
- message = property(_get_message, _set_message)
-
def __init__(self, message=None):
Exception.__init__(self, message)
self.message = message
diff --git a/lib/py/src/ext/protocol.tcc b/lib/py/src/ext/protocol.tcc
index e15df7e..ede2bb4 100644
--- a/lib/py/src/ext/protocol.tcc
+++ b/lib/py/src/ext/protocol.tcc
@@ -174,7 +174,7 @@
if (output_->buf.capacity() < need) {
try {
output_->buf.reserve(need);
- } catch (std::bad_alloc& ex) {
+ } catch (std::bad_alloc&) {
PyErr_SetString(PyExc_MemoryError, "Failed to allocate write buffer");
return false;
}
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 9ae1b11..6c6ef18 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -80,3 +80,7 @@
[self.__class__, self.thrift_spec])
else:
return iprot.readStruct(cls, cls.thrift_spec, True)
+
+
+class TFrozenExceptionBase(TFrozenBase, TExceptionBase):
+ pass
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index 3456e8f..339a283 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -303,8 +303,14 @@
def readContainerStruct(self, spec):
(obj_class, obj_spec) = spec
- obj = obj_class()
- obj.read(self)
+
+ # If obj_class.read is a classmethod (e.g. in frozen structs),
+ # call it as such.
+ if getattr(obj_class.read, '__self__', None) is obj_class:
+ obj = obj_class.read(self)
+ else:
+ obj = obj_class()
+ obj.read(self)
return obj
def readContainerMap(self, spec):
diff --git a/lib/rb/lib/thrift/transport/http_client_transport.rb b/lib/rb/lib/thrift/transport/http_client_transport.rb
index 5c1dd5c..c84304d 100644
--- a/lib/rb/lib/thrift/transport/http_client_transport.rb
+++ b/lib/rb/lib/thrift/transport/http_client_transport.rb
@@ -47,6 +47,8 @@
http.use_ssl = @url.scheme == 'https'
http.verify_mode = @ssl_verify_mode if @url.scheme == 'https'
resp = http.post(@url.request_uri, @outbuf, @headers)
+ raise TransportException.new(TransportException::UNKNOWN, "#{self.class.name} Could not connect to #{@url}, HTTP status code #{resp.code.to_i}") unless (200..299).include?(resp.code.to_i)
+
data = resp.body
data = Bytes.force_binary_encoding(data)
@inbuf = StringIO.new data
diff --git a/lib/rb/spec/http_client_spec.rb b/lib/rb/spec/http_client_spec.rb
index df472ab..292c752 100644
--- a/lib/rb/spec/http_client_spec.rb
+++ b/lib/rb/spec/http_client_spec.rb
@@ -45,6 +45,7 @@
expect(http).to receive(:post).with("/path/to/service?param=value", "a test frame", {"Content-Type"=>"application/x-thrift"}) do
double("Net::HTTPOK").tap do |response|
expect(response).to receive(:body).and_return "data"
+ expect(response).to receive(:code).and_return "200"
end
end
end
@@ -65,6 +66,7 @@
expect(http).to receive(:post).with("/path/to/service?param=value", "test", headers) do
double("Net::HTTPOK").tap do |response|
expect(response).to receive(:body).and_return "data"
+ expect(response).to receive(:code).and_return "200"
end
end
end
@@ -86,6 +88,24 @@
expect(@client.instance_variable_get(:@outbuf)).to eq(Thrift::Bytes.empty_byte_buffer)
end
+ it 'should raise TransportError on HTTP failures' do
+ @client.write "test"
+
+ expect(Net::HTTP).to receive(:new).with("my.domain.com", 80) do
+ double("Net::HTTP").tap do |http|
+ expect(http).to receive(:use_ssl=).with(false)
+ expect(http).to receive(:post).with("/path/to/service?param=value", "test", {"Content-Type"=>"application/x-thrift"}) do
+ double("Net::HTTPOK").tap do |response|
+ expect(response).not_to receive(:body)
+ expect(response).to receive(:code).at_least(:once).and_return "503"
+ end
+ end
+ end
+ end
+
+ expect { @client.flush }.to raise_error(Thrift::TransportException)
+ end
+
end
describe 'ssl enabled' do
@@ -107,6 +127,7 @@
"Content-Type" => "application/x-thrift") do
double("Net::HTTPOK").tap do |response|
expect(response).to receive(:body).and_return "data"
+ expect(response).to receive(:code).and_return "200"
end
end
end
@@ -128,6 +149,7 @@
"Content-Type" => "application/x-thrift") do
double("Net::HTTPOK").tap do |response|
expect(response).to receive(:body).and_return "data"
+ expect(response).to receive(:code).and_return "200"
end
end
end
diff --git a/lib/rs/src/protocol/compact.rs b/lib/rs/src/protocol/compact.rs
index 1750bc4..3e17398 100644
--- a/lib/rs/src/protocol/compact.rs
+++ b/lib/rs/src/protocol/compact.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
+use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use integer_encoding::{VarIntReader, VarIntWriter};
use std::convert::{From, TryFrom};
use std::io;
@@ -247,7 +247,7 @@
}
fn read_double(&mut self) -> ::Result<f64> {
- self.transport.read_f64::<BigEndian>().map_err(From::from)
+ self.transport.read_f64::<LittleEndian>().map_err(From::from)
}
fn read_string(&mut self) -> ::Result<String> {
@@ -521,7 +521,7 @@
}
fn write_double(&mut self, d: f64) -> ::Result<()> {
- self.transport.write_f64::<BigEndian>(d).map_err(From::from)
+ self.transport.write_f64::<LittleEndian>(d).map_err(From::from)
}
fn write_string(&mut self, s: &str) -> ::Result<()> {
@@ -2374,6 +2374,29 @@
(i_prot, o_prot)
}
+ #[test]
+ fn must_read_write_double() {
+ let (mut i_prot, mut o_prot) = test_objects();
+
+ let double = 3.141592653589793238462643;
+ o_prot.write_double(double).unwrap();
+ copy_write_buffer_to_read_buffer!(o_prot);
+
+ assert_eq!(i_prot.read_double().unwrap(), double);
+ }
+
+ #[test]
+ fn must_encode_double_as_other_langs() {
+ let (_, mut o_prot) = test_objects();
+ let expected = [24, 45, 68, 84, 251, 33, 9, 64];
+
+ let double = 3.141592653589793238462643;
+ o_prot.write_double(double).unwrap();
+
+ assert_eq_written_bytes!(o_prot, expected);
+
+ }
+
fn assert_no_write<F>(mut write_fn: F)
where
F: FnMut(&mut TCompactOutputProtocol<WriteHalf<TBufferChannel>>) -> ::Result<()>,
diff --git a/lib/rs/src/protocol/stored.rs b/lib/rs/src/protocol/stored.rs
index faa5128..bf2d8ba 100644
--- a/lib/rs/src/protocol/stored.rs
+++ b/lib/rs/src/protocol/stored.rs
@@ -52,8 +52,8 @@
/// impl TProcessor for ActualProcessor {
/// fn process(
/// &self,
-/// _: &mut TInputProtocol,
-/// _: &mut TOutputProtocol
+/// _: &mut dyn TInputProtocol,
+/// _: &mut dyn TOutputProtocol
/// ) -> thrift::Result<()> {
/// unimplemented!()
/// }
diff --git a/lib/rs/src/server/mod.rs b/lib/rs/src/server/mod.rs
index b719d1b..f24c113 100644
--- a/lib/rs/src/server/mod.rs
+++ b/lib/rs/src/server/mod.rs
@@ -56,7 +56,7 @@
///
/// // `TProcessor` implementation for `SimpleService`
/// impl TProcessor for SimpleServiceSyncProcessor {
-/// fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> {
+/// fn process(&self, i: &mut dyn TInputProtocol, o: &mut dyn TOutputProtocol) -> thrift::Result<()> {
/// unimplemented!();
/// }
/// }
diff --git a/lib/rs/src/server/threaded.rs b/lib/rs/src/server/threaded.rs
index 8f8c082..b33239a 100644
--- a/lib/rs/src/server/threaded.rs
+++ b/lib/rs/src/server/threaded.rs
@@ -64,7 +64,7 @@
///
/// // `TProcessor` implementation for `SimpleService`
/// impl TProcessor for SimpleServiceSyncProcessor {
-/// fn process(&self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> {
+/// fn process(&self, i: &mut dyn TInputProtocol, o: &mut dyn TOutputProtocol) -> thrift::Result<()> {
/// unimplemented!();
/// }
/// }
@@ -90,10 +90,10 @@
/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {});
///
/// // instantiate the server
-/// let i_tr_fact: Box<TReadTransportFactory> = Box::new(TBufferedReadTransportFactory::new());
-/// let i_pr_fact: Box<TInputProtocolFactory> = Box::new(TBinaryInputProtocolFactory::new());
-/// let o_tr_fact: Box<TWriteTransportFactory> = Box::new(TBufferedWriteTransportFactory::new());
-/// let o_pr_fact: Box<TOutputProtocolFactory> = Box::new(TBinaryOutputProtocolFactory::new());
+/// let i_tr_fact: Box<dyn TReadTransportFactory> = Box::new(TBufferedReadTransportFactory::new());
+/// let i_pr_fact: Box<dyn TInputProtocolFactory> = Box::new(TBinaryInputProtocolFactory::new());
+/// let o_tr_fact: Box<dyn TWriteTransportFactory> = Box::new(TBufferedWriteTransportFactory::new());
+/// let o_pr_fact: Box<dyn TOutputProtocolFactory> = Box::new(TBinaryOutputProtocolFactory::new());
///
/// let mut server = TServer::new(
/// i_tr_fact,
diff --git a/lib/rs/src/transport/mem.rs b/lib/rs/src/transport/mem.rs
index 82c4b57..9874257 100644
--- a/lib/rs/src/transport/mem.rs
+++ b/lib/rs/src/transport/mem.rs
@@ -31,7 +31,7 @@
/// `set_readable_bytes(...)`. Callers can then read until the buffer is
/// depleted. No further reads are accepted until the internal read buffer is
/// replenished again.
-#[derive(Debug)]
+#[derive(Clone, Debug)]
pub struct TBufferChannel {
read: Arc<Mutex<ReadData>>,
write: Arc<Mutex<WriteData>>,
diff --git a/test/DebugProtoTest.thrift b/test/DebugProtoTest.thrift
index de47ea7..1ab0f6a 100644
--- a/test/DebugProtoTest.thrift
+++ b/test/DebugProtoTest.thrift
@@ -241,6 +241,10 @@
2: map<string, string> map_field;
}
+exception MutableException {
+ 1: string msg;
+} (python.immutable = "false")
+
service ServiceForExceptionWithAMap {
void methodThatThrowsAnException() throws (1: ExceptionWithAMap xwamap);
}
diff --git a/test/known_failures_Linux.json b/test/known_failures_Linux.json
index dd6195a..e165aa4 100644
--- a/test/known_failures_Linux.json
+++ b/test/known_failures_Linux.json
@@ -85,74 +85,109 @@
"cpp-nodejs_multij-json_http-domain",
"cpp-nodejs_multij-json_http-ip",
"cpp-nodejs_multij-json_http-ip-ssl",
+ "cpp-py3_binary-accel_http-domain",
"cpp-py3_binary-accel_http-ip",
"cpp-py3_binary-accel_http-ip-ssl",
+ "cpp-py3_binary_http-domain",
"cpp-py3_binary_http-ip",
"cpp-py3_binary_http-ip-ssl",
+ "cpp-py3_compact-accelc_http-domain",
"cpp-py3_compact-accelc_http-ip",
"cpp-py3_compact-accelc_http-ip-ssl",
+ "cpp-py3_compact_http-domain",
"cpp-py3_compact_http-ip",
"cpp-py3_compact_http-ip-ssl",
+ "cpp-py3_header_http-domain",
"cpp-py3_header_http-ip",
"cpp-py3_header_http-ip-ssl",
+ "cpp-py3_json_http-domain",
"cpp-py3_json_http-ip",
"cpp-py3_json_http-ip-ssl",
+ "cpp-py3_multi-accel_http-domain",
"cpp-py3_multi-accel_http-ip",
"cpp-py3_multi-accel_http-ip-ssl",
+ "cpp-py3_multi-binary_http-domain",
"cpp-py3_multi-binary_http-ip",
"cpp-py3_multi-binary_http-ip-ssl",
+ "cpp-py3_multi-multia_http-domain",
"cpp-py3_multi-multia_http-ip",
"cpp-py3_multi-multia_http-ip-ssl",
+ "cpp-py3_multi_http-domain",
"cpp-py3_multi_http-ip",
"cpp-py3_multi_http-ip-ssl",
+ "cpp-py3_multic-accelc_http-domain",
"cpp-py3_multic-accelc_http-ip",
"cpp-py3_multic-accelc_http-ip-ssl",
+ "cpp-py3_multic-compact_http-domain",
"cpp-py3_multic-compact_http-ip",
"cpp-py3_multic-compact_http-ip-ssl",
+ "cpp-py3_multic-multiac_http-domain",
"cpp-py3_multic-multiac_http-ip",
"cpp-py3_multic-multiac_http-ip-ssl",
+ "cpp-py3_multic_http-domain",
"cpp-py3_multic_http-ip",
"cpp-py3_multic_http-ip-ssl",
+ "cpp-py3_multih-header_http-domain",
"cpp-py3_multih-header_http-ip",
"cpp-py3_multih-header_http-ip-ssl",
+ "cpp-py3_multij-json_http-domain",
"cpp-py3_multij-json_http-ip",
"cpp-py3_multij-json_http-ip-ssl",
+ "cpp-py3_multij_http-domain",
"cpp-py3_multij_http-ip",
"cpp-py3_multij_http-ip-ssl",
+ "cpp-py_binary-accel_http-domain",
"cpp-py_binary-accel_http-ip",
"cpp-py_binary-accel_http-ip-ssl",
+ "cpp-py_binary_http-domain",
"cpp-py_binary_http-ip",
"cpp-py_binary_http-ip-ssl",
+ "cpp-py_compact-accelc_http-domain",
"cpp-py_compact-accelc_http-ip",
"cpp-py_compact-accelc_http-ip-ssl",
+ "cpp-py_compact_http-domain",
"cpp-py_compact_http-ip",
"cpp-py_compact_http-ip-ssl",
+ "cpp-py_header_http-domain",
"cpp-py_header_http-ip",
"cpp-py_header_http-ip-ssl",
+ "cpp-py_json_http-domain",
"cpp-py_json_http-ip",
"cpp-py_json_http-ip-ssl",
+ "cpp-py_multi-accel_http-domain",
"cpp-py_multi-accel_http-ip",
"cpp-py_multi-accel_http-ip-ssl",
+ "cpp-py_multi-binary_http-domain",
"cpp-py_multi-binary_http-ip",
"cpp-py_multi-binary_http-ip-ssl",
+ "cpp-py_multi-multia_http-domain",
"cpp-py_multi-multia_http-ip",
"cpp-py_multi-multia_http-ip-ssl",
+ "cpp-py_multi_http-domain",
"cpp-py_multi_http-ip",
"cpp-py_multi_http-ip-ssl",
+ "cpp-py_multic-accelc_http-domain",
"cpp-py_multic-accelc_http-ip",
"cpp-py_multic-accelc_http-ip-ssl",
+ "cpp-py_multic-compact_http-domain",
"cpp-py_multic-compact_http-ip",
"cpp-py_multic-compact_http-ip-ssl",
+ "cpp-py_multic-multiac_http-domain",
"cpp-py_multic-multiac_http-ip",
"cpp-py_multic-multiac_http-ip-ssl",
+ "cpp-py_multic_http-domain",
"cpp-py_multic_http-ip",
"cpp-py_multic_http-ip-ssl",
+ "cpp-py_multih-header_http-domain",
"cpp-py_multih-header_http-ip",
"cpp-py_multih-header_http-ip-ssl",
+ "cpp-py_multih_http-domain",
"cpp-py_multih_http-ip",
"cpp-py_multih_http-ip-ssl",
+ "cpp-py_multij-json_http-domain",
"cpp-py_multij-json_http-ip",
"cpp-py_multij-json_http-ip-ssl",
+ "cpp-py_multij_http-domain",
"cpp-py_multij_http-ip",
"cpp-py_multij_http-ip-ssl",
"cpp-rs_multi_buffered-ip",
@@ -389,76 +424,112 @@
"nodejs-lua_binary_http-ip",
"nodejs-lua_compact_http-ip",
"nodejs-lua_json_http-ip",
+ "nodejs-py3_binary-accel_http-domain",
"nodejs-py3_binary-accel_http-ip",
"nodejs-py3_binary-accel_http-ip-ssl",
+ "nodejs-py3_binary_http-domain",
"nodejs-py3_binary_http-ip",
"nodejs-py3_binary_http-ip-ssl",
+ "nodejs-py3_compact-accelc_http-domain",
"nodejs-py3_compact-accelc_http-ip",
"nodejs-py3_compact-accelc_http-ip-ssl",
+ "nodejs-py3_compact_http-domain",
"nodejs-py3_compact_http-ip",
"nodejs-py3_compact_http-ip-ssl",
+ "nodejs-py3_header_http-domain",
"nodejs-py3_header_http-ip",
"nodejs-py3_header_http-ip-ssl",
+ "nodejs-py3_json_http-domain",
"nodejs-py3_json_http-ip",
"nodejs-py3_json_http-ip-ssl",
+ "nodejs-py_binary-accel_http-domain",
"nodejs-py_binary-accel_http-ip",
"nodejs-py_binary-accel_http-ip-ssl",
+ "nodejs-py_binary_http-domain",
"nodejs-py_binary_http-ip",
"nodejs-py_binary_http-ip-ssl",
+ "nodejs-py_compact-accelc_http-domain",
"nodejs-py_compact-accelc_http-ip",
"nodejs-py_compact-accelc_http-ip-ssl",
+ "nodejs-py_compact_http-domain",
"nodejs-py_compact_http-ip",
"nodejs-py_compact_http-ip-ssl",
+ "nodejs-py_header_http-domain",
"nodejs-py_header_http-ip",
"nodejs-py_header_http-ip-ssl",
+ "nodejs-py_json_http-domain",
"nodejs-py_json_http-ip",
"nodejs-py_json_http-ip-ssl",
"perl-rs_multi_buffered-ip",
"perl-rs_multi_framed-ip",
+ "py-cpp_accel-binary_http-domain",
"py-cpp_accel-binary_http-ip",
"py-cpp_accel-binary_http-ip-ssl",
+ "py-cpp_accel-binary_zlib-domain",
"py-cpp_accel-binary_zlib-ip",
"py-cpp_accel-binary_zlib-ip-ssl",
+ "py-cpp_accelc-compact_http-domain",
"py-cpp_accelc-compact_http-ip",
"py-cpp_accelc-compact_http-ip-ssl",
+ "py-cpp_accelc-compact_zlib-domain",
"py-cpp_accelc-compact_zlib-ip",
"py-cpp_accelc-compact_zlib-ip-ssl",
+ "py-cpp_binary_http-domain",
"py-cpp_binary_http-ip",
"py-cpp_binary_http-ip-ssl",
+ "py-cpp_compact_http-domain",
"py-cpp_compact_http-ip",
"py-cpp_compact_http-ip-ssl",
+ "py-cpp_header_http-domain",
"py-cpp_header_http-ip",
"py-cpp_header_http-ip-ssl",
+ "py-cpp_json_http-domain",
"py-cpp_json_http-ip",
"py-cpp_json_http-ip-ssl",
+ "py-cpp_multi-binary_http-domain",
"py-cpp_multi-binary_http-ip",
"py-cpp_multi-binary_http-ip-ssl",
+ "py-cpp_multi_http-domain",
"py-cpp_multi_http-ip",
"py-cpp_multi_http-ip-ssl",
+ "py-cpp_multia-binary_http-domain",
"py-cpp_multia-binary_http-ip",
"py-cpp_multia-binary_http-ip-ssl",
+ "py-cpp_multia-binary_zlib-domain",
"py-cpp_multia-binary_zlib-ip",
"py-cpp_multia-binary_zlib-ip-ssl",
+ "py-cpp_multia-multi_http-domain",
"py-cpp_multia-multi_http-ip",
"py-cpp_multia-multi_http-ip-ssl",
+ "py-cpp_multia-multi_zlib-domain",
"py-cpp_multia-multi_zlib-ip",
"py-cpp_multia-multi_zlib-ip-ssl",
+ "py-cpp_multiac-compact_http-domain",
"py-cpp_multiac-compact_http-ip",
"py-cpp_multiac-compact_http-ip-ssl",
+ "py-cpp_multiac-compact_zlib-domain",
"py-cpp_multiac-compact_zlib-ip",
"py-cpp_multiac-compact_zlib-ip-ssl",
+ "py-cpp_multiac-multic_http-domain",
"py-cpp_multiac-multic_http-ip",
"py-cpp_multiac-multic_http-ip-ssl",
+ "py-cpp_multiac-multic_zlib-domain",
"py-cpp_multiac-multic_zlib-ip",
"py-cpp_multiac-multic_zlib-ip-ssl",
+ "py-cpp_multic-compact_http-domain",
"py-cpp_multic-compact_http-ip",
"py-cpp_multic-compact_http-ip-ssl",
+ "py-cpp_multic_http-domain",
"py-cpp_multic_http-ip",
"py-cpp_multic_http-ip-ssl",
+ "py-cpp_multih_http-domain",
+ "py-cpp_multih-header_http-domain",
"py-cpp_multih-header_http-ip",
"py-cpp_multih-header_http-ip-ssl",
+ "py-cpp_multij_http-domain",
"py-cpp_multih_http-ip",
"py-cpp_multih_http-ip-ssl",
+ "py-cpp_multij-json_http-domain",
"py-cpp_multij-json_http-ip",
"py-cpp_multij-json_http-ip-ssl",
"py-cpp_multij_http-ip",
@@ -504,6 +575,12 @@
"py-lua_binary_http-ip",
"py-lua_compact_http-ip",
"py-lua_json_http-ip",
+ "py-nodejs_accel-binary_http-domain",
+ "py-nodejs_accelc-compact_http-domain",
+ "py-nodejs_binary_http-domain",
+ "py-nodejs_compact_http-domain",
+ "py-nodejs_json_http-domain",
+ "py-nodejs_header_http-domain",
"py-rs_multi_buffered-ip",
"py-rs_multi_framed-ip",
"py-rs_multia-multi_buffered-ip",
@@ -512,52 +589,76 @@
"py-rs_multiac-multic_framed-ip",
"py-rs_multic_buffered-ip",
"py-rs_multic_framed-ip",
+ "py3-cpp_accel-binary_http-domain",
"py3-cpp_accel-binary_http-ip",
"py3-cpp_accel-binary_http-ip-ssl",
+ "py3-cpp_accel-binary_zlib-domain",
"py3-cpp_accel-binary_zlib-ip",
"py3-cpp_accel-binary_zlib-ip-ssl",
+ "py3-cpp_accelc-compact_http-domain",
"py3-cpp_accelc-compact_http-ip",
"py3-cpp_accelc-compact_http-ip-ssl",
+ "py3-cpp_accelc-compact_zlib-domain",
"py3-cpp_accelc-compact_zlib-ip",
"py3-cpp_accelc-compact_zlib-ip-ssl",
+ "py3-cpp_binary_http-domain",
"py3-cpp_binary_http-ip",
"py3-cpp_binary_http-ip-ssl",
+ "py3-cpp_compact_http-domain",
"py3-cpp_compact_http-ip",
"py3-cpp_compact_http-ip-ssl",
+ "py3-cpp_header_http-domain",
"py3-cpp_header_http-ip",
"py3-cpp_header_http-ip-ssl",
+ "py3-cpp_json_http-domain",
"py3-cpp_json_http-ip",
"py3-cpp_json_http-ip-ssl",
+ "py3-cpp_multi-binary_http-domain",
"py3-cpp_multi-binary_http-ip",
"py3-cpp_multi-binary_http-ip-ssl",
+ "py3-cpp_multi_http-domain",
"py3-cpp_multi_http-ip",
"py3-cpp_multi_http-ip-ssl",
+ "py3-cpp_multia-binary_http-domain",
"py3-cpp_multia-binary_http-ip",
"py3-cpp_multia-binary_http-ip-ssl",
+ "py3-cpp_multia-binary_zlib-domain",
"py3-cpp_multia-binary_zlib-ip",
"py3-cpp_multia-binary_zlib-ip-ssl",
+ "py3-cpp_multia-multi_http-domain",
"py3-cpp_multia-multi_http-ip",
"py3-cpp_multia-multi_http-ip-ssl",
+ "py3-cpp_multia-multi_zlib-domain",
"py3-cpp_multia-multi_zlib-ip",
"py3-cpp_multia-multi_zlib-ip-ssl",
+ "py3-cpp_multiac-compact_http-domain",
"py3-cpp_multiac-compact_http-ip",
"py3-cpp_multiac-compact_http-ip-ssl",
+ "py3-cpp_multiac-compact_zlib-domain",
"py3-cpp_multiac-compact_zlib-ip",
"py3-cpp_multiac-compact_zlib-ip-ssl",
+ "py3-cpp_multiac-multic_http-domain",
"py3-cpp_multiac-multic_http-ip",
"py3-cpp_multiac-multic_http-ip-ssl",
+ "py3-cpp_multiac-multic_zlib-domain",
"py3-cpp_multiac-multic_zlib-ip",
"py3-cpp_multiac-multic_zlib-ip-ssl",
+ "py3-cpp_multic-compact_http-domain",
"py3-cpp_multic-compact_http-ip",
"py3-cpp_multic-compact_http-ip-ssl",
+ "py3-cpp_multic_http-domain",
"py3-cpp_multic_http-ip",
"py3-cpp_multic_http-ip-ssl",
+ "py3-cpp_multih-header_http-domain",
"py3-cpp_multih-header_http-ip",
"py3-cpp_multih-header_http-ip-ssl",
+ "py3-cpp_multih_http-domain",
"py3-cpp_multih_http-ip",
"py3-cpp_multih_http-ip-ssl",
+ "py3-cpp_multij-json_http-domain",
"py3-cpp_multij-json_http-ip",
"py3-cpp_multij-json_http-ip-ssl",
+ "py3-cpp_multij_http-domain",
"py3-cpp_multij_http-ip",
"py3-cpp_multij_http-ip-ssl",
"py3-d_accel-binary_http-ip",
@@ -601,6 +702,12 @@
"py3-lua_binary_http-ip",
"py3-lua_compact_http-ip",
"py3-lua_json_http-ip",
+ "py3-nodejs_accel-binary_http-domain",
+ "py3-nodejs_accelc-compact_http-domain",
+ "py3-nodejs_binary_http-domain",
+ "py3-nodejs_compact_http-domain",
+ "py3-nodejs_json_http-domain",
+ "py3-nodejs_header_http-domain",
"py3-rs_multi_buffered-ip",
"py3-rs_multi_framed-ip",
"py3-rs_multia-multi_buffered-ip",
@@ -615,4 +722,4 @@
"rb-cpp_json_framed-domain",
"rb-cpp_json_framed-ip",
"rb-cpp_json_framed-ip-ssl"
-]
\ No newline at end of file
+]
diff --git a/test/netstd/Client/Client.csproj b/test/netstd/Client/Client.csproj
index ed30c30..4ed57cb 100644
--- a/test/netstd/Client/Client.csproj
+++ b/test/netstd/Client/Client.csproj
@@ -19,7 +19,7 @@
-->
<PropertyGroup>
- <TargetFramework>netcoreapp2.0</TargetFramework>
+ <TargetFramework>netcoreapp3.1</TargetFramework>
<AssemblyName>Client</AssemblyName>
<PackageId>Client</PackageId>
<OutputType>Exe</OutputType>
@@ -31,9 +31,9 @@
<GenerateAssemblyCopyrightAttribute>false</GenerateAssemblyCopyrightAttribute>
</PropertyGroup>
<ItemGroup>
- <PackageReference Include="System.Net.Http.WinHttpHandler" Version="4.5.2" />
+ <PackageReference Include="System.Net.Http.WinHttpHandler" Version="4.7.0" />
<PackageReference Include="System.Runtime.Serialization.Primitives" Version="[4.3,)" />
- <PackageReference Include="System.ServiceModel.Primitives" Version="4.5.3" />
+ <PackageReference Include="System.ServiceModel.Primitives" Version="4.7.0" />
<PackageReference Include="System.Threading" Version="[4.3,)" />
</ItemGroup>
<ItemGroup>
diff --git a/test/netstd/Client/Performance/PerformanceTests.cs b/test/netstd/Client/Performance/PerformanceTests.cs
index 041d12e..2c79aa6 100644
--- a/test/netstd/Client/Performance/PerformanceTests.cs
+++ b/test/netstd/Client/Performance/PerformanceTests.cs
@@ -20,6 +20,7 @@
using System.Text;
using ThriftTest;
using Thrift.Collections;
+using Thrift;
using Thrift.Protocol;
using System.Threading;
using Thrift.Transport.Client;
@@ -36,6 +37,7 @@
private TMemoryBufferTransport MemBuffer;
private TTransport Transport;
private LayeredChoice Layered;
+ private readonly TConfiguration Configuration = new TConfiguration();
internal static int Execute()
{
@@ -52,6 +54,11 @@
return 0;
}
+ public PerformanceTests()
+ {
+ Configuration.MaxFrameSize = Configuration.MaxMessageSize; // default frame size is too small for this test
+ }
+
private async Task ProtocolPeformanceTestAsync()
{
Console.WriteLine("Setting up for ProtocolPeformanceTestAsync ...");
@@ -61,10 +68,9 @@
foreach (var layered in Enum.GetValues(typeof(LayeredChoice)))
{
Layered = (LayeredChoice)layered;
-
await RunTestAsync(async (bool b) => { return await GenericProtocolFactory<TBinaryProtocol>(b); });
await RunTestAsync(async (bool b) => { return await GenericProtocolFactory<TCompactProtocol>(b); });
- //await RunTestAsync(async (bool b) => { return await GenericProtocolFactory<TJsonProtocol>(b); });
+ await RunTestAsync(async (bool b) => { return await GenericProtocolFactory<TJsonProtocol>(b); });
}
}
@@ -76,9 +82,9 @@
{
// read happens after write here, so let's take over the written bytes
if (forWrite)
- MemBuffer = new TMemoryBufferTransport();
+ MemBuffer = new TMemoryBufferTransport(Configuration);
else
- MemBuffer = new TMemoryBufferTransport(MemBuffer.GetBuffer());
+ MemBuffer = new TMemoryBufferTransport(MemBuffer.GetBuffer(), Configuration);
// layered transports anyone?
switch (Layered)
diff --git a/test/netstd/Client/Program.cs b/test/netstd/Client/Program.cs
index 62933e6..92000da 100644
--- a/test/netstd/Client/Program.cs
+++ b/test/netstd/Client/Program.cs
@@ -34,37 +34,30 @@
Console.WriteLine("Failed to grow scroll-back buffer");
}
- // split mode and options
- var subArgs = new List<string>(args);
- var firstArg = string.Empty;
- if (subArgs.Count > 0)
- {
- firstArg = subArgs[0];
- subArgs.RemoveAt(0);
- }
-
- // run whatever mode is choosen
- switch(firstArg)
+ // run whatever mode is choosen, default to test impl
+ var firstArg = args.Length > 0 ? args[0] : string.Empty;
+ switch (firstArg)
{
case "client":
- return TestClient.Execute(subArgs);
- case "performance":
+ Console.WriteLine("The 'client' argument is no longer required.");
+ PrintHelp();
+ return -1;
+ case "--performance":
+ case "--performance-test":
return Tests.PerformanceTests.Execute();
case "--help":
PrintHelp();
return 0;
default:
- Console.WriteLine("Invalid argument: {0}", firstArg);
- PrintHelp();
- return -1;
+ return TestClient.Execute(new List<string>(args));
}
}
private static void PrintHelp()
{
Console.WriteLine("Usage:");
- Console.WriteLine(" Client client [options]");
- Console.WriteLine(" Client performance");
+ Console.WriteLine(" Client [options]");
+ Console.WriteLine(" Client --performance-test");
Console.WriteLine(" Client --help");
Console.WriteLine("");
diff --git a/test/netstd/Client/TestClient.cs b/test/netstd/Client/TestClient.cs
index 0f58f95..3eab865 100644
--- a/test/netstd/Client/TestClient.cs
+++ b/test/netstd/Client/TestClient.cs
@@ -28,6 +28,7 @@
using System.Text;
using System.Threading;
using System.Threading.Tasks;
+using Thrift;
using Thrift.Collections;
using Thrift.Protocol;
using Thrift.Transport;
@@ -72,6 +73,7 @@
public LayeredChoice layered = LayeredChoice.None;
public ProtocolChoice protocol = ProtocolChoice.Binary;
public TransportChoice transport = TransportChoice.Socket;
+ private readonly TConfiguration Configuration = null; // or new TConfiguration() if needed
internal void Parse(List<string> args)
{
@@ -235,12 +237,12 @@
{
case TransportChoice.Http:
Debug.Assert(url != null);
- trans = new THttpTransport(new Uri(url), null);
+ trans = new THttpTransport(new Uri(url), Configuration);
break;
case TransportChoice.NamedPipe:
Debug.Assert(pipe != null);
- trans = new TNamedPipeTransport(pipe);
+ trans = new TNamedPipeTransport(pipe,Configuration);
break;
case TransportChoice.TlsSocket:
@@ -250,14 +252,15 @@
throw new InvalidOperationException("Certificate doesn't contain private key");
}
- trans = new TTlsSocketTransport(host, port, 0, cert,
+ trans = new TTlsSocketTransport(host, port, Configuration, 0,
+ cert,
(sender, certificate, chain, errors) => true,
null, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12);
break;
case TransportChoice.Socket:
default:
- trans = new TSocketTransport(host, port);
+ trans = new TSocketTransport(host, port, Configuration);
break;
}
@@ -443,12 +446,12 @@
Normal, // Fairly small array of usual size (256 bytes)
Large, // Large writes/reads may cause range check errors
PipeWriteLimit, // Windows Limit: Pipe write operations across a network are limited to 65,535 bytes per write.
- TwentyMB // that's quite a bit of data
+ FifteenMB // that's quite a bit of data
};
public static byte[] PrepareTestData(bool randomDist, BinaryTestSize testcase)
{
- int amount = -1;
+ int amount;
switch (testcase)
{
case BinaryTestSize.Empty:
@@ -463,8 +466,8 @@
case BinaryTestSize.PipeWriteLimit:
amount = 0xFFFF + 128;
break;
- case BinaryTestSize.TwentyMB:
- amount = 20 * 1024 * 1024;
+ case BinaryTestSize.FifteenMB:
+ amount = 15 * 1024 * 1024;
break;
default:
throw new ArgumentException(nameof(testcase));
@@ -622,26 +625,29 @@
{
Console.WriteLine("*** FAILED ***");
returnCode |= ErrorContainers;
- throw new Exception("CrazyNesting.Equals failed");
}
}
// TODO: Validate received message
Console.Write("testStruct({\"Zero\", 1, -3, -5})");
- var o = new Xtruct();
- o.String_thing = "Zero";
- o.Byte_thing = (sbyte)1;
- o.I32_thing = -3;
- o.I64_thing = -5;
+ var o = new Xtruct
+ {
+ String_thing = "Zero",
+ Byte_thing = (sbyte)1,
+ I32_thing = -3,
+ I64_thing = -5
+ };
var i = await client.testStructAsync(o, MakeTimeoutToken());
Console.WriteLine(" = {\"" + i.String_thing + "\", " + i.Byte_thing + ", " + i.I32_thing + ", " + i.I64_thing + "}");
// TODO: Validate received message
Console.Write("testNest({1, {\"Zero\", 1, -3, -5}, 5})");
- var o2 = new Xtruct2();
- o2.Byte_thing = (sbyte)1;
- o2.Struct_thing = o;
- o2.I32_thing = 5;
+ var o2 = new Xtruct2
+ {
+ Byte_thing = (sbyte)1,
+ Struct_thing = o,
+ I32_thing = 5
+ };
var i2 = await client.testNestAsync(o2, MakeTimeoutToken());
i = i2.Struct_thing;
Console.WriteLine(" = {" + i2.Byte_thing + ", {\"" + i.String_thing + "\", " + i.Byte_thing + ", " + i.I32_thing + ", " + i.I64_thing + "}, " + i2.I32_thing + "}");
@@ -838,16 +844,24 @@
Console.WriteLine("}");
// TODO: Validate received message
- var insane = new Insanity();
- insane.UserMap = new Dictionary<Numberz, long>();
- insane.UserMap[Numberz.FIVE] = 5000L;
- var truck = new Xtruct();
- truck.String_thing = "Truck";
- truck.Byte_thing = (sbyte)8;
- truck.I32_thing = 8;
- truck.I64_thing = 8;
- insane.Xtructs = new List<Xtruct>();
- insane.Xtructs.Add(truck);
+ var insane = new Insanity
+ {
+ UserMap = new Dictionary<Numberz, long>
+ {
+ [Numberz.FIVE] = 5000L
+ }
+ };
+ var truck = new Xtruct
+ {
+ String_thing = "Truck",
+ Byte_thing = (sbyte)8,
+ I32_thing = 8,
+ I64_thing = 8
+ };
+ insane.Xtructs = new List<Xtruct>
+ {
+ truck
+ };
Console.Write("testInsanity()");
var whoa = await client.testInsanityAsync(insane, MakeTimeoutToken());
Console.Write(" = {");
@@ -902,8 +916,10 @@
sbyte arg0 = 1;
var arg1 = 2;
var arg2 = long.MaxValue;
- var multiDict = new Dictionary<short, string>();
- multiDict[1] = "one";
+ var multiDict = new Dictionary<short, string>
+ {
+ [1] = "one"
+ };
var tmpMultiDict = new List<string>();
foreach (var pair in multiDict)
diff --git a/test/netstd/README.md b/test/netstd/README.md
index ed728d1..4ece059 100644
--- a/test/netstd/README.md
+++ b/test/netstd/README.md
@@ -1,12 +1,12 @@
# Apache Thrift net-core-lib tests
-Tests for Thrift client library ported to Microsoft .Net Core
+Tests for Thrift client library ported to Microsoft .NET Core
# Content
- ThriftTest - tests for Thrift library
# Reused components
-- NET Core Standard 1.6 (SDK 2.0.0)
+- NET Core SDK 3.1 (LTS)
# How to build on Windows
- Get Thrift IDL compiler executable, add to some folder and add path to this folder into PATH variable
@@ -15,6 +15,6 @@
- Build with scripts
# How to build on Unix
-- Ensure you have .NET Core 2.0.0 SDK installed or use the Ubuntu Xenial docker image
+- Ensure you have .NET Core 3.0 SDK installed or use the Ubuntu Xenial docker image
- Follow common build practice for Thrift: bootstrap, configure, and make precross
diff --git a/test/netstd/Server/Program.cs b/test/netstd/Server/Program.cs
index 8bfa371..1b8ffd4 100644
--- a/test/netstd/Server/Program.cs
+++ b/test/netstd/Server/Program.cs
@@ -34,34 +34,26 @@
Console.WriteLine("Failed to grow scroll-back buffer");
}
- // split mode and options
- var subArgs = new List<string>(args);
- var firstArg = string.Empty;
- if (subArgs.Count > 0)
- {
- firstArg = subArgs[0];
- subArgs.RemoveAt(0);
- }
-
- // run whatever mode is choosen
- switch(firstArg)
+ // run whatever mode is choosen, default to test impl
+ var firstArg = args.Length > 0 ? args[0] : string.Empty;
+ switch (firstArg)
{
case "server":
- return TestServer.Execute(subArgs);
+ Console.WriteLine("The 'server' argument is no longer required.");
+ PrintHelp();
+ return -1;
case "--help":
PrintHelp();
return 0;
default:
- Console.WriteLine("Invalid argument: {0}", firstArg);
- PrintHelp();
- return -1;
+ return TestServer.Execute(new List<string>( args));
}
}
private static void PrintHelp()
{
Console.WriteLine("Usage:");
- Console.WriteLine(" Server server [options]'");
+ Console.WriteLine(" Server [options]");
Console.WriteLine(" Server --help");
Console.WriteLine("");
diff --git a/test/netstd/Server/Server.csproj b/test/netstd/Server/Server.csproj
index 44f46c9..fa5ce46 100644
--- a/test/netstd/Server/Server.csproj
+++ b/test/netstd/Server/Server.csproj
@@ -19,7 +19,7 @@
-->
<PropertyGroup>
- <TargetFramework>netcoreapp2.0</TargetFramework>
+ <TargetFramework>netcoreapp3.1</TargetFramework>
<AssemblyName>Server</AssemblyName>
<PackageId>Server</PackageId>
<OutputType>Exe</OutputType>
@@ -33,9 +33,9 @@
<ItemGroup>
<PackageReference Include="System.IO.Pipes" Version="4.3.0" />
<PackageReference Include="System.IO.Pipes.AccessControl" Version="4.5.1" />
- <PackageReference Include="System.Net.Http.WinHttpHandler" Version="4.5.2" />
+ <PackageReference Include="System.Net.Http.WinHttpHandler" Version="4.7.0" />
<PackageReference Include="System.Runtime.Serialization.Primitives" Version="[4.3,)" />
- <PackageReference Include="System.ServiceModel.Primitives" Version="4.5.3" />
+ <PackageReference Include="System.ServiceModel.Primitives" Version="4.7.0" />
<PackageReference Include="System.Threading" Version="[4.3,)" />
</ItemGroup>
<ItemGroup>
@@ -49,4 +49,4 @@
<Exec Condition="Exists('thrift')" Command="thrift -out $(ProjectDir) -gen netstd:wcf,union,serial -r ./../../ThriftTest.thrift" />
<Exec Condition="Exists('$(ProjectDir)/../../../compiler/cpp/thrift')" Command="$(ProjectDir)/../../../compiler/cpp/thrift -out $(ProjectDir) -gen netstd:wcf,union,serial -r ./../../ThriftTest.thrift" />
</Target>
-</Project>
\ No newline at end of file
+</Project>
diff --git a/test/netstd/Server/TestServer.cs b/test/netstd/Server/TestServer.cs
index 25c2afc..68461dc 100644
--- a/test/netstd/Server/TestServer.cs
+++ b/test/netstd/Server/TestServer.cs
@@ -148,6 +148,8 @@
public class TestServer
{
public static int _clientID = -1;
+ private static readonly TConfiguration Configuration = null; // or new TConfiguration() if needed
+
public delegate void TestLogDelegate(string msg, params object[] values);
public class MyServerEventHandler : TServerEventHandler
@@ -181,19 +183,19 @@
public class TestHandlerAsync : ThriftTest.IAsync
{
- public TServer server { get; set; }
- private int handlerID;
- private StringBuilder sb = new StringBuilder();
- private TestLogDelegate logger;
+ public TServer Server { get; set; }
+ private readonly int handlerID;
+ private readonly StringBuilder sb = new StringBuilder();
+ private readonly TestLogDelegate logger;
public TestHandlerAsync()
{
handlerID = Interlocked.Increment(ref _clientID);
- logger += testConsoleLogger;
+ logger += TestConsoleLogger;
logger.Invoke("New TestHandler instance created");
}
- public void testConsoleLogger(string msg, params object[] values)
+ public void TestConsoleLogger(string msg, params object[] values)
{
sb.Clear();
sb.AppendFormat("handler{0:D3}:", handlerID);
@@ -525,117 +527,122 @@
public static int Execute(List<string> args)
{
- var loggerFactory = new LoggerFactory();//.AddConsole().AddDebug();
- var logger = new LoggerFactory().CreateLogger("Test");
-
- try
+ using (var loggerFactory = new LoggerFactory()) //.AddConsole().AddDebug();
{
- var param = new ServerParam();
+ var logger = loggerFactory.CreateLogger("Test");
try
{
- param.Parse(args);
+ var param = new ServerParam();
+
+ try
+ {
+ param.Parse(args);
+ }
+ catch (Exception ex)
+ {
+ Console.WriteLine("*** FAILED ***");
+ Console.WriteLine("Error while parsing arguments");
+ Console.WriteLine(ex.Message + " ST: " + ex.StackTrace);
+ return 1;
+ }
+
+
+ // Endpoint transport (mandatory)
+ TServerTransport trans;
+ switch (param.transport)
+ {
+ case TransportChoice.NamedPipe:
+ Debug.Assert(param.pipe != null);
+ trans = new TNamedPipeServerTransport(param.pipe, Configuration);
+ break;
+
+
+ case TransportChoice.TlsSocket:
+ var cert = GetServerCert();
+ if (cert == null || !cert.HasPrivateKey)
+ {
+ cert?.Dispose();
+ throw new InvalidOperationException("Certificate doesn't contain private key");
+ }
+
+ trans = new TTlsServerSocketTransport(param.port, Configuration,
+ cert,
+ (sender, certificate, chain, errors) => true,
+ null, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12);
+ break;
+
+ case TransportChoice.Socket:
+ default:
+ trans = new TServerSocketTransport(param.port, Configuration);
+ break;
+ }
+
+ // Layered transport (mandatory)
+ TTransportFactory transFactory = null;
+ switch (param.buffering)
+ {
+ case BufferChoice.Framed:
+ transFactory = new TFramedTransport.Factory();
+ break;
+ case BufferChoice.Buffered:
+ transFactory = new TBufferedTransport.Factory();
+ break;
+ default:
+ Debug.Assert(param.buffering == BufferChoice.None, "unhandled case");
+ transFactory = null; // no layered transprt
+ break;
+ }
+
+ // Protocol (mandatory)
+ TProtocolFactory proto;
+ switch (param.protocol)
+ {
+ case ProtocolChoice.Compact:
+ proto = new TCompactProtocol.Factory();
+ break;
+ case ProtocolChoice.Json:
+ proto = new TJsonProtocol.Factory();
+ break;
+ case ProtocolChoice.Binary:
+ default:
+ proto = new TBinaryProtocol.Factory();
+ break;
+ }
+
+ // Processor
+ var testHandler = new TestHandlerAsync();
+ var testProcessor = new ThriftTest.AsyncProcessor(testHandler);
+ var processorFactory = new TSingletonProcessorFactory(testProcessor);
+
+ TServer serverEngine = new TSimpleAsyncServer(processorFactory, trans, transFactory, transFactory, proto, proto, logger);
+
+ //Server event handler
+ var serverEvents = new MyServerEventHandler();
+ serverEngine.SetEventHandler(serverEvents);
+
+ // Run it
+ var where = (!string.IsNullOrEmpty(param.pipe)) ? "on pipe " + param.pipe : "on port " + param.port;
+ Console.WriteLine("Starting the AsyncBaseServer " + where +
+ " with processor TPrototypeProcessorFactory prototype factory " +
+ (param.buffering == BufferChoice.Buffered ? " with buffered transport" : "") +
+ (param.buffering == BufferChoice.Framed ? " with framed transport" : "") +
+ (param.transport == TransportChoice.TlsSocket ? " with encryption" : "") +
+ (param.protocol == ProtocolChoice.Compact ? " with compact protocol" : "") +
+ (param.protocol == ProtocolChoice.Json ? " with json protocol" : "") +
+ "...");
+ serverEngine.ServeAsync(CancellationToken.None).GetAwaiter().GetResult();
+ Console.ReadLine();
}
- catch (Exception ex)
+ catch (Exception x)
{
- Console.WriteLine("*** FAILED ***");
- Console.WriteLine("Error while parsing arguments");
- Console.WriteLine(ex.Message + " ST: " + ex.StackTrace);
+ Console.Error.Write(x);
return 1;
}
-
- // Endpoint transport (mandatory)
- TServerTransport trans;
- switch (param.transport)
- {
- case TransportChoice.NamedPipe:
- Debug.Assert(param.pipe != null);
- trans = new TNamedPipeServerTransport(param.pipe);
- break;
-
-
- case TransportChoice.TlsSocket:
- var cert = GetServerCert();
- if (cert == null || !cert.HasPrivateKey)
- {
- throw new InvalidOperationException("Certificate doesn't contain private key");
- }
-
- trans = new TTlsServerSocketTransport( param.port, cert,
- (sender, certificate, chain, errors) => true,
- null, SslProtocols.Tls | SslProtocols.Tls11 | SslProtocols.Tls12);
- break;
-
- case TransportChoice.Socket:
- default:
- trans = new TServerSocketTransport(param.port, 0);
- break;
- }
-
- // Layered transport (mandatory)
- TTransportFactory transFactory = null;
- switch (param.buffering)
- {
- case BufferChoice.Framed:
- transFactory = new TFramedTransport.Factory();
- break;
- case BufferChoice.Buffered:
- transFactory = new TBufferedTransport.Factory();
- break;
- default:
- Debug.Assert(param.buffering == BufferChoice.None, "unhandled case");
- transFactory = null; // no layered transprt
- break;
- }
-
- // Protocol (mandatory)
- TProtocolFactory proto;
- switch (param.protocol)
- {
- case ProtocolChoice.Compact:
- proto = new TCompactProtocol.Factory();
- break;
- case ProtocolChoice.Json:
- proto = new TJsonProtocol.Factory();
- break;
- case ProtocolChoice.Binary:
- default:
- proto = new TBinaryProtocol.Factory();
- break;
- }
-
- // Processor
- var testHandler = new TestHandlerAsync();
- var testProcessor = new ThriftTest.AsyncProcessor(testHandler);
- var processorFactory = new TSingletonProcessorFactory(testProcessor);
-
- TServer serverEngine = new TSimpleAsyncServer(processorFactory, trans, transFactory, transFactory, proto, proto, logger);
-
- //Server event handler
- var serverEvents = new MyServerEventHandler();
- serverEngine.SetEventHandler(serverEvents);
-
- // Run it
- var where = (! string.IsNullOrEmpty(param.pipe)) ? "on pipe " + param.pipe : "on port " + param.port;
- Console.WriteLine("Starting the AsyncBaseServer " + where +
- " with processor TPrototypeProcessorFactory prototype factory " +
- (param.buffering == BufferChoice.Buffered ? " with buffered transport" : "") +
- (param.buffering == BufferChoice.Framed ? " with framed transport" : "") +
- (param.transport == TransportChoice.TlsSocket ? " with encryption" : "") +
- (param.protocol == ProtocolChoice.Compact ? " with compact protocol" : "") +
- (param.protocol == ProtocolChoice.Json ? " with json protocol" : "") +
- "...");
- serverEngine.ServeAsync(CancellationToken.None).GetAwaiter().GetResult();
- Console.ReadLine();
+ Console.WriteLine("done.");
+ return 0;
}
- catch (Exception x)
- {
- Console.Error.Write(x);
- return 1;
- }
- Console.WriteLine("done.");
- return 0;
}
}
diff --git a/test/netstd/ThriftTest.sln b/test/netstd/ThriftTest.sln
index 6bd0855..352576e 100644
--- a/test/netstd/ThriftTest.sln
+++ b/test/netstd/ThriftTest.sln
@@ -4,9 +4,9 @@
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Thrift", "..\..\lib\netstd\Thrift\Thrift.csproj", "{C20EA2A9-7660-47DE-9A49-D1EF12FB2895}"
EndProject
-Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Client", "Client\Client.csproj", "{21039F25-6ED7-4E80-A545-EBC93472EBD1}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Client", "Client\Client.csproj", "{21039F25-6ED7-4E80-A545-EBC93472EBD1}"
EndProject
-Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Server", "Server\Server.csproj", "{0C6E8685-F191-4479-9842-882A38961127}"
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Server", "Server\Server.csproj", "{0C6E8685-F191-4479-9842-882A38961127}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
diff --git a/test/py.tornado/test_suite.py b/test/py.tornado/test_suite.py
index 447fde6..0ee0a9b 100755
--- a/test/py.tornado/test_suite.py
+++ b/test/py.tornado/test_suite.py
@@ -82,10 +82,7 @@
def testException(self, s):
if s == 'Xception':
- x = Xception()
- x.errorCode = 1001
- x.message = s
- raise x
+ raise Xception(1001, s)
elif s == 'throw_undeclared':
raise ValueError('testing undeclared exception')
diff --git a/test/py.twisted/test_suite.py b/test/py.twisted/test_suite.py
index 02eb7f1..6e04493 100755
--- a/test/py.twisted/test_suite.py
+++ b/test/py.twisted/test_suite.py
@@ -76,10 +76,7 @@
def testException(self, s):
if s == 'Xception':
- x = Xception()
- x.errorCode = 1001
- x.message = s
- raise x
+ raise Xception(1001, s)
elif s == "throw_undeclared":
raise ValueError("foo")
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index e7a9a1a..8a30c3a 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -51,7 +51,7 @@
from thrift.transport import TSSLSocket
socket = TSSLSocket.TSSLSocket(options.host, options.port, validate=False)
else:
- socket = TSocket.TSocket(options.host, options.port)
+ socket = TSocket.TSocket(options.host, options.port, options.domain_socket)
# frame or buffer depending upon args
self.transport = TTransport.TBufferedTransport(socket)
if options.trans == 'framed':
@@ -474,6 +474,8 @@
help="protocol to use, one of: accel, accelc, binary, compact, header, json, multi, multia, multiac, multic, multih, multij")
parser.add_option('--transport', dest="trans", type="string",
help="transport to use, one of: buffered, framed, http")
+ parser.add_option('--domain-socket', dest="domain_socket", type="string",
+ help="Unix domain socket path")
parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary')
options, args = parser.parse_args()
diff --git a/test/py/TestFrozen.py b/test/py/TestFrozen.py
index 6d2595c..ce7425f 100755
--- a/test/py/TestFrozen.py
+++ b/test/py/TestFrozen.py
@@ -19,7 +19,9 @@
# under the License.
#
+from DebugProtoTest import Srv
from DebugProtoTest.ttypes import CompactProtoTestStruct, Empty, Wrapper
+from DebugProtoTest.ttypes import ExceptionWithAMap, MutableException
from thrift.Thrift import TFrozenDict
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol, TCompactProtocol
@@ -94,6 +96,21 @@
x2 = self._roundtrip(x, Wrapper)
self.assertEqual(x2.foo, Empty())
+ def test_frozen_exception(self):
+ exc = ExceptionWithAMap(blah='foo')
+ with self.assertRaises(TypeError):
+ exc.blah = 'bar'
+ mutexc = MutableException(msg='foo')
+ mutexc.msg = 'bar'
+ self.assertEqual(mutexc.msg, 'bar')
+
+ def test_frozen_exception_serialization(self):
+ result = Srv.declaredExceptionMethod_result(
+ xwamap=ExceptionWithAMap(blah="error"))
+ deserialized = self._roundtrip(
+ result, Srv.declaredExceptionMethod_result())
+ self.assertEqual(result, deserialized)
+
class TestFrozen(TestFrozenBase):
def protocol(self, trans):
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index d0a13e5..4d90f8f 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -307,7 +307,7 @@
from thrift.transport import TSSLSocket
transport = TSSLSocket.TSSLServerSocket(host, options.port, certfile=abs_key_path)
else:
- transport = TSocket.TServerSocket(host, options.port)
+ transport = TSocket.TServerSocket(host, options.port, options.domain_socket)
tfactory = TTransport.TBufferedTransportFactory()
if options.trans == 'buffered':
tfactory = TTransport.TBufferedTransportFactory()
@@ -385,6 +385,8 @@
help="protocol to use, one of: accel, accelc, binary, compact, json, multi, multia, multiac, multic, multih, multij")
parser.add_option('--transport', dest="trans", type="string",
help="transport to use, one of: buffered, framed, http")
+ parser.add_option('--domain-socket', dest="domain_socket", type="string",
+ help="Unix domain socket path")
parser.add_option('--container-limit', dest='container_limit', type='int', default=None)
parser.add_option('--string-limit', dest='string_limit', type='int', default=None)
parser.set_defaults(port=9090, verbose=1, proto='binary', transport='buffered')
diff --git a/test/tests.json b/test/tests.json
index 78d4c0e..b8b85be 100644
--- a/test/tests.json
+++ b/test/tests.json
@@ -264,7 +264,8 @@
],
"sockets": [
"ip",
- "ip-ssl"
+ "ip-ssl",
+ "domain"
],
"protocols": [
"binary",
@@ -313,7 +314,8 @@
],
"sockets": [
"ip",
- "ip-ssl"
+ "ip-ssl",
+ "domain"
],
"protocols": [
"binary",
diff --git a/tutorial/delphi/DelphiClient/DelphiClient.dpr b/tutorial/delphi/DelphiClient/DelphiClient.dpr
index 4ea9eb3..64d7d68 100644
--- a/tutorial/delphi/DelphiClient/DelphiClient.dpr
+++ b/tutorial/delphi/DelphiClient/DelphiClient.dpr
@@ -26,6 +26,7 @@
Generics.Collections,
Thrift in '..\..\..\lib\delphi\src\Thrift.pas',
Thrift.Collections in '..\..\..\lib\delphi\src\Thrift.Collections.pas',
+ Thrift.Configuration in '..\..\..\lib\delphi\src\Thrift.Configuration.pas',
Thrift.Exception in '..\..\..\lib\delphi\src\Thrift.Exception.pas',
Thrift.Utils in '..\..\..\lib\delphi\src\Thrift.Utils.pas',
Thrift.Stream in '..\..\..\lib\delphi\src\Thrift.Stream.pas',
diff --git a/tutorial/delphi/DelphiClient/DelphiClient.dproj b/tutorial/delphi/DelphiClient/DelphiClient.dproj
index f9adf85..2612f14 100644
--- a/tutorial/delphi/DelphiClient/DelphiClient.dproj
+++ b/tutorial/delphi/DelphiClient/DelphiClient.dproj
@@ -52,6 +52,7 @@
</DelphiCompile>
<DCCReference Include="..\..\..\lib\delphi\src\Thrift.pas"/>
<DCCReference Include="..\..\..\lib\delphi\src\Thrift.Collections.pas"/>
+ <DCCReference Include="..\..\..\lib\delphi\src\Thrift.Configuration.pas"/>
<DCCReference Include="..\..\..\lib\delphi\src\Thrift.Exception.pas"/>
<DCCReference Include="..\..\..\lib\delphi\src\Thrift.Utils.pas"/>
<DCCReference Include="..\..\..\lib\delphi\src\Thrift.Stream.pas"/>
diff --git a/tutorial/delphi/DelphiServer/DelphiServer.dpr b/tutorial/delphi/DelphiServer/DelphiServer.dpr
index fc9997a..41a3514 100644
--- a/tutorial/delphi/DelphiServer/DelphiServer.dpr
+++ b/tutorial/delphi/DelphiServer/DelphiServer.dpr
@@ -28,6 +28,7 @@
Generics.Collections,
Thrift in '..\..\..\lib\delphi\src\Thrift.pas',
Thrift.Collections in '..\..\..\lib\delphi\src\Thrift.Collections.pas',
+ Thrift.Configuration in '..\..\..\lib\delphi\src\Thrift.Configuration.pas',
Thrift.Exception in '..\..\..\lib\delphi\src\Thrift.Exception.pas',
Thrift.Utils in '..\..\..\lib\delphi\src\Thrift.Utils.pas',
Thrift.Stream in '..\..\..\lib\delphi\src\Thrift.Stream.pas',
diff --git a/tutorial/delphi/DelphiServer/DelphiServer.dproj b/tutorial/delphi/DelphiServer/DelphiServer.dproj
index 132d1bf..f62257e 100644
--- a/tutorial/delphi/DelphiServer/DelphiServer.dproj
+++ b/tutorial/delphi/DelphiServer/DelphiServer.dproj
@@ -51,6 +51,7 @@
</DelphiCompile>
<DCCReference Include="..\..\..\lib\delphi\src\Thrift.pas"/>
<DCCReference Include="..\..\..\lib\delphi\src\Thrift.Collections.pas"/>
+ <DCCReference Include="..\..\..\lib\delphi\src\Thrift.Configuration.pas"/>
<DCCReference Include="..\..\..\lib\delphi\src\Thrift.Exception.pas"/>
<DCCReference Include="..\..\..\lib\delphi\src\Thrift.Utils.pas"/>
<DCCReference Include="..\..\..\lib\delphi\src\Thrift.Stream.pas"/>
diff --git a/tutorial/netstd/Client/Client.csproj b/tutorial/netstd/Client/Client.csproj
index a1470a9..10d5040 100644
--- a/tutorial/netstd/Client/Client.csproj
+++ b/tutorial/netstd/Client/Client.csproj
@@ -19,7 +19,7 @@
-->
<PropertyGroup>
- <TargetFramework>netcoreapp2.0</TargetFramework>
+ <TargetFramework>netcoreapp3.1</TargetFramework>
<AssemblyName>Client</AssemblyName>
<PackageId>Client</PackageId>
<OutputType>Exe</OutputType>
@@ -30,7 +30,7 @@
</PropertyGroup>
<ItemGroup>
- <PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="2.2.0" />
+ <PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="3.1.0" />
</ItemGroup>
<ItemGroup>
diff --git a/tutorial/netstd/Client/Program.cs b/tutorial/netstd/Client/Program.cs
index f9509fa..857b3e8 100644
--- a/tutorial/netstd/Client/Program.cs
+++ b/tutorial/netstd/Client/Program.cs
@@ -40,6 +40,7 @@
{
private static ServiceCollection ServiceCollection = new ServiceCollection();
private static ILogger Logger;
+ private static readonly TConfiguration Configuration = null; // new TConfiguration() if needed
private static void DisplayHelp()
{
@@ -143,7 +144,7 @@
private static TTransport GetTransport(string[] args)
{
- TTransport transport = new TSocketTransport(IPAddress.Loopback, 9090);
+ TTransport transport = new TSocketTransport(IPAddress.Loopback, 9090, Configuration);
// construct endpoint transport
var transportArg = args.FirstOrDefault(x => x.StartsWith("-tr"))?.Split(':')?[1];
@@ -152,19 +153,20 @@
switch (selectedTransport)
{
case Transport.Tcp:
- transport = new TSocketTransport(IPAddress.Loopback, 9090);
+ transport = new TSocketTransport(IPAddress.Loopback, 9090, Configuration);
break;
case Transport.NamedPipe:
- transport = new TNamedPipeTransport(".test");
+ transport = new TNamedPipeTransport(".test", Configuration);
break;
case Transport.Http:
- transport = new THttpTransport(new Uri("http://localhost:9090"), null);
+ transport = new THttpTransport(new Uri("http://localhost:9090"), Configuration);
break;
case Transport.TcpTls:
- transport = new TTlsSocketTransport(IPAddress.Loopback, 9090, GetCertificate(), CertValidator, LocalCertificateSelectionCallback);
+ transport = new TTlsSocketTransport(IPAddress.Loopback, 9090, Configuration,
+ GetCertificate(), CertValidator, LocalCertificateSelectionCallback);
break;
default:
diff --git a/tutorial/netstd/Interfaces/Interfaces.csproj b/tutorial/netstd/Interfaces/Interfaces.csproj
index 4ebeb4f..c8b2bd8 100644
--- a/tutorial/netstd/Interfaces/Interfaces.csproj
+++ b/tutorial/netstd/Interfaces/Interfaces.csproj
@@ -33,7 +33,7 @@
</ItemGroup>
<ItemGroup>
- <PackageReference Include="System.ServiceModel.Primitives" Version="4.5.3" />
+ <PackageReference Include="System.ServiceModel.Primitives" Version="4.7.0" />
</ItemGroup>
<Target Name="PreBuild" BeforeTargets="_GenerateRestoreProjectSpec;Restore;Compile">
diff --git a/tutorial/netstd/README.md b/tutorial/netstd/README.md
index b1dea4e..297f4ee 100644
--- a/tutorial/netstd/README.md
+++ b/tutorial/netstd/README.md
@@ -1,11 +1,10 @@
# Building of samples for different platforms
-# Reused components
-- NET Core Standard 2.0
-- NET Core App 2.0
+# Requirements
+- NET Core Standard 3.1 (LTS) runtime or SDK (see below for further info)
# How to build
-- Download and install the latest .NET Core SDK for your platform https://www.microsoft.com/net/core#windowsvs2015 (archive for SDK 1.0.0-preview2-003121 located by: https://github.com/dotnet/core/blob/master/release-notes/download-archive.md)
+- Download and install the latest .NET Core SDK for your platform https://dotnet.microsoft.com/download/dotnet-core
- Ensure that you have thrift.exe which supports netstd lib and it added to PATH
- Go to current folder
- Run **build.sh** or **build.cmd** from the root of cloned repository
@@ -14,29 +13,26 @@
# How to run
-Notes: dotnet run supports passing arguments to app after -- symbols (https://docs.microsoft.com/en-us/dotnet/articles/core/tools/dotnet-run) - example: **dotnet run -- -h** will show help for app
+Depending on the platform, the name of the generated executables will vary. On Linux, it is just "client" or "server", on Windows it is "Client.exe" and "Server.exe". In the following, we use the abbreviated form "Client" and "Server".
- build
- go to folder (Client/Server)
-- run with specifying of correct parameters **dotnet run -tr:tcp -pr:multiplexed**, **dotnet run -help** (later, after migration to csproj and latest SDK will be possibility to use more usable form **dotnet run -- arguments**)
-
-#Notes
-- Possible adding additional platforms after stabilization of .NET Core (runtimes, platforms (Red Hat Linux, OpenSuse, etc.)
+- run the generated executables: server first, then client from a second console
#Known issues
- In trace logging mode you can see some not important internal exceptions
# Running of samples
-Please install Thrift C# .NET Core library or copy sources and build them to correcly build and run samples
+On machines that do not have the SDK installed, you need to install the NET Core runtime first. The SDK is only needed to build programs, otherwise the runtime is sufficient.
# NetCore Server
Usage:
- Server.exe -h
+ Server -h
will diplay help information
- Server.exe -tr:<transport> -pr:<protocol>
+ Server -tr:<transport> -pr:<protocol>
will run server with specified arguments (tcp transport and binary protocol by default)
Options:
@@ -59,7 +55,7 @@
Sample:
- Server.exe -tr:tcp
+ Server -tr:tcp
**Remarks**:
@@ -72,10 +68,10 @@
Usage:
- Client.exe -h
+ Client -h
will diplay help information
- Client.exe -tr:<transport> -pr:<protocol> -mc:<numClients>
+ Client -tr:<transport> -pr:<protocol> -mc:<numClients>
will run client with specified arguments (tcp transport and binary protocol by default)
Options:
@@ -101,7 +97,7 @@
Sample:
- Client.exe -tr:tcp -pr:binary -mc:10
+ Client -tr:tcp -pr:binary -mc:10
Remarks:
@@ -111,8 +107,8 @@
# How to test communication between NetCore and Python
-* Generate code with the latest **thrift.exe** util
-* Ensure that **thrift.exe** util generated folder **gen-py** with generated code for Python
+* Generate code with the latest **thrift** utility
+* Ensure that **thrift** generated folder **gen-py** with generated code for Python exists
* Create **client.py** and **server.py** from the code examples below and save them to the folder with previosly generated folder **gen-py**
* Run netstd samples (client and server) and python samples (client and server)
diff --git a/tutorial/netstd/Server/Program.cs b/tutorial/netstd/Server/Program.cs
index 25e7dae..c1e0cb3 100644
--- a/tutorial/netstd/Server/Program.cs
+++ b/tutorial/netstd/Server/Program.cs
@@ -44,6 +44,7 @@
{
private static ServiceCollection ServiceCollection = new ServiceCollection();
private static ILogger Logger;
+ private static readonly TConfiguration Configuration = null; // new TConfiguration() if needed
public static void Main(string[] args)
{
@@ -163,13 +164,14 @@
switch (transport)
{
case Transport.Tcp:
- serverTransport = new TServerSocketTransport(9090);
+ serverTransport = new TServerSocketTransport(9090, Configuration);
break;
case Transport.NamedPipe:
- serverTransport = new TNamedPipeServerTransport(".test");
+ serverTransport = new TNamedPipeServerTransport(".test", Configuration);
break;
case Transport.TcpTls:
- serverTransport = new TTlsServerSocketTransport(9090, GetCertificate(), ClientCertValidator, LocalCertificateSelectionCallback);
+ serverTransport = new TTlsServerSocketTransport(9090, Configuration,
+ GetCertificate(), ClientCertValidator, LocalCertificateSelectionCallback);
break;
}
@@ -346,7 +348,7 @@
public class Startup
{
- public Startup(IHostingEnvironment env)
+ public Startup(IWebHostEnvironment env)
{
var builder = new ConfigurationBuilder()
.SetBasePath(env.ContentRootPath)
@@ -366,7 +368,7 @@
}
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
- public void Configure(IApplicationBuilder app, IHostingEnvironment env, ILoggerFactory loggerFactory)
+ public void Configure(IApplicationBuilder app, IWebHostEnvironment env, ILoggerFactory loggerFactory)
{
app.UseMiddleware<THttpServerTransport>();
}
diff --git a/tutorial/netstd/Server/Server.csproj b/tutorial/netstd/Server/Server.csproj
index fbc2c03..b3ff516 100644
--- a/tutorial/netstd/Server/Server.csproj
+++ b/tutorial/netstd/Server/Server.csproj
@@ -19,7 +19,7 @@
-->
<PropertyGroup>
- <TargetFramework>netcoreapp2.0</TargetFramework>
+ <TargetFramework>netcoreapp3.1</TargetFramework>
<AssemblyName>Server</AssemblyName>
<PackageId>Server</PackageId>
<OutputType>Exe</OutputType>
@@ -38,7 +38,7 @@
<PackageReference Include="Microsoft.AspNetCore" Version="2.2.0" />
<PackageReference Include="Microsoft.AspNetCore.Server.IISIntegration" Version="2.2.1" />
<PackageReference Include="Microsoft.AspNetCore.Server.Kestrel" Version="2.2.0" />
- <PackageReference Include="Microsoft.Extensions.Configuration.FileExtensions" Version="2.2.0" />
+ <PackageReference Include="Microsoft.Extensions.Configuration.FileExtensions" Version="3.1.0" />
</ItemGroup>
</Project>