THRIFT-2642 Recursive structs don't work in python
Client: Python
Patch: Eric Conner <eric@pinterest.com>

This closes #1293
diff --git a/compiler/cpp/src/thrift/generate/t_py_generator.cc b/compiler/cpp/src/thrift/generate/t_py_generator.cc
index 6b8697d..9d50aaf 100644
--- a/compiler/cpp/src/thrift/generate/t_py_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_py_generator.cc
@@ -150,6 +150,7 @@
   void generate_enum(t_enum* tenum);
   void generate_const(t_const* tconst);
   void generate_struct(t_struct* tstruct);
+  void generate_forward_declaration(t_struct* tstruct);
   void generate_xception(t_struct* txception);
   void generate_service(t_service* tservice);
 
@@ -160,6 +161,7 @@
    */
 
   void generate_py_struct(t_struct* tstruct, bool is_exception);
+  void generate_py_thrift_spec(std::ofstream& out, t_struct* tstruct, bool is_exception);
   void generate_py_struct_definition(std::ofstream& out,
                                      t_struct* tstruct,
                                      bool is_xception = false);
@@ -380,6 +382,8 @@
            << "from thrift.transport import TTransport" << endl
            << import_dynbase_;
 
+  f_types_ << "all_structs = []" << endl;
+
   f_consts_ <<
     py_autogen_comment() << endl <<
     py_imports() << endl <<
@@ -419,7 +423,11 @@
   ss << "from thrift.Thrift import TType, TMessageType, TFrozenDict, TException, "
         "TApplicationException"
      << endl
-     << "from thrift.protocol.TProtocol import TProtocolException";
+     << "from thrift.protocol.TProtocol import TProtocolException"
+     << endl
+     << "from thrift.TRecursive import fix_spec" 
+     << endl;
+
   if (gen_utf8strings_) {
     ss << endl << "import sys";
   }
@@ -430,6 +438,11 @@
  * Closes the type files
  */
 void t_py_generator::close_generator() {
+
+  // Fix thrift_spec definitions for recursive structs.
+  f_types_ << "fix_spec(all_structs)" << endl;
+  f_types_ << "del all_structs" << endl;
+
   // Close types file
   f_types_.close();
   f_consts_.close();
@@ -610,11 +623,21 @@
   return out.str();
 }
 
+/** 
+ * Generates the "forward declarations" for python structs.
+ * These are actually full class definitions so that calls to generate_struct
+ * can add the thrift_spec field.  This is needed so that all thrift_spec 
+ * definitions are grouped at the end of the file to enable co-recursive structs.
+ */
+void t_py_generator::generate_forward_declaration(t_struct* tstruct) {
+    generate_py_struct(tstruct, tstruct->is_xception());
+}
+
 /**
  * Generates a python struct
  */
 void t_py_generator::generate_struct(t_struct* tstruct) {
-  generate_py_struct(tstruct, false);
+  generate_py_thrift_spec(f_types_, tstruct, false);
 }
 
 /**
@@ -624,7 +647,7 @@
  * @param txception The struct definition
  */
 void t_py_generator::generate_xception(t_struct* txception) {
-  generate_py_struct(txception, true);
+  generate_py_thrift_spec(f_types_, txception, true);
 }
 
 /**
@@ -634,6 +657,54 @@
   generate_py_struct_definition(f_types_, tstruct, is_exception);
 }
 
+
+/**
+ * Generate the thrift_spec for a struct
+ * For example,
+ *   all_structs.append(Recursive)
+ *   Recursive.thrift_spec = (
+ *       None,  # 0
+ *       (1, TType.LIST, 'Children', (TType.STRUCT, (Recursive, None), False), None, ),  # 1
+ *   )
+ */
+void t_py_generator::generate_py_thrift_spec(ofstream& out,
+                                             t_struct* tstruct,
+                                             bool /*is_exception*/) {
+  const vector<t_field*>& sorted_members = tstruct->get_sorted_members();
+  vector<t_field*>::const_iterator m_iter;
+
+  // Add struct definition to list so thrift_spec can be fixed for recursive structures.
+  indent(out) << "all_structs.append(" << tstruct->get_name() << ")" << endl;
+
+  if (sorted_members.empty() || (sorted_members[0]->get_key() >= 0)) {
+    indent(out) << tstruct->get_name() << ".thrift_spec = (" << endl;
+    indent_up();
+
+    int sorted_keys_pos = 0;
+    for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
+
+      for (; sorted_keys_pos != (*m_iter)->get_key(); sorted_keys_pos++) {
+        indent(out) << "None,  # " << sorted_keys_pos << endl;
+      }
+
+      indent(out) << "(" << (*m_iter)->get_key() << ", " << type_to_enum((*m_iter)->get_type())
+                  << ", "
+                  << "'" << (*m_iter)->get_name() << "'"
+                  << ", " << type_to_spec_args((*m_iter)->get_type()) << ", "
+                  << render_field_default_value(*m_iter) << ", "
+                  << "),"
+                  << "  # " << sorted_keys_pos << endl;
+
+      sorted_keys_pos++;
+    }
+
+    indent_down();
+    indent(out) << ")" << endl;
+  } else {
+    indent(out) << tstruct->get_name() << ".thrift_spec = ()" << endl;
+  }
+}
+
 /**
  * Generates a struct definition for a thrift data type.
  *
@@ -702,43 +773,14 @@
   // for structures with no members.
   // TODO(dreiss): Test encoding of structs where some inner structs
   // don't have thrift_spec.
-  if (sorted_members.empty() || (sorted_members[0]->get_key() >= 0)) {
-    indent(out) << "thrift_spec = (" << endl;
-    indent_up();
-
-    int sorted_keys_pos = 0;
-    for (m_iter = sorted_members.begin(); m_iter != sorted_members.end(); ++m_iter) {
-
-      for (; sorted_keys_pos != (*m_iter)->get_key(); sorted_keys_pos++) {
-        indent(out) << "None,  # " << sorted_keys_pos << endl;
-      }
-
-      indent(out) << "(" << (*m_iter)->get_key() << ", " << type_to_enum((*m_iter)->get_type())
-                  << ", "
-                  << "'" << (*m_iter)->get_name() << "'"
-                  << ", " << type_to_spec_args((*m_iter)->get_type()) << ", "
-                  << render_field_default_value(*m_iter) << ", "
-                  << "),"
-                  << "  # " << sorted_keys_pos << endl;
-
-      sorted_keys_pos++;
-    }
-
-    indent_down();
-    indent(out) << ")" << endl;
-  } else {
-    indent(out) << "thrift_spec = None" << endl;
-  }
 
   if (members.size() > 0) {
     out << endl;
     out << indent() << "def __init__(self,";
 
     for (m_iter = members.begin(); m_iter != members.end(); ++m_iter) {
-      // This fills in default values, as opposed to nulls
       out << " " << declare_argument(*m_iter) << ",";
     }
-
     out << "):" << endl;
 
     indent_up();
@@ -887,9 +929,9 @@
   indent_up();
 
   if (is_immutable(tstruct)) {
-    indent(out) << "return iprot._fast_decode(None, iprot, (cls, cls.thrift_spec))" << endl;
+    indent(out) << "return iprot._fast_decode(None, iprot, [cls, cls.thrift_spec])" << endl;
   } else {
-    indent(out) << "iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec))" << endl;
+    indent(out) << "iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])" << endl;
     indent(out) << "return" << endl;
   }
   indent_down();
@@ -970,7 +1012,7 @@
   indent_up();
 
   indent(out)
-      << "oprot.trans.write(oprot._fast_encode(self, (self.__class__, self.thrift_spec)))"
+      << "oprot.trans.write(oprot._fast_encode(self, [self.__class__, self.thrift_spec]))"
       << endl;
   indent(out) << "return" << endl;
   indent_down();
@@ -1059,6 +1101,8 @@
     f_service_ << "from tornado import concurrent" << endl;
   }
 
+  f_service_ << "all_structs = []" << endl;
+
   // Generate the three main parts of the service
   generate_service_interface(tservice);
   generate_service_client(tservice);
@@ -1067,6 +1111,8 @@
   generate_service_remote(tservice);
 
   // Close service file
+  f_service_ << "fix_spec(all_structs)" << endl
+             << "del all_structs" << endl << endl;
   f_service_.close();
 }
 
@@ -1084,6 +1130,7 @@
   for (f_iter = functions.begin(); f_iter != functions.end(); ++f_iter) {
     t_struct* ts = (*f_iter)->get_arglist();
     generate_py_struct_definition(f_service_, ts, false);
+    generate_py_thrift_spec(f_service_, ts, false);
     generate_py_function_helpers(*f_iter);
   }
 }
@@ -1108,6 +1155,7 @@
       result.append(*f_iter);
     }
     generate_py_struct_definition(f_service_, &result, false);
+    generate_py_thrift_spec(f_service_, &result, false);
   }
 }
 
@@ -2456,7 +2504,7 @@
   std::ostringstream result;
   result << tfield->get_name() << "=";
   if (tfield->get_value() != NULL) {
-    result << "thrift_spec[" << tfield->get_key() << "][4]";
+    result << render_field_default_value(tfield);
   } else {
     result << "None";
   }
@@ -2607,7 +2655,7 @@
   } else if (ttype->is_base_type() || ttype->is_enum()) {
     return  "None";
   } else if (ttype->is_struct() || ttype->is_xception()) {
-    return "(" + type_name(ttype) + ", " + type_name(ttype) + ".thrift_spec)";
+    return "[" + type_name(ttype) + ", None]";
   } else if (ttype->is_map()) {
     return "(" + type_to_enum(((t_map*)ttype)->get_key_type()) + ", "
            + type_to_spec_args(((t_map*)ttype)->get_key_type()) + ", "
diff --git a/lib/py/src/TRecursive.py b/lib/py/src/TRecursive.py
new file mode 100644
index 0000000..d5a5686
--- /dev/null
+++ b/lib/py/src/TRecursive.py
@@ -0,0 +1,83 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+from thrift.Thrift import TType
+
+TYPE_IDX = 1
+SPEC_ARGS_IDX = 3
+SPEC_ARGS_CLASS_REF_IDX = 0
+SPEC_ARGS_THRIFT_SPEC_IDX = 1
+
+
+def fix_spec(all_structs):
+    """Wire up recursive references for all TStruct definitions inside of each thrift_spec."""
+    for struc in all_structs:
+        spec = struc.thrift_spec
+        for thrift_spec in spec:
+            if thrift_spec is None:
+                continue
+            elif thrift_spec[TYPE_IDX] == TType.STRUCT:
+                other = thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_CLASS_REF_IDX].thrift_spec
+                thrift_spec[SPEC_ARGS_IDX][SPEC_ARGS_THRIFT_SPEC_IDX] = other
+            elif thrift_spec[TYPE_IDX] in (TType.LIST, TType.SET):
+                _fix_list_or_set(thrift_spec[SPEC_ARGS_IDX])
+            elif thrift_spec[TYPE_IDX] == TType.MAP:
+                _fix_map(thrift_spec[SPEC_ARGS_IDX])
+
+
+def _fix_list_or_set(element_type):
+    # For a list or set, the thrift_spec entry looks like, 
+    # (1, TType.LIST, 'lister', (TType.STRUCT, [RecList, None], False), None, ),  # 1
+    # so ``element_type`` will be,
+    # (TType.STRUCT, [RecList, None], False)
+    if element_type[0] == TType.STRUCT:
+        element_type[1][1] = element_type[1][0].thrift_spec
+    elif element_type[0] in (TType.LIST, TType.SET):
+        _fix_list_or_set(element_type[1])
+    elif element_type[0] == TType.MAP:
+        _fix_map(element_type[1])
+
+
+def _fix_map(element_type):
+    # For a map of key -> value type, ``element_type`` will be,
+    # (TType.I16, None, TType.STRUCT, [RecMapBasic, None], False), None, )
+    # which is just a normal struct definition.
+    #
+    # For a map of key -> list / set, ``element_type`` will be,
+    # (TType.I16, None, TType.LIST, (TType.STRUCT, [RecMapList, None], False), False)
+    # and we need to process the 3rd element as a list.
+    # 
+    # For a map of key -> map, ``element_type`` will be,
+    # (TType.I16, None, TType.MAP, (TType.I16, None, TType.STRUCT, 
+    #  [RecMapMap, None], False), False)
+    # and need to process 3rd element as a map.
+
+    # Is the map key a struct?
+    if element_type[0] == TType.STRUCT:
+        element_type[1][1] = element_type[1][0].thrift_spec
+    elif element_type[0] in (TType.LIST, TType.SET):
+        _fix_list_or_set(element_type[1])
+    elif element_type[0] == TType.MAP:
+        _fix_map(element_type[1])
+
+    # Is the map value a struct?
+    if element_type[2] == TType.STRUCT:
+        element_type[3][1] = element_type[3][0].thrift_spec
+    elif element_type[2] in (TType.LIST, TType.SET):
+        _fix_list_or_set(element_type[3])
+    elif element_type[2] == TType.MAP:
+        _fix_map(element_type[3])
diff --git a/lib/py/src/ext/types.cpp b/lib/py/src/ext/types.cpp
index 849ab2f..68443fb 100644
--- a/lib/py/src/ext/types.cpp
+++ b/lib/py/src/ext/types.cpp
@@ -98,13 +98,13 @@
 }
 
 bool parse_struct_args(StructTypeArgs* dest, PyObject* typeargs) {
-  if (PyTuple_Size(typeargs) != 2) {
-    PyErr_SetString(PyExc_TypeError, "expecting tuple of size 2 for struct args");
+  if (PyList_Size(typeargs) != 2) {
+    PyErr_SetString(PyExc_TypeError, "expecting list of size 2 for struct args");
     return false;
   }
 
-  dest->klass = PyTuple_GET_ITEM(typeargs, 0);
-  dest->spec = PyTuple_GET_ITEM(typeargs, 1);
+  dest->klass = PyList_GET_ITEM(typeargs, 0);
+  dest->spec = PyList_GET_ITEM(typeargs, 1);
 
   return true;
 }
diff --git a/lib/py/src/protocol/TBase.py b/lib/py/src/protocol/TBase.py
index 55da19e..9ae1b11 100644
--- a/lib/py/src/protocol/TBase.py
+++ b/lib/py/src/protocol/TBase.py
@@ -44,14 +44,14 @@
         if (iprot._fast_decode is not None and
                 isinstance(iprot.trans, TTransport.CReadableTransport) and
                 self.thrift_spec is not None):
-            iprot._fast_decode(self, iprot, (self.__class__, self.thrift_spec))
+            iprot._fast_decode(self, iprot, [self.__class__, self.thrift_spec])
         else:
             iprot.readStruct(self, self.thrift_spec)
 
     def write(self, oprot):
         if (oprot._fast_encode is not None and self.thrift_spec is not None):
             oprot.trans.write(
-                oprot._fast_encode(self, (self.__class__, self.thrift_spec)))
+                oprot._fast_encode(self, [self.__class__, self.thrift_spec]))
         else:
             oprot.writeStruct(self, self.thrift_spec)
 
@@ -77,6 +77,6 @@
                 cls.thrift_spec is not None):
             self = cls()
             return iprot._fast_decode(None, iprot,
-                                      (self.__class__, self.thrift_spec))
+                                      [self.__class__, self.thrift_spec])
         else:
             return iprot.readStruct(cls, cls.thrift_spec, True)
diff --git a/test/py/Makefile.am b/test/py/Makefile.am
index f105737..53c1e63 100644
--- a/test/py/Makefile.am
+++ b/test/py/Makefile.am
@@ -25,18 +25,26 @@
 thrift_gen =                                    \
         gen-py/ThriftTest/__init__.py           \
         gen-py/DebugProtoTest/__init__.py \
+        gen-py/Recursive/__init__.py \
         gen-py-default/ThriftTest/__init__.py           \
         gen-py-default/DebugProtoTest/__init__.py \
+        gen-py-default/Recursive/__init__.py \
         gen-py-slots/ThriftTest/__init__.py           \
         gen-py-slots/DebugProtoTest/__init__.py \
+        gen-py-slots/Recursive/__init__.py \
         gen-py-oldstyle/ThriftTest/__init__.py \
         gen-py-oldstyle/DebugProtoTest/__init__.py \
+        gen-py-oldstyle/Recursive/__init__.py \
         gen-py-no_utf8strings/ThriftTest/__init__.py \
         gen-py-no_utf8strings/DebugProtoTest/__init__.py \
+        gen-py-no_utf8strings/Recursive/__init__.py \
         gen-py-dynamic/ThriftTest/__init__.py           \
         gen-py-dynamic/DebugProtoTest/__init__.py \
+        gen-py-dynamic/Recursive/__init__.py \
         gen-py-dynamicslots/ThriftTest/__init__.py           \
-        gen-py-dynamicslots/DebugProtoTest/__init__.py
+        gen-py-dynamicslots/DebugProtoTest/__init__.py \
+        gen-py-dynamicslots/Recursive/__init__.py
+
 
 precross: $(thrift_gen)
 BUILT_SOURCES = $(thrift_gen)
diff --git a/test/py/SerializationTest.py b/test/py/SerializationTest.py
index efe3c6d..b080d87 100755
--- a/test/py/SerializationTest.py
+++ b/test/py/SerializationTest.py
@@ -35,6 +35,11 @@
     Xtruct2,
 )
 
+from Recursive.ttypes import RecTree
+from Recursive.ttypes import RecList
+from Recursive.ttypes import CoRec
+from Recursive.ttypes import CoRec2
+from Recursive.ttypes import VectorTest
 from DebugProtoTest.ttypes import CompactProtoTestStruct, Empty
 from thrift.transport import TTransport
 from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol
@@ -285,6 +290,67 @@
         for value in bad_values:
             self.assertRaises(Exception, self._serialize, value)
 
+    def testRecTree(self):
+        """Ensure recursive tree node can be created."""
+        children = []
+        for idx in range(1, 5):
+            node = RecTree(item=idx, children=None)
+            children.append(node)
+
+        parent = RecTree(item=0, children=children)
+        serde_parent = self._deserialize(RecTree, self._serialize(parent))
+        self.assertEquals(0, serde_parent.item)
+        self.assertEquals(4, len(serde_parent.children))
+        for child in serde_parent.children:
+            # Cannot use assertIsInstance in python 2.6?
+            self.assertTrue(isinstance(child, RecTree))
+
+    def _buildLinkedList(self):
+        head = cur = RecList(item=0)
+        for idx in range(1, 5):
+            node = RecList(item=idx)
+            cur.nextitem = node
+            cur = node
+        return head
+
+    def _collapseLinkedList(self, head):
+        out_list = []
+        cur = head
+        while cur is not None:
+            out_list.append(cur.item)
+            cur = cur.nextitem
+        return out_list
+
+    def testRecList(self):
+        """Ensure recursive linked list can be created."""
+        rec_list = self._buildLinkedList()
+        serde_list = self._deserialize(RecList, self._serialize(rec_list))
+        out_list = self._collapseLinkedList(serde_list)
+        self.assertEquals([0, 1, 2, 3, 4], out_list)
+
+    def testCoRec(self):
+        """Ensure co-recursive structures can be created."""
+        item1 = CoRec()
+        item2 = CoRec2()
+
+        item1.other = item2
+        item2.other = item1
+
+        # NOTE [econner724,2017-06-21]: These objects cannot be serialized as serialization
+        # results in an infinite loop. fbthrift also suffers from this
+        # problem.
+
+    def testRecVector(self):
+        """Ensure a list of recursive nodes can be created."""
+        mylist = [self._buildLinkedList(), self._buildLinkedList()]
+        myvec = VectorTest(lister=mylist)
+
+        serde_vec = self._deserialize(VectorTest, self._serialize(myvec))
+        golden_list = [0, 1, 2, 3, 4]
+        for cur_list in serde_vec.lister:
+            out_list = self._collapseLinkedList(cur_list)
+            self.assertEqual(golden_list, out_list)
+
 
 class NormalBinaryTest(AbstractTest):
     protocol_factory = TBinaryProtocol.TBinaryProtocolFactory()
diff --git a/test/py/generate.cmake b/test/py/generate.cmake
index 44c5357..46263c8 100644
--- a/test/py/generate.cmake
+++ b/test/py/generate.cmake
@@ -20,3 +20,10 @@
 generate(${MY_PROJECT_DIR}/test/DebugProtoTest.thrift py:no_utf8strings gen-py-no_utf8strings)
 generate(${MY_PROJECT_DIR}/test/DebugProtoTest.thrift py:dynamic gen-py-dynamic)
 generate(${MY_PROJECT_DIR}/test/DebugProtoTest.thrift py:dynamic,slots gen-py-dynamicslots)
+
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py gen-py-default)
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py:slots gen-py-slots)
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py:old_style gen-py-oldstyle)
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py:no_utf8strings gen-py-no_utf8strings)
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py:dynamic gen-py-dynamic)
+generate(${MY_PROJECT_DIR}/test/Recursive.thrift py:dynamic,slots gen-py-dynamicslots)