Thrift: Native-code Binary Protocol encoder.
Summary:
Merging a patch from Ben Maurer.
This adds a python extension (i.e., a C module) that
encodes Python thrift structs into the standard binary protocol
much faster than our generated Python code.
Also added by-value equality comparison to thrift structs
(to help with testing).
Cleaned up some trailing whitespace too.
Reviewed By: mcslee, dreiss
Test Plan:
Recompiled Thrift.
Thrifted a bunch of IDLs and compared the generated Python output.
Looked at the extension module a lot.
test/FastBinaryTest.py
Revert Plan: ok
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665224 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc
index aeabeef..5ed2567 100644
--- a/compiler/cpp/src/generate/t_py_generator.cc
+++ b/compiler/cpp/src/generate/t_py_generator.cc
@@ -8,6 +8,7 @@
#include <sys/stat.h>
#include <sys/types.h>
#include <sstream>
+#include <algorithm>
#include "t_py_generator.h"
using namespace std;
@@ -48,7 +49,10 @@
f_types_ <<
py_autogen_comment() << endl <<
py_imports() << endl <<
- render_includes() << endl;
+ render_includes() << endl <<
+ "from thrift.transport import TTransport" << endl <<
+ "from thrift.protocol import fastbinary" << endl <<
+ "from thrift.protocol import TBinaryProtocol" << endl;
f_consts_ <<
py_autogen_comment() << endl <<
@@ -118,7 +122,7 @@
f_types_ <<
"class " << tenum->get_name() << ":" << endl;
indent_up();
-
+
vector<t_enum_value*> constants = tenum->get_constants();
vector<t_enum_value*>::iterator c_iter;
int value = -1;
@@ -144,7 +148,7 @@
t_type* type = tconst->get_type();
string name = tconst->get_name();
t_const_value* value = tconst->get_value();
-
+
indent(f_consts_) << name << " = " << render_const_value(type, value);
f_consts_ << endl << endl;
}
@@ -268,7 +272,7 @@
* @param txception The struct definition
*/
void t_py_generator::generate_xception(t_struct* txception) {
- generate_py_struct(txception, true);
+ generate_py_struct(txception, true);
}
/**
@@ -280,6 +284,19 @@
}
/**
+ * Comparator to sort fields in ascending order by key.
+ * Make this a functor instead of a function to help GCC inline it.
+ * The arguments are (const) references to const pointers to const t_fields.
+ * Unfortunately, we cannot declare it within the function. Boo!
+ * http://www.open-std.org/jtc1/sc22/open/n2356/ (paragraph 9).
+ */
+struct FieldKeyCompare {
+ bool operator()(t_field const * const & a, t_field const * const & b) {
+ return a->get_key() < b->get_key();
+ }
+};
+
+/**
* Generates a struct definition for a thrift data type. This is nothing in PHP
* where the objects are all just associative arrays (unless of course we
* decide to start using objects for them...)
@@ -290,8 +307,11 @@
t_struct* tstruct,
bool is_exception,
bool is_result) {
+
const vector<t_field*>& members = tstruct->get_members();
- vector<t_field*>::const_iterator m_iter;
+ vector<t_field*>::const_iterator m_iter;
+ vector<t_field*> sorted_members(members);
+ std::sort(sorted_members.begin(), sorted_members.end(), FieldKeyCompare());
out <<
"class " << tstruct->get_name();
@@ -304,6 +324,53 @@
out << endl;
+ /*
+ Here we generate the structure specification for the fastbinary codec.
+ These specifications have the following structure:
+ thrift_spec -> tuple of item_spec
+ item_spec -> None | (tag, type_enum, name, spec_args, default)
+ tag -> integer
+ type_enum -> TType.I32 | TType.STRING | TType.STRUCT | ...
+ name -> string_literal
+ default -> None # Handled by __init__
+ spec_args -> None # For simple types
+ | (type_enum, spec_args) # Value type for list/set
+ | (type_enum, spec_args, type_enum, spec_args)
+ # Key and value for map
+ | (class_name, spec_args_ptr) # For struct/exception
+ class_name -> identifier # Basically a pointer to the class
+ spec_args_ptr -> expression # just class_name.spec_args
+
+ TODO(dreiss): Consider making this work for structs with negative tags.
+ */
+
+ if (sorted_members.empty() || (sorted_members[0]->get_key() >= 0)) {
+ indent(out) << "thrift_spec = (" << endl;
+ indent_up();
+
+ int sorted_keys_pos = 0;
+ for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
+
+ for (; sorted_keys_pos != (*m_iter)->get_key(); sorted_keys_pos++) {
+ indent(out) << "None, # " << sorted_keys_pos << endl;
+ }
+
+ indent(out) << "(" << (*m_iter)->get_key() << ", "
+ << type_to_enum((*m_iter)->get_type()) << ", "
+ << "'" << (*m_iter)->get_name() << "'" << ", "
+ << type_to_spec_args((*m_iter)->get_type()) << ", "
+ << "None" << ", "
+ << "),"
+ << " # " << sorted_keys_pos
+ << endl;
+
+ sorted_keys_pos ++;
+ }
+
+ indent_down();
+ indent(out) << ")" << endl << endl;
+ }
+
out <<
indent() << "def __init__(self, d=None):" << endl;
indent_up();
@@ -330,9 +397,10 @@
}
indent_down();
-
+
out << endl;
+
generate_py_struct_reader(out, tstruct);
generate_py_struct_writer(out, tstruct);
@@ -346,6 +414,24 @@
indent() << " return repr(self.__dict__)" << endl <<
endl;
+ // Equality and inequality methods that compare by value
+ out <<
+ indent() << "def __eq__(self, other):" << endl;
+ indent_up();
+ out <<
+ indent() << "return isinstance(other, self.__class__) and "
+ "self.__dict__ == other.__dict__" << endl;
+ indent_down();
+ out << endl;
+
+ out <<
+ indent() << "def __ne__(self, other):" << endl;
+ indent_up();
+ out <<
+ indent() << "return not (self == other)" << endl;
+ indent_down();
+ out << endl;
+
indent_down();
}
@@ -360,15 +446,26 @@
indent(out) <<
"def read(self, iprot):" << endl;
indent_up();
-
+
indent(out) <<
- "iprot.readStructBegin()" << endl;
+ "if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated "
+ "and isinstance(iprot.trans, TTransport.CReadableTransport):" << endl;
+ indent_up();
+
+ indent(out) <<
+ "fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))" << endl;
+ indent(out) <<
+ "return" << endl;
+ indent_down();
+
+ indent(out) <<
+ "iprot.readStructBegin()" << endl;
// Loop over reading in fields
indent(out) <<
"while True:" << endl;
indent_up();
-
+
// Read beginning field marker
indent(out) <<
"(fname, ftype, fid) = iprot.readFieldBegin()" << endl;
@@ -380,10 +477,10 @@
indent(out) <<
"break" << endl;
indent_down();
-
+
// Switch statement on the field we are reading
bool first = true;
-
+
// Generate deserialization code for known cases
for (f_iter = fields.begin(); f_iter != fields.end(); ++f_iter) {
if (first) {
@@ -405,18 +502,18 @@
indent() << " iprot.skip(ftype)" << endl;
indent_down();
}
-
+
// In the default case we skip the field
out <<
indent() << "else:" << endl <<
indent() << " iprot.skip(ftype)" << endl;
-
+
// Read field end marker
indent(out) <<
"iprot.readFieldEnd()" << endl;
-
+
indent_down();
-
+
indent(out) <<
"iprot.readStructEnd()" << endl;
@@ -433,7 +530,17 @@
indent(out) <<
"def write(self, oprot):" << endl;
indent_up();
-
+
+ indent(out) <<
+ "if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated:" << endl;
+ indent_up();
+
+ indent(out) <<
+ "oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))" << endl;
+ indent(out) <<
+ "return" << endl;
+ indent_down();
+
indent(out) <<
"oprot.writeStructBegin('" << name << "')" << endl;
@@ -487,8 +594,11 @@
}
f_service_ <<
- "from ttypes import *" << endl <<
+ "from ttypes import *" << endl <<
"from thrift.Thrift import TProcessor" << endl <<
+ "from thrift.transport import TTransport" << endl <<
+ "from thrift.protocol import fastbinary" << endl <<
+ "from thrift.protocol import TBinaryProtocol" << endl <<
endl;
// Generate the three main parts of the service (well, two for now in PHP)
@@ -560,7 +670,7 @@
"class Iface" << extends_if << ":" << endl;
indent_up();
vector<t_function*> functions = tservice->get_functions();
- vector<t_function*>::iterator f_iter;
+ vector<t_function*>::iterator f_iter;
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
f_service_ <<
indent() << "def " << function_signature(*f_iter) << ":" << endl <<
@@ -606,7 +716,7 @@
// Generate client method implementations
vector<t_function*> functions = tservice->get_functions();
- vector<t_function*>::const_iterator f_iter;
+ vector<t_function*>::const_iterator f_iter;
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
t_struct* arg_struct = (*f_iter)->get_arglist();
const vector<t_field*>& fields = arg_struct->get_members();
@@ -651,27 +761,27 @@
// Serialize the request header
f_service_ <<
indent() << "self._oprot.writeMessageBegin('" << (*f_iter)->get_name() << "', TMessageType.CALL, self._seqid)" << endl;
-
+
f_service_ <<
indent() << "args = " << argsname << "()" << endl;
-
+
for (fld_iter = fields.begin(); fld_iter != fields.end(); ++fld_iter) {
f_service_ <<
indent() << "args." << (*fld_iter)->get_name() << " = " << (*fld_iter)->get_name() << endl;
}
-
+
// Write to the stream
f_service_ <<
indent() << "args.write(self._oprot)" << endl <<
indent() << "self._oprot.writeMessageEnd()" << endl <<
- indent() << "self._oprot.trans.flush()" << endl;
+ indent() << "self._oprot.trans.flush()" << endl;
indent_down();
if (!(*f_iter)->is_async()) {
std::string resultname = (*f_iter)->get_name() + "_result";
t_struct noargs(program_);
-
+
t_function recv_function((*f_iter)->get_returntype(),
string("recv_") + (*f_iter)->get_name(),
&noargs);
@@ -719,12 +829,12 @@
} else {
f_service_ <<
indent() << "raise TApplicationException(TApplicationException.MISSING_RESULT, \"" << (*f_iter)->get_name() << " failed: unknown result\");" << endl;
- }
+ }
// Close function
indent_down();
- f_service_ << endl;
- }
+ f_service_ << endl;
+ }
}
indent_down();
@@ -739,7 +849,7 @@
*/
void t_py_generator::generate_service_remote(t_service* tservice) {
vector<t_function*> functions = tservice->get_functions();
- vector<t_function*>::iterator f_iter;
+ vector<t_function*>::iterator f_iter;
string f_remote_name = package_dir_+"/"+service_name_+"-remote";
ofstream f_remote;
@@ -759,7 +869,7 @@
f_remote <<
"import " << service_name_ << endl <<
- "from ttypes import *" << endl <<
+ "from ttypes import *" << endl <<
endl;
f_remote <<
@@ -782,11 +892,11 @@
} else {
f_remote << ", ";
}
- f_remote <<
+ f_remote <<
args[i]->get_type()->get_name() << " " << args[i]->get_name();
}
f_remote << ")'" << endl;
- }
+ }
f_remote <<
" print ''" << endl <<
" sys.exit(0)" << endl <<
@@ -838,7 +948,7 @@
"client = " << service_name_ << ".Client(protocol)" << endl <<
"transport.open()" << endl <<
endl;
-
+
// Generate the dispatch methods
bool first = true;
@@ -868,15 +978,15 @@
}
}
f_remote << "))" << endl;
-
+
f_remote << endl;
}
f_remote << "transport.close()" << endl;
-
+
// Close service file
f_remote.close();
-
+
// Make file executable, love that bitwise OR action
chmod(f_remote_name.c_str(),
S_IRUSR |
@@ -896,7 +1006,7 @@
void t_py_generator::generate_service_server(t_service* tservice) {
// Generate the dispatch methods
vector<t_function*> functions = tservice->get_functions();
- vector<t_function*>::iterator f_iter;
+ vector<t_function*>::iterator f_iter;
string extends = "";
string extends_processor = "";
@@ -924,10 +1034,10 @@
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
f_service_ <<
indent() << "self._processMap[\"" << (*f_iter)->get_name() << "\"] = Processor.process_" << (*f_iter)->get_name() << endl;
- }
+ }
indent_down();
f_service_ << endl;
-
+
// Generate the server implementation
indent(f_service_) <<
"def process(self, iprot, oprot):" << endl;
@@ -1005,7 +1115,7 @@
indent() << "try:" << endl;
indent_up();
}
-
+
// Generate the function call
t_struct* arg_struct = tfunction->get_arglist();
const std::vector<t_field*>& fields = arg_struct->get_members();
@@ -1090,7 +1200,7 @@
} else if (type->is_base_type() || type->is_enum()) {
indent(out) <<
name << " = iprot.";
-
+
if (type->is_base_type()) {
t_base_type::t_base tbase = ((t_base_type*)type)->get_base();
switch (tbase) {
@@ -1098,7 +1208,7 @@
throw "compiler error: cannot serialize void field in a struct: " +
name;
break;
- case t_base_type::TYPE_STRING:
+ case t_base_type::TYPE_STRING:
out << "readString();";
break;
case t_base_type::TYPE_BOOL:
@@ -1130,7 +1240,7 @@
} else {
printf("DO NOT KNOW HOW TO DESERIALIZE FIELD '%s' TYPE '%s'\n",
tfield->get_name().c_str(), type->get_name().c_str());
- }
+ }
}
/**
@@ -1155,7 +1265,7 @@
string ktype = tmp("_ktype");
string vtype = tmp("_vtype");
string etype = tmp("_etype");
-
+
t_field fsize(g_type_i32, size);
t_field fktype(g_type_byte, ktype);
t_field fvtype(g_type_byte, vtype);
@@ -1180,9 +1290,9 @@
string i = tmp("_i");
indent(out) <<
"for " << i << " in xrange(" << size << "):" << endl;
-
+
indent_up();
-
+
if (ttype->is_map()) {
generate_deserialize_map_element(out, (t_map*)ttype, prefix);
} else if (ttype->is_set()) {
@@ -1190,7 +1300,7 @@
} else if (ttype->is_list()) {
generate_deserialize_list_element(out, (t_list*)ttype, prefix);
}
-
+
indent_down();
// Read container end
@@ -1269,7 +1379,7 @@
throw "CANNOT GENERATE SERIALIZE CODE FOR void TYPE: " +
prefix + tfield->get_name();
}
-
+
if (type->is_struct() || type->is_xception()) {
generate_serialize_struct(out,
(t_struct*)type,
@@ -1284,7 +1394,7 @@
indent(out) <<
"oprot.";
-
+
if (type->is_base_type()) {
t_base_type::t_base tbase = ((t_base_type*)type)->get_base();
switch (tbase) {
@@ -1365,27 +1475,27 @@
if (ttype->is_map()) {
string kiter = tmp("kiter");
string viter = tmp("viter");
- indent(out) <<
+ indent(out) <<
"for " << kiter << "," << viter << " in " << prefix << ".items():" << endl;
indent_up();
generate_serialize_map_element(out, (t_map*)ttype, kiter, viter);
indent_down();
} else if (ttype->is_set()) {
string iter = tmp("iter");
- indent(out) <<
+ indent(out) <<
"for " << iter << " in " << prefix << ":" << endl;
indent_up();
generate_serialize_set_element(out, (t_set*)ttype, iter);
indent_down();
} else if (ttype->is_list()) {
string iter = tmp("iter");
- indent(out) <<
+ indent(out) <<
"for " << iter << " in " << prefix << ":" << endl;
indent_up();
generate_serialize_list_element(out, (t_list*)ttype, iter);
indent_down();
}
-
+
if (ttype->is_map()) {
indent(out) <<
"oprot.writeMapEnd()" << endl;
@@ -1500,7 +1610,7 @@
*/
string t_py_generator::type_to_enum(t_type* type) {
type = get_true_type(type);
-
+
if (type->is_base_type()) {
t_base_type::t_base tbase = ((t_base_type*)type)->get_base();
switch (tbase) {
@@ -1535,3 +1645,37 @@
throw "INVALID TYPE IN type_to_enum: " + type->get_name();
}
+
+/** See the comment inside generate_py_struct_definition for what this is. */
+string t_py_generator::type_to_spec_args(t_type* ttype) {
+ while (ttype->is_typedef()) {
+ ttype = ((t_typedef*)ttype)->get_type();
+ }
+
+ if (ttype->is_base_type() || ttype->is_enum()) {
+ return "None";
+ } else if (ttype->is_struct() || ttype->is_xception()) {
+ return "(" + type_name(ttype) + ", " + type_name(ttype) + ".thrift_spec)";
+ } else if (ttype->is_map()) {
+ return "(" +
+ type_to_enum(((t_map*)ttype)->get_key_type()) + "," +
+ type_to_spec_args(((t_map*)ttype)->get_key_type()) + "," +
+ type_to_enum(((t_map*)ttype)->get_val_type()) + "," +
+ type_to_spec_args(((t_map*)ttype)->get_val_type()) +
+ ")";
+
+ } else if (ttype->is_set()) {
+ return "(" +
+ type_to_enum(((t_set*)ttype)->get_elem_type()) + "," +
+ type_to_spec_args(((t_set*)ttype)->get_elem_type()) +
+ ")";
+
+ } else if (ttype->is_list()) {
+ return "(" +
+ type_to_enum(((t_list*)ttype)->get_elem_type()) + "," +
+ type_to_spec_args(((t_list*)ttype)->get_elem_type()) +
+ ")";
+ }
+
+ throw "INVALID TYPE IN type_to_spec_args: " + ttype->get_name();
+}
diff --git a/compiler/cpp/src/generate/t_py_generator.h b/compiler/cpp/src/generate/t_py_generator.h
index 87b3f47..b301d36 100644
--- a/compiler/cpp/src/generate/t_py_generator.h
+++ b/compiler/cpp/src/generate/t_py_generator.h
@@ -72,18 +72,18 @@
*/
void generate_deserialize_field (std::ofstream &out,
- t_field* tfield,
+ t_field* tfield,
std::string prefix="",
bool inclass=false);
-
+
void generate_deserialize_struct (std::ofstream &out,
t_struct* tstruct,
std::string prefix="");
-
+
void generate_deserialize_container (std::ofstream &out,
t_type* ttype,
std::string prefix="");
-
+
void generate_deserialize_set_element (std::ofstream &out,
t_set* tset,
std::string prefix="");
@@ -133,6 +133,7 @@
std::string function_signature(t_function* tfunction, std::string prefix="");
std::string argument_list(t_struct* tstruct);
std::string type_to_enum(t_type* ttype);
+ std::string type_to_spec_args(t_type* ttype);
private:
@@ -141,7 +142,7 @@
*/
std::ofstream f_types_;
- std::ofstream f_consts_;
+ std::ofstream f_consts_;
std::ofstream f_service_;
std::string package_dir_;
diff --git a/lib/py/setup.py b/lib/py/setup.py
index 8ff1645..582a985 100644
--- a/lib/py/setup.py
+++ b/lib/py/setup.py
@@ -6,7 +6,11 @@
# See accompanying file LICENSE or visit the Thrift site at:
# http://developers.facebook.com/thrift/
-from distutils.core import setup
+from distutils.core import setup, Extension
+
+fastbinarymod = Extension('thrift.protocol.fastbinary',
+ sources = ['src/protocol/fastbinary.c'],
+ )
setup(name = 'Thrift',
version = '1.0',
@@ -16,5 +20,6 @@
url = 'http://code.facebook.com/thrift',
packages = ['thrift', 'thrift.protocol', 'thrift.transport', 'thrift.server'],
package_dir = {'thrift' : 'src'},
+ ext_modules = [fastbinarymod],
)
diff --git a/lib/py/src/protocol/TBinaryProtocol.py b/lib/py/src/protocol/TBinaryProtocol.py
index 6ae0c86..3fd6b02 100644
--- a/lib/py/src/protocol/TBinaryProtocol.py
+++ b/lib/py/src/protocol/TBinaryProtocol.py
@@ -77,7 +77,7 @@
self.writeByte(1)
else:
self.writeByte(0)
-
+
def writeByte(self, byte):
buff = pack("!b", byte)
self.trans.write(buff)
@@ -89,7 +89,7 @@
def writeI32(self, i32):
buff = pack("!i", i32)
self.trans.write(buff)
-
+
def writeI64(self, i64):
buff = pack("!q", i64)
self.trans.write(buff)
@@ -199,6 +199,7 @@
str = self.trans.readAll(len)
return str
+
class TBinaryProtocolFactory:
def __init__(self, strictRead=False, strictWrite=True):
self.strictRead = strictRead
@@ -207,3 +208,32 @@
def getProtocol(self, trans):
prot = TBinaryProtocol(trans, self.strictRead, self.strictWrite)
return prot
+
+
+class TBinaryProtocolAccelerated(TBinaryProtocol):
+
+ """C-Accelerated version of TBinaryProtocol.
+
+ This class does not override any of TBinaryProtocol's methods,
+ but the generated code recognizes it directly and will call into
+ our C module to do the encoding, bypassing this object entirely.
+ We inherit from TBinaryProtocol so that the normal TBinaryProtocol
+ encoding can happen if the fastbinary module doesn't work for some
+ reason. (TODO(dreiss): Make this happen sanely.)
+
+ In order to take advantage of the C module, just use
+ TBinaryProtocolAccelerated instead of TBinaryProtocol.
+
+ NOTE: This code was contributed by an external developer.
+ The internal Thrift team has reviewed and tested it,
+ but we cannot guarantee that it is production-ready.
+ Please feel free to report bugs and/or success stories
+ to the public mailing list.
+ """
+
+ pass
+
+
+class TBinaryProtocolAcceleratedFactory:
+ def getProtocol(self, trans):
+ return TBinaryProtocolAccelerated(trans)
diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py
index bcc981d..11ae3a7 100644
--- a/lib/py/src/protocol/__init__.py
+++ b/lib/py/src/protocol/__init__.py
@@ -6,4 +6,4 @@
# See accompanying file LICENSE or visit the Thrift site at:
# http://developers.facebook.com/thrift/
-__all__ = ['TProtocol', 'TBinaryProtocol']
+__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary']
diff --git a/lib/py/src/protocol/fastbinary.c b/lib/py/src/protocol/fastbinary.c
new file mode 100644
index 0000000..cfc504e
--- /dev/null
+++ b/lib/py/src/protocol/fastbinary.c
@@ -0,0 +1,1146 @@
+// Copyright (c) 2006- Facebook
+// Distributed under the Thrift Software License
+//
+// See accompanying file LICENSE or visit the Thrift site at:
+// http://developers.facebook.com/thrift/
+//
+// NOTE: This code was contributed by an external developer.
+// The internal Thrift team has reviewed and tested it,
+// but we cannot guarantee that it is production-ready.
+// Please feel free to report bugs and/or success stories
+// to the public mailing list.
+
+#include <Python.h>
+#include "cStringIO.h"
+#include <stdbool.h>
+#include <stdint.h>
+#include <netinet/in.h>
+
+// TODO(dreiss): defval appears to be unused. Look into removing it.
+// TODO(dreiss): Make parse_spec_args recursive, and cache the output
+// permanently in the object. (Malloc and orphan.)
+// TODO(dreiss): Why do we need cStringIO for reading, why not just char*?
+// Can cStringIO let us work with a BufferedTransport?
+// TODO(dreiss): Don't ignore the rv from cwrite (maybe).
+
+/* ====== BEGIN UTILITIES ====== */
+
+#define INIT_OUTBUF_SIZE 128
+
+// Stolen out of TProtocol.h.
+// It would be a huge pain to have both get this from one place.
+typedef enum TType {
+ T_STOP = 0,
+ T_VOID = 1,
+ T_BOOL = 2,
+ T_BYTE = 3,
+ T_I08 = 3,
+ T_I16 = 6,
+ T_I32 = 8,
+ T_U64 = 9,
+ T_I64 = 10,
+ T_DOUBLE = 4,
+ T_STRING = 11,
+ T_UTF7 = 11,
+ T_STRUCT = 12,
+ T_MAP = 13,
+ T_SET = 14,
+ T_LIST = 15,
+ T_UTF8 = 16,
+ T_UTF16 = 17
+} TType;
+
+// Same comment as the enum. Sorry.
+#if __BYTE_ORDER == __BIG_ENDIAN
+# define ntohll(n) (n)
+# define htonll(n) (n)
+#elif __BYTE_ORDER == __LITTLE_ENDIAN
+# if defined(__GNUC__) && defined(__GLIBC__)
+# include <byteswap.h>
+# define ntohll(n) bswap_64(n)
+# define htonll(n) bswap_64(n)
+# else /* GNUC & GLIBC */
+# define ntohll(n) ( (((unsigned long long)ntohl(n)) << 32) + ntohl(n >> 32) )
+# define htonll(n) ( (((unsigned long long)htonl(n)) << 32) + htonl(n >> 32) )
+# endif /* GNUC & GLIBC */
+#else /* __BYTE_ORDER */
+# error "Can't define htonll or ntohll!"
+#endif
+
+// Doing a benchmark shows that interning actually makes a difference, amazingly.
+#define INTERN_STRING(value) _intern_ ## value
+
+#define INT_CONV_ERROR_OCCURRED(v) ( ((v) == -1) && PyErr_Occurred() )
+#define CHECK_RANGE(v, min, max) ( ((v) <= (max)) && ((v) >= (min)) )
+
+/**
+ * A cache of the spec_args for a set or list,
+ * so we don't have to keep calling PyTuple_GET_ITEM.
+ */
+typedef struct {
+ TType element_type;
+ PyObject* typeargs;
+} SetListTypeArgs;
+
+/**
+ * A cache of the spec_args for a map,
+ * so we don't have to keep calling PyTuple_GET_ITEM.
+ */
+typedef struct {
+ TType ktag;
+ TType vtag;
+ PyObject* ktypeargs;
+ PyObject* vtypeargs;
+} MapTypeArgs;
+
+/**
+ * A cache of the spec_args for a struct,
+ * so we don't have to keep calling PyTuple_GET_ITEM.
+ */
+typedef struct {
+ PyObject* klass;
+ PyObject* spec;
+} StructTypeArgs;
+
+/**
+ * A cache of the item spec from a struct specification,
+ * so we don't have to keep calling PyTuple_GET_ITEM.
+ */
+typedef struct {
+ int tag;
+ TType type;
+ PyObject* attrname;
+ PyObject* typeargs;
+ PyObject* defval;
+} StructItemSpec;
+
+/**
+ * A cache of the two key attributes of a CReadableTransport,
+ * so we don't have to keep calling PyObject_GetAttr.
+ */
+typedef struct {
+ PyObject* stringiobuf;
+ PyObject* refill_callable;
+} DecodeBuffer;
+
+/** Pointer to interned string to speed up attribute lookup. */
+static PyObject* INTERN_STRING(cstringio_buf);
+/** Pointer to interned string to speed up attribute lookup. */
+static PyObject* INTERN_STRING(cstringio_refill);
+
+static inline bool
+check_ssize_t_32(Py_ssize_t len) {
+ // error from getting the int
+ if (INT_CONV_ERROR_OCCURRED(len)) {
+ return false;
+ }
+ if (!CHECK_RANGE(len, 0, INT32_MAX)) {
+ PyErr_SetString(PyExc_OverflowError, "string size out of range");
+ return false;
+ }
+ return true;
+}
+
+static inline bool
+parse_pyint(PyObject* o, int32_t* ret, int32_t min, int32_t max) {
+ long val = PyInt_AsLong(o);
+
+ if (INT_CONV_ERROR_OCCURRED(val)) {
+ return false;
+ }
+ if (!CHECK_RANGE(val, min, max)) {
+ PyErr_SetString(PyExc_OverflowError, "int out of range");
+ return false;
+ }
+
+ *ret = (int32_t) val;
+ return true;
+}
+
+
+/* --- FUNCTIONS TO PARSE STRUCT SPECIFICATOINS --- */
+
+static bool
+parse_set_list_args(SetListTypeArgs* dest, PyObject* typeargs) {
+ if (PyTuple_Size(typeargs) != 2) {
+ PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for list/set type args");
+ return false;
+ }
+
+ dest->element_type = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0));
+ if (INT_CONV_ERROR_OCCURRED(dest->element_type)) {
+ return false;
+ }
+
+ dest->typeargs = PyTuple_GET_ITEM(typeargs, 1);
+
+ return true;
+}
+
+static bool
+parse_map_args(MapTypeArgs* dest, PyObject* typeargs) {
+ if (PyTuple_Size(typeargs) != 4) {
+ PyErr_SetString(PyExc_TypeError, "expecting 4 arguments for typeargs to map");
+ return false;
+ }
+
+ dest->ktag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 0));
+ if (INT_CONV_ERROR_OCCURRED(dest->ktag)) {
+ return false;
+ }
+
+ dest->vtag = PyInt_AsLong(PyTuple_GET_ITEM(typeargs, 2));
+ if (INT_CONV_ERROR_OCCURRED(dest->vtag)) {
+ return false;
+ }
+
+ dest->ktypeargs = PyTuple_GET_ITEM(typeargs, 1);
+ dest->vtypeargs = PyTuple_GET_ITEM(typeargs, 3);
+
+ return true;
+}
+
+static bool
+parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) {
+ if (PyTuple_Size(typeargs) != 2) {
+ PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args");
+ return false;
+ }
+
+ dest->klass = PyTuple_GET_ITEM(typeargs, 0);
+ dest->spec = PyTuple_GET_ITEM(typeargs, 1);
+
+ return true;
+}
+
+static int
+parse_struct_item_spec(StructItemSpec* dest, PyObject* spec_tuple) {
+
+ // i'd like to use ParseArgs here, but it seems to be a bottleneck.
+ if (PyTuple_Size(spec_tuple) != 5) {
+ PyErr_SetString(PyExc_TypeError, "expecting 5 arguments for spec tuple");
+ return false;
+ }
+
+ dest->tag = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 0));
+ if (INT_CONV_ERROR_OCCURRED(dest->tag)) {
+ return false;
+ }
+
+ dest->type = PyInt_AsLong(PyTuple_GET_ITEM(spec_tuple, 1));
+ if (INT_CONV_ERROR_OCCURRED(dest->type)) {
+ return false;
+ }
+
+ dest->attrname = PyTuple_GET_ITEM(spec_tuple, 2);
+ dest->typeargs = PyTuple_GET_ITEM(spec_tuple, 3);
+ dest->defval = PyTuple_GET_ITEM(spec_tuple, 4);
+ return true;
+}
+
+/* ====== END UTILITIES ====== */
+
+
+/* ====== BEGIN WRITING FUNCTIONS ====== */
+
+/* --- LOW-LEVEL WRITING FUNCTIONS --- */
+
+static void writeByte(PyObject* outbuf, int8_t val) {
+ int8_t net = val;
+ PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int8_t));
+}
+
+static void writeI16(PyObject* outbuf, int16_t val) {
+ int16_t net = (int16_t)htons(val);
+ PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int16_t));
+}
+
+static void writeI32(PyObject* outbuf, int32_t val) {
+ int32_t net = (int32_t)htonl(val);
+ PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int32_t));
+}
+
+static void writeI64(PyObject* outbuf, int64_t val) {
+ int64_t net = (int64_t)htonll(val);
+ PycStringIO->cwrite(outbuf, (char*)&net, sizeof(int64_t));
+}
+
+static void writeDouble(PyObject* outbuf, double dub) {
+ // Unfortunately, bitwise_cast doesn't work in C. Bad C!
+ union {
+ double f;
+ int64_t t;
+ } transfer;
+ transfer.f = dub;
+ writeI64(outbuf, transfer.t);
+}
+
+
+/* --- MAIN RECURSIVE OUTPUT FUCNTION -- */
+
+static int
+output_val(PyObject* output, PyObject* value, TType type, PyObject* typeargs) {
+ /*
+ * Refcounting Strategy:
+ *
+ * We assume that elements of the thrift_spec tuple are not going to be
+ * mutated, so we don't ref count those at all. Other than that, we try to
+ * keep a reference to all the user-created objects while we work with them.
+ * output_val assumes that a reference is already held. The *caller* is
+ * responsible for handling references
+ */
+
+ switch (type) {
+
+ case T_BOOL: {
+ int v = PyObject_IsTrue(value);
+ if (v == -1) {
+ return false;
+ }
+
+ writeByte(output, (int8_t) v);
+ break;
+ }
+ case T_I08: {
+ int32_t val;
+
+ if (!parse_pyint(value, &val, INT8_MIN, INT8_MAX)) {
+ return false;
+ }
+
+ writeByte(output, (int8_t) val);
+ break;
+ }
+ case T_I16: {
+ int32_t val;
+
+ if (!parse_pyint(value, &val, INT16_MIN, INT16_MAX)) {
+ return false;
+ }
+
+ writeI16(output, (int16_t) val);
+ break;
+ }
+ case T_I32: {
+ int32_t val;
+
+ if (!parse_pyint(value, &val, INT32_MIN, INT32_MAX)) {
+ return false;
+ }
+
+ writeI32(output, val);
+ break;
+ }
+ case T_I64: {
+ int64_t nval = PyLong_AsLongLong(value);
+
+ if (INT_CONV_ERROR_OCCURRED(nval)) {
+ return false;
+ }
+
+ if (!CHECK_RANGE(nval, INT64_MIN, INT64_MAX)) {
+ PyErr_SetString(PyExc_OverflowError, "int out of range");
+ return false;
+ }
+
+ writeI64(output, nval);
+ break;
+ }
+
+ case T_DOUBLE: {
+ double nval = PyFloat_AsDouble(value);
+ if (nval == -1.0 && PyErr_Occurred()) {
+ return false;
+ }
+
+ writeDouble(output, nval);
+ break;
+ }
+
+ case T_STRING: {
+ Py_ssize_t len = PyString_Size(value);
+
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ writeI32(output, (int32_t) len);
+ PycStringIO->cwrite(output, PyString_AsString(value), (int32_t) len);
+ break;
+ }
+
+ case T_LIST:
+ case T_SET: {
+ Py_ssize_t len;
+ SetListTypeArgs parsedargs;
+ PyObject *item;
+ PyObject *iterator;
+
+ if (!parse_set_list_args(&parsedargs, typeargs)) {
+ return false;
+ }
+
+ len = PyObject_Length(value);
+
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ writeByte(output, parsedargs.element_type);
+ writeI32(output, (int32_t) len);
+
+ iterator = PyObject_GetIter(value);
+ if (iterator == NULL) {
+ return false;
+ }
+
+ while ((item = PyIter_Next(iterator))) {
+ if (!output_val(output, item, parsedargs.element_type, parsedargs.typeargs)) {
+ Py_DECREF(item);
+ Py_DECREF(iterator);
+ return false;
+ }
+ Py_DECREF(item);
+ }
+
+ Py_DECREF(iterator);
+
+ if (PyErr_Occurred()) {
+ return false;
+ }
+
+ break;
+ }
+
+ case T_MAP: {
+ PyObject *k, *v;
+ int pos = 0;
+ Py_ssize_t len;
+
+ MapTypeArgs parsedargs;
+
+ len = PyDict_Size(value);
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ if (!parse_map_args(&parsedargs, typeargs)) {
+ return false;
+ }
+
+ writeByte(output, parsedargs.ktag);
+ writeByte(output, parsedargs.vtag);
+ writeI32(output, len);
+
+ // TODO(bmaurer): should support any mapping, not just dicts
+ while (PyDict_Next(value, &pos, &k, &v)) {
+ // TODO(dreiss): Think hard about whether these INCREFs actually
+ // turn any unsafe scenarios into safe scenarios.
+ Py_INCREF(k);
+ Py_INCREF(v);
+
+ if (!output_val(output, k, parsedargs.ktag, parsedargs.ktypeargs)
+ || !output_val(output, v, parsedargs.vtag, parsedargs.vtypeargs)) {
+ Py_DECREF(k);
+ Py_DECREF(v);
+ return false;
+ }
+ }
+ break;
+ }
+
+ // TODO(dreiss): Consider breaking this out as a function
+ // the way we did for decode_struct.
+ case T_STRUCT: {
+ StructTypeArgs parsedargs;
+ Py_ssize_t nspec;
+ Py_ssize_t i;
+
+ if (!parse_struct_args(&parsedargs, typeargs)) {
+ return false;
+ }
+
+ nspec = PyTuple_Size(parsedargs.spec);
+
+ if (nspec == -1) {
+ return false;
+ }
+
+ for (i = 0; i < nspec; i++) {
+ StructItemSpec parsedspec;
+ PyObject* spec_tuple;
+ PyObject* instval = NULL;
+
+ spec_tuple = PyTuple_GET_ITEM(parsedargs.spec, i);
+ if (spec_tuple == Py_None) {
+ continue;
+ }
+
+ if (!parse_struct_item_spec (&parsedspec, spec_tuple)) {
+ return false;
+ }
+
+ instval = PyObject_GetAttr(value, parsedspec.attrname);
+
+ if (!instval) {
+ return false;
+ }
+
+ if (instval == Py_None) {
+ Py_DECREF(instval);
+ continue;
+ }
+
+ writeByte(output, (int8_t) parsedspec.type);
+ writeI16(output, parsedspec.tag);
+
+ if (!output_val(output, instval, parsedspec.type, parsedspec.typeargs)) {
+ Py_DECREF(instval);
+ return false;
+ }
+
+ Py_DECREF(instval);
+ }
+
+ writeByte(output, (int8_t)T_STOP);
+ break;
+ }
+
+ case T_STOP:
+ case T_VOID:
+ case T_UTF16:
+ case T_UTF8:
+ case T_U64:
+ default:
+ PyErr_SetString(PyExc_TypeError, "Unexpected TType");
+ return false;
+
+ }
+
+ return true;
+}
+
+
+/* --- TOP-LEVEL WRAPPER FOR OUTPUT -- */
+
+static PyObject *
+encode_binary(PyObject *self, PyObject *args) {
+ PyObject* enc_obj;
+ PyObject* type_args;
+ PyObject* buf;
+ PyObject* ret = NULL;
+
+ if (!PyArg_ParseTuple(args, "OO", &enc_obj, &type_args)) {
+ return NULL;
+ }
+
+ buf = PycStringIO->NewOutput(INIT_OUTBUF_SIZE);
+ if (output_val(buf, enc_obj, T_STRUCT, type_args)) {
+ ret = PycStringIO->cgetvalue(buf);
+ }
+
+ Py_DECREF(buf);
+ return ret;
+}
+
+/* ====== END WRITING FUNCTIONS ====== */
+
+
+/* ====== BEGIN READING FUNCTIONS ====== */
+
+/* --- LOW-LEVEL READING FUNCTIONS --- */
+
+static void
+free_decodebuf(DecodeBuffer* d) {
+ Py_XDECREF(d->stringiobuf);
+ Py_XDECREF(d->refill_callable);
+}
+
+static bool
+decode_buffer_from_obj(DecodeBuffer* dest, PyObject* obj) {
+ dest->stringiobuf = PyObject_GetAttr(obj, INTERN_STRING(cstringio_buf));
+ if (!dest->stringiobuf) {
+ return false;
+ }
+
+ if (!PycStringIO_InputCheck(dest->stringiobuf)) {
+ free_decodebuf(dest);
+ PyErr_SetString(PyExc_TypeError, "expecting stringio input");
+ return false;
+ }
+
+ dest->refill_callable = PyObject_GetAttr(obj, INTERN_STRING(cstringio_refill));
+
+ if(!dest->refill_callable) {
+ free_decodebuf(dest);
+ return false;
+ }
+
+ if (!PyCallable_Check(dest->refill_callable)) {
+ free_decodebuf(dest);
+ PyErr_SetString(PyExc_TypeError, "expecting callable");
+ return false;
+ }
+
+ return true;
+}
+
+static bool readBytes(DecodeBuffer* input, char** output, int len) {
+ int read;
+
+ // TODO(dreiss): Don't fear the malloc. Think about taking a copy of
+ // the partial read instead of forcing the transport
+ // to prepend it to its buffer.
+
+ read = PycStringIO->cread(input->stringiobuf, output, len);
+
+ if (read == len) {
+ return true;
+ } else if (read == -1) {
+ return false;
+ } else {
+ PyObject* newiobuf;
+
+ // using building functions as this is a rare codepath
+ newiobuf = PyObject_CallFunction(
+ input->refill_callable, "s#i", *output, len, read, NULL);
+ if (newiobuf == NULL) {
+ return false;
+ }
+
+ // must do this *AFTER* the call so that we don't deref the io buffer
+ Py_CLEAR(input->stringiobuf);
+ input->stringiobuf = newiobuf;
+
+ read = PycStringIO->cread(input->stringiobuf, output, len);
+
+ if (read == len) {
+ return true;
+ } else if (read == -1) {
+ return false;
+ } else {
+ // TODO(dreiss): This could be a valid code path for big binary blobs.
+ PyErr_SetString(PyExc_TypeError,
+ "refill claimed to have refilled the buffer, but didn't!!");
+ return false;
+ }
+ }
+}
+
+static int8_t readByte(DecodeBuffer* input) {
+ char* buf;
+ if (!readBytes(input, &buf, sizeof(int8_t))) {
+ return -1;
+ }
+
+ return *(int8_t*) buf;
+}
+
+static int16_t readI16(DecodeBuffer* input) {
+ char* buf;
+ if (!readBytes(input, &buf, sizeof(int16_t))) {
+ return -1;
+ }
+
+ return (int16_t) ntohs(*(int16_t*) buf);
+}
+
+static int32_t readI32(DecodeBuffer* input) {
+ char* buf;
+ if (!readBytes(input, &buf, sizeof(int32_t))) {
+ return -1;
+ }
+ return (int32_t) ntohl(*(int32_t*) buf);
+}
+
+
+static int64_t readI64(DecodeBuffer* input) {
+ char* buf;
+ if (!readBytes(input, &buf, sizeof(int64_t))) {
+ return -1;
+ }
+
+ return (int64_t) ntohll(*(int64_t*) buf);
+}
+
+static double readDouble(DecodeBuffer* input) {
+ union {
+ int64_t f;
+ double t;
+ } transfer;
+
+ transfer.f = readI64(input);
+ if (transfer.f == -1) {
+ return -1;
+ }
+ return transfer.t;
+}
+
+static bool
+checkTypeByte(DecodeBuffer* input, TType expected) {
+ TType got = readByte(input);
+
+ if (expected != got) {
+ PyErr_SetString(PyExc_TypeError, "got wrong ttype while reading field");
+ return false;
+ }
+ return true;
+}
+
+static bool
+skip(DecodeBuffer* input, TType type) {
+#define SKIPBYTES(n) \
+ do { \
+ if (!readBytes(input, &dummy_buf, (n))) { \
+ return false; \
+ } \
+ } while(0)
+
+ char* dummy_buf;
+
+ switch (type) {
+
+ case T_BOOL:
+ case T_I08: SKIPBYTES(1); break;
+ case T_I16: SKIPBYTES(2); break;
+ case T_I32: SKIPBYTES(4); break;
+ case T_I64:
+ case T_DOUBLE: SKIPBYTES(8); break;
+
+ case T_STRING: {
+ // TODO(dreiss): Find out if these check_ssize_t32s are really necessary.
+ int len = readI32(input);
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+ SKIPBYTES(len);
+ break;
+ }
+
+ case T_LIST:
+ case T_SET: {
+ TType etype;
+ int len, i;
+
+ etype = readByte(input);
+ if (etype == -1) {
+ return false;
+ }
+
+ len = readI32(input);
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ for (i = 0; i < len; i++) {
+ if (!skip(input, etype)) {
+ return false;
+ }
+ }
+ break;
+ }
+
+ case T_MAP: {
+ TType ktype, vtype;
+ int len, i;
+
+ ktype = readByte(input);
+ if (ktype == -1) {
+ return false;
+ }
+
+ vtype = readByte(input);
+ if (vtype == -1) {
+ return false;
+ }
+
+ len = readI32(input);
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ for (i = 0; i < len; i++) {
+ if (!(skip(input, ktype) && skip(input, vtype))) {
+ return false;
+ }
+ }
+ break;
+ }
+
+ case T_STRUCT: {
+ while (true) {
+ TType type;
+
+ type = readByte(input);
+ if (type == -1) {
+ return false;
+ }
+
+ if (type == T_STOP)
+ break;
+
+ SKIPBYTES(2); // tag
+ if (!skip(input, type)) {
+ return false;
+ }
+ }
+ break;
+ }
+
+ case T_STOP:
+ case T_VOID:
+ case T_UTF16:
+ case T_UTF8:
+ case T_U64:
+ default:
+ PyErr_SetString(PyExc_TypeError, "Unexpected TType");
+ return false;
+
+ }
+
+ return false;
+
+#undef SKIPBYTES
+}
+
+
+/* --- HELPER FUNCTION FOR DECODE_VAL --- */
+
+static PyObject*
+decode_val(DecodeBuffer* input, TType type, PyObject* typeargs);
+
+static bool
+decode_struct(DecodeBuffer* input, PyObject* output, PyObject* spec_seq) {
+ int spec_seq_len = PyTuple_Size(spec_seq);
+ if (spec_seq_len == -1) {
+ return false;
+ }
+
+ while (true) {
+ TType type;
+ int16_t tag;
+ PyObject* item_spec;
+ PyObject* fieldval = NULL;
+ StructItemSpec parsedspec;
+
+ type = readByte(input);
+ if (type == T_STOP) {
+ break;
+ }
+ tag = readI16(input);
+
+ if (tag >= 0 && tag < spec_seq_len) {
+ item_spec = PyTuple_GET_ITEM(spec_seq, tag);
+ } else {
+ item_spec = Py_None;
+ }
+
+ if (item_spec == Py_None) {
+ if (!skip(input, type)) {
+ return false;
+ }
+ }
+
+ if (!parse_struct_item_spec(&parsedspec, item_spec)) {
+ return false;
+ }
+ if (parsedspec.type != type) {
+ PyErr_SetString(PyExc_TypeError, "struct field had wrong type while reading");
+ return false;
+ }
+
+ fieldval = decode_val(input, parsedspec.type, parsedspec.typeargs);
+ if (fieldval == NULL) {
+ return false;
+ }
+
+ if (PyObject_SetAttr(output, parsedspec.attrname, fieldval) == -1) {
+ Py_DECREF(fieldval);
+ return false;
+ }
+ Py_DECREF(fieldval);
+ }
+ return true;
+}
+
+
+/* --- MAIN RECURSIVE INPUT FUCNTION --- */
+
+// Returns a new reference.
+static PyObject*
+decode_val(DecodeBuffer* input, TType type, PyObject* typeargs) {
+ switch (type) {
+
+ case T_BOOL: {
+ int8_t v = readByte(input);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ return NULL;
+ }
+
+ switch (v) {
+ case 0: Py_RETURN_FALSE;
+ case 1: Py_RETURN_TRUE;
+ // Don't laugh. This is a potentially serious issue.
+ default: PyErr_SetString(PyExc_TypeError, "boolean out of range"); return NULL;
+ }
+ break;
+ }
+ case T_I08: {
+ int8_t v = readByte(input);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ return NULL;
+ }
+
+ return PyInt_FromLong(v);
+ }
+ case T_I16: {
+ int16_t v = readI16(input);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ return NULL;
+ }
+ return PyInt_FromLong(v);
+ }
+ case T_I32: {
+ int32_t v = readI32(input);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ return NULL;
+ }
+ return PyInt_FromLong(v);
+ }
+
+ case T_I64: {
+ int64_t v = readI64(input);
+ if (INT_CONV_ERROR_OCCURRED(v)) {
+ return NULL;
+ }
+ // TODO(dreiss): Find out if we can take this fastpath always when
+ // sizeof(long) == sizeof(long long).
+ if (CHECK_RANGE(v, LONG_MIN, LONG_MAX)) {
+ return PyInt_FromLong((long) v);
+ }
+
+ return PyLong_FromLongLong(v);
+ }
+
+ case T_DOUBLE: {
+ double v = readDouble(input);
+ if (v == -1.0 && PyErr_Occurred()) {
+ return false;
+ }
+ return PyFloat_FromDouble(v);
+ }
+
+ case T_STRING: {
+ Py_ssize_t len = readI32(input);
+ char* buf;
+ if (!readBytes(input, &buf, len)) {
+ return NULL;
+ }
+
+ return PyString_FromStringAndSize(buf, len);
+ }
+
+ case T_LIST:
+ case T_SET: {
+ SetListTypeArgs parsedargs;
+ int32_t len;
+ PyObject* ret = NULL;
+ int i;
+
+ if (!parse_set_list_args(&parsedargs, typeargs)) {
+ return NULL;
+ }
+
+ if (!checkTypeByte(input, parsedargs.element_type)) {
+ return NULL;
+ }
+
+ len = readI32(input);
+ if (!check_ssize_t_32(len)) {
+ return NULL;
+ }
+
+ ret = PyList_New(len);
+ if (!ret) {
+ return NULL;
+ }
+
+ for (i = 0; i < len; i++) {
+ PyObject* item = decode_val(input, parsedargs.element_type, parsedargs.typeargs);
+ if (!item) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+ PyList_SET_ITEM(ret, i, item);
+ }
+
+ // TODO(dreiss): Consider biting the bullet and making two separate cases
+ // for list and set, avoiding this post facto conversion.
+ if (type == T_SET) {
+ PyObject* setret;
+#if (PY_VERSION_HEX < 0x02050000)
+ // hack needed for older versions
+ setret = PyObject_CallFunctionObjArgs((PyObject*)&PySet_Type, ret, NULL);
+#else
+ // official version
+ setret = PySet_New(ret);
+#endif
+ Py_DECREF(ret);
+ return setret;
+ }
+ return ret;
+ }
+
+ case T_MAP: {
+ int32_t len;
+ int i;
+ MapTypeArgs parsedargs;
+ PyObject* ret = NULL;
+
+ if (!parse_map_args(&parsedargs, typeargs)) {
+ return NULL;
+ }
+
+ if (!checkTypeByte(input, parsedargs.ktag)) {
+ return NULL;
+ }
+ if (!checkTypeByte(input, parsedargs.vtag)) {
+ return NULL;
+ }
+
+ len = readI32(input);
+ if (!check_ssize_t_32(len)) {
+ return false;
+ }
+
+ ret = PyDict_New();
+ if (!ret) {
+ goto error;
+ }
+
+ for (i = 0; i < len; i++) {
+ PyObject* k = NULL;
+ PyObject* v = NULL;
+ k = decode_val(input, parsedargs.ktag, parsedargs.ktypeargs);
+ if (k == NULL) {
+ goto loop_error;
+ }
+ v = decode_val(input, parsedargs.vtag, parsedargs.vtypeargs);
+ if (v == NULL) {
+ goto loop_error;
+ }
+ if (PyDict_SetItem(ret, k, v) == -1) {
+ goto loop_error;
+ }
+
+ Py_DECREF(k);
+ Py_DECREF(v);
+ continue;
+
+ // Yuck! Destructors, anyone?
+ loop_error:
+ Py_XDECREF(k);
+ Py_XDECREF(v);
+ goto error;
+ }
+
+ return ret;
+
+ error:
+ Py_XDECREF(ret);
+ return NULL;
+ }
+
+ case T_STRUCT: {
+ StructTypeArgs parsedargs;
+ if (!parse_struct_args(&parsedargs, typeargs)) {
+ return NULL;
+ }
+
+ PyObject* ret = PyObject_CallObject(parsedargs.klass, NULL);
+ if (!ret) {
+ return NULL;
+ }
+
+ if (!decode_struct(input, ret, parsedargs.spec)) {
+ Py_DECREF(ret);
+ return NULL;
+ }
+
+ return ret;
+ }
+
+ case T_STOP:
+ case T_VOID:
+ case T_UTF16:
+ case T_UTF8:
+ case T_U64:
+ default:
+ PyErr_SetString(PyExc_TypeError, "Unexpected TType");
+ return NULL;
+ }
+}
+
+
+/* --- TOP-LEVEL WRAPPER FOR INPUT -- */
+
+static PyObject*
+decode_binary(PyObject *self, PyObject *args) {
+ PyObject* output_obj = NULL;
+ PyObject* transport = NULL;
+ PyObject* typeargs = NULL;
+ StructTypeArgs parsedargs;
+ DecodeBuffer input = {};
+
+ if (!PyArg_ParseTuple(args, "OOO", &output_obj, &transport, &typeargs)) {
+ return NULL;
+ }
+
+ if (!parse_struct_args(&parsedargs, typeargs)) {
+ return NULL;
+ }
+
+ if (!decode_buffer_from_obj(&input, transport)) {
+ return NULL;
+ }
+
+ if (!decode_struct(&input, output_obj, parsedargs.spec)) {
+ free_decodebuf(&input);
+ return NULL;
+ }
+
+ free_decodebuf(&input);
+
+ Py_RETURN_NONE;
+}
+
+/* ====== END READING FUNCTIONS ====== */
+
+
+/* -- PYTHON MODULE SETUP STUFF --- */
+
+static PyMethodDef ThriftFastBinaryMethods[] = {
+
+ {"encode_binary", encode_binary, METH_VARARGS, ""},
+ {"decode_binary", decode_binary, METH_VARARGS, ""},
+
+ {NULL, NULL, 0, NULL} /* Sentinel */
+};
+
+PyMODINIT_FUNC
+initfastbinary(void) {
+#define INIT_INTERN_STRING(value) \
+ do { \
+ INTERN_STRING(value) = PyString_InternFromString(#value); \
+ if(!INTERN_STRING(value)) return; \
+ } while(0)
+
+ INIT_INTERN_STRING(cstringio_buf);
+ INIT_INTERN_STRING(cstringio_refill);
+#undef INIT_INTERN_STRING
+
+ PycString_IMPORT;
+ if (PycStringIO == NULL) return;
+
+ (void) Py_InitModule("thrift.protocol.fastbinary", ThriftFastBinaryMethods);
+}
diff --git a/lib/py/src/transport/TTransport.py b/lib/py/src/transport/TTransport.py
index 3c18221..0f5bfdc 100644
--- a/lib/py/src/transport/TTransport.py
+++ b/lib/py/src/transport/TTransport.py
@@ -55,6 +55,34 @@
def flush(self):
pass
+# This class should be thought of as an interface.
+class CReadableTransport:
+ """base class for transports that are readable from C"""
+
+ # TODO(dreiss): Think about changing this interface to allow us to use
+ # a (Python, not c) StringIO instead, because it allows
+ # you to write after reading.
+
+ # NOTE: This is a classic class, so properties will NOT work
+ # correctly for setting.
+ @property
+ def cstringio_buf(self):
+ """A cStringIO buffer that contains the current chunk we are reading."""
+ pass
+
+ def cstringio_refill(self, partialread, reqlen):
+ """Refills cstringio_buf.
+
+ Returns the currently used buffer (which can but need not be the same as
+ the old cstringio_buf). partialread is what the C code has read from the
+ buffer, and should be inserted into the buffer before any more reads. The
+ return value must be a new, not borrowed reference. Something along the
+ lines of self._buf should be fine.
+
+ If reqlen bytes can't be read, throw EOFError.
+ """
+ pass
+
class TServerTransportBase:
"""Base class for Thrift server transports."""
@@ -112,8 +140,14 @@
self.__trans.flush()
self.__buf = StringIO()
-class TMemoryBuffer(TTransportBase):
- """Wraps a string object as a TTransport"""
+class TMemoryBuffer(TTransportBase, CReadableTransport):
+ """Wraps a cStringIO object as a TTransport.
+
+ NOTE: Unlike the C++ version of this class, you cannot write to it
+ then immediately read from it. If you want to read from a
+ TMemoryBuffer, you must either pass a string to the constructor.
+ TODO(dreiss): Make this work like the C++ version.
+ """
def __init__(self, value=None):
"""value -- a value to read from for stringio
@@ -146,6 +180,15 @@
def getvalue(self):
return self._buffer.getvalue()
+ # Implement the CReadableTransport interface.
+ @property
+ def cstringio_buf(self):
+ return self._buffer
+
+ def cstringio_refill(self, partialread, reqlen):
+ # only one shot at reading...
+ raise EOFException()
+
class TFramedTransportFactory:
"""Factory transport that builds framed transports"""
@@ -193,7 +236,7 @@
buff = self.__trans.readAll(4)
sz, = unpack('!i', buff)
self.__rbuf = self.__trans.readAll(sz)
-
+
def write(self, buf):
if self.__wbuf == None:
return self.__trans.write(buf)
diff --git a/test/DebugProtoTest.thrift b/test/DebugProtoTest.thrift
index ac3b9b4..bbd86df 100644
--- a/test/DebugProtoTest.thrift
+++ b/test/DebugProtoTest.thrift
@@ -36,3 +36,30 @@
2: set<list<string>> contain,
3: map<string,list<Bonk>> bonks,
}
+
+struct Backwards {
+ 2: i32 first_tag2,
+ 1: i32 second_tag1,
+}
+
+struct Empty {
+}
+
+struct Wrapper {
+ 1: Empty foo
+}
+
+struct RandomStuff {
+ 1: i32 a,
+ 2: i32 b,
+ 3: i32 c,
+ 4: i32 d,
+ 5: list<i32> myintlist,
+ 6: map<i32,Wrapper> maps,
+ 7: i64 bigint,
+ 8: double triple,
+}
+
+service Srv {
+ i32 Janky(i32 arg)
+}
diff --git a/test/FastbinaryTest.py b/test/FastbinaryTest.py
new file mode 100755
index 0000000..0918002
--- /dev/null
+++ b/test/FastbinaryTest.py
@@ -0,0 +1,190 @@
+#!/usr/bin/env python
+r"""
+thrift -py DebugProtoTest.thrift
+./FastbinaryTest.py
+"""
+
+# TODO(dreiss): Test error cases. Check for memory leaks.
+
+import sys
+sys.path.append('./gen-py')
+
+import math
+from DebugProtoTest import Srv
+from DebugProtoTest.ttypes import *
+from thrift.transport import TTransport
+from thrift.protocol import TBinaryProtocol
+
+import timeit
+from cStringIO import StringIO
+from copy import deepcopy
+from pprint import pprint
+
+class TDevNullTransport(TTransport.TTransportBase):
+ def __init__(self):
+ pass
+ def isOpen(self):
+ return True
+
+ooe1 = OneOfEach()
+ooe1.im_true = True;
+ooe1.im_false = False;
+ooe1.a_bite = 0xd6;
+ooe1.integer16 = 27000;
+ooe1.integer32 = 1<<24;
+ooe1.integer64 = 6000 * 1000 * 1000;
+ooe1.double_precision = math.pi;
+ooe1.some_characters = "Debug THIS!";
+ooe1.zomg_unicode = "\xd7\n\a\t";
+
+ooe2 = OneOfEach();
+ooe2.integer16 = 16;
+ooe2.integer32 = 32;
+ooe2.integer64 = 64;
+ooe2.double_precision = (math.sqrt(5)+1)/2;
+ooe2.some_characters = ":R (me going \"rrrr\")";
+ooe2.zomg_unicode = "\xd3\x80\xe2\x85\xae\xce\x9d\x20"\
+ "\xd0\x9d\xce\xbf\xe2\x85\xbf\xd0\xbe"\
+ "\xc9\xa1\xd0\xb3\xd0\xb0\xcf\x81\xe2\x84\x8e"\
+ "\x20\xce\x91\x74\x74\xce\xb1\xe2\x85\xbd\xce\xba"\
+ "\xc7\x83\xe2\x80\xbc";
+
+hm = HolyMoley({"big":[], "contain":set(), "bonks":{}})
+hm.big.append(ooe1)
+hm.big.append(ooe2)
+hm.big[0].a_bite = 0x22;
+hm.big[1].a_bite = 0x22;
+
+hm.contain.add(("and a one", "and a two"))
+hm.contain.add(("then a one, two", "three!", "FOUR!"))
+hm.contain.add(())
+
+hm.bonks["nothing"] = [];
+hm.bonks["something"] = [
+ Bonk({"type":1, "message":"Wait."}),
+ Bonk({"type":2, "message":"What?"}),
+]
+hm.bonks["poe"] = [
+ Bonk({"type":3, "message":"quoth"}),
+ Bonk({"type":4, "message":"the raven"}),
+ Bonk({"type":5, "message":"nevermore"}),
+]
+
+rs = RandomStuff()
+rs.a = 1
+rs.b = 2
+rs.c = 3
+rs.myintlist = range(20)
+rs.maps = {1:Wrapper({"foo":Empty()}),2:Wrapper({"foo":Empty()})}
+rs.bigint = 124523452435L
+rs.triple = 3.14
+
+my_zero = Srv.Janky_result({"arg":5})
+my_nega = Srv.Janky_args({"success":6})
+
+def checkWrite(o):
+ trans_fast = TTransport.TMemoryBuffer()
+ trans_slow = TTransport.TMemoryBuffer()
+ prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast)
+ prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow)
+
+ o.write(prot_fast)
+ o.write(prot_slow)
+ ORIG = trans_slow.getvalue()
+ MINE = trans_fast.getvalue()
+ if ORIG != MINE:
+ print "mine: %s\norig: %s" % (repr(MINE), repr(ORIG))
+
+def checkRead(o):
+ prot = TBinaryProtocol.TBinaryProtocol(TTransport.TMemoryBuffer())
+ o.write(prot)
+ prot = TBinaryProtocol.TBinaryProtocolAccelerated(
+ TTransport.TMemoryBuffer(
+ prot.trans.getvalue()))
+ c = o.__class__()
+ c.read(prot)
+ if c != o:
+ print "copy: "
+ pprint(eval(repr(c)))
+ print "orig: "
+ pprint(eval(repr(o)))
+
+
+def doTest():
+ checkWrite(hm)
+ no_set = deepcopy(hm)
+ no_set.contain = set()
+ checkRead(no_set)
+ checkWrite(rs)
+ checkRead(rs)
+ checkWrite(my_zero)
+ checkRead(my_zero)
+ checkRead(Backwards({"first_tag2":4, "second_tag1":2}))
+ try:
+ checkWrite(my_nega)
+ print "Hey, did this get fixed?"
+ except AttributeError:
+ # Sorry, doesn't work with negative tags.
+ pass
+
+ # One case where the serialized form changes, but only superficially.
+ o = Backwards({"first_tag2":4, "second_tag1":2})
+ trans_fast = TTransport.TMemoryBuffer()
+ trans_slow = TTransport.TMemoryBuffer()
+ prot_fast = TBinaryProtocol.TBinaryProtocolAccelerated(trans_fast)
+ prot_slow = TBinaryProtocol.TBinaryProtocol(trans_slow)
+
+ o.write(prot_fast)
+ o.write(prot_slow)
+ ORIG = trans_slow.getvalue()
+ MINE = trans_fast.getvalue()
+ if ORIG == MINE:
+ print "That shouldn't happen."
+
+
+ prot = TBinaryProtocol.TBinaryProtocolAccelerated(TTransport.TMemoryBuffer())
+ o.write(prot)
+ prot = TBinaryProtocol.TBinaryProtocol(
+ TTransport.TMemoryBuffer(
+ prot.trans.getvalue()))
+ c = o.__class__()
+ c.read(prot)
+ if c != o:
+ print "copy: "
+ pprint(eval(repr(c)))
+ print "orig: "
+ pprint(eval(repr(o)))
+
+
+
+def doBenchmark():
+
+ iters = 25000
+
+ setup = """
+from __main__ import hm, rs, TDevNullTransport
+from thrift.protocol import TBinaryProtocol
+trans = TDevNullTransport()
+prot = TBinaryProtocol.TBinaryProtocol%s(trans)
+"""
+
+ setup_fast = setup % "Accelerated"
+ setup_slow = setup % ""
+
+ print "Starting Benchmarks"
+
+ print "HolyMoley Standard = %f" % \
+ timeit.Timer('hm.write(prot)', setup_slow).timeit(number=iters)
+ print "HolyMoley Acceler. = %f" % \
+ timeit.Timer('hm.write(prot)', setup_fast).timeit(number=iters)
+
+ print "FastStruct Standard = %f" % \
+ timeit.Timer('rs.write(prot)', setup_slow).timeit(number=iters)
+ print "FastStruct Acceler. = %f" % \
+ timeit.Timer('rs.write(prot)', setup_fast).timeit(number=iters)
+
+
+
+doTest()
+doBenchmark()
+