THRIFT-4196 Support recursive types in Rust
Client: rs
Patch: Allen George <allen.george@gmail.com>

This closes #1267
diff --git a/.gitignore b/.gitignore
index 88bb9c6..00cf8bb 100644
--- a/.gitignore
+++ b/.gitignore
@@ -273,6 +273,7 @@
 /lib/rs/test/src/base_one.rs
 /lib/rs/test/src/base_two.rs
 /lib/rs/test/src/midlayer.rs
+/lib/rs/test/src/recursive.rs
 /lib/rs/test/src/ultimate.rs
 /lib/rs/*.iml
 /lib/rs/**/*.iml
diff --git a/compiler/cpp/src/thrift/generate/t_rs_generator.cc b/compiler/cpp/src/thrift/generate/t_rs_generator.cc
index 1f1e1d8..89afd7a 100644
--- a/compiler/cpp/src/thrift/generate/t_rs_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_rs_generator.cc
@@ -260,7 +260,8 @@
   void render_struct_sync_read(const string &struct_name, t_struct *tstruct, t_rs_generator::e_struct_type struct_type);
 
   // Write the rust function that deserializes a single type (i.e. i32 etc.) from its wire representation.
-  void render_type_sync_read(const string &type_var, t_type *ttype);
+  // Set `is_boxed` to `true` if the resulting value should be wrapped in a `Box::new(...)`.
+  void render_type_sync_read(const string &type_var, t_type *ttype, bool is_boxed = false);
 
   // Read the wire representation of a list and convert it to its corresponding rust implementation.
   // The deserialized list is stored in `list_variable`.
@@ -353,12 +354,28 @@
 
   string handler_successful_return_struct(t_function* tfunc);
 
+  // Writes the result of `render_rift_error_struct` wrapped in an `Err(thrift::Error(...))`.
   void render_rift_error(
     const string& error_kind,
     const string& error_struct,
     const string& sub_error_kind,
     const string& error_message
   );
+
+  // Write a thrift::Error variant struct. Error structs take the form:
+  // ```
+  // pub struct error_struct {
+  //   kind: sub_error_kind,
+  //   message: error_message,
+  // }
+  // ```
+  // A concrete example is:
+  // ```
+  //  pub struct ApplicationError {
+  //    kind: ApplicationErrorKind::Unknown,
+  //    message: "This is some error message",
+  //  }
+  // ```
   void render_rift_error_struct(
     const string& error_struct,
     const string& sub_error_kind,
@@ -1858,7 +1875,7 @@
 }
 
 // Construct the rust representation of all supported types from the wire.
-void t_rs_generator::render_type_sync_read(const string &type_var, t_type *ttype) {
+void t_rs_generator::render_type_sync_read(const string &type_var, t_type *ttype, bool is_boxed) {
   if (ttype->is_base_type()) {
     t_base_type* tbase_type = (t_base_type*)ttype;
     switch (tbase_type->get_base()) {
@@ -1891,13 +1908,23 @@
       return;
     }
   } else if (ttype->is_typedef()) {
+    // FIXME: not a fan of separate `is_boxed` parameter
+    // This is problematic because it's an optional parameter, and only comes
+    // into play once. The core issue is that I lose an important piece of type
+    // information (whether the type is a fwd ref) by unwrapping the typedef'd
+    // type and making the recursive call using it. I can't modify or wrap the
+    // generated string after the fact because it's written directly into the file,
+    // so I have to pass this parameter along. Going with this approach because it
+    // seems like the lowest-cost option to easily support recursive types.
     t_typedef* ttypedef = (t_typedef*)ttype;
-    render_type_sync_read(type_var, ttypedef->get_type());
+    render_type_sync_read(type_var, ttypedef->get_type(), ttypedef->is_forward_typedef());
     return;
   } else if (ttype->is_enum() || ttype->is_struct() || ttype->is_xception()) {
+    string read_call(to_rust_type(ttype) + "::read_from_in_protocol(i_prot)?");
+    read_call = is_boxed ? "Box::new(" + read_call + ")" : read_call;
     f_gen_
       << indent()
-      << "let " << type_var << " = " <<  to_rust_type(ttype) << "::read_from_in_protocol(i_prot)?;"
+      << "let " << type_var << " = " <<  read_call << ";"
       << endl;
     return;
   } else if (ttype->is_map()) {
@@ -2979,7 +3006,10 @@
       }
     }
   } else if (ttype->is_typedef()) {
-    return rust_namespace(ttype) + ((t_typedef*)ttype)->get_symbolic();
+    t_typedef* ttypedef = (t_typedef*)ttype;
+    string rust_type = rust_namespace(ttype) + ttypedef->get_symbolic();
+    rust_type =  ttypedef->is_forward_typedef() ? "Box<" + rust_type + ">" : rust_type;
+    return rust_type;
   } else if (ttype->is_enum()) {
     return rust_namespace(ttype) + ttype->get_name();
   } else if (ttype->is_struct() || ttype->is_xception()) {
diff --git a/lib/rs/test/Makefile.am b/lib/rs/test/Makefile.am
index 8896940..87208d7 100644
--- a/lib/rs/test/Makefile.am
+++ b/lib/rs/test/Makefile.am
@@ -19,11 +19,12 @@
 
 THRIFT = $(top_builddir)/compiler/cpp/thrift
 
-stubs: thrifts/Base_One.thrift thrifts/Base_Two.thrift thrifts/Midlayer.thrift thrifts/Ultimate.thrift $(THRIFT)
+stubs: thrifts/Base_One.thrift thrifts/Base_Two.thrift thrifts/Midlayer.thrift thrifts/Ultimate.thrift $(top_builddir)/test/Recursive.thrift $(THRIFT)
 	$(THRIFT) -I ./thrifts -out src --gen rs thrifts/Base_One.thrift
 	$(THRIFT) -I ./thrifts -out src --gen rs thrifts/Base_Two.thrift
 	$(THRIFT) -I ./thrifts -out src --gen rs thrifts/Midlayer.thrift
 	$(THRIFT) -I ./thrifts -out src --gen rs thrifts/Ultimate.thrift
+	$(THRIFT) -out src --gen rs $(top_builddir)/test/Recursive.thrift
 
 check: stubs
 	$(CARGO) build
diff --git a/lib/rs/test/src/bin/kitchen_sink_client.rs b/lib/rs/test/src/bin/kitchen_sink_client.rs
index 9738298..fb6ea15 100644
--- a/lib/rs/test/src/bin/kitchen_sink_client.rs
+++ b/lib/rs/test/src/bin/kitchen_sink_client.rs
@@ -21,8 +21,12 @@
 extern crate kitchen_sink;
 extern crate thrift;
 
+use std::convert::Into;
+
 use kitchen_sink::base_two::{TNapkinServiceSyncClient, TRamenServiceSyncClient};
 use kitchen_sink::midlayer::{MealServiceSyncClient, TMealServiceSyncClient};
+use kitchen_sink::recursive;
+use kitchen_sink::recursive::{CoRec, CoRec2, RecList, RecTree, TTestServiceSyncClient};
 use kitchen_sink::ultimate::{FullMealServiceSyncClient, TFullMealServiceSyncClient};
 use thrift::transport::{ReadHalf, TFramedReadTransport, TFramedWriteTransport, TIoChannel,
                         TTcpChannel, WriteHalf};
@@ -47,7 +51,7 @@
         (@arg host: --host +takes_value "Host on which the Thrift test server is located")
         (@arg port: --port +takes_value "Port on which the Thrift test server is listening")
         (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")")
-        (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\")")
+        (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\", \"recursive\")")
     )
             .get_matches();
 
@@ -80,8 +84,9 @@
     o_prot: Box<TOutputProtocol>,
 ) -> thrift::Result<()> {
     match service {
-        "full" => run_full_meal_service(i_prot, o_prot),
-        "part" => run_meal_service(i_prot, o_prot),
+        "full" => exec_full_meal_client(i_prot, o_prot),
+        "part" => exec_meal_client(i_prot, o_prot),
+        "recursive" => exec_recursive_client(i_prot, o_prot),
         _ => Err(thrift::Error::from(format!("unknown service type {}", service)),),
     }
 }
@@ -95,7 +100,7 @@
     c.split()
 }
 
-fn run_meal_service(
+fn exec_meal_client(
     i_prot: Box<TInputProtocol>,
     o_prot: Box<TOutputProtocol>,
 ) -> thrift::Result<()> {
@@ -105,28 +110,155 @@
     // this is because the MealService struct does not contain the appropriate service marker
 
     // only the following three calls work
-    execute_call("part", "ramen", || client.ramen(50))?;
-    execute_call("part", "meal", || client.meal())?;
-    execute_call("part", "napkin", || client.napkin())?;
+    execute_call("part", "ramen", || client.ramen(50))
+        .map(|_| ())?;
+    execute_call("part", "meal", || client.meal())
+        .map(|_| ())?;
+    execute_call("part", "napkin", || client.napkin())
+        .map(|_| ())?;
 
     Ok(())
 }
 
-fn run_full_meal_service(
+fn exec_full_meal_client(
     i_prot: Box<TInputProtocol>,
     o_prot: Box<TOutputProtocol>,
 ) -> thrift::Result<()> {
     let mut client = FullMealServiceSyncClient::new(i_prot, o_prot);
 
-    execute_call("full", "ramen", || client.ramen(100))?;
-    execute_call("full", "meal", || client.meal())?;
-    execute_call("full", "napkin", || client.napkin())?;
-    execute_call("full", "full meal", || client.full_meal())?;
+    execute_call("full", "ramen", || client.ramen(100))
+        .map(|_| ())?;
+    execute_call("full", "meal", || client.meal())
+        .map(|_| ())?;
+    execute_call("full", "napkin", || client.napkin())
+        .map(|_| ())?;
+    execute_call("full", "full meal", || client.full_meal())
+        .map(|_| ())?;
 
     Ok(())
 }
 
-fn execute_call<F, R>(service_type: &str, call_name: &str, mut f: F) -> thrift::Result<()>
+fn exec_recursive_client(
+    i_prot: Box<TInputProtocol>,
+    o_prot: Box<TOutputProtocol>,
+) -> thrift::Result<()> {
+    let mut client = recursive::TestServiceSyncClient::new(i_prot, o_prot);
+
+    let tree = RecTree {
+        children: Some(
+            vec![
+                Box::new(
+                    RecTree {
+                        children: Some(
+                            vec![
+                                Box::new(
+                                    RecTree {
+                                        children: None,
+                                        item: Some(3),
+                                    },
+                                ),
+                                Box::new(
+                                    RecTree {
+                                        children: None,
+                                        item: Some(4),
+                                    },
+                                ),
+                            ],
+                        ),
+                        item: Some(2),
+                    },
+                ),
+            ],
+        ),
+        item: Some(1),
+    };
+
+    let expected_tree = RecTree {
+        children: Some(
+            vec![
+                Box::new(
+                    RecTree {
+                        children: Some(
+                            vec![
+                                Box::new(
+                                    RecTree {
+                                        children: Some(Vec::new()), // remote returns an empty list
+                                        item: Some(3),
+                                    },
+                                ),
+                                Box::new(
+                                    RecTree {
+                                        children: Some(Vec::new()), // remote returns an empty list
+                                        item: Some(4),
+                                    },
+                                ),
+                            ],
+                        ),
+                        item: Some(2),
+                    },
+                ),
+            ],
+        ),
+        item: Some(1),
+    };
+
+    let returned_tree = execute_call("recursive", "echo_tree", || client.echo_tree(tree.clone()))?;
+    if returned_tree != expected_tree {
+        return Err(
+            format!(
+                "mismatched recursive tree {:?} {:?}",
+                expected_tree,
+                returned_tree
+            )
+                    .into(),
+        );
+    }
+
+    let list = RecList {
+        nextitem: Some(
+            Box::new(
+                RecList {
+                    nextitem: Some(
+                        Box::new(
+                            RecList {
+                                nextitem: None,
+                                item: Some(3),
+                            },
+                        ),
+                    ),
+                    item: Some(2),
+                },
+            ),
+        ),
+        item: Some(1),
+    };
+    let returned_list = execute_call("recursive", "echo_list", || client.echo_list(list.clone()))?;
+    if returned_list != list {
+        return Err(format!("mismatched recursive list {:?} {:?}", list, returned_list).into(),);
+    }
+
+    let co_rec = CoRec {
+        other: Some(
+            Box::new(
+                CoRec2 {
+                    other: Some(CoRec { other: Some(Box::new(CoRec2 { other: None })) }),
+                },
+            ),
+        ),
+    };
+    let returned_co_rec = execute_call(
+        "recursive",
+        "echo_co_rec",
+        || client.echo_co_rec(co_rec.clone()),
+    )?;
+    if returned_co_rec != co_rec {
+        return Err(format!("mismatched co_rec {:?} {:?}", co_rec, returned_co_rec).into(),);
+    }
+
+    Ok(())
+}
+
+fn execute_call<F, R>(service_type: &str, call_name: &str, mut f: F) -> thrift::Result<R>
 where
     F: FnMut() -> thrift::Result<R>,
 {
@@ -144,5 +276,5 @@
         }
     }
 
-    res.map(|_| ())
+    res
 }
diff --git a/lib/rs/test/src/bin/kitchen_sink_server.rs b/lib/rs/test/src/bin/kitchen_sink_server.rs
index 19112cd..15ceb29 100644
--- a/lib/rs/test/src/bin/kitchen_sink_server.rs
+++ b/lib/rs/test/src/bin/kitchen_sink_server.rs
@@ -24,6 +24,7 @@
 use kitchen_sink::base_one::Noodle;
 use kitchen_sink::base_two::{Napkin, NapkinServiceSyncHandler, Ramen, RamenServiceSyncHandler};
 use kitchen_sink::midlayer::{Dessert, Meal, MealServiceSyncHandler, MealServiceSyncProcessor};
+use kitchen_sink::recursive;
 use kitchen_sink::ultimate::{Drink, FullMeal, FullMealAndDrinks,
                              FullMealAndDrinksServiceSyncProcessor, FullMealServiceSyncHandler};
 use kitchen_sink::ultimate::FullMealAndDrinksServiceSyncHandler;
@@ -52,7 +53,7 @@
         (about: "Thrift Rust kitchen sink test server")
         (@arg port: --port +takes_value "port on which the test server listens")
         (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")")
-        (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\")")
+        (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\", \"recursive\")")
     )
             .get_matches();
 
@@ -111,6 +112,15 @@
                 o_protocol_factory,
             )
         }
+        "recursive" => {
+            run_recursive_server(
+                &listen_address,
+                r_transport_factory,
+                i_protocol_factory,
+                w_transport_factory,
+                o_protocol_factory,
+            )
+        }
         unknown => Err(format!("unsupported service type {}", unknown).into()),
     }
 }
@@ -248,3 +258,47 @@
 fn napkin() -> Napkin {
     Napkin {}
 }
+
+fn run_recursive_server<RTF, IPF, WTF, OPF>(
+    listen_address: &str,
+    r_transport_factory: RTF,
+    i_protocol_factory: IPF,
+    w_transport_factory: WTF,
+    o_protocol_factory: OPF,
+) -> thrift::Result<()>
+where
+    RTF: TReadTransportFactory + 'static,
+    IPF: TInputProtocolFactory + 'static,
+    WTF: TWriteTransportFactory + 'static,
+    OPF: TOutputProtocolFactory + 'static,
+{
+    let processor = recursive::TestServiceSyncProcessor::new(RecursiveTestServerHandler {});
+    let mut server = TServer::new(
+        r_transport_factory,
+        i_protocol_factory,
+        w_transport_factory,
+        o_protocol_factory,
+        processor,
+        1,
+    );
+
+    server.listen(listen_address)
+}
+
+struct RecursiveTestServerHandler;
+impl recursive::TestServiceSyncHandler for RecursiveTestServerHandler {
+    fn handle_echo_tree(&self, tree: recursive::RecTree) -> thrift::Result<recursive::RecTree> {
+        println!("{:?}", tree);
+        Ok(tree)
+    }
+
+    fn handle_echo_list(&self, lst: recursive::RecList) -> thrift::Result<recursive::RecList> {
+        println!("{:?}", lst);
+        Ok(lst)
+    }
+
+    fn handle_echo_co_rec(&self, item: recursive::CoRec) -> thrift::Result<recursive::CoRec> {
+        println!("{:?}", item);
+        Ok(item)
+    }
+}
diff --git a/lib/rs/test/src/lib.rs b/lib/rs/test/src/lib.rs
index 53f4873..e5e176e 100644
--- a/lib/rs/test/src/lib.rs
+++ b/lib/rs/test/src/lib.rs
@@ -23,6 +23,7 @@
 pub mod base_two;
 pub mod midlayer;
 pub mod ultimate;
+pub mod recursive;
 
 #[cfg(test)]
 mod tests {