THRIFT-2642 Recursive structs don't work in python
Client: Python
Patch: Eric Conner <eric@pinterest.com>
This closes #1293
diff --git a/compiler/cpp/src/thrift/generate/t_py_generator.cc b/compiler/cpp/src/thrift/generate/t_py_generator.cc
index 6b8697d..9d50aaf 100644
--- a/compiler/cpp/src/thrift/generate/t_py_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_py_generator.cc
@@ -150,6 +150,7 @@
void generate_enum(t_enum* tenum);
void generate_const(t_const* tconst);
void generate_struct(t_struct* tstruct);
+ void generate_forward_declaration(t_struct* tstruct);
void generate_xception(t_struct* txception);
void generate_service(t_service* tservice);
@@ -160,6 +161,7 @@
*/
void generate_py_struct(t_struct* tstruct, bool is_exception);
+ void generate_py_thrift_spec(std::ofstream& out, t_struct* tstruct, bool is_exception);
void generate_py_struct_definition(std::ofstream& out,
t_struct* tstruct,
bool is_xception = false);
@@ -380,6 +382,8 @@
<< "from thrift.transport import TTransport" << endl
<< import_dynbase_;
+ f_types_ << "all_structs = []" << endl;
+
f_consts_ <<
py_autogen_comment() << endl <<
py_imports() << endl <<
@@ -419,7 +423,11 @@
ss << "from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, "
"TApplicationException"
<< endl
- << "from thrift.protocol.TProtocol import TProtocolException";
+ << "from thrift.protocol.TProtocol import TProtocolException"
+ << endl
+ << "from thrift.TRecursive import fix_spec"
+ << endl;
+
if (gen_utf8strings_) {
ss << endl << "import sys";
}
@@ -430,6 +438,11 @@
* Closes the type files
*/
void t_py_generator::close_generator() {
+
+ // Fix thrift_spec definitions for recursive structs.
+ f_types_ << "fix_spec(all_structs)" << endl;
+ f_types_ << "del all_structs" << endl;
+
// Close types file
f_types_.close();
f_consts_.close();
@@ -610,11 +623,21 @@
return out.str();
}
+/**
+ * Generates the "forward declarations" for python structs.
+ * These are actually full class definitions so that calls to generate_struct
+ * can add the thrift_spec field. This is needed so that all thrift_spec
+ * definitions are grouped at the end of the file to enable co-recursive structs.
+ */
+void t_py_generator::generate_forward_declaration(t_struct* tstruct) {
+ generate_py_struct(tstruct, tstruct->is_xception());
+}
+
/**
* Generates a python struct
*/
void t_py_generator::generate_struct(t_struct* tstruct) {
- generate_py_struct(tstruct, false);
+ generate_py_thrift_spec(f_types_, tstruct, false);
}
/**
@@ -624,7 +647,7 @@
* @param txception The struct definition
*/
void t_py_generator::generate_xception(t_struct* txception) {
- generate_py_struct(txception, true);
+ generate_py_thrift_spec(f_types_, txception, true);
}
/**
@@ -634,6 +657,54 @@
generate_py_struct_definition(f_types_, tstruct, is_exception);
}
+
+/**
+ * Generate the thrift_spec for a struct
+ * For example,
+ * all_structs.append(Recursive)
+ * Recursive.thrift_spec = (
+ * None, # 0
+ * (1, TType.LIST, 'Children', (TType.STRUCT, (Recursive, None), False), None, ), # 1
+ * )
+ */
+void t_py_generator::generate_py_thrift_spec(ofstream& out,
+ t_struct* tstruct,
+ bool /*is_exception*/) {
+ const vector<t_field*>& sorted_members = tstruct->get_sorted_members();
+ vector<t_field*>::const_iterator m_iter;
+
+ // Add struct definition to list so thrift_spec can be fixed for recursive structures.
+ indent(out) << "all_structs.append(" << tstruct->get_name() << ")" << endl;
+
+ if (sorted_members.empty() || (sorted_members[0]->get_key() >= 0)) {
+ indent(out) << tstruct->get_name() << ".thrift_spec = (" << endl;
+ indent_up();
+
+ int sorted_keys_pos = 0;
+ for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
+
+ for (; sorted_keys_pos != (*m_iter)->get_key(); sorted_keys_pos++) {
+ indent(out) << "None, # " << sorted_keys_pos << endl;
+ }
+
+ indent(out) << "(" << (*m_iter)->get_key() << ", " << type_to_enum((*m_iter)->get_type())
+ << ", "
+ << "'" << (*m_iter)->get_name() << "'"
+ << ", " << type_to_spec_args((*m_iter)->get_type()) << ", "
+ << render_field_default_value(*m_iter) << ", "
+ << "),"
+ << " # " << sorted_keys_pos << endl;
+
+ sorted_keys_pos++;
+ }
+
+ indent_down();
+ indent(out) << ")" << endl;
+ } else {
+ indent(out) << tstruct->get_name() << ".thrift_spec = ()" << endl;
+ }
+}
+
/**
* Generates a struct definition for a thrift data type.
*
@@ -702,43 +773,14 @@
// for structures with no members.
// TODO(dreiss): Test encoding of structs where some inner structs
// don't have thrift_spec.
- if (sorted_members.empty() || (sorted_members[0]->get_key() >= 0)) {
- indent(out) << "thrift_spec = (" << endl;
- indent_up();
-
- int sorted_keys_pos = 0;
- for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
-
- for (; sorted_keys_pos != (*m_iter)->get_key(); sorted_keys_pos++) {
- indent(out) << "None, # " << sorted_keys_pos << endl;
- }
-
- indent(out) << "(" << (*m_iter)->get_key() << ", " << type_to_enum((*m_iter)->get_type())
- << ", "
- << "'" << (*m_iter)->get_name() << "'"
- << ", " << type_to_spec_args((*m_iter)->get_type()) << ", "
- << render_field_default_value(*m_iter) << ", "
- << "),"
- << " # " << sorted_keys_pos << endl;
-
- sorted_keys_pos++;
- }
-
- indent_down();
- indent(out) << ")" << endl;
- } else {
- indent(out) << "thrift_spec = None" << endl;
- }
if (members.size() > 0) {
out << endl;
out << indent() << "def __init__(self,";
for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
- // This fills in default values, as opposed to nulls
out << " " << declare_argument(*m_iter) << ",";
}
-
out << "):" << endl;
indent_up();
@@ -887,9 +929,9 @@
indent_up();
if (is_immutable(tstruct)) {
- indent(out) << "return iprot._fast_decode(None, iprot, (cls, cls.thrift_spec))" << endl;
+ indent(out) << "return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])" << endl;
} else {
- indent(out) << "iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec))" << endl;
+ indent(out) << "iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])" << endl;
indent(out) << "return" << endl;
}
indent_down();
@@ -970,7 +1012,7 @@
indent_up();
indent(out)
- << "oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec)))"
+ << "oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))"
<< endl;
indent(out) << "return" << endl;
indent_down();
@@ -1059,6 +1101,8 @@
f_service_ << "from tornado import concurrent" << endl;
}
+ f_service_ << "all_structs = []" << endl;
+
// Generate the three main parts of the service
generate_service_interface(tservice);
generate_service_client(tservice);
@@ -1067,6 +1111,8 @@
generate_service_remote(tservice);
// Close service file
+ f_service_ << "fix_spec(all_structs)" << endl
+ << "del all_structs" << endl << endl;
f_service_.close();
}
@@ -1084,6 +1130,7 @@
for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
t_struct* ts = (*f_iter)->get_arglist();
generate_py_struct_definition(f_service_, ts, false);
+ generate_py_thrift_spec(f_service_, ts, false);
generate_py_function_helpers(*f_iter);
}
}
@@ -1108,6 +1155,7 @@
result.append(*f_iter);
}
generate_py_struct_definition(f_service_, &result, false);
+ generate_py_thrift_spec(f_service_, &result, false);
}
}
@@ -2456,7 +2504,7 @@
std::ostringstream result;
result << tfield->get_name() << "=";
if (tfield->get_value() != NULL) {
- result << "thrift_spec[" << tfield->get_key() << "][4]";
+ result << render_field_default_value(tfield);
} else {
result << "None";
}
@@ -2607,7 +2655,7 @@
} else if (ttype->is_base_type() || ttype->is_enum()) {
return "None";
} else if (ttype->is_struct() || ttype->is_xception()) {
- return "(" + type_name(ttype) + ", " + type_name(ttype) + ".thrift_spec)";
+ return "[" + type_name(ttype) + ", None]";
} else if (ttype->is_map()) {
return "(" + type_to_enum(((t_map*)ttype)->get_key_type()) + ", "
+ type_to_spec_args(((t_map*)ttype)->get_key_type()) + ", "
diff --git a/lib/py/src/TRecursive.py b/lib/py/src/TRecursive.py
new file mode 100644
index 0000000..d5a5686
--- /dev/null
+++ b/lib/py/src/TRecursive.py
@@ -0,0 +1,83 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from thrift.Thrift import TType
+
+TYPE_IDX = 1
+SPEC_ARGS_IDX = 3
+SPEC_ARGS_CLASS_REF_IDX = 0
+SPEC_ARGS_THRIFT_SPEC_IDX = 1
+
+
+def fix_spec(all_structs):
+ """Wire up recursive references for all TStruct definitions inside of each thrift_spec."""
+ for struc in all_structs:
+ spec = struc.thrift_spec
+ for thrift_spec in spec:
+ if thrift_spec is None:
+ continue
+ elif thrift_spec[TYPE_IDX] == TType.STRUCT:
+ other = thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_CLASS_REF_IDX].thrift_spec
+ thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_THRIFT_SPEC_IDX] = other
+ elif thrift_spec[TYPE_IDX] in (TType.LIST, TType.SET):
+ _fix_list_or_set(thrift_spec[SPEC_ARGS_IDX])
+ elif thrift_spec[TYPE_IDX] == TType.MAP:
+ _fix_map(thrift_spec[SPEC_ARGS_IDX])
+
+
+def _fix_list_or_set(element_type):
+ # For a list or set, the thrift_spec entry looks like,
+ # (1, TType.LIST, 'lister', (TType.STRUCT, [RecList, None], False), None, ), # 1
+ # so ``element_type`` will be,
+ # (TType.STRUCT, [RecList, None], False)
+ if element_type[0] == TType.STRUCT:
+ element_type[1][1] = element_type[1][0].thrift_spec
+ elif element_type[0] in (TType.LIST, TType.SET):
+ _fix_list_or_set(element_type[1])
+ elif element_type[0] == TType.MAP:
+ _fix_map(element_type[1])
+
+
+def _fix_map(element_type):
+ # For a map of key -> value type, ``element_type`` will be,
+ # (TType.I16, None, TType.STRUCT, [RecMapBasic, None], False), None, )
+ # which is just a normal struct definition.
+ #
+ # For a map of key -> list / set, ``element_type`` will be,
+ # (TType.I16, None, TType.LIST, (TType.STRUCT, [RecMapList, None], False), False)
+ # and we need to process the 3rd element as a list.
+ #
+ # For a map of key -> map, ``element_type`` will be,
+ # (TType.I16, None, TType.MAP, (TType.I16, None, TType.STRUCT,
+ # [RecMapMap, None], False), False)
+ # and need to process 3rd element as a map.
+
+ # Is the map key a struct?
+ if element_type[0] == TType.STRUCT:
+ element_type[1][1] = element_type[1][0].thrift_spec
+ elif element_type[0] in (TType.LIST, TType.SET):
+ _fix_list_or_set(element_type[1])
+ elif element_type[0] == TType.MAP:
+ _fix_map(element_type[1])
+
+ # Is the map value a struct?
+ if element_type[2] == TType.STRUCT:
+ element_type[3][1] = element_type[3][0].thrift_spec
+ elif element_type[2] in (TType.LIST, TType.SET):
+ _fix_list_or_set(element_type[3])
+ elif element_type[2] == TType.MAP:
+ _fix_map(element_type[3])
diff --git a/lib/py/src/ext/types.cpp b/lib/py/src/ext/types.cpp
index 849ab2f..68443fb 100644
--- a/lib/py/src/ext/types.cpp
+++ b/lib/py/src/ext/types.cpp
@@ -98,13 +98,13 @@
}
bool parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) {
- if (PyTuple_Size(typeargs) != 2) {
- PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args");
+ if (PyList_Size(typeargs) != 2) {
+ PyErr_SetString(PyExc_TypeError, "expecting list of size 2 for struct args");
return false;
}
- dest->klass = PyTuple_GET_ITEM(typeargs, 0);
- dest->spec = PyTuple_GET_ITEM(typeargs, 1);
+ dest->klass = PyList_GET_ITEM(typeargs, 0);
+ dest->spec = PyList_GET_ITEM(typeargs, 1);
return true;
}
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 55da19e..9ae1b11 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -44,14 +44,14 @@
if (iprot._fast_decode is not None and
isinstance(iprot.trans, TTransport.CReadableTransport) and
self.thrift_spec is not None):
- iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec))
+ iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])
else:
iprot.readStruct(self, self.thrift_spec)
def write(self, oprot):
if (oprot._fast_encode is not None and self.thrift_spec is not None):
oprot.trans.write(
- oprot._fast_encode(self, (self.__class__, self.thrift_spec)))
+ oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
else:
oprot.writeStruct(self, self.thrift_spec)
@@ -77,6 +77,6 @@
cls.thrift_spec is not None):
self = cls()
return iprot._fast_decode(None, iprot,
- (self.__class__, self.thrift_spec))
+ [self.__class__, self.thrift_spec])
else:
return iprot.readStruct(cls, cls.thrift_spec, True)
diff --git a/test/py/Makefile.am b/test/py/Makefile.am
index f105737..53c1e63 100644
--- a/test/py/Makefile.am
+++ b/test/py/Makefile.am
@@ -25,18 +25,26 @@
thrift_gen = \
gen-py/ThriftTest/__init__.py \
gen-py/DebugProtoTest/__init__.py \
+ gen-py/Recursive/__init__.py \
gen-py-default/ThriftTest/__init__.py \
gen-py-default/DebugProtoTest/__init__.py \
+ gen-py-default/Recursive/__init__.py \
gen-py-slots/ThriftTest/__init__.py \
gen-py-slots/DebugProtoTest/__init__.py \
+ gen-py-slots/Recursive/__init__.py \
gen-py-oldstyle/ThriftTest/__init__.py \
gen-py-oldstyle/DebugProtoTest/__init__.py \
+ gen-py-oldstyle/Recursive/__init__.py \
gen-py-no_utf8strings/ThriftTest/__init__.py \
gen-py-no_utf8strings/DebugProtoTest/__init__.py \
+ gen-py-no_utf8strings/Recursive/__init__.py \
gen-py-dynamic/ThriftTest/__init__.py \
gen-py-dynamic/DebugProtoTest/__init__.py \
+ gen-py-dynamic/Recursive/__init__.py \
gen-py-dynamicslots/ThriftTest/__init__.py \
- gen-py-dynamicslots/DebugProtoTest/__init__.py
+ gen-py-dynamicslots/DebugProtoTest/__init__.py \
+ gen-py-dynamicslots/Recursive/__init__.py
+
precross: $(thrift_gen)
BUILT_SOURCES = $(thrift_gen)
diff --git a/test/py/SerializationTest.py b/test/py/SerializationTest.py
index efe3c6d..b080d87 100755
--- a/test/py/SerializationTest.py
+++ b/test/py/SerializationTest.py
@@ -35,6 +35,11 @@
Xtruct2,
)
+from Recursive.ttypes import RecTree
+from Recursive.ttypes import RecList
+from Recursive.ttypes import CoRec
+from Recursive.ttypes import CoRec2
+from Recursive.ttypes import VectorTest
from DebugProtoTest.ttypes import CompactProtoTestStruct, Empty
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol
@@ -285,6 +290,67 @@
for value in bad_values:
self.assertRaises(Exception, self._serialize, value)
+ def testRecTree(self):
+ """Ensure recursive tree node can be created."""
+ children = []
+ for idx in range(1, 5):
+ node = RecTree(item=idx, children=None)
+ children.append(node)
+
+ parent = RecTree(item=0, children=children)
+ serde_parent = self._deserialize(RecTree, self._serialize(parent))
+ self.assertEquals(0, serde_parent.item)
+ self.assertEquals(4, len(serde_parent.children))
+ for child in serde_parent.children:
+ # Cannot use assertIsInstance in python 2.6?
+ self.assertTrue(isinstance(child, RecTree))
+
+ def _buildLinkedList(self):
+ head = cur = RecList(item=0)
+ for idx in range(1, 5):
+ node = RecList(item=idx)
+ cur.nextitem = node
+ cur = node
+ return head
+
+ def _collapseLinkedList(self, head):
+ out_list = []
+ cur = head
+ while cur is not None:
+ out_list.append(cur.item)
+ cur = cur.nextitem
+ return out_list
+
+ def testRecList(self):
+ """Ensure recursive linked list can be created."""
+ rec_list = self._buildLinkedList()
+ serde_list = self._deserialize(RecList, self._serialize(rec_list))
+ out_list = self._collapseLinkedList(serde_list)
+ self.assertEquals([0, 1, 2, 3, 4], out_list)
+
+ def testCoRec(self):
+ """Ensure co-recursive structures can be created."""
+ item1 = CoRec()
+ item2 = CoRec2()
+
+ item1.other = item2
+ item2.other = item1
+
+ # NOTE [econner724,2017-06-21]: These objects cannot be serialized as serialization
+ # results in an infinite loop. fbthrift also suffers from this
+ # problem.
+
+ def testRecVector(self):
+ """Ensure a list of recursive nodes can be created."""
+ mylist = [self._buildLinkedList(), self._buildLinkedList()]
+ myvec = VectorTest(lister=mylist)
+
+ serde_vec = self._deserialize(VectorTest, self._serialize(myvec))
+ golden_list = [0, 1, 2, 3, 4]
+ for cur_list in serde_vec.lister:
+ out_list = self._collapseLinkedList(cur_list)
+ self.assertEqual(golden_list, out_list)
+
class NormalBinaryTest(AbstractTest):
protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
diff --git a/test/py/generate.cmake b/test/py/generate.cmake
index 44c5357..46263c8 100644
--- a/test/py/generate.cmake
+++ b/test/py/generate.cmake
@@ -20,3 +20,10 @@
generate(${MY_PROJECT_DIR}/test/DebugProtoTest.thrift py:no_utf8strings gen-py-no_utf8strings)
generate(${MY_PROJECT_DIR}/test/DebugProtoTest.thrift py:dynamic gen-py-dynamic)
generate(${MY_PROJECT_DIR}/test/DebugProtoTest.thrift py:dynamic,slots gen-py-dynamicslots)
+
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py gen-py-default)
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py:slots gen-py-slots)
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py:old_style gen-py-oldstyle)
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py:no_utf8strings gen-py-no_utf8strings)
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py:dynamic gen-py-dynamic)
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py:dynamic,slots gen-py-dynamicslots)