THRIFT-1115 python TBase class for dynamic (de)serialization, and __slots__ option for memory savings
Patch: Will Pierce
git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1169492 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/.gitignore b/.gitignore
index bfb61d2..082f7b7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -179,6 +179,7 @@
/test/py/Makefile
/test/py/Makefile.in
/test/py/gen-py
+/test/py/gen-py-*
/test/py.twisted/Makefile
/test/py.twisted/Makefile.in
/test/py.twisted/_trial_temp/
diff --git a/compiler/cpp/src/generate/t_py_generator.cc b/compiler/cpp/src/generate/t_py_generator.cc
index 6a82bd7..34acba4 100644
--- a/compiler/cpp/src/generate/t_py_generator.cc
+++ b/compiler/cpp/src/generate/t_py_generator.cc
@@ -52,12 +52,44 @@
iter = parsed_options.find("new_style");
gen_newstyle_ = (iter != parsed_options.end());
+ iter = parsed_options.find("slots");
+ gen_slots_ = (iter != parsed_options.end());
+
+ iter = parsed_options.find("dynamic");
+ gen_dynamic_ = (iter != parsed_options.end());
+
+ if (gen_dynamic_) {
+ gen_newstyle_ = 0; // dynamic is newstyle
+ gen_dynbaseclass_ = "TBase";
+ gen_dynbaseclass_exc_ = "TExceptionBase";
+ import_dynbase_ = "from thrift.protocol.TBase import TBase, TExceptionBase\n";
+ }
+
+ iter = parsed_options.find("dynbase");
+ if (iter != parsed_options.end()) {
+ gen_dynbase_ = true;
+ gen_dynbaseclass_ = (iter->second);
+ }
+
+ iter = parsed_options.find("dynexc");
+ if (iter != parsed_options.end()) {
+ gen_dynbaseclass_exc_ = (iter->second);
+ }
+
+ iter = parsed_options.find("dynimport");
+ if (iter != parsed_options.end()) {
+ gen_dynbase_ = true;
+ import_dynbase_ = (iter->second);
+ }
+
iter = parsed_options.find("twisted");
gen_twisted_ = (iter != parsed_options.end());
iter = parsed_options.find("utf8strings");
gen_utf8strings_ = (iter != parsed_options.end());
+ copy_options_ = option_string;
+
if (gen_twisted_){
out_dir_base_ = "gen-py.twisted";
} else {
@@ -214,17 +246,32 @@
private:
/**
- * True iff we should generate new-style classes.
+ * True if we should generate new-style classes.
*/
bool gen_newstyle_;
+ /**
+ * True if we should generate dynamic style classes.
+ */
+ bool gen_dynamic_;
+
+ bool gen_dynbase_;
+ std::string gen_dynbaseclass_;
+ std::string gen_dynbaseclass_exc_;
+
+ std::string import_dynbase_;
+
+ bool gen_slots_;
+
+ std::string copy_options_;
+
/**
- * True iff we should generate Twisted-friendly RPC services.
+ * True if we should generate Twisted-friendly RPC services.
*/
bool gen_twisted_;
/**
- * True iff strings should be encoded using utf-8.
+ * True if strings should be encoded using utf-8.
*/
bool gen_utf8strings_;
@@ -325,13 +372,19 @@
* Renders all the imports necessary to use the accelerated TBinaryProtocol
*/
string t_py_generator::render_fastbinary_includes() {
- return
- "from thrift.transport import TTransport\n"
- "from thrift.protocol import TBinaryProtocol, TProtocol\n"
- "try:\n"
- " from thrift.protocol import fastbinary\n"
- "except:\n"
- " fastbinary = None\n";
+ string hdr = "";
+ if (gen_dynamic_) {
+ hdr += std::string(import_dynbase_);
+ } else {
+ hdr +=
+ "from thrift.transport import TTransport\n"
+ "from thrift.protocol import TBinaryProtocol, TProtocol\n"
+ "try:\n"
+ " from thrift.protocol import fastbinary\n"
+ "except:\n"
+ " fastbinary = None\n";
+ }
+ return hdr;
}
/**
@@ -343,6 +396,8 @@
"# Autogenerated by Thrift Compiler (" + THRIFT_VERSION + ")\n" +
"#\n" +
"# DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING\n" +
+ "#\n" +
+ "# options string: " + copy_options_ + "\n" +
"#\n";
}
@@ -351,7 +406,7 @@
*/
string t_py_generator::py_imports() {
return
- string("from thrift.Thrift import *");
+ string("from thrift.Thrift import TType, TMessageType");
}
/**
@@ -384,6 +439,7 @@
f_types_ <<
"class " << tenum->get_name() <<
(gen_newstyle_ ? "(object)" : "") <<
+ (gen_dynamic_ ? "(" + gen_dynbaseclass_ + ")" : "") <<
":" << endl;
indent_up();
generate_python_docstring(f_types_, tenum);
@@ -575,12 +631,19 @@
out << std::endl <<
"class " << tstruct->get_name();
if (is_exception) {
- out << "(Exception)";
- } else if (gen_newstyle_) {
- out << "(object)";
+ if (gen_dynamic_) {
+ out << "(" << gen_dynbaseclass_exc_ << ")";
+ } else {
+ out << "(Exception)";
+ }
+ } else {
+ if (gen_newstyle_) {
+ out << "(object)";
+ } else if (gen_dynamic_) {
+ out << "(" << gen_dynbaseclass_ << ")";
+ }
}
- out <<
- ":" << endl;
+ out << ":" << endl;
indent_up();
generate_python_docstring(out, tstruct);
@@ -606,6 +669,17 @@
TODO(dreiss): Consider making this work for structs with negative tags.
*/
+ if (gen_slots_) {
+ indent(out) << "__slots__ = [ " << endl;
+ indent_up();
+ for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
+ indent(out) << "'" << (*m_iter)->get_name() << "'," << endl;
+ }
+ indent_down();
+ indent(out) << " ]" << endl << endl;
+
+ }
+
// TODO(dreiss): Look into generating an empty tuple instead of None
// for structures with no members.
// TODO(dreiss): Test encoding of structs where some inner structs
@@ -672,8 +746,10 @@
out << endl;
}
- generate_py_struct_reader(out, tstruct);
- generate_py_struct_writer(out, tstruct);
+ if (!gen_dynamic_) {
+ generate_py_struct_reader(out, tstruct);
+ generate_py_struct_writer(out, tstruct);
+ }
// For exceptions only, generate a __str__ method. This is
// because when raised exceptions are printed to the console, __repr__
@@ -685,31 +761,61 @@
endl;
}
- // Printing utilities so that on the command line thrift
- // structs look pretty like dictionaries
- out <<
- indent() << "def __repr__(self):" << endl <<
- indent() << " L = ['%s=%r' % (key, value)" << endl <<
- indent() << " for key, value in self.__dict__.iteritems()]" << endl <<
- indent() << " return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl <<
- endl;
+ if (!gen_slots_) {
+ // Printing utilities so that on the command line thrift
+ // structs look pretty like dictionaries
+ out <<
+ indent() << "def __repr__(self):" << endl <<
+ indent() << " L = ['%s=%r' % (key, value)" << endl <<
+ indent() << " for key, value in self.__dict__.iteritems()]" << endl <<
+ indent() << " return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl <<
+ endl;
- // Equality and inequality methods that compare by value
- out <<
- indent() << "def __eq__(self, other):" << endl;
- indent_up();
- out <<
- indent() << "return isinstance(other, self.__class__) and "
- "self.__dict__ == other.__dict__" << endl;
- indent_down();
- out << endl;
+ // Equality and inequality methods that compare by value
+ out <<
+ indent() << "def __eq__(self, other):" << endl;
+ indent_up();
+ out <<
+ indent() << "return isinstance(other, self.__class__) and "
+ "self.__dict__ == other.__dict__" << endl;
+ indent_down();
+ out << endl;
- out <<
- indent() << "def __ne__(self, other):" << endl;
- indent_up();
- out <<
- indent() << "return not (self == other)" << endl;
- indent_down();
+ out <<
+ indent() << "def __ne__(self, other):" << endl;
+ indent_up();
+
+ out <<
+ indent() << "return not (self == other)" << endl;
+ indent_down();
+ } else if (!gen_dynamic_) {
+ // no base class available to implement __eq__ and __repr__ and __ne__ for us
+ // so we must provide one that uses __slots__
+ out <<
+ indent() << "def __repr__(self):" << endl <<
+ indent() << " L = ['%s=%r' % (key, getattr(self, key))" << endl <<
+ indent() << " for key in self.__slots__]" << endl <<
+ indent() << " return '%s(%s)' % (self.__class__.__name__, ', '.join(L))" << endl <<
+ endl;
+
+ // Equality method that compares each attribute by value and type, walking __slots__
+ out <<
+ indent() << "def __eq__(self, other):" << endl <<
+ indent() << " if not isinstance(other, self.__class__):" << endl <<
+ indent() << " return False" << endl <<
+ indent() << " for attr in self.__slots__:" << endl <<
+ indent() << " my_val = getattr(self, attr)" << endl <<
+ indent() << " other_val = getattr(other, attr)" << endl <<
+ indent() << " if my_val != other_val:" << endl <<
+ indent() << " return False" << endl <<
+ indent() << " return True" << endl <<
+ endl;
+
+ out <<
+ indent() << "def __ne__(self, other):" << endl <<
+ indent() << " return not (self == other)" << endl <<
+ endl;
+ }
indent_down();
}
@@ -984,7 +1090,7 @@
} else {
if (gen_twisted_) {
extends_if = "(Interface)";
- } else if (gen_newstyle_) {
+ } else if (gen_newstyle_ || gen_dynamic_) {
extends_if = "(object)";
}
}
@@ -1031,8 +1137,8 @@
extends_client = extends + ".Client, ";
}
} else {
- if (gen_twisted_ && gen_newstyle_) {
- extends_client = "(object)";
+ if (gen_twisted_ && (gen_newstyle_ || gen_dynamic_)) {
+ extends_client = "(object)";
}
}
@@ -2388,6 +2494,11 @@
THRIFT_REGISTER_GENERATOR(py, "Python",
" new_style: Generate new-style classes.\n" \
" twisted: Generate Twisted-friendly RPC services.\n" \
-" utf8strings: Encode/decode strings using utf8 in the generated code.\n"
-)
+" utf8strings: Encode/decode strings using utf8 in the generated code.\n" \
+" slots: Generate code using slots for instance members.\n" \
+" dynamic: Generate dynamic code, less code generated but slower.\n" \
+" dynbase=CLS Derive generated classes from class CLS instead of TBase.\n" \
+" dynexc=CLS Derive generated exceptions from CLS instead of TExceptionBase.\n" \
+" dynimport='from foo.bar import CLS'\n" \
+" Add an import line to generated code to find the dynbase class.\n")
diff --git a/lib/py/src/Thrift.py b/lib/py/src/Thrift.py
index af6f58d..1d271fc 100644
--- a/lib/py/src/Thrift.py
+++ b/lib/py/src/Thrift.py
@@ -38,6 +38,25 @@
UTF8 = 16
UTF16 = 17
+ _VALUES_TO_NAMES = ( 'STOP',
+ 'VOID',
+ 'BOOL',
+ 'BYTE',
+ 'DOUBLE',
+ None,
+ 'I16',
+ None,
+ 'I32',
+ None,
+ 'I64',
+ 'STRING',
+ 'STRUCT',
+ 'MAP',
+ 'SET',
+ 'LIST',
+ 'UTF8',
+ 'UTF16' )
+
class TMessageType:
CALL = 1
REPLY = 2
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
new file mode 100644
index 0000000..e675c7d
--- /dev/null
+++ b/lib/py/src/protocol/TBase.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from thrift.Thrift import *
+from thrift.protocol import TBinaryProtocol
+from thrift.transport import TTransport
+
+try:
+ from thrift.protocol import fastbinary
+except:
+ fastbinary = None
+
+class TBase(object):
+ __slots__ = []
+
+ def __repr__(self):
+ L = ['%s=%r' % (key, getattr(self, key))
+ for key in self.__slots__ ]
+ return '%s(%s)' % (self.__class__.__name__, ', '.join(L))
+
+ def __eq__(self, other):
+ if not isinstance(other, self.__class__):
+ return False
+ for attr in self.__slots__:
+ my_val = getattr(self, attr)
+ other_val = getattr(other, attr)
+ if my_val != other_val:
+ return False
+ return True
+
+ def __ne__(self, other):
+ return not (self == other)
+
+ def read(self, iprot):
+ if iprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and isinstance(iprot.trans, TTransport.CReadableTransport) and self.thrift_spec is not None and fastbinary is not None:
+ fastbinary.decode_binary(self, iprot.trans, (self.__class__, self.thrift_spec))
+ return
+ iprot.readStruct(self, self.thrift_spec)
+
+ def write(self, oprot):
+ if oprot.__class__ == TBinaryProtocol.TBinaryProtocolAccelerated and self.thrift_spec is not None and fastbinary is not None:
+ oprot.trans.write(fastbinary.encode_binary(self, (self.__class__, self.thrift_spec)))
+ return
+ oprot.writeStruct(self, self.thrift_spec)
+
+class TExceptionBase(Exception):
+ # old style class so python2.4 can raise exceptions derived from this
+ # This can't inherit from TBase because of that limitation.
+ __slots__ = []
+
+ __repr__ = TBase.__repr__.im_func
+ __eq__ = TBase.__eq__.im_func
+ __ne__ = TBase.__ne__.im_func
+ read = TBase.read.im_func
+ write = TBase.write.im_func
+
diff --git a/lib/py/src/protocol/TCompactProtocol.py b/lib/py/src/protocol/TCompactProtocol.py
index 6d57aeb..016a331 100644
--- a/lib/py/src/protocol/TCompactProtocol.py
+++ b/lib/py/src/protocol/TCompactProtocol.py
@@ -1,3 +1,22 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
from TProtocol import *
from struct import pack, unpack
diff --git a/lib/py/src/protocol/TProtocol.py b/lib/py/src/protocol/TProtocol.py
index be3cb14..7338ff6 100644
--- a/lib/py/src/protocol/TProtocol.py
+++ b/lib/py/src/protocol/TProtocol.py
@@ -200,6 +200,205 @@
self.skip(etype)
self.readListEnd()
+ # tuple of: ( 'reader method' name, is_container boolean, 'writer_method' name )
+ _TTYPE_HANDLERS = (
+ (None, None, False), # 0 == TType,STOP
+ (None, None, False), # 1 == TType.VOID # TODO: handle void?
+ ('readBool', 'writeBool', False), # 2 == TType.BOOL
+ ('readByte', 'writeByte', False), # 3 == TType.BYTE and I08
+ ('readDouble', 'writeDouble', False), # 4 == TType.DOUBLE
+ (None, None, False), # 5, undefined
+ ('readI16', 'writeI16', False), # 6 == TType.I16
+ (None, None, False), # 7, undefined
+ ('readI32', 'writeI32', False), # 8 == TType.I32
+ (None, None, False), # 9, undefined
+ ('readI64', 'writeI64', False), # 10 == TType.I64
+ ('readString', 'writeString', False), # 11 == TType.STRING and UTF7
+ ('readContainerStruct', 'writeContainerStruct', True), # 12 == TType.STRUCT
+ ('readContainerMap', 'writeContainerMap', True), # 13 == TType.MAP
+ ('readContainerSet', 'writeContainerSet', True), # 14 == TType.SET
+ ('readContainerList', 'writeContainerList', True), # 15 == TType.LIST
+ (None, None, False), # 16 == TType.UTF8 # TODO: handle utf8 types?
+ (None, None, False)# 17 == TType.UTF16 # TODO: handle utf16 types?
+ )
+
+ def readFieldByTType(self, ttype, spec):
+ try:
+ (r_handler, w_handler, is_container) = self._TTYPE_HANDLERS[ttype]
+ except IndexError:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid field type %d' % (ttype))
+ if r_handler is None:
+ raise TProtocolException(type=TProtocolException.INVALID_DATA,
+ message='Invalid field type %d' % (ttype))
+ reader = getattr(self, r_handler)
+ if not is_container:
+ return reader()
+ return reader(spec)
+
+ def readContainerList(self, spec):
+ results = []
+ ttype, tspec = spec[0], spec[1]
+ r_handler = self._TTYPE_HANDLERS[ttype][0]
+ reader = getattr(self, r_handler)
+ (list_type, list_len) = self.readListBegin()
+ if tspec is None:
+ # list values are simple types
+ for idx in xrange(list_len):
+ results.append(reader())
+ else:
+ # this is like an inlined readFieldByTType
+ container_reader = self._TTYPE_HANDLERS[list_type][0]
+ val_reader = getattr(self, container_reader)
+ for idx in xrange(list_len):
+ val = val_reader(tspec)
+ results.append(val)
+ self.readListEnd()
+ return results
+
+ def readContainerSet(self, spec):
+ results = set()
+ ttype, tspec = spec[0], spec[1]
+ r_handler = self._TTYPE_HANDLERS[ttype][0]
+ reader = getattr(self, r_handler)
+ (set_type, set_len) = self.readSetBegin()
+ if tspec is None:
+ # set members are simple types
+ for idx in xrange(set_len):
+ results.add(reader())
+ else:
+ container_reader = self._TTYPE_HANDLERS[set_type][0]
+ val_reader = getattr(self, container_reader)
+ for idx in xrange(set_len):
+ results.add(val_reader(tspec))
+ self.readSetEnd()
+ return results
+
+ def readContainerStruct(self, spec):
+ (obj_class, obj_spec) = spec
+ obj = obj_class()
+ obj.read(self)
+ return obj
+
+ def readContainerMap(self, spec):
+ results = dict()
+ key_ttype, key_spec = spec[0], spec[1]
+ val_ttype, val_spec = spec[2], spec[3]
+ (map_ktype, map_vtype, map_len) = self.readMapBegin()
+ # TODO: compare types we just decoded with thrift_spec and abort/skip if types disagree
+ key_reader = getattr(self, self._TTYPE_HANDLERS[key_ttype][0])
+ val_reader = getattr(self, self._TTYPE_HANDLERS[val_ttype][0])
+ # list values are simple types
+ for idx in xrange(map_len):
+ if key_spec is None:
+ k_val = key_reader()
+ else:
+ k_val = self.readFieldByTType(key_ttype, key_spec)
+ if val_spec is None:
+ v_val = val_reader()
+ else:
+ v_val = self.readFieldByTType(val_ttype, val_spec)
+ # this raises a TypeError with unhashable keys types. i.e. d=dict(); d[[0,1]] = 2 fails
+ results[k_val] = v_val
+ self.readMapEnd()
+ return results
+
+ def readStruct(self, obj, thrift_spec):
+ self.readStructBegin()
+ while True:
+ (fname, ftype, fid) = self.readFieldBegin()
+ if ftype == TType.STOP:
+ break
+ try:
+ field = thrift_spec[fid]
+ except IndexError:
+ self.skip(ftype)
+ else:
+ if field is not None and ftype == field[1]:
+ fname = field[2]
+ fspec = field[3]
+ val = self.readFieldByTType(ftype, fspec)
+ setattr(obj, fname, val)
+ else:
+ self.skip(ftype)
+ self.readFieldEnd()
+ self.readStructEnd()
+
+ def writeContainerStruct(self, val, spec):
+ val.write(self)
+
+ def writeContainerList(self, val, spec):
+ self.writeListBegin(spec[0], len(val))
+ r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
+ e_writer = getattr(self, w_handler)
+ if not is_container:
+ for elem in val:
+ e_writer(elem)
+ else:
+ for elem in val:
+ e_writer(elem, spec[1])
+ self.writeListEnd()
+
+ def writeContainerSet(self, val, spec):
+ self.writeSetBegin(spec[0], len(val))
+ r_handler, w_handler, is_container = self._TTYPE_HANDLERS[spec[0]]
+ e_writer = getattr(self, w_handler)
+ if not is_container:
+ for elem in val:
+ e_writer(elem)
+ else:
+ for elem in val:
+ e_writer(elem, spec[1])
+ self.writeSetEnd()
+
+ def writeContainerMap(self, val, spec):
+ k_type = spec[0]
+ v_type = spec[2]
+ ignore, ktype_name, k_is_container = self._TTYPE_HANDLERS[k_type]
+ ignore, vtype_name, v_is_container = self._TTYPE_HANDLERS[v_type]
+ k_writer = getattr(self, ktype_name)
+ v_writer = getattr(self, vtype_name)
+ self.writeMapBegin(k_type, v_type, len(val))
+ for m_key, m_val in val.iteritems():
+ if not k_is_container:
+ k_writer(m_key)
+ else:
+ k_writer(m_key, spec[1])
+ if not v_is_container:
+ v_writer(m_val)
+ else:
+ v_writer(m_val, spec[3])
+ self.writeMapEnd()
+
+ def writeStruct(self, obj, thrift_spec):
+ self.writeStructBegin(obj.__class__.__name__)
+ for field in thrift_spec:
+ if field is None:
+ continue
+ fname = field[2]
+ val = getattr(obj, fname)
+ if val is None:
+ # skip writing out unset fields
+ continue
+ fid = field[0]
+ ftype = field[1]
+ fspec = field[3]
+ # get the writer method for this value
+ self.writeFieldBegin(fname, ftype, fid)
+ self.writeFieldByTType(ftype, val, fspec)
+ self.writeFieldEnd()
+ self.writeFieldStop()
+ self.writeStructEnd()
+
+ def writeFieldByTType(self, ttype, val, spec):
+ r_handler, w_handler, is_container = self._TTYPE_HANDLERS[ttype]
+ writer = getattr(self, w_handler)
+ if is_container:
+ writer(val, spec)
+ else:
+ writer(val)
+
class TProtocolFactory:
def getProtocol(self, trans):
pass
+
diff --git a/lib/py/src/protocol/__init__.py b/lib/py/src/protocol/__init__.py
index 01bfe18..d53359b 100644
--- a/lib/py/src/protocol/__init__.py
+++ b/lib/py/src/protocol/__init__.py
@@ -17,4 +17,4 @@
# under the License.
#
-__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary']
+__all__ = ['TProtocol', 'TBinaryProtocol', 'fastbinary', 'TBase']
diff --git a/test/ThriftTest.thrift b/test/ThriftTest.thrift
index b6cd939..17b0295 100644
--- a/test/ThriftTest.thrift
+++ b/test/ThriftTest.thrift
@@ -208,3 +208,21 @@
3000: VersioningTestV2 vertwo3000,
4000: list<i32> big_numbers
}
+
+struct NestedListsI32x2 {
+ 1: list<list<i32>> integerlist
+}
+struct NestedListsI32x3 {
+ 1: list<list<list<i32>>> integerlist
+}
+struct NestedMixedx2 {
+ 1: list<set<i32>> int_set_list
+ 2: map<i32,set<string>> map_int_strset
+ 3: list<map<i32,set<string>>> map_int_strset_list
+}
+struct ListBonks {
+ 1: list<Bonk> bonk
+}
+struct NestedListsBonk {
+ 1: list<list<list<Bonk>>> bonk
+}
diff --git a/test/py/Makefile.am b/test/py/Makefile.am
index 63b7a89..2317ef6 100644
--- a/test/py/Makefile.am
+++ b/test/py/Makefile.am
@@ -19,22 +19,30 @@
THRIFT = $(top_srcdir)/compiler/cpp/thrift
-py_unit_tests = \
- SerializationTest.py \
- TestEof.py \
- TestSyntax.py \
- RunClientServer.py
+py_unit_tests = RunClientServer.py
thrift_gen = \
gen-py/ThriftTest/__init__.py \
- gen-py/DebugProtoTest/__init__.py
+ gen-py/DebugProtoTest/__init__.py \
+ gen-py-default/ThriftTest/__init__.py \
+ gen-py-default/DebugProtoTest/__init__.py \
+ gen-py-slots/ThriftTest/__init__.py \
+ gen-py-slots/DebugProtoTest/__init__.py \
+ gen-py-newstyle/ThriftTest/__init__.py \
+ gen-py-newstyle/DebugProtoTest/__init__.py \
+ gen-py-newstyleslots/ThriftTest/__init__.py \
+ gen-py-newstyleslots/DebugProtoTest/__init__.py \
+ gen-py-dynamic/ThriftTest/__init__.py \
+ gen-py-dynamic/DebugProtoTest/__init__.py \
+ gen-py-dynamicslots/ThriftTest/__init__.py \
+ gen-py-dynamicslots/DebugProtoTest/__init__.py
helper_scripts= \
TestClient.py \
TestServer.py
check_SCRIPTS= \
- $(thrift_gen) \
+ $(thrift_gen) \
$(py_unit_tests) \
$(helper_scripts)
@@ -42,7 +50,29 @@
gen-py/%/__init__.py: ../%.thrift
- $(THRIFT) --gen py $<
+ $(THRIFT) --gen py $<
+ test -d gen-py-default || mkdir gen-py-default
+ $(THRIFT) --gen py -out gen-py-default $<
+
+gen-py-slots/%/__init__.py: ../%.thrift
+ test -d gen-py-slots || mkdir gen-py-slots
+ $(THRIFT) --gen py:slots -out gen-py-slots $<
+
+gen-py-newstyle/%/__init__.py: ../%.thrift
+ test -d gen-py-newstyle || mkdir gen-py-newstyle
+ $(THRIFT) --gen py:new_style -out gen-py-newstyle $<
+
+gen-py-newstyleslots/%/__init__.py: ../%.thrift
+ test -d gen-py-newstyleslots || mkdir gen-py-newstyleslots
+ $(THRIFT) --gen py:new_style,slots -out gen-py-newstyleslots $<
+
+gen-py-dynamic/%/__init__.py: ../%.thrift
+ test -d gen-py-dynamic || mkdir gen-py-dynamic
+ $(THRIFT) --gen py:dynamic -out gen-py-dynamic $<
+
+gen-py-dynamicslots/%/__init__.py: ../%.thrift
+ test -d gen-py-dynamicslots || mkdir gen-py-dynamicslots
+ $(THRIFT) --gen py:dynamic,slots -out gen-py-dynamicslots $<
clean-local:
- $(RM) -r gen-py
+ $(RM) -r gen-py gen-py-slots gen-py-default gen-py-newstyle gen-py-newstyleslots gen-py-dynamic gen-py-dynamicslots
diff --git a/test/py/RunClientServer.py b/test/py/RunClientServer.py
index 633856f..8a7fda6 100755
--- a/test/py/RunClientServer.py
+++ b/test/py/RunClientServer.py
@@ -28,6 +28,9 @@
from optparse import OptionParser
parser = OptionParser()
+parser.add_option('--genpydirs', type='string', dest='genpydirs',
+ default='default,slots,newstyle,newstyleslots,dynamic,dynamicslots',
+ help='directory extensions for generated code, used as suffixes for \"gen-py-*\" added sys.path for individual tests')
parser.add_option("--port", type="int", dest="port", default=9090,
help="port number for server to listen on")
parser.add_option('-v', '--verbose', action="store_const",
@@ -39,11 +42,15 @@
parser.set_defaults(verbose=1)
options, args = parser.parse_args()
+generated_dirs = []
+for gp_dir in options.genpydirs.split(','):
+ generated_dirs.append('gen-py-%s' % (gp_dir))
+
+SCRIPTS = ['SerializationTest.py', 'TestEof.py', 'TestSyntax.py', 'TestSocket.py']
FRAMED = ["TNonblockingServer"]
SKIP_ZLIB = ['TNonblockingServer', 'THttpServer']
SKIP_SSL = ['TNonblockingServer', 'THttpServer']
-EXTRA_DELAY = ['TProcessPoolServer']
-EXTRA_SLEEP = 3.5
+EXTRA_DELAY = dict(TProcessPoolServer=3.5)
PROTOS= [
'accel',
@@ -85,11 +92,21 @@
def relfile(fname):
return os.path.join(os.path.dirname(__file__), fname)
-def runTest(server_class, proto, port, use_zlib, use_ssl):
+def runScriptTest(genpydir, script):
+ script_args = [sys.executable, relfile(script) ]
+ script_args.append('--genpydir=%s' % genpydir)
+ serverproc = subprocess.Popen(script_args)
+ print '\nTesting script: %s\n----' % (' '.join(script_args))
+ ret = subprocess.call(script_args)
+ if ret != 0:
+ raise Exception("Script subprocess failed, retcode=%d, args: %s" % (ret, ' '.join(script_args)))
+
+def runServiceTest(genpydir, server_class, proto, port, use_zlib, use_ssl):
# Build command line arguments
server_args = [sys.executable, relfile('TestServer.py') ]
cli_args = [sys.executable, relfile('TestClient.py') ]
for which in (server_args, cli_args):
+ which.append('--genpydir=%s' % genpydir)
which.append('--proto=%s' % proto) # accel, binary or compact
which.append('--port=%d' % port) # default to 9090
if use_zlib:
@@ -110,7 +127,7 @@
if options.verbose > 0:
print 'Testing server %s: %s' % (server_class, ' '.join(server_args))
serverproc = subprocess.Popen(server_args)
- time.sleep(0.2)
+ time.sleep(0.15)
try:
if options.verbose > 0:
print 'Testing client: %s' % (' '.join(cli_args))
@@ -124,29 +141,47 @@
print 'FAIL: Server process (%s) failed with retcode %d' % (' '.join(server_args), serverproc.returncode)
raise Exception('Server subprocess %s died, args: %s' % (server_class, ' '.join(server_args)))
else:
- if server_class in EXTRA_DELAY:
- if options.verbose > 0:
- print 'Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child processes to terminate via alarm' % (server_class,
- proto, use_zlib, use_ssl, EXTRA_SLEEP)
- time.sleep(EXTRA_SLEEP)
+ extra_sleep = EXTRA_DELAY.get(server_class, 0)
+ if extra_sleep > 0 and options.verbose > 0:
+ print 'Giving %s (proto=%s,zlib=%s,ssl=%s) an extra %d seconds for child processes to terminate via alarm' % (server_class,
+ proto, use_zlib, use_ssl, extra_sleep)
+ time.sleep(extra_sleep)
os.kill(serverproc.pid, signal.SIGKILL)
# wait for shutdown
- time.sleep(0.1)
+ time.sleep(0.05)
test_count = 0
+# run tests without a client/server first
+print '----------------'
+print ' Executing individual test scripts with various generated code directories'
+print ' Directories to be tested: ' + ', '.join(generated_dirs)
+print ' Scripts to be tested: ' + ', '.join(SCRIPTS)
+print '----------------'
+for genpydir in generated_dirs:
+ for script in SCRIPTS:
+ runScriptTest(genpydir, script)
+
+print '----------------'
+print ' Executing Client/Server tests with various generated code directories'
+print ' Servers to be tested: ' + ', '.join(SERVERS)
+print ' Directories to be tested: ' + ', '.join(generated_dirs)
+print ' Protocols to be tested: ' + ', '.join(PROTOS)
+print ' Options to be tested: ZLIB(yes/no), SSL(yes/no)'
+print '----------------'
for try_server in SERVERS:
- for try_proto in PROTOS:
- for with_zlib in (False, True):
- # skip any servers that don't work with the Zlib transport
- if with_zlib and try_server in SKIP_ZLIB:
- continue
- for with_ssl in (False, True):
- # skip any servers that don't work with SSL
- if with_ssl and try_server in SKIP_SSL:
+ for genpydir in generated_dirs:
+ for try_proto in PROTOS:
+ for with_zlib in (False, True):
+ # skip any servers that don't work with the Zlib transport
+ if with_zlib and try_server in SKIP_ZLIB:
continue
- test_count += 1
- if options.verbose > 0:
- print '\nTest run #%d: Server=%s, Proto=%s, zlib=%s, SSL=%s' % (test_count, try_server, try_proto, with_zlib, with_ssl)
- runTest(try_server, try_proto, options.port, with_zlib, with_ssl)
- if options.verbose > 0:
- print 'OK: Finished %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.' % (try_server, try_proto, with_zlib, with_ssl, test_count)
+ for with_ssl in (False, True):
+ # skip any servers that don't work with SSL
+ if with_ssl and try_server in SKIP_SSL:
+ continue
+ test_count += 1
+ if options.verbose > 0:
+ print '\nTest run #%d: (includes %s) Server=%s, Proto=%s, zlib=%s, SSL=%s' % (test_count, genpydir, try_server, try_proto, with_zlib, with_ssl)
+ runServiceTest(genpydir, try_server, try_proto, options.port, with_zlib, with_ssl)
+ if options.verbose > 0:
+ print 'OK: Finished (includes %s) %s / %s proto / zlib=%s / SSL=%s. %d combinations tested.' % (genpydir, try_server, try_proto, with_zlib, with_ssl, test_count)
diff --git a/test/py/SerializationTest.py b/test/py/SerializationTest.py
index 3ba76fb..0664146 100755
--- a/test/py/SerializationTest.py
+++ b/test/py/SerializationTest.py
@@ -20,7 +20,12 @@
#
import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
from ThriftTest.ttypes import *
@@ -119,28 +124,86 @@
byte_list_map={0 : [], 1 : [1], 2 : [1, 2]},
)
+ self.nested_lists_i32x2 = NestedListsI32x2(
+ [
+ [ 1, 1, 2 ],
+ [ 2, 7, 9 ],
+ [ 3, 5, 8 ]
+ ]
+ )
+
+ self.nested_lists_i32x3 = NestedListsI32x3(
+ [
+ [
+ [ 2, 7, 9 ],
+ [ 3, 5, 8 ]
+ ],
+ [
+ [ 1, 1, 2 ],
+ [ 1, 4, 9 ]
+ ]
+ ]
+ )
+
+ self.nested_mixedx2 = NestedMixedx2( int_set_list=[
+ set([1,2,3]),
+ set([1,4,9]),
+ set([1,2,3,5,8,13,21]),
+ set([-1, 0, 1])
+ ],
+ # note, the sets below are sets of chars, since the strings are iterated
+ map_int_strset={ 10:set('abc'), 20:set('def'), 30:set('GHI') },
+ map_int_strset_list=[
+ { 10:set('abc'), 20:set('def'), 30:set('GHI') },
+ { 100:set('lmn'), 200:set('opq'), 300:set('RST') },
+ { 1000:set('uvw'), 2000:set('wxy'), 3000:set('XYZ') }
+ ]
+ )
+
+ self.nested_lists_bonk = NestedListsBonk(
+ [
+ [
+ [
+ Bonk(message='inner A first', type=1),
+ Bonk(message='inner A second', type=1)
+ ],
+ [
+ Bonk(message='inner B first', type=2),
+ Bonk(message='inner B second', type=2)
+ ]
+ ]
+ ]
+ )
+
+ self.list_bonks = ListBonks(
+ [
+ Bonk(message='inner A', type=1),
+ Bonk(message='inner B', type=2),
+ Bonk(message='inner C', type=0)
+ ]
+ )
def _serialize(self, obj):
- trans = TTransport.TMemoryBuffer()
- prot = self.protocol_factory.getProtocol(trans)
- obj.write(prot)
- return trans.getvalue()
+ trans = TTransport.TMemoryBuffer()
+ prot = self.protocol_factory.getProtocol(trans)
+ obj.write(prot)
+ return trans.getvalue()
def _deserialize(self, objtype, data):
- prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
- ret = objtype()
- ret.read(prot)
- return ret
+ prot = self.protocol_factory.getProtocol(TTransport.TMemoryBuffer(data))
+ ret = objtype()
+ ret.read(prot)
+ return ret
def testForwards(self):
- obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
- self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
- self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
+ obj = self._deserialize(VersioningTestV2, self._serialize(self.v1obj))
+ self.assertEquals(obj.begin_in_both, self.v1obj.begin_in_both)
+ self.assertEquals(obj.end_in_both, self.v1obj.end_in_both)
def testBackwards(self):
- obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
- self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
- self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
+ obj = self._deserialize(VersioningTestV1, self._serialize(self.v2obj))
+ self.assertEquals(obj.begin_in_both, self.v2obj.begin_in_both)
+ self.assertEquals(obj.end_in_both, self.v2obj.end_in_both)
def testSerializeV1(self):
obj = self._deserialize(VersioningTestV1, self._serialize(self.v1obj))
@@ -152,20 +215,57 @@
def testBools(self):
self.assertNotEquals(self.bools, self.bools_flipped)
+ self.assertNotEquals(self.bools, self.v1obj)
obj = self._deserialize(Bools, self._serialize(self.bools))
self.assertEquals(obj, self.bools)
obj = self._deserialize(Bools, self._serialize(self.bools_flipped))
self.assertEquals(obj, self.bools_flipped)
+ rep = repr(self.bools)
+ self.assertTrue(len(rep) > 0)
def testLargeDeltas(self):
# test large field deltas (meaningful in CompactProto only)
obj = self._deserialize(LargeDeltas, self._serialize(self.large_deltas))
self.assertEquals(obj, self.large_deltas)
+ rep = repr(self.large_deltas)
+ self.assertTrue(len(rep) > 0)
+
+ def testNestedListsI32x2(self):
+ obj = self._deserialize(NestedListsI32x2, self._serialize(self.nested_lists_i32x2))
+ self.assertEquals(obj, self.nested_lists_i32x2)
+ rep = repr(self.nested_lists_i32x2)
+ self.assertTrue(len(rep) > 0)
+
+ def testNestedListsI32x3(self):
+ obj = self._deserialize(NestedListsI32x3, self._serialize(self.nested_lists_i32x3))
+ self.assertEquals(obj, self.nested_lists_i32x3)
+ rep = repr(self.nested_lists_i32x3)
+ self.assertTrue(len(rep) > 0)
+
+ def testNestedMixedx2(self):
+ obj = self._deserialize(NestedMixedx2, self._serialize(self.nested_mixedx2))
+ self.assertEquals(obj, self.nested_mixedx2)
+ rep = repr(self.nested_mixedx2)
+ self.assertTrue(len(rep) > 0)
+
+ def testNestedListsBonk(self):
+ obj = self._deserialize(NestedListsBonk, self._serialize(self.nested_lists_bonk))
+ self.assertEquals(obj, self.nested_lists_bonk)
+ rep = repr(self.nested_lists_bonk)
+ self.assertTrue(len(rep) > 0)
+
+ def testListBonks(self):
+ obj = self._deserialize(ListBonks, self._serialize(self.list_bonks))
+ self.assertEquals(obj, self.list_bonks)
+ rep = repr(self.list_bonks)
+ self.assertTrue(len(rep) > 0)
def testCompactStruct(self):
# test large field deltas (meaningful in CompactProto only)
obj = self._deserialize(CompactProtoTestStruct, self._serialize(self.compact_struct))
self.assertEquals(obj, self.compact_struct)
+ rep = repr(self.compact_struct)
+ self.assertTrue(len(rep) > 0)
class NormalBinaryTest(AbstractTest):
protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
diff --git a/test/py/TestClient.py b/test/py/TestClient.py
index 6429ec3..e5d4326 100755
--- a/test/py/TestClient.py
+++ b/test/py/TestClient.py
@@ -20,23 +20,16 @@
#
import sys, glob
-sys.path.insert(0, './gen-py')
sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
-from ThriftTest import ThriftTest
-from ThriftTest.ttypes import *
-from thrift.transport import TTransport
-from thrift.transport import TSocket
-from thrift.transport import THttpClient
-from thrift.transport import TZlibTransport
-from thrift.protocol import TBinaryProtocol
-from thrift.protocol import TCompactProtocol
import unittest
import time
from optparse import OptionParser
-
parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir',
+ default='gen-py',
+ help='include this local directory in sys.path for locating generated code')
parser.add_option("--port", type="int", dest="port",
help="connect to server at port")
parser.add_option("--host", type="string", dest="host",
@@ -60,6 +53,17 @@
parser.set_defaults(framed=False, http_path=None, verbose=1, host='localhost', port=9090, proto='binary')
options, args = parser.parse_args()
+sys.path.insert(0, options.genpydir)
+
+from ThriftTest import ThriftTest
+from ThriftTest.ttypes import *
+from thrift.transport import TTransport
+from thrift.transport import TSocket
+from thrift.transport import THttpClient
+from thrift.transport import TZlibTransport
+from thrift.protocol import TBinaryProtocol
+from thrift.protocol import TCompactProtocol
+
class AbstractTest(unittest.TestCase):
def setUp(self):
if options.http_path:
@@ -176,6 +180,9 @@
except Xception, x:
self.assertEqual(x.errorCode, 1001)
self.assertEqual(x.message, 'Xception')
+ # ensure exception's repr method works
+ x_repr = repr(x)
+ self.assertEqual(x_repr, 'Xception(errorCode=1001, message=\'Xception\')')
try:
self.client.testException("throw_undeclared")
@@ -225,4 +232,4 @@
self.createTests()
if __name__ == "__main__":
- OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=2))
+ OwnArgsTestProgram(defaultTest="suite", testRunner=unittest.TextTestRunner(verbosity=1))
diff --git a/test/py/TestEof.py b/test/py/TestEof.py
index 7ff0b42..a9d81f1 100755
--- a/test/py/TestEof.py
+++ b/test/py/TestEof.py
@@ -20,7 +20,12 @@
#
import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
from ThriftTest import ThriftTest
diff --git a/test/py/TestServer.py b/test/py/TestServer.py
index fa62765..6f4af44 100755
--- a/test/py/TestServer.py
+++ b/test/py/TestServer.py
@@ -20,24 +20,13 @@
#
from __future__ import division
import sys, glob, time
-sys.path.insert(0, './gen-py')
sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
from optparse import OptionParser
-from ThriftTest import ThriftTest
-from ThriftTest.ttypes import *
-from thrift.transport import TTransport
-from thrift.transport import TSocket
-from thrift.transport import TZlibTransport
-from thrift.protocol import TBinaryProtocol
-from thrift.protocol import TCompactProtocol
-from thrift.server import TServer, TNonblockingServer, THttpServer
-
-PROT_FACTORIES = {'binary': TBinaryProtocol.TBinaryProtocolFactory,
- 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
- 'compact': TCompactProtocol.TCompactProtocolFactory}
-
parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir',
+ default='gen-py',
+ help='include this local directory in sys.path for locating generated code')
parser.add_option("--port", type="int", dest="port",
help="port number for server to listen on")
parser.add_option("--zlib", action="store_true", dest="zlib",
@@ -55,6 +44,21 @@
parser.set_defaults(port=9090, verbose=1, proto='binary')
options, args = parser.parse_args()
+sys.path.insert(0, options.genpydir)
+
+from ThriftTest import ThriftTest
+from ThriftTest.ttypes import *
+from thrift.transport import TTransport
+from thrift.transport import TSocket
+from thrift.transport import TZlibTransport
+from thrift.protocol import TBinaryProtocol
+from thrift.protocol import TCompactProtocol
+from thrift.server import TServer, TNonblockingServer, THttpServer
+
+PROT_FACTORIES = {'binary': TBinaryProtocol.TBinaryProtocolFactory,
+ 'accel': TBinaryProtocol.TBinaryProtocolAcceleratedFactory,
+ 'compact': TCompactProtocol.TCompactProtocolFactory}
+
class TestHandler:
def testVoid(self):
@@ -105,7 +109,7 @@
x.message = str
raise x
elif str == "throw_undeclared":
- raise ValueError("foo")
+ raise ValueError("Exception test PASSES.")
def testOneway(self, seconds):
if options.verbose > 1:
@@ -206,7 +210,10 @@
worker.terminate()
if options.verbose > 0:
print 'Requesting server to stop()'
- server.stop()
+ try:
+ server.stop()
+ except:
+ pass
signal.signal(signal.SIGALRM, clean_shutdown)
signal.alarm(2)
set_alarm()
diff --git a/test/py/TestSocket.py b/test/py/TestSocket.py
index 2f7353f..b9bdf27 100755
--- a/test/py/TestSocket.py
+++ b/test/py/TestSocket.py
@@ -20,7 +20,12 @@
#
import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
from ThriftTest import ThriftTest
diff --git a/test/py/TestSyntax.py b/test/py/TestSyntax.py
index df67d48..9f71cf5 100755
--- a/test/py/TestSyntax.py
+++ b/test/py/TestSyntax.py
@@ -20,7 +20,12 @@
#
import sys, glob
-sys.path.insert(0, './gen-py')
+from optparse import OptionParser
+parser = OptionParser()
+parser.add_option('--genpydir', type='string', dest='genpydir', default='gen-py')
+options, args = parser.parse_args()
+del sys.argv[1:] # clean up hack so unittest doesn't complain
+sys.path.insert(0, options.genpydir)
sys.path.insert(0, glob.glob('../../lib/py/build/lib.*')[0])
# Just import these generated files to make sure they are syntactically valid