THRIFT-2945 Add Rust support
Client: Rust
Patch: Allen George <allen.george@gmail.com>
This closes #1147
diff --git a/.gitignore b/.gitignore
index 9d2463e..0a98a13 100644
--- a/.gitignore
+++ b/.gitignore
@@ -261,6 +261,17 @@
/lib/erl/logs/
/lib/go/test/gopath/
/lib/go/test/ThriftTest.thrift
+/lib/rs/target/
+/lib/rs/Cargo.lock
+/lib/rs/test/Cargo.lock
+/lib/rs/test/target/
+/lib/rs/test/bin/
+/lib/rs/test/src/base_one.rs
+/lib/rs/test/src/base_two.rs
+/lib/rs/test/src/midlayer.rs
+/lib/rs/test/src/ultimate.rs
+/lib/rs/*.iml
+/lib/rs/**/*.iml
/libtool
/ltmain.sh
/missing
@@ -300,6 +311,12 @@
/test/netcore/**/obj
/test/netcore/**/gen-*
/test/netcore/Thrift
+/test/rs/Cargo.lock
+/test/rs/src/thrift_test.rs
+/test/rs/bin/
+/test/rs/target/
+/test/rs/*.iml
+/test/rs/**/*.iml
/tutorial/cpp/TutorialClient
/tutorial/cpp/TutorialServer
/tutorial/c_glib/tutorial_client
@@ -338,4 +355,10 @@
/tutorial/netcore/**/obj
/tutorial/netcore/**/gen-*
/tutorial/netcore/Thrift
+/tutorial/rs/*.iml
+/tutorial/rs/src/shared.rs
+/tutorial/rs/src/tutorial.rs
+/tutorial/rs/bin
+/tutorial/rs/target
+/tutorial/rs/Cargo.lock
/ylwrap
diff --git a/Makefile.am b/Makefile.am
index ed58265..89a0adc 100755
--- a/Makefile.am
+++ b/Makefile.am
@@ -54,7 +54,7 @@
space := $(empty) $(empty)
comma := ,
-CROSS_LANGS = @MAYBE_CPP@ @MAYBE_C_GLIB@ @MAYBE_D@ @MAYBE_JAVA@ @MAYBE_CSHARP@ @MAYBE_PYTHON@ @MAYBE_PY3@ @MAYBE_RUBY@ @MAYBE_HASKELL@ @MAYBE_PERL@ @MAYBE_PHP@ @MAYBE_GO@ @MAYBE_NODEJS@ @MAYBE_DART@ @MAYBE_ERLANG@ @MAYBE_LUA@
+CROSS_LANGS = @MAYBE_CPP@ @MAYBE_C_GLIB@ @MAYBE_D@ @MAYBE_JAVA@ @MAYBE_CSHARP@ @MAYBE_PYTHON@ @MAYBE_PY3@ @MAYBE_RUBY@ @MAYBE_HASKELL@ @MAYBE_PERL@ @MAYBE_PHP@ @MAYBE_GO@ @MAYBE_NODEJS@ @MAYBE_DART@ @MAYBE_ERLANG@ @MAYBE_LUA@ @MAYBE_RS@
CROSS_LANGS_COMMA_SEPARATED = $(subst $(space),$(comma),$(CROSS_LANGS))
if WITH_PY3
diff --git a/compiler/cpp/CMakeLists.txt b/compiler/cpp/CMakeLists.txt
index 9f7585d..8e861e4 100644
--- a/compiler/cpp/CMakeLists.txt
+++ b/compiler/cpp/CMakeLists.txt
@@ -101,6 +101,7 @@
THRIFT_ADD_COMPILER(d "Enable compiler for D" ON)
THRIFT_ADD_COMPILER(lua "Enable compiler for Lua" ON)
THRIFT_ADD_COMPILER(gv "Enable compiler for GraphViz" ON)
+THRIFT_ADD_COMPILER(rs "Enable compiler for Rust" ON)
THRIFT_ADD_COMPILER(xml "Enable compiler for XML" ON)
# Thrift is looking for include files in the src directory
diff --git a/compiler/cpp/Makefile.am b/compiler/cpp/Makefile.am
index 5d424b4..5082033 100644
--- a/compiler/cpp/Makefile.am
+++ b/compiler/cpp/Makefile.am
@@ -107,7 +107,8 @@
src/thrift/generate/t_go_generator.cc \
src/thrift/generate/t_gv_generator.cc \
src/thrift/generate/t_d_generator.cc \
- src/thrift/generate/t_lua_generator.cc
+ src/thrift/generate/t_lua_generator.cc \
+ src/thrift/generate/t_rs_generator.cc
thrift_CPPFLAGS = -I$(srcdir)/src
thrift_CXXFLAGS = -Wall -Wextra -pedantic
diff --git a/compiler/cpp/compiler.vcxproj b/compiler/cpp/compiler.vcxproj
index 1e86360..4b03253 100644
--- a/compiler/cpp/compiler.vcxproj
+++ b/compiler/cpp/compiler.vcxproj
@@ -79,6 +79,7 @@
<ClCompile Include="src\thrift\generate\t_php_generator.cc" />
<ClCompile Include="src\thrift\generate\t_py_generator.cc" />
<ClCompile Include="src\thrift\generate\t_rb_generator.cc" />
+ <ClCompile Include="src\thrift\generate\t_rs_generator.cc" />
<ClCompile Include="src\thrift\generate\t_st_generator.cc" />
<ClCompile Include="src\thrift\generate\t_swift_generator.cc" />
<ClCompile Include="src\thrift\generate\t_xml_generator.cc" />
diff --git a/compiler/cpp/compiler.vcxproj.filters b/compiler/cpp/compiler.vcxproj.filters
index 9b14bbf..b96865b 100644
--- a/compiler/cpp/compiler.vcxproj.filters
+++ b/compiler/cpp/compiler.vcxproj.filters
@@ -161,6 +161,9 @@
<ClCompile Include="src\generate\t_rb_generator.cc">
<Filter>generate</Filter>
</ClCompile>
+ <ClCompile Include="src\generate\t_rs_generator.cc">
+ <Filter>generate</Filter>
+ </ClCompile>
<ClCompile Include="src\generate\t_st_generator.cc">
<Filter>generate</Filter>
</ClCompile>
@@ -193,4 +196,4 @@
<None Include="src\thriftl.ll" />
<None Include="src\thrifty.yy" />
</ItemGroup>
-</Project>
\ No newline at end of file
+</Project>
diff --git a/compiler/cpp/src/thrift/generate/t_rs_generator.cc b/compiler/cpp/src/thrift/generate/t_rs_generator.cc
new file mode 100644
index 0000000..5cd304b
--- /dev/null
+++ b/compiler/cpp/src/thrift/generate/t_rs_generator.cc
@@ -0,0 +1,3164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <string>
+#include <fstream>
+#include <iostream>
+
+#include "thrift/platform.h"
+#include "thrift/generate/t_generator.h"
+
+using std::map;
+using std::ofstream;
+using std::ostringstream;
+using std::string;
+using std::vector;
+using std::set;
+
+static const string endl = "\n"; // avoid ostream << std::endl flushes
+
+static const string SERVICE_RESULT_VARIABLE = "result_value";
+static const string RESULT_STRUCT_SUFFIX = "Result";
+static const string RUST_RESERVED_WORDS[] = {
+ "abstract", "alignof", "as", "become",
+ "box", "break", "const", "continue",
+ "crate", "do", "else", "enum",
+ "extern", "false", "final", "fn",
+ "for", "if", "impl", "in",
+ "let", "loop", "macro", "match",
+ "mod", "move", "mut", "offsetof",
+ "override", "priv", "proc", "pub",
+ "pure", "ref", "return", "Self",
+ "self", "sizeof", "static", "struct",
+ "super", "trait", "true", "type",
+ "typeof", "unsafe", "unsized", "use",
+ "virtual", "where", "while", "yield"
+};
+const set<string> RUST_RESERVED_WORDS_SET(
+ RUST_RESERVED_WORDS,
+ RUST_RESERVED_WORDS + sizeof(RUST_RESERVED_WORDS)/sizeof(RUST_RESERVED_WORDS[0])
+);
+
+// FIXME: extract common TMessageIdentifier function
+// FIXME: have to_rust_type deal with Option
+
+class t_rs_generator : public t_generator {
+public:
+ t_rs_generator(
+ t_program* program,
+ const std::map<std::string, std::string>&,
+ const std::string&
+ ) : t_generator(program) {
+ gen_dir_ = get_out_dir();
+ }
+
+ /**
+ * Init and close methods
+ */
+
+ void init_generator();
+ void close_generator();
+
+ /**
+ * Program-level generation functions
+ */
+
+ void generate_typedef(t_typedef* ttypedef);
+ void generate_enum(t_enum* tenum);
+ void generate_const(t_const* tconst);
+ void generate_struct(t_struct* tstruct);
+ void generate_xception(t_struct* txception);
+ void generate_service(t_service* tservice);
+
+private:
+ // struct type
+ // T_REGULAR: user-defined struct in the IDL
+ // T_ARGS: struct used to hold all service-call parameters
+ // T_RESULT: struct used to hold all service-call returns and exceptions
+ // T_EXCEPTION: user-defined exception in the IDL
+ enum e_struct_type { T_REGULAR, T_ARGS, T_RESULT, T_EXCEPTION };
+
+ // Directory to which generated code is written.
+ string gen_dir_;
+
+ // File to which generated code is written.
+ std::ofstream f_gen_;
+
+ // Write the common compiler attributes and module includes to the top of the auto-generated file.
+ void render_attributes_and_includes();
+
+ // Create the closure of Rust modules referenced by this service.
+ void compute_service_referenced_modules(t_service *tservice, set<string> &referenced_modules);
+
+ // Write the rust representation of an enum.
+ void render_enum_definition(t_enum* tenum, const string& enum_name);
+
+ // Write the impl blocks associated with the traits necessary to convert an enum to/from an i32.
+ void render_enum_conversion(t_enum* tenum, const string& enum_name);
+
+ // Write the impl block associated with the rust representation of an enum. This includes methods
+ // to write the enum to a protocol, read it from a protocol, etc.
+ void render_enum_impl(const string& enum_name);
+
+ // Write a simple rust const value (ie. `pub const FOO: foo...`).
+ void render_const_value(const string& name, t_type* ttype, t_const_value* tvalue);
+
+ // Write a constant list, set, map or struct. These constants require allocation and cannot be defined
+ // using a 'pub const'. As a result, I create a holder struct with a single `const_value` method that
+ // returns the initialized instance.
+ void render_const_value_holder(const string& name, t_type* ttype, t_const_value* tvalue);
+
+ // Write the actual const value - the right side of a const definition.
+ void render_const_value(t_type* ttype, t_const_value* tvalue);
+
+ // Write a const struct (returned from `const_value` method).
+ void render_const_struct(t_type* ttype, t_const_value* tvalue);
+
+ // Write a const list (returned from `const_value` method).
+ void render_const_list(t_type* ttype, t_const_value* tvalue);
+
+ // Write a const set (returned from `const_value` method).
+ void render_const_set(t_type* ttype, t_const_value* tvalue);
+
+ // Write a const map (returned from `const_value` method).
+ void render_const_map(t_type* ttype, t_const_value* tvalue);
+
+ // Write the code to insert constant values into a rust vec or set. The
+ // `insert_function` is the rust function that we'll use to insert the elements.
+ void render_container_const_value(
+ const string& insert_function,
+ t_type* ttype,
+ t_const_value* tvalue
+ );
+
+ // Write the rust representation of a thrift struct to the generated file. Set `struct_type` to `T_ARGS`
+ // if rendering the struct used to pack arguments for a service call. When `struct_type` is `T_ARGS` the
+ // struct and its members have module visibility, and all fields are required. When `struct_type` is
+ // anything else the struct and its members have public visibility and fields have the visibility set
+ // in their definition.
+ void render_struct(const string& struct_name, t_struct* tstruct, t_rs_generator::e_struct_type struct_type);
+
+ // Write the comment block preceding a type definition (and implementation).
+ void render_type_comment(const string& struct_name);
+
+ // Write the rust representation of a thrift struct. Supports argument structs, result structs,
+ // user-defined structs and exception structs. The exact struct type to be generated is controlled
+ // by the `struct_type` parameter, which (among other things) modifies the visibility of the
+ // written struct and members, controls which trait implementations are generated.
+ void render_struct_definition(
+ const string& struct_name,
+ t_struct* tstruct,
+ t_rs_generator::e_struct_type struct_type
+ );
+
+ // Writes the impl block associated with the rust representation of a struct. At minimum this
+ // contains the methods to read from a protocol and write to a protocol. Additional methods may
+ // be generated depending on `struct_type`.
+ void render_struct_impl(
+ const string& struct_name,
+ t_struct* tstruct,
+ t_rs_generator::e_struct_type struct_type
+ );
+
+ // Generate a `fn new(...)` for a struct with name `struct_name` and type `t_struct`. The auto-generated
+ // code may include generic type parameters to make the constructor more ergonomic. `struct_type` controls
+ // the visibility of the generated constructor.
+ void render_struct_constructor(
+ const string& struct_name,
+ t_struct* tstruct,
+ t_rs_generator::e_struct_type struct_type
+ );
+
+ // Write the `ok_or` method added to all Thrift service call result structs. You can use this method
+ // to convert a struct into a `Result` and use it in a `try!` or combinator chain.
+ void render_result_struct_to_result_method(t_struct* tstruct);
+
+ // Write the implementations for the `Error` and `Debug` traits. These traits are necessary for a
+ // user-defined exception to be properly handled as Rust errors.
+ void render_exception_struct_error_trait_impls(const string& struct_name, t_struct* tstruct);
+
+ // Write the implementations for the `Default`. This trait allows you to specify only the fields you want
+ // and use `..Default::default()` to fill in the rest.
+ void render_struct_default_trait_impl(const string& struct_name, t_struct* tstruct);
+
+ // Write the function that serializes a struct to its wire representation. If `struct_type` is `T_ARGS`
+ // then all fields are considered "required", if not, the default optionality is used.
+ void render_struct_sync_write(t_struct *tstruct, t_rs_generator::e_struct_type struct_type);
+
+ // Helper function that serializes a single struct field to its wire representation. Unpacks the
+ // variable (since it may be optional) and serializes according to the optionality rules required by `req`.
+ // Variables in auto-generated code are passed by reference. Since this function may be called in
+ // contexts where the variable is *already* a reference you can set `field_var_is_ref` to `true` to avoid
+ // generating an extra, unnecessary `&` that the compiler will have to automatically dereference.
+ void render_struct_field_sync_write(
+ const string &field_var,
+ bool field_var_is_ref,
+ t_field *tfield,
+ t_field::e_req req);
+
+ // Write the rust function that serializes a single type (i.e. a i32 etc.) to its wire representation.
+ // Variables in auto-generated code are passed by reference. Since this function may be called in
+ // contexts where the variable is *already* a reference you can set `type_var_is_ref` to `true` to avoid
+ // generating an extra, unnecessary `&` that the compiler will have to automatically dereference.
+ void render_type_sync_write(const string &type_var, bool type_var_is_ref, t_type *ttype);
+
+ // Write a list to the output protocol. `list_variable` is the variable containing the list
+ // that will be written to the output protocol.
+ // Variables in auto-generated code are passed by reference. Since this function may be called in
+ // contexts where the variable is *already* a reference you can set `list_var_is_ref` to `true` to avoid
+ // generating an extra, unnecessary `&` that the compiler will have to automatically dereference.
+ void render_list_sync_write(const string &list_var, bool list_var_is_ref, t_list *tlist);
+
+ // Write a set to the output protocol. `set_variable` is the variable containing the set that will
+ // be written to the output protocol.
+ // Variables in auto-generated code are passed by reference. Since this function may be called in
+ // contexts where the variable is *already* a reference you can set `set_var_is_ref` to `true` to avoid
+ // generating an extra, unnecessary `&` that the compiler will have to automatically dereference.
+ void render_set_sync_write(const string &set_var, bool set_var_is_ref, t_set *tset);
+
+ // Write a map to the output protocol. `map_variable` is the variable containing the map that will
+ // be written to the output protocol.
+ // Variables in auto-generated code are passed by reference. Since this function may be called in
+ // contexts where the variable is *already* a reference you can set `map_var_is_ref` to `true` to avoid
+ // generating an extra, unnecessary `&` that the compiler will have to automatically dereference.
+ void render_map_sync_write(const string &map_var, bool map_var_is_ref, t_map *tset);
+
+ // Return `true` if we need to dereference ths type when writing an element from a container.
+ // Iterations on rust containers are performed as follows: `for v in &values { ... }`
+ // where `v` has type `&RUST_TYPE` All defined functions take primitives by value, so, if the
+ // rendered code is calling such a function it has to dereference `v`.
+ bool needs_deref_on_container_write(t_type* ttype);
+
+ // Write the code to read bytes from the wire into the given `t_struct`. `struct_name` is the
+ // actual Rust name of the `t_struct`. If `struct_type` is `T_ARGS` then all struct fields are
+ // necessary. Otherwise, the field's default optionality is used.
+ void render_struct_sync_read(const string &struct_name, t_struct *tstruct, t_rs_generator::e_struct_type struct_type);
+
+ // Write the rust function that deserializes a single type (i.e. i32 etc.) from its wire representation.
+ void render_type_sync_read(const string &type_var, t_type *ttype);
+
+ // Read the wire representation of a list and convert it to its corresponding rust implementation.
+ // The deserialized list is stored in `list_variable`.
+ void render_list_sync_read(t_list *tlist, const string &list_variable);
+
+ // Read the wire representation of a set and convert it to its corresponding rust implementation.
+ // The deserialized set is stored in `set_variable`.
+ void render_set_sync_read(t_set *tset, const string &set_variable);
+
+ // Read the wire representation of a map and convert it to its corresponding rust implementation.
+ // The deserialized map is stored in `map_variable`.
+ void render_map_sync_read(t_map *tmap, const string &map_variable);
+
+ // Return a temporary variable used to store values when deserializing nested containers.
+ string struct_field_read_temp_variable(t_field* tfield);
+
+ // Top-level function that calls the various render functions necessary to write the rust representation
+ // of a thrift union (i.e. an enum).
+ void render_union(t_struct* tstruct);
+
+ // Write the enum corresponding to the Thrift union.
+ void render_union_definition(const string& union_name, t_struct* tstruct);
+
+ // Write the enum impl (with read/write functions) for the Thrift union.
+ void render_union_impl(const string& union_name, t_struct* tstruct);
+
+ // Write the `ENUM::write_to_out_protocol` function.
+ void render_union_sync_write(const string &union_name, t_struct *tstruct);
+
+ // Write the `ENUM::read_from_in_protocol` function.
+ void render_union_sync_read(const string &union_name, t_struct *tstruct);
+
+ // Top-level function that calls the various render functions necessary to write the rust representation
+ // of a Thrift client.
+ void render_sync_client(t_service* tservice);
+
+ // Write the trait with the service-call methods for `tservice`.
+ void render_sync_client_trait(t_service *tservice);
+
+ // Write the trait to be implemented by the client impl if end users can use it to make service calls.
+ void render_sync_client_marker_trait(t_service *tservice);
+
+ // Write the code to create the Thrift service sync client struct and its matching 'impl' block.
+ void render_sync_client_definition_and_impl(const string& client_impl_name);
+
+ // Write the code to create the `SyncClient::new` functions as well as any other functions
+ // callers would like to use on the Thrift service sync client.
+ void render_sync_client_lifecycle_functions(const string& client_struct);
+
+ // Write the code to create the impl block for the `TThriftClient` trait. Since generated
+ // Rust Thrift clients perform all their operations using methods defined in this trait, we
+ // have to implement it for the client structs.
+ void render_sync_client_tthriftclient_impl(const string &client_impl_name);
+
+ // Write the marker traits for any service(s) being extended, including the one for the current
+ // service itself (i.e. `tservice`)
+ void render_sync_client_marker_trait_impls(t_service *tservice, const string &impl_struct_name);
+
+ // Generate a list of all the traits this Thrift client struct extends.
+ string sync_client_marker_traits_for_extension(t_service *tservice);
+
+ // Top-level function that writes the code to make the Thrift service calls.
+ void render_sync_client_process_impl(t_service* tservice);
+
+ // Write the actual function that calls out to the remote service and processes its response.
+ void render_sync_send_recv_wrapper(t_function* tfunc);
+
+ // Write the `send` functionality for a Thrift service call represented by a `t_service->t_function`.
+ void render_sync_send(t_function* tfunc);
+
+ // Write the `recv` functionality for a Thrift service call represented by a `t_service->t_function`.
+ // This method is only rendered if the function is *not* oneway.
+ void render_sync_recv(t_function* tfunc);
+
+ void render_sync_processor(t_service *tservice);
+
+ void render_sync_handler_trait(t_service *tservice);
+ void render_sync_processor_definition_and_impl(t_service *tservice);
+ void render_sync_process_delegation_functions(t_service *tservice);
+ void render_sync_process_function(t_function *tfunc, const string &handler_type);
+ void render_process_match_statements(t_service* tservice);
+ void render_sync_handler_succeeded(t_function *tfunc);
+ void render_sync_handler_failed(t_function *tfunc);
+ void render_sync_handler_failed_user_exception_branch(t_function *tfunc);
+ void render_sync_handler_failed_application_exception_branch(t_function *tfunc, const string &app_err_var);
+ void render_sync_handler_failed_default_exception_branch(t_function *tfunc);
+ void render_sync_handler_send_exception_response(t_function *tfunc, const string &err_var);
+ void render_service_call_structs(t_service* tservice);
+ void render_result_value_struct(t_function* tfunc);
+
+ string handler_successful_return_struct(t_function* tfunc);
+
+ void render_rift_error(
+ const string& error_kind,
+ const string& error_struct,
+ const string& sub_error_kind,
+ const string& error_message
+ );
+ void render_rift_error_struct(
+ const string& error_struct,
+ const string& sub_error_kind,
+ const string& error_message
+ );
+
+ // Return a string containing all the unpacked service call args given a service call function
+ // `t_function`. Prepends the args with `&mut self` and includes the arg types in the returned string,
+ // for example: `fn foo(&mut self, field_0: String)`.
+ string rust_sync_service_call_declaration(t_function* tfunc);
+
+ // Return a string containing all the unpacked service call args given a service call function
+ // `t_function`. Only includes the arg names, each of which is prefixed with the optional prefix
+ // `field_prefix`, for example: `self.field_0`.
+ string rust_sync_service_call_invocation(t_function* tfunc, const string& field_prefix = "");
+
+ // Return a string containing all fields in the struct `tstruct` for use in a function declaration.
+ // Each field is followed by its type, for example: `field_0: String`.
+ string struct_to_declaration(t_struct* tstruct, t_rs_generator::e_struct_type struct_type);
+
+ // Return a string containing all fields in the struct `tstruct` for use in a function call,
+ // for example: `field_0: String`.
+ string struct_to_invocation(t_struct* tstruct, const string& field_prefix = "");
+
+ // Write the documentation for a struct, service-call or other documentation-annotated element.
+ void render_rustdoc(t_doc* tdoc);
+
+ // Return a string representing the rust type given a `t_type`.
+ string to_rust_type(t_type* ttype, bool ordered_float = true);
+
+ // Return a string representing the rift `protocol::TType` given a `t_type`.
+ string to_rust_field_type_enum(t_type* ttype);
+
+ // Return the default value to be used when initializing a struct field which has `OPT_IN_REQ_OUT`
+ // optionality.
+ string opt_in_req_out_value(t_type* ttype);
+
+ // Return `true` if we can write a const of the form `pub const FOO: ...`.
+ bool can_generate_simple_const(t_type* ttype);
+
+ // Return `true` if we cannot write a standard Rust constant (because the type needs some allocation).
+ bool can_generate_const_holder(t_type* ttype);
+
+ // Return `true` if this type is a void, and should be represented by the rust `()` type.
+ bool is_void(t_type* ttype);
+
+ t_field::e_req actual_field_req(t_field* tfield, t_rs_generator::e_struct_type struct_type);
+
+ // Return `true` if this `t_field::e_req` is either `t_field::T_OPTIONAL` or `t_field::T_OPT_IN_REQ_OUT`
+ // and needs to be wrapped by an `Option<TYPE_NAME>`, `false` otherwise.
+ bool is_optional(t_field::e_req req);
+
+ // Return `true` if the service call has arguments, `false` otherwise.
+ bool has_args(t_function* tfunc);
+
+ // Return `true` if a service call has non-`()` arguments, `false` otherwise.
+ bool has_non_void_args(t_function* tfunc);
+
+ // Return `pub ` (notice trailing whitespace!) if the struct should be public, `` (empty string) otherwise.
+ string visibility_qualifier(t_rs_generator::e_struct_type struct_type);
+
+ // Returns the namespace prefix for a given Thrift service. If the type is defined in the presently-computed
+ // Thrift program, then an empty string is returned.
+ string rust_namespace(t_service* tservice);
+
+ // Returns the namespace prefix for a given Thrift type. If the type is defined in the presently-computed
+ // Thrift program, then an empty string is returned.
+ string rust_namespace(t_type* ttype);
+
+ // Returns the camel-cased name for a Rust struct type. Handles the case where `tstruct->get_name()` is
+ // a reserved word.
+ string rust_struct_name(t_struct* tstruct);
+
+ // Returns the snake-cased name for a Rust field or local variable. Handles the case where
+ // `tfield->get_name()` is a reserved word.
+ string rust_field_name(t_field* tstruct);
+
+ // Returns the camel-cased name for a Rust union type. Handles the case where `tstruct->get_name()` is
+ // a reserved word.
+ string rust_union_field_name(t_field* tstruct);
+
+ // Converts any variable name into a 'safe' variant that does not clash with any Rust reserved keywords.
+ string rust_safe_name(const string& name);
+
+ // Return `true` if the name is a reserved Rust keyword, `false` otherwise.
+ bool is_reserved(const string& name);
+
+ // Return the name of the function that users will invoke to make outgoing service calls.
+ string service_call_client_function_name(t_function* tfunc);
+
+ // Return the name of the function that users will have to implement to handle incoming service calls.
+ string service_call_handler_function_name(t_function* tfunc);
+
+ // Return the name of the struct used to pack the return value
+ // and user-defined exceptions for the thrift service call.
+ string service_call_result_struct_name(t_function* tfunc);
+
+ string rust_sync_client_marker_trait_name(t_service* tservice);
+
+ // Return the trait name for the sync service client given a `t_service`.
+ string rust_sync_client_trait_name(t_service* tservice);
+
+ // Return the name for the sync service client struct given a `t_service`.
+ string rust_sync_client_impl_name(t_service* tservice);
+
+ // Return the trait name that users will have to implement for the server half of a Thrift service.
+ string rust_sync_handler_trait_name(t_service* tservice);
+
+ // Return the struct name for the server half of a Thrift service.
+ string rust_sync_processor_name(t_service* tservice);
+
+ // Return the struct name for the struct that contains all the service-call implementations for
+ // the server half of a Thrift service.
+ string rust_sync_processor_impl_name(t_service *tservice);
+
+ // Properly uppercase names for use in Rust.
+ string rust_upper_case(const string& name);
+
+ // Snake-case field, parameter and function names and make them Rust friendly.
+ string rust_snake_case(const string& name);
+
+ // Camel-case type/variant names and make them Rust friendly.
+ string rust_camel_case(const string& name);
+
+ // Replace all instances of `search_string` with `replace_string` in `target`.
+ void string_replace(string& target, const string& search_string, const string& replace_string);
+};
+
+void t_rs_generator::init_generator() {
+ // make output directory for this thrift program
+ MKDIR(gen_dir_.c_str());
+
+ // create the file into which we're going to write the generated code
+ string f_gen_name = gen_dir_ + "/" + rust_snake_case(get_program()->get_name()) + ".rs";
+ f_gen_.open(f_gen_name.c_str());
+
+ // header comment
+ f_gen_ << "// " << autogen_summary() << endl;
+ f_gen_ << "// DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING" << endl;
+ f_gen_ << endl;
+
+ render_attributes_and_includes();
+}
+
+void t_rs_generator::render_attributes_and_includes() {
+ // turn off some compiler/clippy warnings
+
+ // code always includes BTreeMap/BTreeSet/OrderedFloat
+ f_gen_ << "#![allow(unused_imports)]" << endl;
+ // constructors take *all* struct parameters, which can trigger the "too many arguments" warning
+ // some auto-gen'd types can be deeply nested. clippy recommends factoring them out which is hard to autogen
+ f_gen_ << "#![cfg_attr(feature = \"cargo-clippy\", allow(too_many_arguments, type_complexity))]" << endl;
+ f_gen_ << endl;
+
+ // add standard includes
+ f_gen_ << "extern crate ordered_float;" << endl;
+ f_gen_ << "extern crate thrift;" << endl;
+ f_gen_ << "extern crate try_from;" << endl;
+ f_gen_ << endl;
+ f_gen_ << "use ordered_float::OrderedFloat;" << endl;
+ f_gen_ << "use std::cell::RefCell;" << endl;
+ f_gen_ << "use std::collections::{BTreeMap, BTreeSet};" << endl;
+ f_gen_ << "use std::convert::From;" << endl;
+ f_gen_ << "use std::default::Default;" << endl;
+ f_gen_ << "use std::error::Error;" << endl;
+ f_gen_ << "use std::fmt;" << endl;
+ f_gen_ << "use std::fmt::{Display, Formatter};" << endl;
+ f_gen_ << "use std::rc::Rc;" << endl;
+ f_gen_ << "use try_from::TryFrom;" << endl;
+ f_gen_ << endl;
+ f_gen_ << "use thrift::{ApplicationError, ApplicationErrorKind, ProtocolError, ProtocolErrorKind, TThriftClient};" << endl;
+ f_gen_ << "use thrift::protocol::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType, TInputProtocol, TOutputProtocol, TSetIdentifier, TStructIdentifier, TType};" << endl;
+ f_gen_ << "use thrift::protocol::field_id;" << endl;
+ f_gen_ << "use thrift::protocol::verify_expected_message_type;" << endl;
+ f_gen_ << "use thrift::protocol::verify_expected_sequence_number;" << endl;
+ f_gen_ << "use thrift::protocol::verify_expected_service_call;" << endl;
+ f_gen_ << "use thrift::protocol::verify_required_field_exists;" << endl;
+ f_gen_ << "use thrift::server::TProcessor;" << endl;
+ f_gen_ << endl;
+
+ // add all the program includes
+ // NOTE: this is more involved than you would expect because of service extension
+ // Basically, I have to find the closure of all the services and include their modules at the top-level
+
+ set<string> referenced_modules;
+
+ // first, start by adding explicit thrift includes
+ const vector<t_program*> includes = get_program()->get_includes();
+ vector<t_program*>::const_iterator includes_iter;
+ for(includes_iter = includes.begin(); includes_iter != includes.end(); ++includes_iter) {
+ referenced_modules.insert((*includes_iter)->get_name());
+ }
+
+ // next, recursively iterate through all the services and add the names of any programs they reference
+ const vector<t_service*> services = get_program()->get_services();
+ vector<t_service*>::const_iterator service_iter;
+ for (service_iter = services.begin(); service_iter != services.end(); ++service_iter) {
+ compute_service_referenced_modules(*service_iter, referenced_modules);
+ }
+
+ // finally, write all the "pub use..." declarations
+ if (!referenced_modules.empty()) {
+ set<string>::iterator module_iter;
+ for (module_iter = referenced_modules.begin(); module_iter != referenced_modules.end(); ++module_iter) {
+ f_gen_ << "use " << rust_snake_case(*module_iter) << ";" << endl;
+ }
+ f_gen_ << endl;
+ }
+}
+
+void t_rs_generator::compute_service_referenced_modules(
+ t_service *tservice,
+ set<string> &referenced_modules
+) {
+ t_service* extends = tservice->get_extends();
+ if (extends) {
+ if (extends->get_program() != get_program()) {
+ referenced_modules.insert(extends->get_program()->get_name());
+ }
+ compute_service_referenced_modules(extends, referenced_modules);
+ }
+}
+
+void t_rs_generator::close_generator() {
+ f_gen_.close();
+}
+
+//-----------------------------------------------------------------------------
+//
+// Consts
+//
+// NOTE: consider using macros to generate constants
+//
+//-----------------------------------------------------------------------------
+
+// This is worse than it should be because constants
+// aren't (sensibly) limited to scalar types
+void t_rs_generator::generate_const(t_const* tconst) {
+ string name = tconst->get_name();
+ t_type* ttype = tconst->get_type();
+ t_const_value* tvalue = tconst->get_value();
+
+ if (can_generate_simple_const(ttype)) {
+ render_const_value(name, ttype, tvalue);
+ } else if (can_generate_const_holder(ttype)) {
+ render_const_value_holder(name, ttype, tvalue);
+ } else {
+ throw "cannot generate const for " + name;
+ }
+}
+
+void t_rs_generator::render_const_value(const string& name, t_type* ttype, t_const_value* tvalue) {
+ if (!can_generate_simple_const(ttype)) {
+ throw "cannot generate simple rust constant for " + ttype->get_name();
+ }
+
+ f_gen_ << "pub const " << rust_upper_case(name) << ": " << to_rust_type(ttype) << " = ";
+ render_const_value(ttype, tvalue);
+ f_gen_ << ";" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_const_value_holder(const string& name, t_type* ttype, t_const_value* tvalue) {
+ if (!can_generate_const_holder(ttype)) {
+ throw "cannot generate constant holder for " + ttype->get_name();
+ }
+
+ string holder_name("Const" + rust_camel_case(name));
+
+ f_gen_ << indent() << "pub struct " << holder_name << ";" << endl;
+ f_gen_ << indent() << "impl " << holder_name << " {" << endl;
+ indent_up();
+
+ f_gen_ << indent() << "pub fn const_value() -> " << to_rust_type(ttype) << " {" << endl;
+ indent_up();
+ render_const_value(ttype, tvalue);
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_const_value(t_type* ttype, t_const_value* tvalue) {
+ if (ttype->is_base_type()) {
+ t_base_type* tbase_type = (t_base_type*)ttype;
+ switch (tbase_type->get_base()) {
+ case t_base_type::TYPE_STRING:
+ if (tbase_type->is_binary()) {
+ f_gen_ << "\"" << tvalue->get_string() << "\""<< ".to_owned().into_bytes()";
+ } else {
+ f_gen_ << "\"" << tvalue->get_string() << "\""<< ".to_owned()";
+ }
+ break;
+ case t_base_type::TYPE_BOOL:
+ f_gen_ << (tvalue->get_integer() ? "true" : "false");
+ break;
+ case t_base_type::TYPE_I8:
+ case t_base_type::TYPE_I16:
+ case t_base_type::TYPE_I32:
+ case t_base_type::TYPE_I64:
+ f_gen_ << tvalue->get_integer();
+ break;
+ case t_base_type::TYPE_DOUBLE:
+ f_gen_
+ << indent()
+ << "OrderedFloat::from(" << tvalue->get_double() << ")"
+ << endl;
+ break;
+ default:
+ throw "cannot generate const value for " + t_base_type::t_base_name(tbase_type->get_base());
+ }
+ } else if (ttype->is_typedef()) {
+ render_const_value(get_true_type(ttype), tvalue);
+ } else if (ttype->is_enum()) {
+ f_gen_ << indent() << "{" << endl;
+ indent_up();
+ f_gen_
+ << indent()
+ << to_rust_type(ttype)
+ << "::try_from("
+ << tvalue->get_integer()
+ << ").expect(\"expecting valid const value\")"
+ << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ } else if (ttype->is_struct() || ttype->is_xception()) {
+ render_const_struct(ttype, tvalue);
+ } else if (ttype->is_container()) {
+ f_gen_ << indent() << "{" << endl;
+ indent_up();
+
+ if (ttype->is_list()) {
+ render_const_list(ttype, tvalue);
+ } else if (ttype->is_set()) {
+ render_const_set(ttype, tvalue);
+ } else if (ttype->is_map()) {
+ render_const_map(ttype, tvalue);
+ } else {
+ throw "cannot generate const container value for " + ttype->get_name();
+ }
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ } else {
+ throw "cannot generate const value for " + ttype->get_name();
+ }
+}
+
+void t_rs_generator::render_const_struct(t_type* ttype, t_const_value*) {
+ if (((t_struct*)ttype)->is_union()) {
+ f_gen_ << indent() << "{" << endl;
+ indent_up();
+ f_gen_ << indent() << "unimplemented!()" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ } else {
+ f_gen_ << indent() << "{" << endl;
+ indent_up();
+ f_gen_ << indent() << "unimplemented!()" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ }
+}
+
+void t_rs_generator::render_const_list(t_type* ttype, t_const_value* tvalue) {
+ t_type* elem_type = ((t_list*)ttype)->get_elem_type();
+ f_gen_ << indent() << "let mut l: Vec<" << to_rust_type(elem_type) << "> = Vec::new();" << endl;
+ const vector<t_const_value*>& elems = tvalue->get_list();
+ vector<t_const_value*>::const_iterator elem_iter;
+ for(elem_iter = elems.begin(); elem_iter != elems.end(); ++elem_iter) {
+ t_const_value* elem_value = (*elem_iter);
+ render_container_const_value("l.push", elem_type, elem_value);
+ }
+ f_gen_ << indent() << "l" << endl;
+}
+
+void t_rs_generator::render_const_set(t_type* ttype, t_const_value* tvalue) {
+ t_type* elem_type = ((t_set*)ttype)->get_elem_type();
+ f_gen_ << indent() << "let mut s: BTreeSet<" << to_rust_type(elem_type) << "> = BTreeSet::new();" << endl;
+ const vector<t_const_value*>& elems = tvalue->get_list();
+ vector<t_const_value*>::const_iterator elem_iter;
+ for(elem_iter = elems.begin(); elem_iter != elems.end(); ++elem_iter) {
+ t_const_value* elem_value = (*elem_iter);
+ render_container_const_value("s.insert", elem_type, elem_value);
+ }
+ f_gen_ << indent() << "s" << endl;
+}
+
+void t_rs_generator::render_const_map(t_type* ttype, t_const_value* tvalue) {
+ t_type* key_type = ((t_map*)ttype)->get_key_type();
+ t_type* val_type = ((t_map*)ttype)->get_val_type();
+ f_gen_
+ << indent()
+ << "let mut m: BTreeMap<"
+ << to_rust_type(key_type) << ", " << to_rust_type(val_type)
+ << "> = BTreeMap::new();"
+ << endl;
+ const map<t_const_value*, t_const_value*>& elems = tvalue->get_map();
+ map<t_const_value*, t_const_value*>::const_iterator elem_iter;
+ for (elem_iter = elems.begin(); elem_iter != elems.end(); ++elem_iter) {
+ t_const_value* key_value = elem_iter->first;
+ t_const_value* val_value = elem_iter->second;
+ if (get_true_type(key_type)->is_base_type()) {
+ f_gen_ << indent() << "let k = ";
+ render_const_value(key_type, key_value);
+ f_gen_ << ";" << endl;
+ } else {
+ f_gen_ << indent() << "let k = {" << endl;
+ indent_up();
+ render_const_value(key_type, key_value);
+ indent_down();
+ f_gen_ << indent() << "};" << endl;
+ }
+ if (get_true_type(val_type)->is_base_type()) {
+ f_gen_ << indent() << "let v = ";
+ render_const_value(val_type, val_value);
+ f_gen_ << ";" << endl;
+ } else {
+ f_gen_ << indent() << "let v = {" << endl;
+ indent_up();
+ render_const_value(val_type, val_value);
+ indent_down();
+ f_gen_ << indent() << "};" << endl;
+ }
+ f_gen_ << indent() << "m.insert(k, v);" << endl;
+ }
+ f_gen_ << indent() << "m" << endl;
+}
+
+void t_rs_generator::render_container_const_value(
+ const string& insert_function,
+ t_type* ttype,
+ t_const_value* tvalue
+) {
+ if (get_true_type(ttype)->is_base_type()) {
+ f_gen_ << indent() << insert_function << "(";
+ render_const_value(ttype, tvalue);
+ f_gen_ << ");" << endl;
+ } else {
+ f_gen_ << indent() << insert_function << "(" << endl;
+ indent_up();
+ render_const_value(ttype, tvalue);
+ indent_down();
+ f_gen_ << indent() << ");" << endl;
+ }
+}
+
+//-----------------------------------------------------------------------------
+//
+// Typedefs
+//
+//-----------------------------------------------------------------------------
+
+void t_rs_generator::generate_typedef(t_typedef* ttypedef) {
+ std::string actual_type = to_rust_type(ttypedef->get_type());
+ f_gen_ << "pub type " << rust_safe_name(ttypedef->get_symbolic()) << " = " << actual_type << ";" << endl;
+ f_gen_ << endl;
+}
+
+//-----------------------------------------------------------------------------
+//
+// Enums
+//
+//-----------------------------------------------------------------------------
+
+void t_rs_generator::generate_enum(t_enum* tenum) {
+ string enum_name(rust_camel_case(tenum->get_name()));
+ render_enum_definition(tenum, enum_name);
+ render_enum_impl(enum_name);
+ render_enum_conversion(tenum, enum_name);
+}
+
+void t_rs_generator::render_enum_definition(t_enum* tenum, const string& enum_name) {
+ render_rustdoc((t_doc*) tenum);
+ f_gen_ << "#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]" << endl;
+ f_gen_ << "pub enum " << enum_name << " {" << endl;
+ indent_up();
+
+ vector<t_enum_value*> constants = tenum->get_constants();
+ vector<t_enum_value*>::iterator constants_iter;
+ for (constants_iter = constants.begin(); constants_iter != constants.end(); ++constants_iter) {
+ t_enum_value* val = (*constants_iter);
+ render_rustdoc((t_doc*) val);
+ f_gen_
+ << indent()
+ << uppercase(val->get_name())
+ << " = "
+ << val->get_value()
+ << ","
+ << endl;
+ }
+
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_enum_impl(const string& enum_name) {
+ f_gen_ << "impl " << enum_name << " {" << endl;
+ indent_up();
+
+ f_gen_
+ << indent()
+ << "pub fn write_to_out_protocol(&self, o_prot: &mut TOutputProtocol) -> thrift::Result<()> {"
+ << endl;
+ indent_up();
+ f_gen_ << indent() << "o_prot.write_i32(*self as i32)" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+
+ f_gen_
+ << indent()
+ << "pub fn read_from_in_protocol(i_prot: &mut TInputProtocol) -> thrift::Result<" << enum_name << "> {"
+ << endl;
+ indent_up();
+
+ f_gen_ << indent() << "let enum_value = i_prot.read_i32()?;" << endl;
+ f_gen_ << indent() << enum_name << "::try_from(enum_value)";
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_enum_conversion(t_enum* tenum, const string& enum_name) {
+ f_gen_ << "impl TryFrom<i32> for " << enum_name << " {" << endl;
+ indent_up();
+
+ f_gen_ << indent() << "type Err = thrift::Error;";
+
+ f_gen_ << indent() << "fn try_from(i: i32) -> Result<Self, Self::Err> {" << endl;
+ indent_up();
+
+ f_gen_ << indent() << "match i {" << endl;
+ indent_up();
+
+ vector<t_enum_value*> constants = tenum->get_constants();
+ vector<t_enum_value*>::iterator constants_iter;
+ for (constants_iter = constants.begin(); constants_iter != constants.end(); ++constants_iter) {
+ t_enum_value* val = (*constants_iter);
+ f_gen_
+ << indent()
+ << val->get_value()
+ << " => Ok(" << enum_name << "::" << uppercase(val->get_name()) << "),"
+ << endl;
+ }
+ f_gen_ << indent() << "_ => {" << endl;
+ indent_up();
+ render_rift_error(
+ "Protocol",
+ "ProtocolError",
+ "ProtocolErrorKind::InvalidData",
+ "format!(\"cannot convert enum constant {} to " + enum_name + "\", i)"
+ );
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+//-----------------------------------------------------------------------------
+//
+// Structs, Unions and Exceptions
+//
+//-----------------------------------------------------------------------------
+
+void t_rs_generator::generate_xception(t_struct* txception) {
+ render_struct(rust_struct_name(txception), txception, t_rs_generator::T_EXCEPTION);
+}
+
+void t_rs_generator::generate_struct(t_struct* tstruct) {
+ if (tstruct->is_union()) {
+ render_union(tstruct);
+ } else if (tstruct->is_struct()) {
+ render_struct(rust_struct_name(tstruct), tstruct, t_rs_generator::T_REGULAR);
+ } else {
+ throw "cannot generate struct for exception";
+ }
+}
+
+void t_rs_generator::render_struct(
+ const string& struct_name,
+ t_struct* tstruct,
+ t_rs_generator::e_struct_type struct_type
+) {
+ render_type_comment(struct_name);
+ render_struct_definition(struct_name, tstruct, struct_type);
+ render_struct_impl(struct_name, tstruct, struct_type);
+ if (struct_type == t_rs_generator::T_REGULAR || struct_type == t_rs_generator::T_EXCEPTION) {
+ render_struct_default_trait_impl(struct_name, tstruct);
+ }
+ if (struct_type == t_rs_generator::T_EXCEPTION) {
+ render_exception_struct_error_trait_impls(struct_name, tstruct);
+ }
+}
+
+void t_rs_generator::render_struct_definition(
+ const string& struct_name,
+ t_struct* tstruct,
+ t_rs_generator::e_struct_type struct_type
+) {
+ render_rustdoc((t_doc*) tstruct);
+ f_gen_ << "#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]" << endl;
+ f_gen_ << visibility_qualifier(struct_type) << "struct " << struct_name << " {" << endl;
+
+ // render the members
+ vector<t_field*> members = tstruct->get_sorted_members();
+ if (!members.empty()) {
+ indent_up();
+
+ vector<t_field*>::iterator members_iter;
+ for(members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = (*members_iter);
+ t_field::e_req member_req = actual_field_req(member, struct_type);
+
+ string rust_type = to_rust_type(member->get_type());
+ rust_type = is_optional(member_req) ? "Option<" + rust_type + ">" : rust_type;
+
+ render_rustdoc((t_doc*) member);
+ f_gen_
+ << indent()
+ << visibility_qualifier(struct_type)
+ << rust_field_name(member) << ": " << rust_type << ","
+ << endl;
+ }
+
+ indent_down();
+ }
+
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_exception_struct_error_trait_impls(const string& struct_name, t_struct* tstruct) {
+ // error::Error trait
+ f_gen_ << "impl Error for " << struct_name << " {" << endl;
+ indent_up();
+ f_gen_ << indent() << "fn description(&self) -> &str {" << endl;
+ indent_up();
+ f_gen_ << indent() << "\"" << "remote service threw " << tstruct->get_name() << "\"" << endl; // use *original* name
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+
+ // convert::From trait
+ f_gen_ << "impl From<" << struct_name << "> for thrift::Error {" << endl;
+ indent_up();
+ f_gen_ << indent() << "fn from(e: " << struct_name << ") -> Self {" << endl;
+ indent_up();
+ f_gen_ << indent() << "thrift::Error::User(Box::new(e))" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+
+ // fmt::Display trait
+ f_gen_ << "impl Display for " << struct_name << " {" << endl;
+ indent_up();
+ f_gen_ << indent() << "fn fmt(&self, f: &mut Formatter) -> fmt::Result {" << endl;
+ indent_up();
+ f_gen_ << indent() << "self.description().fmt(f)" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_struct_default_trait_impl(const string& struct_name, t_struct* tstruct) {
+ bool has_required_field = false;
+
+ const vector<t_field*>& members = tstruct->get_sorted_members();
+ vector<t_field*>::const_iterator members_iter;
+ for (members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = *members_iter;
+ if (!is_optional(member->get_req())) {
+ has_required_field = true;
+ break;
+ }
+ }
+
+ if (has_required_field) {
+ return;
+ }
+
+ f_gen_ << "impl Default for " << struct_name << " {" << endl;
+ indent_up();
+ f_gen_ << indent() << "fn default() -> Self {" << endl;
+ indent_up();
+
+ if (members.empty()) {
+ f_gen_ << indent() << struct_name << "{}" << endl;
+ } else {
+ f_gen_ << indent() << struct_name << "{" << endl;
+ indent_up();
+ for (members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field *member = (*members_iter);
+ string member_name(rust_field_name(member));
+ f_gen_ << indent() << member_name << ": " << opt_in_req_out_value(member->get_type()) << "," << endl;
+ }
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ }
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_struct_impl(
+ const string& struct_name,
+ t_struct* tstruct,
+ t_rs_generator::e_struct_type struct_type
+) {
+ f_gen_ << "impl " << struct_name << " {" << endl;
+ indent_up();
+
+ if (struct_type == t_rs_generator::T_REGULAR || struct_type == t_rs_generator::T_EXCEPTION) {
+ render_struct_constructor(struct_name, tstruct, struct_type);
+ }
+
+ render_struct_sync_read(struct_name, tstruct, struct_type);
+ render_struct_sync_write(tstruct, struct_type);
+
+ if (struct_type == t_rs_generator::T_RESULT) {
+ render_result_struct_to_result_method(tstruct);
+ }
+
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_struct_constructor(
+ const string& struct_name,
+ t_struct* tstruct,
+ t_rs_generator::e_struct_type struct_type
+) {
+ const vector<t_field*>& members = tstruct->get_sorted_members();
+ vector<t_field*>::const_iterator members_iter;
+
+ // build the convenience type parameters that allows us to pass unwrapped values to a constructor and
+ // have them automatically converted into Option<value>
+ bool first_arg = true;
+
+ ostringstream generic_type_parameters;
+ ostringstream generic_type_qualifiers;
+ for(members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = (*members_iter);
+ t_field::e_req member_req = actual_field_req(member, struct_type);
+
+ if (is_optional(member_req)) {
+ if (first_arg) {
+ first_arg = false;
+ } else {
+ generic_type_parameters << ", ";
+ generic_type_qualifiers << ", ";
+ }
+ generic_type_parameters << "F" << member->get_key();
+ generic_type_qualifiers << "F" << member->get_key() << ": Into<Option<" << to_rust_type(member->get_type()) << ">>";
+ }
+ }
+
+ string type_parameter_string = generic_type_parameters.str();
+ if (type_parameter_string.length() != 0) {
+ type_parameter_string = "<" + type_parameter_string + ">";
+ }
+
+ string type_qualifier_string = generic_type_qualifiers.str();
+ if (type_qualifier_string.length() != 0) {
+ type_qualifier_string = "where " + type_qualifier_string + " ";
+ }
+
+ // now build the actual constructor arg list
+ // when we're building this list we have to use the type parameters in place of the actual type names
+ // if necessary
+ ostringstream args;
+ first_arg = true;
+ for(members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = (*members_iter);
+ t_field::e_req member_req = actual_field_req(member, struct_type);
+ string member_name(rust_field_name(member));
+
+ if (first_arg) {
+ first_arg = false;
+ } else {
+ args << ", ";
+ }
+
+ if (is_optional(member_req)) {
+ args << member_name << ": " << "F" << member->get_key();
+ } else {
+ args << member_name << ": " << to_rust_type(member->get_type());
+ }
+ }
+
+ string arg_string = args.str();
+
+ string visibility(visibility_qualifier(struct_type));
+ f_gen_
+ << indent()
+ << visibility
+ << "fn new"
+ << type_parameter_string
+ << "("
+ << arg_string
+ << ") -> "
+ << struct_name
+ << " "
+ << type_qualifier_string
+ << "{"
+ << endl;
+ indent_up();
+
+ if (members.size() == 0) {
+ f_gen_ << indent() << struct_name << " {}" << endl;
+ } else {
+ f_gen_ << indent() << struct_name << " {" << endl;
+ indent_up();
+
+ for(members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = (*members_iter);
+ t_field::e_req member_req = actual_field_req(member, struct_type);
+ string member_name(rust_field_name(member));
+
+ if (is_optional(member_req)) {
+ f_gen_ << indent() << member_name << ": " << member_name << ".into()," << endl;
+ } else {
+ f_gen_ << indent() << member_name << ": " << member_name << "," << endl;
+ }
+ }
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ }
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_result_struct_to_result_method(t_struct* tstruct) {
+ // we don't use the rust struct name in this method, just the service call name
+ string service_call_name = tstruct->get_name();
+
+ // check that we actually have a result
+ size_t index = service_call_name.find(RESULT_STRUCT_SUFFIX, 0);
+ if (index == std::string::npos) {
+ throw "result struct " + service_call_name + " missing result suffix";
+ } else {
+ service_call_name.replace(index, 6, "");
+ }
+
+ const vector<t_field*>& members = tstruct->get_sorted_members();
+ vector<t_field*>::const_iterator members_iter;
+
+ // find out what the call's expected return type was
+ string rust_return_type = "()";
+ for(members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = (*members_iter);
+ if (member->get_name() == SERVICE_RESULT_VARIABLE) { // don't have to check safe name here
+ rust_return_type = to_rust_type(member->get_type());
+ break;
+ }
+ }
+
+ // NOTE: ideally I would generate the branches and render them separately
+ // I tried this however, and the resulting code was harder to understand
+ // maintaining a rendered branch count (while a little ugly) got me the
+ // rendering I wanted with code that was reasonably understandable
+
+ f_gen_ << indent() << "fn ok_or(self) -> thrift::Result<" << rust_return_type << "> {" << endl;
+ indent_up();
+
+ int rendered_branch_count = 0;
+
+ // render the exception branches
+ for(members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* tfield = (*members_iter);
+ if (tfield->get_name() != SERVICE_RESULT_VARIABLE) { // don't have to check safe name here
+ string field_name("self." + rust_field_name(tfield));
+ string branch_statement = rendered_branch_count == 0 ? "if" : "} else if";
+
+ f_gen_ << indent() << branch_statement << " " << field_name << ".is_some() {" << endl;
+ indent_up();
+ f_gen_ << indent() << "Err(thrift::Error::User(Box::new(" << field_name << ".unwrap())))" << endl;
+ indent_down();
+
+ rendered_branch_count++;
+ }
+ }
+
+ // render the return value branches
+ if (rust_return_type == "()") {
+ if (rendered_branch_count == 0) {
+ // we have the unit return and this service call has no user-defined
+ // exceptions. this means that we've a trivial return (happens with oneways)
+ f_gen_ << indent() << "Ok(())" << endl;
+ } else {
+ // we have the unit return, but there are user-defined exceptions
+ // if we've gotten this far then we have the default return (i.e. call successful)
+ f_gen_ << indent() << "} else {" << endl;
+ indent_up();
+ f_gen_ << indent() << "Ok(())" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ }
+ } else {
+ string branch_statement = rendered_branch_count == 0 ? "if" : "} else if";
+ f_gen_ << indent() << branch_statement << " self." << SERVICE_RESULT_VARIABLE << ".is_some() {" << endl;
+ indent_up();
+ f_gen_ << indent() << "Ok(self." << SERVICE_RESULT_VARIABLE << ".unwrap())" << endl;
+ indent_down();
+ f_gen_ << indent() << "} else {" << endl;
+ indent_up();
+ // if we haven't found a valid return value *or* a user exception
+ // then we're in trouble; return a default error
+ render_rift_error(
+ "Application",
+ "ApplicationError",
+ "ApplicationErrorKind::MissingResult",
+ "\"no result received for " + service_call_name + "\""
+ );
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ }
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_union(t_struct* tstruct) {
+ string union_name(rust_struct_name(tstruct));
+ render_type_comment(union_name);
+ render_union_definition(union_name, tstruct);
+ render_union_impl(union_name, tstruct);
+}
+
+void t_rs_generator::render_union_definition(const string& union_name, t_struct* tstruct) {
+ const vector<t_field*>& members = tstruct->get_sorted_members();
+ if (members.empty()) {
+ throw "cannot generate rust enum with 0 members"; // may be valid thrift, but it's invalid rust
+ }
+
+ f_gen_ << "#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]" << endl;
+ f_gen_ << "pub enum " << union_name << " {" << endl;
+ indent_up();
+
+ vector<t_field*>::const_iterator member_iter;
+ for(member_iter = members.begin(); member_iter != members.end(); ++member_iter) {
+ t_field* tfield = (*member_iter);
+ f_gen_
+ << indent()
+ << rust_union_field_name(tfield)
+ << "(" << to_rust_type(tfield->get_type()) << "),"
+ << endl;
+ }
+
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_union_impl(const string& union_name, t_struct* tstruct) {
+ f_gen_ << "impl " << union_name << " {" << endl;
+ indent_up();
+
+ render_union_sync_read(union_name, tstruct);
+ render_union_sync_write(union_name, tstruct);
+
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+//-----------------------------------------------------------------------------
+//
+// Sync Struct Write
+//
+//-----------------------------------------------------------------------------
+
+void t_rs_generator::render_struct_sync_write(
+ t_struct *tstruct,
+ t_rs_generator::e_struct_type struct_type
+) {
+ f_gen_
+ << indent()
+ << visibility_qualifier(struct_type)
+ << "fn write_to_out_protocol(&self, o_prot: &mut TOutputProtocol) -> thrift::Result<()> {"
+ << endl;
+ indent_up();
+
+ // write struct header to output protocol
+ // note: use the *original* struct name here
+ f_gen_ << indent() << "let struct_ident = TStructIdentifier::new(\"" + tstruct->get_name() + "\");" << endl;
+ f_gen_ << indent() << "o_prot.write_struct_begin(&struct_ident)?;" << endl;
+
+ // write struct members to output protocol
+ vector<t_field*> members = tstruct->get_sorted_members();
+ if (!members.empty()) {
+ vector<t_field*>::iterator members_iter;
+ for(members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = (*members_iter);
+ t_field::e_req member_req = actual_field_req(member, struct_type);
+ string member_var("self." + rust_field_name(member));
+ render_struct_field_sync_write(member_var, false, member, member_req);
+ }
+ }
+
+ // write struct footer to output protocol
+ f_gen_ << indent() << "o_prot.write_field_stop()?;" << endl;
+ f_gen_ << indent() << "o_prot.write_struct_end()" << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_union_sync_write(const string &union_name, t_struct *tstruct) {
+ f_gen_
+ << indent()
+ << "pub fn write_to_out_protocol(&self, o_prot: &mut TOutputProtocol) -> thrift::Result<()> {"
+ << endl;
+ indent_up();
+
+ // write struct header to output protocol
+ // note: use the *original* struct name here
+ f_gen_ << indent() << "let struct_ident = TStructIdentifier::new(\"" + tstruct->get_name() + "\");" << endl;
+ f_gen_ << indent() << "o_prot.write_struct_begin(&struct_ident)?;" << endl;
+
+ // write the enum field to the output protocol
+ vector<t_field*> members = tstruct->get_sorted_members();
+ if (!members.empty()) {
+ f_gen_ << indent() << "match *self {" << endl;
+ indent_up();
+ vector<t_field*>::iterator members_iter;
+ for(members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = (*members_iter);
+ t_field::e_req member_req = t_field::T_REQUIRED;
+ t_type* ttype = member->get_type();
+ string match_var((ttype->is_base_type() && !ttype->is_string()) ? "f" : "ref f");
+ f_gen_
+ << indent()
+ << union_name << "::" << rust_union_field_name(member)
+ << "(" << match_var << ") => {"
+ << endl;
+ indent_up();
+ render_struct_field_sync_write("f", true, member, member_req);
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+ }
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ }
+
+ // write struct footer to output protocol
+ f_gen_ << indent() << "o_prot.write_field_stop()?;" << endl;
+ f_gen_ << indent() << "o_prot.write_struct_end()" << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_struct_field_sync_write(
+ const string &field_var,
+ bool field_var_is_ref,
+ t_field *tfield,
+ t_field::e_req req
+) {
+ t_type* field_type = tfield->get_type();
+ t_type* actual_type = get_true_type(field_type);
+
+ ostringstream field_stream;
+ field_stream
+ << "TFieldIdentifier::new("
+ << "\"" << tfield->get_name() << "\"" << ", " // note: use *original* name
+ << to_rust_field_type_enum(field_type) << ", "
+ << tfield->get_key() << ")";
+ string field_ident_string = field_stream.str();
+
+ if (is_optional(req)) {
+ string let_var((actual_type->is_base_type() && !actual_type->is_string()) ? "fld_var" : "ref fld_var");
+ f_gen_ << indent() << "if let Some(" << let_var << ") = " << field_var << " {" << endl;
+ indent_up();
+ f_gen_ << indent() << "o_prot.write_field_begin(&" << field_ident_string << ")?;" << endl;
+ render_type_sync_write("fld_var", true, field_type);
+ f_gen_ << indent() << "o_prot.write_field_end()?;" << endl;
+ f_gen_ << indent() << "()" << endl; // FIXME: remove this extraneous '()'
+ indent_down();
+ f_gen_ << indent() << "} else {" << endl; // FIXME: remove else branch
+ indent_up();
+ /* FIXME: rethink how I deal with OPT_IN_REQ_OUT
+ if (req == t_field::T_OPT_IN_REQ_OUT) {
+ f_gen_ << indent() << "let field_ident = " << field_ident_string << ";" << endl;
+ f_gen_ << indent() << "o_prot.write_field_begin(&field_ident)?;" << endl;
+ f_gen_ << indent() << "o_prot.write_field_end()?;" << endl;
+ }*/
+ f_gen_ << indent() << "()" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ } else {
+ f_gen_ << indent() << "o_prot.write_field_begin(&" << field_ident_string << ")?;" << endl;
+ render_type_sync_write(field_var, field_var_is_ref, tfield->get_type());
+ f_gen_ << indent() << "o_prot.write_field_end()?;" << endl;
+ }
+}
+
+void t_rs_generator::render_type_sync_write(const string &type_var, bool type_var_is_ref, t_type *ttype) {
+ if (ttype->is_base_type()) {
+ t_base_type* tbase_type = (t_base_type*)ttype;
+ switch (tbase_type->get_base()) {
+ case t_base_type::TYPE_VOID:
+ throw "cannot write field of type TYPE_VOID to output protocol";
+ case t_base_type::TYPE_STRING: {
+ string ref(type_var_is_ref ? "" : "&");
+ if (tbase_type->is_binary()) {
+ f_gen_ << indent() << "o_prot.write_bytes(" + ref + type_var + ")?;" << endl;
+ } else {
+ f_gen_ << indent() << "o_prot.write_string(" + ref + type_var + ")?;" << endl;
+ }
+ return;
+ }
+ case t_base_type::TYPE_BOOL:
+ f_gen_ << indent() << "o_prot.write_bool(" + type_var + ")?;" << endl;
+ return;
+ case t_base_type::TYPE_I8:
+ f_gen_ << indent() << "o_prot.write_i8(" + type_var + ")?;" << endl;
+ return;
+ case t_base_type::TYPE_I16:
+ f_gen_ << indent() << "o_prot.write_i16(" + type_var + ")?;" << endl;
+ return;
+ case t_base_type::TYPE_I32:
+ f_gen_ << indent() << "o_prot.write_i32(" + type_var + ")?;" << endl;
+ return;
+ case t_base_type::TYPE_I64:
+ f_gen_ << indent() << "o_prot.write_i64(" + type_var + ")?;" << endl;
+ return;
+ case t_base_type::TYPE_DOUBLE:
+ f_gen_ << indent() << "o_prot.write_double(" + type_var + ".into())?;" << endl;
+ return;
+ }
+ } else if (ttype->is_typedef()) {
+ t_typedef* ttypedef = (t_typedef*) ttype;
+ render_type_sync_write(type_var, type_var_is_ref, ttypedef->get_type());
+ return;
+ } else if (ttype->is_enum() || ttype->is_struct() || ttype->is_xception()) {
+ f_gen_ << indent() << type_var + ".write_to_out_protocol(o_prot)?;" << endl;
+ return;
+ } else if (ttype->is_map()) {
+ render_map_sync_write(type_var, type_var_is_ref, (t_map *) ttype);
+ return;
+ } else if (ttype->is_set()) {
+ render_set_sync_write(type_var, type_var_is_ref, (t_set *) ttype);
+ return;
+ } else if (ttype->is_list()) {
+ render_list_sync_write(type_var, type_var_is_ref, (t_list *) ttype);
+ return;
+ }
+
+ throw "cannot write unsupported type " + ttype->get_name();
+}
+
+void t_rs_generator::render_list_sync_write(const string &list_var, bool list_var_is_ref, t_list *tlist) {
+ t_type* elem_type = tlist->get_elem_type();
+
+ f_gen_
+ << indent()
+ << "o_prot.write_list_begin("
+ << "&TListIdentifier::new("
+ << to_rust_field_type_enum(elem_type) << ", "
+ << list_var << ".len() as i32" << ")"
+ << ")?;"
+ << endl;
+
+ string ref(list_var_is_ref ? "" : "&");
+ f_gen_ << indent() << "for e in " << ref << list_var << " {" << endl;
+ indent_up();
+ render_type_sync_write(needs_deref_on_container_write(elem_type) ? "*e" : "e", true, elem_type);
+ f_gen_ << indent() << "o_prot.write_list_end()?;" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_set_sync_write(const string &set_var, bool set_var_is_ref, t_set *tset) {
+ t_type* elem_type = tset->get_elem_type();
+
+ f_gen_
+ << indent()
+ << "o_prot.write_set_begin("
+ << "&TSetIdentifier::new("
+ << to_rust_field_type_enum(elem_type) << ", "
+ << set_var << ".len() as i32" << ")"
+ << ")?;"
+ << endl;
+
+ string ref(set_var_is_ref ? "" : "&");
+ f_gen_ << indent() << "for e in " << ref << set_var << " {" << endl;
+ indent_up();
+ render_type_sync_write(needs_deref_on_container_write(elem_type) ? "*e" : "e", true, elem_type);
+ f_gen_ << indent() << "o_prot.write_set_end()?;" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_map_sync_write(const string &map_var, bool map_var_is_ref, t_map *tmap) {
+ t_type* key_type = tmap->get_key_type();
+ t_type* val_type = tmap->get_val_type();
+
+ f_gen_
+ << indent()
+ << "o_prot.write_map_begin("
+ << "&TMapIdentifier::new("
+ << to_rust_field_type_enum(key_type) << ", "
+ << to_rust_field_type_enum(val_type) << ", "
+ << map_var << ".len() as i32)"
+ << ")?;"
+ << endl;
+
+ string ref(map_var_is_ref ? "" : "&");
+ f_gen_ << indent() << "for (k, v) in " << ref << map_var << " {" << endl;
+ indent_up();
+ render_type_sync_write(needs_deref_on_container_write(key_type) ? "*k" : "k", true, key_type);
+ render_type_sync_write(needs_deref_on_container_write(val_type) ? "*v" : "v", true, val_type);
+ f_gen_ << indent() << "o_prot.write_map_end()?;" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+bool t_rs_generator::needs_deref_on_container_write(t_type* ttype) {
+ ttype = get_true_type(ttype);
+ return ttype->is_base_type() && !ttype->is_string();
+}
+
+//-----------------------------------------------------------------------------
+//
+// Sync Struct Read
+//
+//-----------------------------------------------------------------------------
+
+void t_rs_generator::render_struct_sync_read(
+ const string &struct_name,
+ t_struct *tstruct, t_rs_generator::e_struct_type struct_type
+) {
+ f_gen_
+ << indent()
+ << visibility_qualifier(struct_type)
+ << "fn read_from_in_protocol(i_prot: &mut TInputProtocol) -> thrift::Result<" << struct_name << "> {"
+ << endl;
+
+ indent_up();
+
+ f_gen_ << indent() << "i_prot.read_struct_begin()?;" << endl;
+
+ // create temporary variables: one for each field in the struct
+ const vector<t_field*> members = tstruct->get_sorted_members();
+ vector<t_field*>::const_iterator members_iter;
+ for (members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = (*members_iter);
+ t_field::e_req member_req = actual_field_req(member, struct_type);
+
+ f_gen_
+ << indent()
+ << "let mut " << struct_field_read_temp_variable(member)
+ << ": Option<" << to_rust_type(member->get_type()) << "> = ";
+ if (member_req == t_field::T_OPT_IN_REQ_OUT) {
+ f_gen_ << opt_in_req_out_value(member->get_type()) << ";";
+ } else {
+ f_gen_ << "None;";
+ }
+ f_gen_ << endl;
+ }
+
+ // now loop through the fields we've received
+ f_gen_ << indent() << "loop {" << endl; // start loop
+ indent_up();
+
+ // break out if you've found the Stop field
+ f_gen_ << indent() << "let field_ident = i_prot.read_field_begin()?;" << endl;
+ f_gen_ << indent() << "if field_ident.field_type == TType::Stop {" << endl;
+ indent_up();
+ f_gen_ << indent() << "break;" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+
+ // now read all the fields found
+ f_gen_ << indent() << "let field_id = field_id(&field_ident)?;" << endl;
+ f_gen_ << indent() << "match field_id {" << endl; // start match
+ indent_up();
+
+ for (members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* tfield = (*members_iter);
+ f_gen_ << indent() << tfield->get_key() << " => {" << endl;
+ indent_up();
+ render_type_sync_read("val", tfield->get_type());
+ f_gen_ << indent() << struct_field_read_temp_variable(tfield) << " = Some(val);" << endl;
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+ }
+
+ // default case (skip fields)
+ f_gen_ << indent() << "_ => {" << endl;
+ indent_up();
+ f_gen_ << indent() << "i_prot.skip(field_ident.field_type)?;" << endl;
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+
+ indent_down();
+ f_gen_ << indent() << "};" << endl; // finish match
+ f_gen_ << indent() << "i_prot.read_field_end()?;" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl; // finish loop
+ f_gen_ << indent() << "i_prot.read_struct_end()?;" << endl; // read message footer from the wire
+
+ // verify that all required fields exist
+ for (members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* tfield = (*members_iter);
+ t_field::e_req req = actual_field_req(tfield, struct_type);
+ if (!is_optional(req)) {
+ f_gen_
+ << indent()
+ << "verify_required_field_exists("
+ << "\"" << struct_name << "." << rust_field_name(tfield) << "\""
+ << ", "
+ << "&" << struct_field_read_temp_variable(tfield)
+ << ")?;" << endl;
+ }
+ }
+
+ // construct the struct
+ if (members.size() == 0) {
+ f_gen_ << indent() << "let ret = " << struct_name << " {};" << endl;
+ } else {
+ f_gen_ << indent() << "let ret = " << struct_name << " {" << endl;
+ indent_up();
+
+ for (members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* tfield = (*members_iter);
+ t_field::e_req req = actual_field_req(tfield, struct_type);
+ string field_name(rust_field_name(tfield));
+ string field_key = struct_field_read_temp_variable(tfield);
+ if (is_optional(req)) {
+ f_gen_ << indent() << field_name << ": " << field_key << "," << endl;
+ } else {
+ f_gen_
+ << indent()
+ << field_name
+ << ": "
+ << field_key
+ << ".expect(\"auto-generated code should have checked for presence of required fields\")"
+ << ","
+ << endl;
+ }
+ }
+
+ indent_down();
+ f_gen_ << indent() << "};" << endl;
+ }
+
+ // return the constructed value
+ f_gen_ << indent() << "Ok(ret)" << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_union_sync_read(const string &union_name, t_struct *tstruct) {
+ f_gen_
+ << indent()
+ << "pub fn read_from_in_protocol(i_prot: &mut TInputProtocol) -> thrift::Result<" << union_name << "> {"
+ << endl;
+ indent_up();
+
+ // create temporary variables to hold the
+ // completed union as well as a count of fields read
+ f_gen_ << indent() << "let mut ret: Option<" << union_name << "> = None;" << endl;
+ f_gen_ << indent() << "let mut received_field_count = 0;" << endl;
+
+ // read the struct preamble
+ f_gen_ << indent() << "i_prot.read_struct_begin()?;" << endl;
+
+ // now loop through the fields we've received
+ f_gen_ << indent() << "loop {" << endl; // start loop
+ indent_up();
+
+ // break out if you've found the Stop field
+ f_gen_ << indent() << "let field_ident = i_prot.read_field_begin()?;" << endl;
+ f_gen_ << indent() << "if field_ident.field_type == TType::Stop {" << endl;
+ indent_up();
+ f_gen_ << indent() << "break;" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+
+ // now read all the fields found
+ f_gen_ << indent() << "let field_id = field_id(&field_ident)?;" << endl;
+ f_gen_ << indent() << "match field_id {" << endl; // start match
+ indent_up();
+
+ const vector<t_field*> members = tstruct->get_sorted_members();
+ vector<t_field*>::const_iterator members_iter;
+ for (members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = (*members_iter);
+ f_gen_ << indent() << member->get_key() << " => {" << endl;
+ indent_up();
+ render_type_sync_read("val", member->get_type());
+ f_gen_ << indent() << "if ret.is_none() {" << endl;
+ indent_up();
+ f_gen_
+ << indent()
+ << "ret = Some(" << union_name << "::" << rust_union_field_name(member) << "(val));"
+ << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << indent() << "received_field_count += 1;" << endl;
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+ }
+
+ // default case (skip fields)
+ f_gen_ << indent() << "_ => {" << endl;
+ indent_up();
+ f_gen_ << indent() << "i_prot.skip(field_ident.field_type)?;" << endl;
+ f_gen_ << indent() << "received_field_count += 1;" << endl;
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+
+ indent_down();
+ f_gen_ << indent() << "};" << endl; // finish match
+ f_gen_ << indent() << "i_prot.read_field_end()?;" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl; // finish loop
+ f_gen_ << indent() << "i_prot.read_struct_end()?;" << endl; // finish reading message from wire
+
+ // return the value or an error
+ f_gen_ << indent() << "if received_field_count == 0 {" << endl;
+ indent_up();
+ render_rift_error(
+ "Protocol",
+ "ProtocolError",
+ "ProtocolErrorKind::InvalidData",
+ "\"received empty union from remote " + union_name + "\""
+ );
+ indent_down();
+ f_gen_ << indent() << "} else if received_field_count > 1 {" << endl;
+ indent_up();
+ render_rift_error(
+ "Protocol",
+ "ProtocolError",
+ "ProtocolErrorKind::InvalidData",
+ "\"received multiple fields for union from remote " + union_name + "\""
+ );
+ indent_down();
+ f_gen_ << indent() << "} else {" << endl;
+ indent_up();
+ f_gen_ << indent() << "Ok(ret.expect(\"return value should have been constructed\"))" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+// Construct the rust representation of all supported types from the wire.
+void t_rs_generator::render_type_sync_read(const string &type_var, t_type *ttype) {
+ if (ttype->is_base_type()) {
+ t_base_type* tbase_type = (t_base_type*)ttype;
+ switch (tbase_type->get_base()) {
+ case t_base_type::TYPE_VOID:
+ throw "cannot read field of type TYPE_VOID from input protocol";
+ case t_base_type::TYPE_STRING:
+ if (tbase_type->is_binary()) {
+ f_gen_ << indent() << "let " << type_var << " = i_prot.read_bytes()?;" << endl;
+ } else {
+ f_gen_ << indent() << "let " << type_var << " = i_prot.read_string()?;" << endl;
+ }
+ return;
+ case t_base_type::TYPE_BOOL:
+ f_gen_ << indent() << "let " << type_var << " = i_prot.read_bool()?;" << endl;
+ return;
+ case t_base_type::TYPE_I8:
+ f_gen_ << indent() << "let " << type_var << " = i_prot.read_i8()?;" << endl;
+ return;
+ case t_base_type::TYPE_I16:
+ f_gen_ << indent() << "let " << type_var << " = i_prot.read_i16()?;" << endl;
+ return;
+ case t_base_type::TYPE_I32:
+ f_gen_ << indent() << "let " << type_var << " = i_prot.read_i32()?;" << endl;
+ return;
+ case t_base_type::TYPE_I64:
+ f_gen_ << indent() << "let " << type_var << " = i_prot.read_i64()?;" << endl;
+ return;
+ case t_base_type::TYPE_DOUBLE:
+ f_gen_ << indent() << "let " << type_var << " = OrderedFloat::from(i_prot.read_double()?);" << endl;
+ return;
+ }
+ } else if (ttype->is_typedef()) {
+ t_typedef* ttypedef = (t_typedef*)ttype;
+ render_type_sync_read(type_var, ttypedef->get_type());
+ return;
+ } else if (ttype->is_enum() || ttype->is_struct() || ttype->is_xception()) {
+ f_gen_
+ << indent()
+ << "let " << type_var << " = " << to_rust_type(ttype) << "::read_from_in_protocol(i_prot)?;"
+ << endl;
+ return;
+ } else if (ttype->is_map()) {
+ render_map_sync_read((t_map *) ttype, type_var);
+ return;
+ } else if (ttype->is_set()) {
+ render_set_sync_read((t_set *) ttype, type_var);
+ return;
+ } else if (ttype->is_list()) {
+ render_list_sync_read((t_list *) ttype, type_var);
+ return;
+ }
+
+ throw "cannot read unsupported type " + ttype->get_name();
+}
+
+// Construct the rust representation of a list from the wire.
+void t_rs_generator::render_list_sync_read(t_list *tlist, const string &list_var) {
+ t_type* elem_type = tlist->get_elem_type();
+
+ f_gen_ << indent() << "let list_ident = i_prot.read_list_begin()?;" << endl;
+ f_gen_
+ << indent()
+ << "let mut " << list_var << ": " << to_rust_type((t_type*) tlist)
+ << " = Vec::with_capacity(list_ident.size as usize);"
+ << endl;
+ f_gen_ << indent() << "for _ in 0..list_ident.size {" << endl;
+
+ indent_up();
+
+ string list_elem_var = tmp("list_elem_");
+ render_type_sync_read(list_elem_var, elem_type);
+ f_gen_ << indent() << list_var << ".push(" << list_elem_var << ");" << endl;
+
+ indent_down();
+
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << indent() << "i_prot.read_list_end()?;" << endl;
+}
+
+// Construct the rust representation of a set from the wire.
+void t_rs_generator::render_set_sync_read(t_set *tset, const string &set_var) {
+ t_type* elem_type = tset->get_elem_type();
+
+ f_gen_ << indent() << "let set_ident = i_prot.read_set_begin()?;" << endl;
+ f_gen_
+ << indent()
+ << "let mut " << set_var << ": " << to_rust_type((t_type*) tset)
+ << " = BTreeSet::new();"
+ << endl;
+ f_gen_ << indent() << "for _ in 0..set_ident.size {" << endl;
+
+ indent_up();
+
+ string set_elem_var = tmp("set_elem_");
+ render_type_sync_read(set_elem_var, elem_type);
+ f_gen_ << indent() << set_var << ".insert(" << set_elem_var << ");" << endl;
+
+ indent_down();
+
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << indent() << "i_prot.read_set_end()?;" << endl;
+}
+
+// Construct the rust representation of a map from the wire.
+void t_rs_generator::render_map_sync_read(t_map *tmap, const string &map_var) {
+ t_type* key_type = tmap->get_key_type();
+ t_type* val_type = tmap->get_val_type();
+
+ f_gen_ << indent() << "let map_ident = i_prot.read_map_begin()?;" << endl;
+ f_gen_
+ << indent()
+ << "let mut " << map_var << ": " << to_rust_type((t_type*) tmap)
+ << " = BTreeMap::new();"
+ << endl;
+ f_gen_ << indent() << "for _ in 0..map_ident.size {" << endl;
+
+ indent_up();
+
+ string key_elem_var = tmp("map_key_");
+ render_type_sync_read(key_elem_var, key_type);
+ string val_elem_var = tmp("map_val_");
+ render_type_sync_read(val_elem_var, val_type);
+ f_gen_ << indent() << map_var << ".insert(" << key_elem_var << ", " << val_elem_var << ");" << endl;
+
+ indent_down();
+
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << indent() << "i_prot.read_map_end()?;" << endl;
+}
+
+string t_rs_generator::struct_field_read_temp_variable(t_field* tfield) {
+ std::ostringstream foss;
+ foss << "f_" << tfield->get_key();
+ return foss.str();
+}
+
+//-----------------------------------------------------------------------------
+//
+// Sync Client
+//
+//-----------------------------------------------------------------------------
+
+void t_rs_generator::generate_service(t_service* tservice) {
+ render_sync_client(tservice);
+ render_sync_processor(tservice);
+ render_service_call_structs(tservice);
+}
+
+void t_rs_generator::render_service_call_structs(t_service* tservice) {
+ const std::vector<t_function*> functions = tservice->get_functions();
+ std::vector<t_function*>::const_iterator func_iter;
+
+ // thrift args for service calls are packed
+ // into a struct that's transmitted over the wire, so
+ // generate structs for those too
+ //
+ // thrift returns are *also* packed into a struct
+ // that's passed over the wire, so, generate the struct
+ // for that too. Note that this result struct *also*
+ // contains the exceptions as well
+ for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) {
+ t_function* tfunc = (*func_iter);
+ render_struct(rust_struct_name(tfunc->get_arglist()), tfunc->get_arglist(), t_rs_generator::T_ARGS);
+ if (!tfunc->is_oneway()) {
+ render_result_value_struct(tfunc);
+ }
+ }
+}
+
+void t_rs_generator::render_sync_client(t_service* tservice) {
+ string client_impl_name(rust_sync_client_impl_name(tservice));
+
+ render_type_comment(tservice->get_name() + " service client"); // note: use *original* name
+ render_sync_client_trait(tservice);
+ render_sync_client_marker_trait(tservice);
+ render_sync_client_definition_and_impl(client_impl_name);
+ render_sync_client_tthriftclient_impl(client_impl_name);
+ render_sync_client_marker_trait_impls(tservice, client_impl_name); f_gen_ << endl;
+ render_sync_client_process_impl(tservice);
+}
+
+void t_rs_generator::render_sync_client_trait(t_service *tservice) {
+ string extension = "";
+ if (tservice->get_extends()) {
+ t_service* extends = tservice->get_extends();
+ extension = " : " + rust_namespace(extends) + rust_sync_client_trait_name(extends);
+ }
+
+ render_rustdoc((t_doc*) tservice);
+ f_gen_ << "pub trait " << rust_sync_client_trait_name(tservice) << extension << " {" << endl;
+ indent_up();
+
+ const std::vector<t_function*> functions = tservice->get_functions();
+ std::vector<t_function*>::const_iterator func_iter;
+ for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) {
+ t_function* tfunc = (*func_iter);
+ string func_name = service_call_client_function_name(tfunc);
+ string func_args = rust_sync_service_call_declaration(tfunc);
+ string func_return = to_rust_type(tfunc->get_returntype());
+ render_rustdoc((t_doc*) tfunc);
+ f_gen_ << indent() << "fn " << func_name << func_args << " -> thrift::Result<" << func_return << ">;" << endl;
+ }
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_sync_client_marker_trait(t_service *tservice) {
+ f_gen_ << indent() << "pub trait " << rust_sync_client_marker_trait_name(tservice) << " {}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_sync_client_marker_trait_impls(t_service *tservice, const string &impl_struct_name) {
+ f_gen_
+ << indent()
+ << "impl " << rust_namespace(tservice) << rust_sync_client_marker_trait_name(tservice)
+ << " for " << impl_struct_name
+ << " {}"
+ << endl;
+
+ t_service* extends = tservice->get_extends();
+ if (extends) {
+ render_sync_client_marker_trait_impls(extends, impl_struct_name);
+ }
+}
+
+void t_rs_generator::render_sync_client_definition_and_impl(const string& client_impl_name) {
+ // render the definition for the client struct
+ f_gen_ << "pub struct " << client_impl_name << " {" << endl;
+ indent_up();
+ f_gen_ << indent() << "_i_prot: Box<TInputProtocol>," << endl;
+ f_gen_ << indent() << "_o_prot: Box<TOutputProtocol>," << endl;
+ f_gen_ << indent() << "_sequence_number: i32," << endl;
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+
+ // render the struct implementation
+ // this includes the new() function as well as the helper send/recv methods for each service call
+ f_gen_ << "impl " << client_impl_name << " {" << endl;
+ indent_up();
+ render_sync_client_lifecycle_functions(client_impl_name);
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_sync_client_lifecycle_functions(const string& client_struct) {
+ f_gen_
+ << indent()
+ << "pub fn new(input_protocol: Box<TInputProtocol>, output_protocol: Box<TOutputProtocol>) -> "
+ << client_struct
+ << " {"
+ << endl;
+ indent_up();
+
+ f_gen_
+ << indent()
+ << client_struct
+ << " { _i_prot: input_protocol, _o_prot: output_protocol, _sequence_number: 0 }"
+ << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_sync_client_tthriftclient_impl(const string &client_impl_name) {
+ f_gen_ << indent() << "impl TThriftClient for " << client_impl_name << " {" << endl;
+ indent_up();
+
+ f_gen_ << indent() << "fn i_prot_mut(&mut self) -> &mut TInputProtocol { &mut *self._i_prot }" << endl;
+ f_gen_ << indent() << "fn o_prot_mut(&mut self) -> &mut TOutputProtocol { &mut *self._o_prot }" << endl;
+ f_gen_ << indent() << "fn sequence_number(&self) -> i32 { self._sequence_number }" << endl;
+ f_gen_
+ << indent()
+ << "fn increment_sequence_number(&mut self) -> i32 { self._sequence_number += 1; self._sequence_number }"
+ << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_sync_client_process_impl(t_service* tservice) {
+ string marker_extension = "" + sync_client_marker_traits_for_extension(tservice);
+
+ f_gen_
+ << "impl <C: TThriftClient + " << rust_sync_client_marker_trait_name(tservice) << marker_extension << "> "
+ << rust_sync_client_trait_name(tservice)
+ << " for C {" << endl;
+ indent_up();
+
+ const std::vector<t_function*> functions = tservice->get_functions();
+ std::vector<t_function*>::const_iterator func_iter;
+ for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) {
+ t_function* func = (*func_iter);
+ render_sync_send_recv_wrapper(func);
+ }
+
+ indent_down();
+ f_gen_ << "}" << endl;
+ f_gen_ << endl;
+}
+
+string t_rs_generator::sync_client_marker_traits_for_extension(t_service *tservice) {
+ string marker_extension;
+
+ t_service* extends = tservice->get_extends();
+ if (extends) {
+ marker_extension = " + " + rust_namespace(extends) + rust_sync_client_marker_trait_name(extends);
+ marker_extension = marker_extension + sync_client_marker_traits_for_extension(extends);
+ }
+
+ return marker_extension;
+}
+
+void t_rs_generator::render_sync_send_recv_wrapper(t_function* tfunc) {
+ string func_name = service_call_client_function_name(tfunc);
+ string func_decl_args = rust_sync_service_call_declaration(tfunc);
+ string func_call_args = rust_sync_service_call_invocation(tfunc);
+ string func_return = to_rust_type(tfunc->get_returntype());
+
+ f_gen_
+ << indent()
+ << "fn " << func_name << func_decl_args << " -> thrift::Result<" << func_return
+ << "> {"
+ << endl;
+ indent_up();
+
+ f_gen_ << indent() << "(" << endl;
+ indent_up();
+ render_sync_send(tfunc);
+ indent_down();
+ f_gen_ << indent() << ")?;" << endl;
+ if (tfunc->is_oneway()) {
+ f_gen_ << indent() << "Ok(())" << endl;
+ } else {
+ render_sync_recv(tfunc);
+ }
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_sync_send(t_function* tfunc) {
+ f_gen_ << indent() << "{" << endl;
+ indent_up();
+
+ // increment the sequence number and generate the call header
+ string message_type = tfunc->is_oneway() ? "TMessageType::OneWay" : "TMessageType::Call";
+ f_gen_ << indent() << "self.increment_sequence_number();" << endl;
+ f_gen_
+ << indent()
+ << "let message_ident = "
+ << "TMessageIdentifier::new(\"" << tfunc->get_name() << "\", " // note: use *original* name
+ << message_type << ", "
+ << "self.sequence_number());"
+ << endl;
+ // pack the arguments into the containing struct that we'll write out over the wire
+ // note that this struct is generated even if we have 0 args
+ ostringstream struct_definition;
+ vector<t_field*> members = tfunc->get_arglist()->get_sorted_members();
+ vector<t_field*>::iterator members_iter;
+ for (members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* member = (*members_iter);
+ string member_name(rust_field_name(member));
+ struct_definition << member_name << ": " << member_name << ", ";
+ }
+ string struct_fields = struct_definition.str();
+ if (struct_fields.size() > 0) {
+ struct_fields = struct_fields.substr(0, struct_fields.size() - 2); // strip trailing comma
+ }
+ f_gen_
+ << indent()
+ << "let call_args = "
+ << rust_struct_name(tfunc->get_arglist())
+ << " { "
+ << struct_fields
+ << " };"
+ << endl;
+ // write everything over the wire
+ f_gen_ << indent() << "self.o_prot_mut().write_message_begin(&message_ident)?;" << endl;
+ f_gen_ << indent() << "call_args.write_to_out_protocol(self.o_prot_mut())?;" << endl; // written even if we have 0 args
+ f_gen_ << indent() << "self.o_prot_mut().write_message_end()?;" << endl;
+ f_gen_ << indent() << "self.o_prot_mut().flush()" << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_sync_recv(t_function* tfunc) {
+ f_gen_ << indent() << "{" << endl;
+ indent_up();
+
+ f_gen_ << indent() << "let message_ident = self.i_prot_mut().read_message_begin()?;" << endl;
+ f_gen_ << indent() << "verify_expected_sequence_number(self.sequence_number(), message_ident.sequence_number)?;" << endl;
+ f_gen_ << indent() << "verify_expected_service_call(\"" << tfunc->get_name() <<"\", &message_ident.name)?;" << endl; // note: use *original* name
+ // FIXME: replace with a "try" block
+ f_gen_ << indent() << "if message_ident.message_type == TMessageType::Exception {" << endl;
+ indent_up();
+ f_gen_ << indent() << "let remote_error = thrift::Error::read_application_error_from_in_protocol(self.i_prot_mut())?;" << endl;
+ f_gen_ << indent() << "self.i_prot_mut().read_message_end()?;" << endl;
+ f_gen_ << indent() << "return Err(thrift::Error::Application(remote_error))" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << indent() << "verify_expected_message_type(TMessageType::Reply, message_ident.message_type)?;" << endl;
+ f_gen_ << indent() << "let result = " << service_call_result_struct_name(tfunc) << "::read_from_in_protocol(self.i_prot_mut())?;" << endl;
+ f_gen_ << indent() << "self.i_prot_mut().read_message_end()?;" << endl;
+ f_gen_ << indent() << "result.ok_or()" << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+string t_rs_generator::rust_sync_service_call_declaration(t_function* tfunc) {
+ ostringstream func_args;
+ func_args << "(&mut self";
+
+ if (has_args(tfunc)) {
+ func_args << ", "; // put comma after "&mut self"
+ func_args << struct_to_declaration(tfunc->get_arglist(), T_ARGS);
+ }
+
+ func_args << ")";
+ return func_args.str();
+}
+
+string t_rs_generator::rust_sync_service_call_invocation(t_function* tfunc, const string& field_prefix) {
+ ostringstream func_args;
+ func_args << "(";
+
+ if (has_args(tfunc)) {
+ func_args << struct_to_invocation(tfunc->get_arglist(), field_prefix);
+ }
+
+ func_args << ")";
+ return func_args.str();
+}
+
+string t_rs_generator::struct_to_declaration(t_struct* tstruct, t_rs_generator::e_struct_type struct_type) {
+ ostringstream args;
+
+ bool first_arg = true;
+ std::vector<t_field*> fields = tstruct->get_sorted_members();
+ std::vector<t_field*>::iterator field_iter;
+ for (field_iter = fields.begin(); field_iter != fields.end(); ++field_iter) {
+ t_field* tfield = (*field_iter);
+ t_field::e_req field_req = actual_field_req(tfield, struct_type);
+ string rust_type = to_rust_type(tfield->get_type());
+ rust_type = is_optional(field_req) ? "Option<" + rust_type + ">" : rust_type;
+
+ if (first_arg) {
+ first_arg = false;
+ } else {
+ args << ", ";
+ }
+
+ args << rust_field_name(tfield) << ": " << rust_type;
+ }
+
+ return args.str();
+}
+
+string t_rs_generator::struct_to_invocation(t_struct* tstruct, const string& field_prefix) {
+ ostringstream args;
+
+ bool first_arg = true;
+ std::vector<t_field*> fields = tstruct->get_sorted_members();
+ std::vector<t_field*>::iterator field_iter;
+ for (field_iter = fields.begin(); field_iter != fields.end(); ++field_iter) {
+ t_field* tfield = (*field_iter);
+
+ if (first_arg) {
+ first_arg = false;
+ } else {
+ args << ", ";
+ }
+
+ args << field_prefix << rust_field_name(tfield);
+ }
+
+ return args.str();
+}
+
+void t_rs_generator::render_result_value_struct(t_function* tfunc) {
+ string result_struct_name = service_call_result_struct_name(tfunc);
+ t_struct result(program_, result_struct_name);
+
+ t_field return_value(tfunc->get_returntype(), SERVICE_RESULT_VARIABLE, 0);
+ return_value.set_req(t_field::T_OPTIONAL);
+ if (!tfunc->get_returntype()->is_void()) {
+ result.append(&return_value);
+ }
+
+ t_struct* exceptions = tfunc->get_xceptions();
+ const vector<t_field*>& exception_types = exceptions->get_members();
+ vector<t_field*>::const_iterator exception_iter;
+ for(exception_iter = exception_types.begin(); exception_iter != exception_types.end(); ++exception_iter) {
+ t_field* exception_type = *exception_iter;
+ exception_type->set_req(t_field::T_OPTIONAL);
+ result.append(exception_type);
+ }
+
+ render_struct(result_struct_name, &result, t_rs_generator::T_RESULT);
+}
+
+//-----------------------------------------------------------------------------
+//
+// Sync Processor
+//
+//-----------------------------------------------------------------------------
+
+void t_rs_generator::render_sync_processor(t_service *tservice) {
+ render_type_comment(tservice->get_name() + " service processor"); // note: use *original* name
+ render_sync_handler_trait(tservice);
+ render_sync_processor_definition_and_impl(tservice);
+}
+
+void t_rs_generator::render_sync_handler_trait(t_service *tservice) {
+ string extension = "";
+ if (tservice->get_extends() != NULL) {
+ t_service* extends = tservice->get_extends();
+ extension = " : " + rust_namespace(extends) + rust_sync_handler_trait_name(extends);
+ }
+
+ const std::vector<t_function*> functions = tservice->get_functions();
+ std::vector<t_function*>::const_iterator func_iter;
+
+ render_rustdoc((t_doc*) tservice);
+ f_gen_ << "pub trait " << rust_sync_handler_trait_name(tservice) << extension << " {" << endl;
+ indent_up();
+ for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) {
+ t_function* tfunc = (*func_iter);
+ string func_name = service_call_handler_function_name(tfunc);
+ string func_args = rust_sync_service_call_declaration(tfunc);
+ string func_return = to_rust_type(tfunc->get_returntype());
+ render_rustdoc((t_doc*) tfunc);
+ f_gen_
+ << indent()
+ << "fn "
+ << func_name << func_args
+ << " -> thrift::Result<" << func_return << ">;"
+ << endl;
+ }
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_sync_processor_definition_and_impl(t_service *tservice) {
+ string service_processor_name = rust_sync_processor_name(tservice);
+ string handler_trait_name = rust_sync_handler_trait_name(tservice);
+
+ // struct
+ f_gen_
+ << indent()
+ << "pub struct " << service_processor_name
+ << "<H: " << handler_trait_name
+ << "> {"
+ << endl;
+ indent_up();
+ f_gen_ << indent() << "handler: H," << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << endl;
+
+ // delegating impl
+ f_gen_
+ << indent()
+ << "impl <H: " << handler_trait_name << "> "
+ << service_processor_name
+ << "<H> {"
+ << endl;
+ indent_up();
+ f_gen_ << indent() << "pub fn new(handler: H) -> " << service_processor_name << "<H> {" << endl;
+ indent_up();
+ f_gen_ << indent() << service_processor_name << " {" << endl;
+ indent_up();
+ f_gen_ << indent() << "handler: handler," << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ render_sync_process_delegation_functions(tservice);
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << endl;
+
+ // actual impl
+ string service_actual_processor_name = rust_sync_processor_impl_name(tservice);
+ f_gen_ << indent() << "pub struct " << service_actual_processor_name << ";" << endl;
+ f_gen_ << endl;
+ f_gen_ << indent() << "impl " << service_actual_processor_name << " {" << endl;
+ indent_up();
+
+ vector<t_function*> functions = tservice->get_functions();
+ vector<t_function*>::iterator func_iter;
+ for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) {
+ t_function* tfunc = (*func_iter);
+ render_sync_process_function(tfunc, handler_trait_name);
+ }
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << endl;
+
+ // processor impl
+ f_gen_
+ << indent()
+ << "impl <H: "
+ << handler_trait_name << "> TProcessor for "
+ << service_processor_name
+ << "<H> {"
+ << endl;
+ indent_up();
+
+ f_gen_
+ << indent()
+ << "fn process(&mut self, i_prot: &mut TInputProtocol, o_prot: &mut TOutputProtocol) -> thrift::Result<()> {"
+ << endl;
+ indent_up();
+ f_gen_ << indent() << "let message_ident = i_prot.read_message_begin()?;" << endl;
+ f_gen_ << indent() << "match &*message_ident.name {" << endl; // [sigh] explicit deref coercion
+ indent_up();
+ render_process_match_statements(tservice);
+ f_gen_ << indent() << "method => {" << endl;
+ indent_up();
+ render_rift_error(
+ "Application",
+ "ApplicationError",
+ "ApplicationErrorKind::UnknownMethod",
+ "format!(\"unknown method {}\", method)"
+ );
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ f_gen_ << endl;
+}
+
+void t_rs_generator::render_sync_process_delegation_functions(t_service *tservice) {
+ string actual_processor(rust_namespace(tservice) + rust_sync_processor_impl_name(tservice));
+
+ vector<t_function*> functions = tservice->get_functions();
+ vector<t_function*>::iterator func_iter;
+ for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) {
+ t_function* tfunc = (*func_iter);
+ string function_name("process_" + rust_snake_case(tfunc->get_name()));
+ f_gen_
+ << indent()
+ << "fn " << function_name
+ << "(&mut self, "
+ << "incoming_sequence_number: i32, "
+ << "i_prot: &mut TInputProtocol, "
+ << "o_prot: &mut TOutputProtocol) "
+ << "-> thrift::Result<()> {"
+ << endl;
+ indent_up();
+
+ f_gen_
+ << indent()
+ << actual_processor
+ << "::" << function_name
+ << "("
+ << "&mut self.handler, "
+ << "incoming_sequence_number, "
+ << "i_prot, "
+ << "o_prot"
+ << ")"
+ << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+ }
+
+ t_service* extends = tservice->get_extends();
+ if (extends) {
+ render_sync_process_delegation_functions(extends);
+ }
+}
+
+void t_rs_generator::render_process_match_statements(t_service* tservice) {
+ vector<t_function*> functions = tservice->get_functions();
+ vector<t_function*>::iterator func_iter;
+ for(func_iter = functions.begin(); func_iter != functions.end(); ++func_iter) {
+ t_function* tfunc = (*func_iter);
+ f_gen_ << indent() << "\"" << tfunc->get_name() << "\"" << " => {" << endl; // note: use *original* name
+ indent_up();
+ f_gen_
+ << indent()
+ << "self.process_" << rust_snake_case(tfunc->get_name())
+ << "(message_ident.sequence_number, i_prot, o_prot)"
+ << endl;
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+ }
+
+ t_service* extends = tservice->get_extends();
+ if (extends) {
+ render_process_match_statements(extends);
+ }
+}
+
+void t_rs_generator::render_sync_process_function(t_function *tfunc, const string &handler_type) {
+ string sequence_number_param("incoming_sequence_number");
+ string output_protocol_param("o_prot");
+
+ if (tfunc->is_oneway()) {
+ sequence_number_param = "_";
+ output_protocol_param = "_";
+ }
+
+ f_gen_
+ << indent()
+ << "pub fn process_" << rust_snake_case(tfunc->get_name())
+ << "<H: " << handler_type << ">"
+ << "(handler: &mut H, "
+ << sequence_number_param << ": i32, "
+ << "i_prot: &mut TInputProtocol, "
+ << output_protocol_param << ": &mut TOutputProtocol) "
+ << "-> thrift::Result<()> {"
+ << endl;
+
+ indent_up();
+
+ // *always* read arguments from the input protocol
+ f_gen_
+ << indent()
+ << "let "
+ << (has_non_void_args(tfunc) ? "args" : "_")
+ << " = "
+ << rust_struct_name(tfunc->get_arglist())
+ << "::read_from_in_protocol(i_prot)?;"
+ << endl;
+
+ f_gen_
+ << indent()
+ << "match handler."
+ << service_call_handler_function_name(tfunc)
+ << rust_sync_service_call_invocation(tfunc, "args.")
+ << " {"
+ << endl; // start match
+ indent_up();
+
+ // handler succeeded
+ string handler_return_variable = tfunc->is_oneway() || tfunc->get_returntype()->is_void() ? "_" : "handler_return";
+ f_gen_ << indent() << "Ok(" << handler_return_variable << ") => {" << endl;
+ indent_up();
+ render_sync_handler_succeeded(tfunc);
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+ // handler failed
+ f_gen_ << indent() << "Err(e) => {" << endl;
+ indent_up();
+ render_sync_handler_failed(tfunc);
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl; // end match
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl; // end function
+}
+
+void t_rs_generator::render_sync_handler_succeeded(t_function *tfunc) {
+ if (tfunc->is_oneway()) {
+ f_gen_ << indent() << "Ok(())" << endl;
+ } else {
+ f_gen_
+ << indent()
+ << "let message_ident = TMessageIdentifier::new("
+ << "\"" << tfunc->get_name() << "\", " // note: use *original* name
+ << "TMessageType::Reply, "
+ << "incoming_sequence_number);"
+ << endl;
+ f_gen_ << indent() << "o_prot.write_message_begin(&message_ident)?;" << endl;
+ f_gen_ << indent() << "let ret = " << handler_successful_return_struct(tfunc) <<";" << endl;
+ f_gen_ << indent() << "ret.write_to_out_protocol(o_prot)?;" << endl;
+ f_gen_ << indent() << "o_prot.write_message_end()?;" << endl;
+ f_gen_ << indent() << "o_prot.flush()" << endl;
+ }
+}
+
+void t_rs_generator::render_sync_handler_failed(t_function *tfunc) {
+ string err_var("e");
+
+ f_gen_ << indent() << "match " << err_var << " {" << endl;
+ indent_up();
+
+ // if there are any user-defined exceptions for this service call handle them first
+ if (tfunc->get_xceptions() != NULL && tfunc->get_xceptions()->get_sorted_members().size() > 0) {
+ string user_err_var("usr_err");
+ f_gen_ << indent() << "thrift::Error::User(" << user_err_var << ") => {" << endl;
+ indent_up();
+ render_sync_handler_failed_user_exception_branch(tfunc);
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+ }
+
+ // application error
+ string app_err_var("app_err");
+ f_gen_ << indent() << "thrift::Error::Application(" << app_err_var << ") => {" << endl;
+ indent_up();
+ render_sync_handler_failed_application_exception_branch(tfunc, app_err_var);
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+
+ // default case
+ f_gen_ << indent() << "_ => {" << endl;
+ indent_up();
+ render_sync_handler_failed_default_exception_branch(tfunc);
+ indent_down();
+ f_gen_ << indent() << "}," << endl;
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_sync_handler_failed_user_exception_branch(t_function *tfunc) {
+ if (tfunc->get_xceptions() == NULL || tfunc->get_xceptions()->get_sorted_members().empty()) {
+ throw "cannot render user exception branches if no user exceptions defined";
+ }
+
+ const vector<t_field*> txceptions = tfunc->get_xceptions()->get_sorted_members();
+ vector<t_field*>::const_iterator xception_iter;
+ int branches_rendered = 0;
+
+ // run through all user-defined exceptions
+ for (xception_iter = txceptions.begin(); xception_iter != txceptions.end(); ++xception_iter) {
+ t_field* xception_field = (*xception_iter);
+
+ string if_statement(branches_rendered == 0 ? "if usr_err" : "} else if usr_err");
+ string exception_type(to_rust_type(xception_field->get_type()));
+ f_gen_ << indent() << if_statement << ".downcast_ref::<" << exception_type << ">().is_some() {" << endl;
+ indent_up();
+
+ f_gen_
+ << indent()
+ << "let err = usr_err.downcast::<" << exception_type << ">().expect(\"downcast already checked\");"
+ << endl;
+
+ // render the members of the return struct
+ ostringstream members;
+
+ bool has_result_variable = !(tfunc->is_oneway() || tfunc->get_returntype()->is_void());
+ if (has_result_variable) {
+ members << SERVICE_RESULT_VARIABLE << ": None, ";
+ }
+
+ vector<t_field*>::const_iterator xception_members_iter;
+ for(xception_members_iter = txceptions.begin(); xception_members_iter != txceptions.end(); ++xception_members_iter) {
+ t_field* member = (*xception_members_iter);
+ string member_name(rust_field_name(member));
+ if (member == xception_field) {
+ members << member_name << ": Some(*err), ";
+ } else {
+ members << member_name << ": None, ";
+ }
+ }
+
+ string member_string = members.str();
+ member_string.replace(member_string.size() - 2, 2, " "); // trim trailing comma
+
+ // now write out the return struct
+ f_gen_
+ << indent()
+ << "let ret_err = "
+ << service_call_result_struct_name(tfunc)
+ << "{ " << member_string << "};"
+ << endl;
+
+ f_gen_
+ << indent()
+ << "let message_ident = "
+ << "TMessageIdentifier::new("
+ << "\"" << tfunc->get_name() << "\", " // note: use *original* name
+ << "TMessageType::Reply, "
+ << "incoming_sequence_number);"
+ << endl;
+ f_gen_ << indent() << "o_prot.write_message_begin(&message_ident)?;" << endl;
+ f_gen_ << indent() << "ret_err.write_to_out_protocol(o_prot)?;" << endl;
+ f_gen_ << indent() << "o_prot.write_message_end()?;" << endl;
+ f_gen_ << indent() << "o_prot.flush()" << endl;
+
+ indent_down();
+
+ branches_rendered++;
+ }
+
+ // the catch all, if somehow it was a user exception that we don't support
+ f_gen_ << indent() << "} else {" << endl;
+ indent_up();
+
+ // FIXME: same as default block below
+
+ f_gen_ << indent() << "let ret_err = {" << endl;
+ indent_up();
+ render_rift_error_struct("ApplicationError", "ApplicationErrorKind::Unknown", "usr_err.description()");
+ indent_down();
+ f_gen_ << indent() << "};" << endl;
+ render_sync_handler_send_exception_response(tfunc, "ret_err");
+
+ indent_down();
+ f_gen_ << indent() << "}" << endl;
+}
+
+void t_rs_generator::render_sync_handler_failed_application_exception_branch(
+ t_function *tfunc,
+ const string &app_err_var
+) {
+ if (tfunc->is_oneway()) {
+ f_gen_ << indent() << "Err(thrift::Error::Application(" << app_err_var << "))" << endl;
+ } else {
+ render_sync_handler_send_exception_response(tfunc, app_err_var);
+ }
+}
+
+void t_rs_generator::render_sync_handler_failed_default_exception_branch(t_function *tfunc) {
+ f_gen_ << indent() << "let ret_err = {" << endl;
+ indent_up();
+ render_rift_error_struct("ApplicationError", "ApplicationErrorKind::Unknown", "e.description()");
+ indent_down();
+ f_gen_ << indent() << "};" << endl;
+ if (tfunc->is_oneway()) {
+ f_gen_ << indent() << "Err(thrift::Error::Application(ret_err))" << endl;
+ } else {
+ render_sync_handler_send_exception_response(tfunc, "ret_err");
+ }
+}
+
+void t_rs_generator::render_sync_handler_send_exception_response(t_function *tfunc, const string &err_var) {
+ f_gen_
+ << indent()
+ << "let message_ident = TMessageIdentifier::new("
+ << "\"" << tfunc->get_name() << "\", " // note: use *original* name
+ << "TMessageType::Exception, "
+ << "incoming_sequence_number);"
+ << endl;
+ f_gen_ << indent() << "o_prot.write_message_begin(&message_ident)?;" << endl;
+ f_gen_ << indent() << "thrift::Error::write_application_error_to_out_protocol(&" << err_var << ", o_prot)?;" << endl;
+ f_gen_ << indent() << "o_prot.write_message_end()?;" << endl;
+ f_gen_ << indent() << "o_prot.flush()" << endl;
+}
+
+string t_rs_generator::handler_successful_return_struct(t_function* tfunc) {
+ int member_count = 0;
+ ostringstream return_struct;
+
+ return_struct << service_call_result_struct_name(tfunc) << " { ";
+
+ // actual return
+ if (!tfunc->get_returntype()->is_void()) {
+ return_struct << "result_value: Some(handler_return)";
+ member_count++;
+ }
+
+ // any user-defined exceptions
+ if (tfunc->get_xceptions() != NULL) {
+ t_struct* txceptions = tfunc->get_xceptions();
+ const vector<t_field*> members = txceptions->get_sorted_members();
+ vector<t_field*>::const_iterator members_iter;
+ for (members_iter = members.begin(); members_iter != members.end(); ++members_iter) {
+ t_field* xception_field = (*members_iter);
+ if (member_count > 0) { return_struct << ", "; }
+ return_struct << rust_field_name(xception_field) << ": None";
+ member_count++;
+ }
+ }
+
+ return_struct << " }";
+
+ return return_struct.str();
+}
+
+//-----------------------------------------------------------------------------
+//
+// Utility
+//
+//-----------------------------------------------------------------------------
+
+void t_rs_generator::render_type_comment(const string& type_name) {
+ f_gen_ << "//" << endl;
+ f_gen_ << "// " << type_name << endl;
+ f_gen_ << "//" << endl;
+ f_gen_ << endl;
+}
+
+// NOTE: do *not* put in an extra newline after doc is generated.
+// This is because rust docs have to abut the line they're documenting.
+void t_rs_generator::render_rustdoc(t_doc* tdoc) {
+ if (!tdoc->has_doc()) {
+ return;
+ }
+
+ generate_docstring_comment(f_gen_, "", "///", tdoc->get_doc(), "");
+}
+
+void t_rs_generator::render_rift_error(
+ const string& error_kind,
+ const string& error_struct,
+ const string& sub_error_kind,
+ const string& error_message
+) {
+ f_gen_ << indent() << "Err(" << endl;
+ indent_up();
+ f_gen_ << indent() << "thrift::Error::" << error_kind << "(" << endl;
+ indent_up();
+ render_rift_error_struct(error_struct, sub_error_kind, error_message);
+ indent_down();
+ f_gen_ << indent() << ")" << endl;
+ indent_down();
+ f_gen_ << indent() << ")" << endl;
+}
+
+void t_rs_generator::render_rift_error_struct(
+ const string& error_struct,
+ const string& sub_error_kind,
+ const string& error_message
+) {
+ f_gen_ << indent() << error_struct << "::new(" << endl;
+ indent_up();
+ f_gen_ << indent() << sub_error_kind << "," << endl;
+ f_gen_ << indent() << error_message << endl;
+ indent_down();
+ f_gen_ << indent() << ")" << endl;
+}
+
+string t_rs_generator::to_rust_type(t_type* ttype, bool ordered_float) {
+ // ttype = get_true_type(ttype); <-- recurses through as many typedef layers as necessary
+ if (ttype->is_base_type()) {
+ t_base_type* tbase_type = ((t_base_type*)ttype);
+ switch (tbase_type->get_base()) {
+ case t_base_type::TYPE_VOID:
+ return "()";
+ case t_base_type::TYPE_STRING:
+ if (tbase_type->is_binary()) {
+ return "Vec<u8>";
+ } else {
+ return "String";
+ }
+ case t_base_type::TYPE_BOOL:
+ return "bool";
+ case t_base_type::TYPE_I8:
+ return "i8";
+ case t_base_type::TYPE_I16:
+ return "i16";
+ case t_base_type::TYPE_I32:
+ return "i32";
+ case t_base_type::TYPE_I64:
+ return "i64";
+ case t_base_type::TYPE_DOUBLE:
+ if (ordered_float) {
+ return "OrderedFloat<f64>";
+ } else {
+ return "f64";
+ }
+ }
+ } else if (ttype->is_typedef()) {
+ return rust_namespace(ttype) + ((t_typedef*)ttype)->get_symbolic();
+ } else if (ttype->is_enum()) {
+ return rust_namespace(ttype) + ttype->get_name();
+ } else if (ttype->is_struct() || ttype->is_xception()) {
+ return rust_namespace(ttype) + rust_camel_case(ttype->get_name());
+ } else if (ttype->is_map()) {
+ t_map* tmap = (t_map*)ttype;
+ return "BTreeMap<" + to_rust_type(tmap->get_key_type()) + ", " + to_rust_type(tmap->get_val_type()) + ">";
+ } else if (ttype->is_set()) {
+ t_set* tset = (t_set*)ttype;
+ return "BTreeSet<" + to_rust_type(tset->get_elem_type()) + ">";
+ } else if (ttype->is_list()) {
+ t_list* tlist = (t_list*)ttype;
+ return "Vec<" + to_rust_type(tlist->get_elem_type()) + ">";
+ }
+
+ throw "cannot find rust type for " + ttype->get_name();
+}
+
+string t_rs_generator::to_rust_field_type_enum(t_type* ttype) {
+ ttype = get_true_type(ttype);
+ if (ttype->is_base_type()) {
+ t_base_type::t_base tbase = ((t_base_type*)ttype)->get_base();
+ switch (tbase) {
+ case t_base_type::TYPE_VOID:
+ throw "will not generate protocol::TType for TYPE_VOID";
+ case t_base_type::TYPE_STRING: // both strings and binary are actually encoded as TType::String
+ return "TType::String";
+ case t_base_type::TYPE_BOOL:
+ return "TType::Bool";
+ case t_base_type::TYPE_I8:
+ return "TType::I08";
+ case t_base_type::TYPE_I16:
+ return "TType::I16";
+ case t_base_type::TYPE_I32:
+ return "TType::I32";
+ case t_base_type::TYPE_I64:
+ return "TType::I64";
+ case t_base_type::TYPE_DOUBLE:
+ return "TType::Double";
+ }
+ } else if (ttype->is_enum()) {
+ return "TType::I32";
+ } else if (ttype->is_struct() || ttype->is_xception()) {
+ return "TType::Struct";
+ } else if (ttype->is_map()) {
+ return "TType::Map";
+ } else if (ttype->is_set()) {
+ return "TType::Set";
+ } else if (ttype->is_list()) {
+ return "TType::List";
+ }
+
+ throw "cannot find TType for " + ttype->get_name();
+}
+
+string t_rs_generator::opt_in_req_out_value(t_type* ttype) {
+ ttype = get_true_type(ttype);
+ if (ttype->is_base_type()) {
+ t_base_type* tbase_type = ((t_base_type*)ttype);
+ switch (tbase_type->get_base()) {
+ case t_base_type::TYPE_VOID:
+ throw "cannot generate OPT_IN_REQ_OUT value for void";
+ case t_base_type::TYPE_STRING:
+ if (tbase_type->is_binary()) {
+ return "Some(Vec::new())";
+ } else {
+ return "Some(\"\".to_owned())";
+ }
+ case t_base_type::TYPE_BOOL:
+ return "Some(false)";
+ case t_base_type::TYPE_I8:
+ case t_base_type::TYPE_I16:
+ case t_base_type::TYPE_I32:
+ case t_base_type::TYPE_I64:
+ return "Some(0)";
+ case t_base_type::TYPE_DOUBLE:
+ return "Some(OrderedFloat::from(0.0))";
+ }
+
+ } else if (ttype->is_enum() || ttype->is_struct() || ttype->is_xception()) {
+ return "None";
+ } else if (ttype->is_list()) {
+ return "Some(Vec::new())";
+ } else if (ttype->is_set()) {
+ return "Some(BTreeSet::new())";
+ } else if (ttype->is_map()) {
+ return "Some(BTreeMap::new())";
+ }
+
+ throw "cannot generate opt-in-req-out value for type " + ttype->get_name();
+}
+
+bool t_rs_generator::can_generate_simple_const(t_type* ttype) {
+ t_type* actual_type = get_true_type(ttype);
+ if (actual_type->is_base_type()) {
+ t_base_type* tbase_type = (t_base_type*)actual_type;
+ return !(tbase_type->get_base() == t_base_type::TYPE_DOUBLE);
+ } else {
+ return false;
+ }
+}
+
+bool t_rs_generator::can_generate_const_holder(t_type* ttype) {
+ t_type* actual_type = get_true_type(ttype);
+ return !can_generate_simple_const(actual_type) && !actual_type->is_service();
+}
+
+bool t_rs_generator::is_void(t_type* ttype) {
+ return ttype->is_base_type() && ((t_base_type*)ttype)->get_base() == t_base_type::TYPE_VOID;
+}
+
+bool t_rs_generator::is_optional(t_field::e_req req) {
+ return req == t_field::T_OPTIONAL || req == t_field::T_OPT_IN_REQ_OUT;
+}
+
+t_field::e_req t_rs_generator::actual_field_req(t_field* tfield, t_rs_generator::e_struct_type struct_type) {
+ return struct_type == t_rs_generator::T_ARGS ? t_field::T_REQUIRED : tfield->get_req();
+}
+
+bool t_rs_generator::has_args(t_function* tfunc) {
+ return tfunc->get_arglist() != NULL && !tfunc->get_arglist()->get_sorted_members().empty();
+}
+
+bool t_rs_generator::has_non_void_args(t_function* tfunc) {
+ bool has_non_void_args = false;
+
+ const vector<t_field*> args = tfunc->get_arglist()->get_sorted_members();
+ vector<t_field*>::const_iterator args_iter;
+ for (args_iter = args.begin(); args_iter != args.end(); ++args_iter) {
+ t_field* tfield = (*args_iter);
+ if (!tfield->get_type()->is_void()) {
+ has_non_void_args = true;
+ break;
+ }
+ }
+
+ return has_non_void_args;
+}
+
+string t_rs_generator::visibility_qualifier(t_rs_generator::e_struct_type struct_type) {
+ switch(struct_type) {
+ case t_rs_generator::T_ARGS:
+ case t_rs_generator::T_RESULT:
+ return "";
+ default:
+ return "pub ";
+ }
+}
+
+string t_rs_generator::rust_namespace(t_service* tservice) {
+ if (tservice->get_program()->get_name() != get_program()->get_name()) {
+ return rust_snake_case(tservice->get_program()->get_name()) + "::";
+ } else {
+ return "";
+ }
+}
+
+string t_rs_generator::rust_namespace(t_type* ttype) {
+ if (ttype->get_program()->get_name() != get_program()->get_name()) {
+ return rust_snake_case(ttype->get_program()->get_name()) + "::";
+ } else {
+ return "";
+ }
+}
+
+bool t_rs_generator::is_reserved(const string& name) {
+ return RUST_RESERVED_WORDS_SET.find(name) != RUST_RESERVED_WORDS_SET.end();
+}
+
+string t_rs_generator::rust_struct_name(t_struct* tstruct) {
+ string base_struct_name(rust_camel_case(tstruct->get_name()));
+ return rust_safe_name(base_struct_name);
+}
+
+string t_rs_generator::rust_field_name(t_field* tfield) {
+ string base_field_name(rust_snake_case(tfield->get_name()));
+ return rust_safe_name(base_field_name);
+}
+
+string t_rs_generator::rust_union_field_name(t_field* tfield) {
+ string base_field_name(rust_camel_case(tfield->get_name()));
+ return rust_safe_name(base_field_name);
+}
+
+string t_rs_generator::rust_safe_name(const string& name) {
+ if (is_reserved(name)) {
+ return name + "_";
+ } else {
+ return name;
+ }
+}
+
+string t_rs_generator::service_call_client_function_name(t_function* tfunc) {
+ return rust_snake_case(tfunc->get_name());
+}
+
+string t_rs_generator::service_call_handler_function_name(t_function* tfunc) {
+ return "handle_" + rust_snake_case(tfunc->get_name());
+}
+
+string t_rs_generator::service_call_result_struct_name(t_function* tfunc) {
+ return rust_camel_case(tfunc->get_name()) + RESULT_STRUCT_SUFFIX;
+}
+
+string t_rs_generator::rust_sync_client_marker_trait_name(t_service* tservice) {
+ return "T" + rust_camel_case(tservice->get_name()) + "SyncClientMarker";
+}
+
+string t_rs_generator::rust_sync_client_trait_name(t_service* tservice) {
+ return "T" + rust_camel_case(tservice->get_name()) + "SyncClient";
+}
+
+string t_rs_generator::rust_sync_client_impl_name(t_service* tservice) {
+ return rust_camel_case(tservice->get_name()) + "SyncClient";
+}
+
+string t_rs_generator::rust_sync_handler_trait_name(t_service* tservice) {
+ return rust_camel_case(tservice->get_name()) + "SyncHandler";
+}
+
+string t_rs_generator::rust_sync_processor_name(t_service* tservice) {
+ return rust_camel_case(tservice->get_name()) + "SyncProcessor";
+}
+
+string t_rs_generator::rust_sync_processor_impl_name(t_service *tservice) {
+ return "T" + rust_camel_case(tservice->get_name()) + "ProcessFunctions";
+}
+
+string t_rs_generator::rust_upper_case(const string& name) {
+ string str(uppercase(underscore(name)));
+ string_replace(str, "__", "_");
+ return str;
+}
+
+string t_rs_generator::rust_snake_case(const string& name) {
+ string str(decapitalize(underscore(name)));
+ string_replace(str, "__", "_");
+ return str;
+}
+
+string t_rs_generator::rust_camel_case(const string& name) {
+ string str(capitalize(camelcase(name)));
+ string_replace(str, "_", "");
+ return str;
+}
+
+void t_rs_generator::string_replace(string& target, const string& search_string, const string& replace_string) {
+ if (target.empty()) {
+ return;
+ }
+
+ size_t match_len = search_string.length();
+ size_t replace_len = replace_string.length();
+
+ size_t search_idx = 0;
+ size_t match_idx;
+ while ((match_idx = target.find(search_string, search_idx)) != string::npos) {
+ target.replace(match_idx, match_len, replace_string);
+ search_idx = match_idx + replace_len;
+ }
+}
+
+THRIFT_REGISTER_GENERATOR(
+ rs,
+ "Rust",
+ "\n") // no Rust-generator-specific options
diff --git a/configure.ac b/configure.ac
index dad10a7..0452a15 100755
--- a/configure.ac
+++ b/configure.ac
@@ -129,6 +129,7 @@
with_d="no"
with_nodejs="no"
with_lua="no"
+ with_rs="no"
fi
@@ -410,6 +411,29 @@
fi
AM_CONDITIONAL(WITH_GO, [test "$have_go" = "yes"])
+AX_THRIFT_LIB(rs, [Rust], yes)
+have_rs="no"
+if test "$with_rs" = "yes"; then
+ AC_PATH_PROG([CARGO], [cargo])
+ AC_PATH_PROG([RUSTC], [rustc])
+ if [[ -x "$CARGO" ]] && [[ -x "$RUSTC" ]]; then
+ min_rustc_version="1.13"
+
+ AC_MSG_CHECKING([for rustc version])
+ rustc_version=`$RUSTC --version 2>&1 | $SED -e 's/\(rustc \)\([0-9]\)\.\([0-9][0-9]*\)\.\([0-9][0-9]*\).*/\2.\3/'`
+ AC_MSG_RESULT($rustc_version)
+ AC_SUBST([rustc_version],[$rustc_version])
+
+ AX_COMPARE_VERSION([$min_rustc_version],[le],[$rustc_version],[
+ :
+ have_rs="yes"
+ ],[
+ :
+ have_rs="no"
+ ])
+ fi
+fi
+AM_CONDITIONAL(WITH_RS, [test "$have_rs" = "yes"])
AX_THRIFT_LIB(haxe, [Haxe], yes)
if test "$with_haxe" = "yes"; then
@@ -777,6 +801,8 @@
lib/dart/Makefile
lib/py/Makefile
lib/rb/Makefile
+ lib/rs/Makefile
+ lib/rs/test/Makefile
lib/lua/Makefile
lib/xml/Makefile
lib/xml/test/Makefile
@@ -798,6 +824,7 @@
test/py.twisted/Makefile
test/py.tornado/Makefile
test/rb/Makefile
+ test/rs/Makefile
tutorial/Makefile
tutorial/c_glib/Makefile
tutorial/cpp/Makefile
@@ -814,6 +841,7 @@
tutorial/py.twisted/Makefile
tutorial/py.tornado/Makefile
tutorial/rb/Makefile
+ tutorial/rs/Makefile
])
if test "$have_cpp" = "yes" ; then MAYBE_CPP="cpp" ; else MAYBE_CPP="" ; fi
@@ -848,6 +876,8 @@
AC_SUBST([MAYBE_ERLANG])
if test "$have_lua" = "yes" ; then MAYBE_LUA="lua" ; else MAYBE_LUA="" ; fi
AC_SUBST([MAYBE_LUA])
+if test "$have_rs" = "yes" ; then MAYBE_RS="rs" ; else MAYBE_RS="" ; fi
+AC_SUBST([MAYBE_RS])
AC_OUTPUT
@@ -873,6 +903,7 @@
echo "Building D Library ........... : $have_d"
echo "Building NodeJS Library ...... : $have_nodejs"
echo "Building Lua Library ......... : $have_lua"
+echo "Building Rust Library ........ : $have_rs"
if test "$have_cpp" = "yes" ; then
echo
@@ -974,6 +1005,13 @@
echo "Lua Library:"
echo " Using Lua .............. : $LUA"
fi
+if test "$have_rs" = "yes" ; then
+ echo
+ echo "Rust Library:"
+ echo " Using Cargo................ : $CARGO"
+ echo " Using rustc................ : $RUSTC"
+ echo " Using Rust version......... : $($RUSTC --version)"
+fi
echo
echo "If something is missing that you think should be present,"
echo "please skim the output of configure to find the missing"
diff --git a/lib/Makefile.am b/lib/Makefile.am
index 21d807a..636f42c 100644
--- a/lib/Makefile.am
+++ b/lib/Makefile.am
@@ -93,6 +93,10 @@
SUBDIRS += lua
endif
+if WITH_RS
+SUBDIRS += rs
+endif
+
# All of the libs that don't use Automake need to go in here
# so they will end up in our release tarballs.
EXTRA_DIST = \
diff --git a/lib/rs/Cargo.toml b/lib/rs/Cargo.toml
new file mode 100644
index 0000000..07c5e67
--- /dev/null
+++ b/lib/rs/Cargo.toml
@@ -0,0 +1,18 @@
+[package]
+name = "thrift"
+description = "Rust bindings for the Apache Thrift RPC system"
+version = "1.0.0"
+license = "Apache-2.0"
+authors = ["Apache Thrift Developers <dev@thrift.apache.org>"]
+homepage = "http://thrift.apache.org"
+documentation = "https://thrift.apache.org"
+readme = "README.md"
+exclude = ["Makefile*", "test/**"]
+keywords = ["thrift"]
+
+[dependencies]
+integer-encoding = "1.0.3"
+log = "~0.3.6"
+byteorder = "0.5.3"
+try_from = "0.2.0"
+
diff --git a/lib/rs/Makefile.am b/lib/rs/Makefile.am
new file mode 100644
index 0000000..0a34120
--- /dev/null
+++ b/lib/rs/Makefile.am
@@ -0,0 +1,46 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+SUBDIRS = .
+
+if WITH_TESTS
+SUBDIRS += test
+endif
+
+install:
+ @echo '##############################################################'
+ @echo '##############################################################'
+ @echo 'The Rust client library should be installed via a Cargo.toml dependency - please see /lib/rs/README.md'
+ @echo '##############################################################'
+ @echo '##############################################################'
+
+check-local:
+ $(CARGO) test
+
+all-local:
+ $(CARGO) build
+
+clean-local:
+ $(CARGO) clean
+ -$(RM) Cargo.lock
+
+EXTRA_DIST = \
+ src \
+ Cargo.toml \
+ README.md
diff --git a/lib/rs/README.md b/lib/rs/README.md
new file mode 100644
index 0000000..8b35eda
--- /dev/null
+++ b/lib/rs/README.md
@@ -0,0 +1,60 @@
+# Rust Thrift library
+
+## Overview
+
+This crate implements the components required to build a working Thrift server
+and client. It is divided into the following modules:
+
+ 1. errors
+ 2. protocol
+ 3. transport
+ 4. server
+ 5. autogen
+
+The modules are layered as shown. The `generated` layer is code generated by the
+Thrift compiler's Rust plugin. It uses the components defined in this crate to
+serialize and deserialize types and implement RPC. Users interact with these
+types and services by writing their own code on top.
+
+ ```text
+ +-----------+
+ | app dev |
+ +-----------+
+ | generated | <-> errors/results
+ +-----------+
+ | protocol |
+ +-----------+
+ | transport |
+ +-----------+
+ ```
+
+## Using this crate
+
+Add `thrift = "x.y.z"` to your `Cargo.toml`, where `x.y.z` is the version of the
+Thrift compiler you're using.
+
+## API Documentation
+
+Full [Rustdoc](https://docs.rs/thrift/)
+
+## Contributing
+
+Bug reports and PRs are always welcome! Please see the
+[Thrift website](https://thrift.apache.org/) for more details.
+
+Thrift Rust support requires code in several directories:
+
+* `compiler/cpp/src/thrift/generate/t_rs_generator.cc`: binding code generator
+* `lib/rs`: runtime library
+* `lib/rs/test`: supplemental tests
+* `tutorial/rs`: tutorial client and server
+* `test/rs`: cross-language test client and server
+
+All library code, test code and auto-generated code compiles and passes clippy
+without warnings. All new code must do the same! When making changes ensure that:
+
+* `rustc` does does output any warnings
+* `clippy` with default settings does not output any warnings (includes auto-generated code)
+* `cargo test` is successful
+* `make precross` and `make check` are successful
+* `tutorial/bin/tutorial_client` and `tutorial/bin/tutorial_server` communicate
diff --git a/lib/rs/src/autogen.rs b/lib/rs/src/autogen.rs
new file mode 100644
index 0000000..289c7be
--- /dev/null
+++ b/lib/rs/src/autogen.rs
@@ -0,0 +1,45 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Thrift compiler auto-generated support.
+//!
+//!
+//! Types and functions used internally by the Thrift compiler's Rust plugin
+//! to implement required functionality. Users should never have to use code
+//! in this module directly.
+
+use ::protocol::{TInputProtocol, TOutputProtocol};
+
+/// Specifies the minimum functionality an auto-generated client should provide
+/// to communicate with a Thrift server.
+pub trait TThriftClient {
+ /// Returns the input protocol used to read serialized Thrift messages
+ /// from the Thrift server.
+ fn i_prot_mut(&mut self) -> &mut TInputProtocol;
+ /// Returns the output protocol used to write serialized Thrift messages
+ /// to the Thrift server.
+ fn o_prot_mut(&mut self) -> &mut TOutputProtocol;
+ /// Returns the sequence number of the last message written to the Thrift
+ /// server. Returns `0` if no messages have been written. Sequence
+ /// numbers should *never* be negative, and this method returns an `i32`
+ /// simply because the Thrift protocol encodes sequence numbers as `i32` on
+ /// the wire.
+ fn sequence_number(&self) -> i32; // FIXME: consider returning a u32
+ /// Increments the sequence number, indicating that a message with that
+ /// number has been sent to the Thrift server.
+ fn increment_sequence_number(&mut self) -> i32;
+}
diff --git a/lib/rs/src/errors.rs b/lib/rs/src/errors.rs
new file mode 100644
index 0000000..a6049d5
--- /dev/null
+++ b/lib/rs/src/errors.rs
@@ -0,0 +1,678 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::convert::{From, Into};
+use std::error::Error as StdError;
+use std::fmt::{Debug, Display, Formatter};
+use std::{error, fmt, io, string};
+use try_from::TryFrom;
+
+use ::protocol::{TFieldIdentifier, TInputProtocol, TOutputProtocol, TStructIdentifier, TType};
+
+// FIXME: should all my error structs impl error::Error as well?
+// FIXME: should all fields in TransportError, ProtocolError and ApplicationError be optional?
+
+/// Error type returned by all runtime library functions.
+///
+/// `thrift::Error` is used throughout this crate as well as in auto-generated
+/// Rust code. It consists of four variants defined by convention across Thrift
+/// implementations:
+///
+/// 1. `Transport`: errors encountered while operating on I/O channels
+/// 2. `Protocol`: errors encountered during runtime-library processing
+/// 3. `Application`: errors encountered within auto-generated code
+/// 4. `User`: IDL-defined exception structs
+///
+/// The `Application` variant also functions as a catch-all: all handler errors
+/// are automatically turned into application errors.
+///
+/// All error variants except `Error::User` take an eponymous struct with two
+/// required fields:
+///
+/// 1. `kind`: variant-specific enum identifying the error sub-type
+/// 2. `message`: human-readable error info string
+///
+/// `kind` is defined by convention while `message` is freeform. If none of the
+/// enumerated kinds are suitable use `Unknown`.
+///
+/// To simplify error creation convenience constructors are defined for all
+/// variants, and conversions from their structs (`thrift::TransportError`,
+/// `thrift::ProtocolError` and `thrift::ApplicationError` into `thrift::Error`.
+///
+/// # Examples
+///
+/// Create a `TransportError`.
+///
+/// ```
+/// use thrift;
+/// use thrift::{TransportError, TransportErrorKind};
+///
+/// // explicit
+/// let err0: thrift::Result<()> = Err(
+/// thrift::Error::Transport(
+/// TransportError {
+/// kind: TransportErrorKind::TimedOut,
+/// message: format!("connection to server timed out")
+/// }
+/// )
+/// );
+///
+/// // use conversion
+/// let err1: thrift::Result<()> = Err(
+/// thrift::Error::from(
+/// TransportError {
+/// kind: TransportErrorKind::TimedOut,
+/// message: format!("connection to server timed out")
+/// }
+/// )
+/// );
+///
+/// // use struct constructor
+/// let err2: thrift::Result<()> = Err(
+/// thrift::Error::Transport(
+/// TransportError::new(
+/// TransportErrorKind::TimedOut,
+/// "connection to server timed out"
+/// )
+/// )
+/// );
+///
+///
+/// // use error variant constructor
+/// let err3: thrift::Result<()> = Err(
+/// thrift::new_transport_error(
+/// TransportErrorKind::TimedOut,
+/// "connection to server timed out"
+/// )
+/// );
+/// ```
+///
+/// Create an error from a string.
+///
+/// ```
+/// use thrift;
+/// use thrift::{ApplicationError, ApplicationErrorKind};
+///
+/// // we just use `From::from` to convert a `String` into a `thrift::Error`
+/// let err0: thrift::Result<()> = Err(
+/// thrift::Error::from("This is an error")
+/// );
+///
+/// // err0 is equivalent to...
+/// let err1: thrift::Result<()> = Err(
+/// thrift::Error::Application(
+/// ApplicationError {
+/// kind: ApplicationErrorKind::Unknown,
+/// message: format!("This is an error")
+/// }
+/// )
+/// );
+/// ```
+///
+/// Return an IDL-defined exception.
+///
+/// ```text
+/// // Thrift IDL exception definition.
+/// exception Xception {
+/// 1: i32 errorCode,
+/// 2: string message
+/// }
+/// ```
+///
+/// ```
+/// use std::convert::From;
+/// use std::error::Error;
+/// use std::fmt;
+/// use std::fmt::{Display, Formatter};
+///
+/// // auto-generated by the Thrift compiler
+/// #[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
+/// pub struct Xception {
+/// pub error_code: Option<i32>,
+/// pub message: Option<String>,
+/// }
+///
+/// // auto-generated by the Thrift compiler
+/// impl Error for Xception {
+/// fn description(&self) -> &str {
+/// "remote service threw Xception"
+/// }
+/// }
+///
+/// // auto-generated by the Thrift compiler
+/// impl From<Xception> for thrift::Error {
+/// fn from(e: Xception) -> Self {
+/// thrift::Error::User(Box::new(e))
+/// }
+/// }
+///
+/// // auto-generated by the Thrift compiler
+/// impl Display for Xception {
+/// fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+/// self.description().fmt(f)
+/// }
+/// }
+///
+/// // in user code...
+/// let err: thrift::Result<()> = Err(
+/// thrift::Error::from(Xception { error_code: Some(1), message: None })
+/// );
+/// ```
+pub enum Error {
+ /// Errors encountered while operating on I/O channels.
+ ///
+ /// These include *connection closed* and *bind failure*.
+ Transport(TransportError),
+ /// Errors encountered during runtime-library processing.
+ ///
+ /// These include *message too large* and *unsupported protocol version*.
+ Protocol(ProtocolError),
+ /// Errors encountered within auto-generated code, or when incoming
+ /// or outgoing messages violate the Thrift spec.
+ ///
+ /// These include *out-of-order messages* and *missing required struct
+ /// fields*.
+ ///
+ /// This variant also functions as a catch-all: errors from handler
+ /// functions are automatically returned as an `ApplicationError`.
+ Application(ApplicationError),
+ /// IDL-defined exception structs.
+ User(Box<error::Error + Sync + Send>),
+}
+
+impl Error {
+ /// Create an `ApplicationError` from its wire representation.
+ ///
+ /// Application code **should never** call this method directly.
+ pub fn read_application_error_from_in_protocol(i: &mut TInputProtocol)
+ -> ::Result<ApplicationError> {
+ let mut message = "general remote error".to_owned();
+ let mut kind = ApplicationErrorKind::Unknown;
+
+ i.read_struct_begin()?;
+
+ loop {
+ let field_ident = i.read_field_begin()?;
+
+ if field_ident.field_type == TType::Stop {
+ break;
+ }
+
+ let id = field_ident.id.expect("sender should always specify id for non-STOP field");
+
+ match id {
+ 1 => {
+ let remote_message = i.read_string()?;
+ i.read_field_end()?;
+ message = remote_message;
+ }
+ 2 => {
+ let remote_type_as_int = i.read_i32()?;
+ let remote_kind: ApplicationErrorKind = TryFrom::try_from(remote_type_as_int)
+ .unwrap_or(ApplicationErrorKind::Unknown);
+ i.read_field_end()?;
+ kind = remote_kind;
+ }
+ _ => {
+ i.skip(field_ident.field_type)?;
+ }
+ }
+ }
+
+ i.read_struct_end()?;
+
+ Ok(ApplicationError {
+ kind: kind,
+ message: message,
+ })
+ }
+
+ /// Convert an `ApplicationError` into its wire representation and write
+ /// it to the remote.
+ ///
+ /// Application code **should never** call this method directly.
+ pub fn write_application_error_to_out_protocol(e: &ApplicationError,
+ o: &mut TOutputProtocol)
+ -> ::Result<()> {
+ o.write_struct_begin(&TStructIdentifier { name: "TApplicationException".to_owned() })?;
+
+ let message_field = TFieldIdentifier::new("message", TType::String, 1);
+ let type_field = TFieldIdentifier::new("type", TType::I32, 2);
+
+ o.write_field_begin(&message_field)?;
+ o.write_string(&e.message)?;
+ o.write_field_end()?;
+
+ o.write_field_begin(&type_field)?;
+ o.write_i32(e.kind as i32)?;
+ o.write_field_end()?;
+
+ o.write_field_stop()?;
+ o.write_struct_end()?;
+
+ o.flush()
+ }
+}
+
+impl error::Error for Error {
+ fn description(&self) -> &str {
+ match *self {
+ Error::Transport(ref e) => TransportError::description(e),
+ Error::Protocol(ref e) => ProtocolError::description(e),
+ Error::Application(ref e) => ApplicationError::description(e),
+ Error::User(ref e) => e.description(),
+ }
+ }
+}
+
+impl Debug for Error {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ match *self {
+ Error::Transport(ref e) => Debug::fmt(e, f),
+ Error::Protocol(ref e) => Debug::fmt(e, f),
+ Error::Application(ref e) => Debug::fmt(e, f),
+ Error::User(ref e) => Debug::fmt(e, f),
+ }
+ }
+}
+
+impl Display for Error {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ match *self {
+ Error::Transport(ref e) => Display::fmt(e, f),
+ Error::Protocol(ref e) => Display::fmt(e, f),
+ Error::Application(ref e) => Display::fmt(e, f),
+ Error::User(ref e) => Display::fmt(e, f),
+ }
+ }
+}
+
+impl From<String> for Error {
+ fn from(s: String) -> Self {
+ Error::Application(ApplicationError {
+ kind: ApplicationErrorKind::Unknown,
+ message: s,
+ })
+ }
+}
+
+impl<'a> From<&'a str> for Error {
+ fn from(s: &'a str) -> Self {
+ Error::Application(ApplicationError {
+ kind: ApplicationErrorKind::Unknown,
+ message: String::from(s),
+ })
+ }
+}
+
+impl From<TransportError> for Error {
+ fn from(e: TransportError) -> Self {
+ Error::Transport(e)
+ }
+}
+
+impl From<ProtocolError> for Error {
+ fn from(e: ProtocolError) -> Self {
+ Error::Protocol(e)
+ }
+}
+
+impl From<ApplicationError> for Error {
+ fn from(e: ApplicationError) -> Self {
+ Error::Application(e)
+ }
+}
+
+/// Create a new `Error` instance of type `Transport` that wraps a
+/// `TransportError`.
+pub fn new_transport_error<S: Into<String>>(kind: TransportErrorKind, message: S) -> Error {
+ Error::Transport(TransportError::new(kind, message))
+}
+
+/// Information about I/O errors.
+#[derive(Debug)]
+pub struct TransportError {
+ /// I/O error variant.
+ ///
+ /// If a specific `TransportErrorKind` does not apply use
+ /// `TransportErrorKind::Unknown`.
+ pub kind: TransportErrorKind,
+ /// Human-readable error message.
+ pub message: String,
+}
+
+impl TransportError {
+ /// Create a new `TransportError`.
+ pub fn new<S: Into<String>>(kind: TransportErrorKind, message: S) -> TransportError {
+ TransportError {
+ kind: kind,
+ message: message.into(),
+ }
+ }
+}
+
+/// I/O error categories.
+///
+/// This list may grow, and it is not recommended to match against it.
+#[derive(Clone, Copy, Eq, Debug, PartialEq)]
+pub enum TransportErrorKind {
+ /// Catch-all I/O error.
+ Unknown = 0,
+ /// An I/O operation was attempted when the transport channel was not open.
+ NotOpen = 1,
+ /// The transport channel cannot be opened because it was opened previously.
+ AlreadyOpen = 2,
+ /// An I/O operation timed out.
+ TimedOut = 3,
+ /// A read could not complete because no bytes were available.
+ EndOfFile = 4,
+ /// An invalid (buffer/message) size was requested or received.
+ NegativeSize = 5,
+ /// Too large a buffer or message size was requested or received.
+ SizeLimit = 6,
+}
+
+impl TransportError {
+ fn description(&self) -> &str {
+ match self.kind {
+ TransportErrorKind::Unknown => "transport error",
+ TransportErrorKind::NotOpen => "not open",
+ TransportErrorKind::AlreadyOpen => "already open",
+ TransportErrorKind::TimedOut => "timed out",
+ TransportErrorKind::EndOfFile => "end of file",
+ TransportErrorKind::NegativeSize => "negative size message",
+ TransportErrorKind::SizeLimit => "message too long",
+ }
+ }
+}
+
+impl Display for TransportError {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(f, "{}", self.description())
+ }
+}
+
+impl TryFrom<i32> for TransportErrorKind {
+ type Err = Error;
+ fn try_from(from: i32) -> Result<Self, Self::Err> {
+ match from {
+ 0 => Ok(TransportErrorKind::Unknown),
+ 1 => Ok(TransportErrorKind::NotOpen),
+ 2 => Ok(TransportErrorKind::AlreadyOpen),
+ 3 => Ok(TransportErrorKind::TimedOut),
+ 4 => Ok(TransportErrorKind::EndOfFile),
+ 5 => Ok(TransportErrorKind::NegativeSize),
+ 6 => Ok(TransportErrorKind::SizeLimit),
+ _ => {
+ Err(Error::Protocol(ProtocolError {
+ kind: ProtocolErrorKind::Unknown,
+ message: format!("cannot convert {} to TransportErrorKind", from),
+ }))
+ }
+ }
+ }
+}
+
+impl From<io::Error> for Error {
+ fn from(err: io::Error) -> Self {
+ match err.kind() {
+ io::ErrorKind::ConnectionReset |
+ io::ErrorKind::ConnectionRefused |
+ io::ErrorKind::NotConnected => {
+ Error::Transport(TransportError {
+ kind: TransportErrorKind::NotOpen,
+ message: err.description().to_owned(),
+ })
+ }
+ io::ErrorKind::AlreadyExists => {
+ Error::Transport(TransportError {
+ kind: TransportErrorKind::AlreadyOpen,
+ message: err.description().to_owned(),
+ })
+ }
+ io::ErrorKind::TimedOut => {
+ Error::Transport(TransportError {
+ kind: TransportErrorKind::TimedOut,
+ message: err.description().to_owned(),
+ })
+ }
+ io::ErrorKind::UnexpectedEof => {
+ Error::Transport(TransportError {
+ kind: TransportErrorKind::EndOfFile,
+ message: err.description().to_owned(),
+ })
+ }
+ _ => {
+ Error::Transport(TransportError {
+ kind: TransportErrorKind::Unknown,
+ message: err.description().to_owned(), // FIXME: use io error's debug string
+ })
+ }
+ }
+ }
+}
+
+impl From<string::FromUtf8Error> for Error {
+ fn from(err: string::FromUtf8Error) -> Self {
+ Error::Protocol(ProtocolError {
+ kind: ProtocolErrorKind::InvalidData,
+ message: err.description().to_owned(), // FIXME: use fmt::Error's debug string
+ })
+ }
+}
+
+/// Create a new `Error` instance of type `Protocol` that wraps a
+/// `ProtocolError`.
+pub fn new_protocol_error<S: Into<String>>(kind: ProtocolErrorKind, message: S) -> Error {
+ Error::Protocol(ProtocolError::new(kind, message))
+}
+
+/// Information about errors that occur in the runtime library.
+#[derive(Debug)]
+pub struct ProtocolError {
+ /// Protocol error variant.
+ ///
+ /// If a specific `ProtocolErrorKind` does not apply use
+ /// `ProtocolErrorKind::Unknown`.
+ pub kind: ProtocolErrorKind,
+ /// Human-readable error message.
+ pub message: String,
+}
+
+impl ProtocolError {
+ /// Create a new `ProtocolError`.
+ pub fn new<S: Into<String>>(kind: ProtocolErrorKind, message: S) -> ProtocolError {
+ ProtocolError {
+ kind: kind,
+ message: message.into(),
+ }
+ }
+}
+
+/// Runtime library error categories.
+///
+/// This list may grow, and it is not recommended to match against it.
+#[derive(Clone, Copy, Eq, Debug, PartialEq)]
+pub enum ProtocolErrorKind {
+ /// Catch-all runtime-library error.
+ Unknown = 0,
+ /// An invalid argument was supplied to a library function, or invalid data
+ /// was received from a Thrift endpoint.
+ InvalidData = 1,
+ /// An invalid size was received in an encoded field.
+ NegativeSize = 2,
+ /// Thrift message or field was too long.
+ SizeLimit = 3,
+ /// Unsupported or unknown Thrift protocol version.
+ BadVersion = 4,
+ /// Unsupported Thrift protocol, server or field type.
+ NotImplemented = 5,
+ /// Reached the maximum nested depth to which an encoded Thrift field could
+ /// be skipped.
+ DepthLimit = 6,
+}
+
+impl ProtocolError {
+ fn description(&self) -> &str {
+ match self.kind {
+ ProtocolErrorKind::Unknown => "protocol error",
+ ProtocolErrorKind::InvalidData => "bad data",
+ ProtocolErrorKind::NegativeSize => "negative message size",
+ ProtocolErrorKind::SizeLimit => "message too long",
+ ProtocolErrorKind::BadVersion => "invalid thrift version",
+ ProtocolErrorKind::NotImplemented => "not implemented",
+ ProtocolErrorKind::DepthLimit => "maximum skip depth reached",
+ }
+ }
+}
+
+impl Display for ProtocolError {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(f, "{}", self.description())
+ }
+}
+
+impl TryFrom<i32> for ProtocolErrorKind {
+ type Err = Error;
+ fn try_from(from: i32) -> Result<Self, Self::Err> {
+ match from {
+ 0 => Ok(ProtocolErrorKind::Unknown),
+ 1 => Ok(ProtocolErrorKind::InvalidData),
+ 2 => Ok(ProtocolErrorKind::NegativeSize),
+ 3 => Ok(ProtocolErrorKind::SizeLimit),
+ 4 => Ok(ProtocolErrorKind::BadVersion),
+ 5 => Ok(ProtocolErrorKind::NotImplemented),
+ 6 => Ok(ProtocolErrorKind::DepthLimit),
+ _ => {
+ Err(Error::Protocol(ProtocolError {
+ kind: ProtocolErrorKind::Unknown,
+ message: format!("cannot convert {} to ProtocolErrorKind", from),
+ }))
+ }
+ }
+ }
+}
+
+/// Create a new `Error` instance of type `Application` that wraps an
+/// `ApplicationError`.
+pub fn new_application_error<S: Into<String>>(kind: ApplicationErrorKind, message: S) -> Error {
+ Error::Application(ApplicationError::new(kind, message))
+}
+
+/// Information about errors in auto-generated code or in user-implemented
+/// service handlers.
+#[derive(Debug)]
+pub struct ApplicationError {
+ /// Application error variant.
+ ///
+ /// If a specific `ApplicationErrorKind` does not apply use
+ /// `ApplicationErrorKind::Unknown`.
+ pub kind: ApplicationErrorKind,
+ /// Human-readable error message.
+ pub message: String,
+}
+
+impl ApplicationError {
+ /// Create a new `ApplicationError`.
+ pub fn new<S: Into<String>>(kind: ApplicationErrorKind, message: S) -> ApplicationError {
+ ApplicationError {
+ kind: kind,
+ message: message.into(),
+ }
+ }
+}
+
+/// Auto-generated or user-implemented code error categories.
+///
+/// This list may grow, and it is not recommended to match against it.
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+pub enum ApplicationErrorKind {
+ /// Catch-all application error.
+ Unknown = 0,
+ /// Made service call to an unknown service method.
+ UnknownMethod = 1,
+ /// Received an unknown Thrift message type. That is, not one of the
+ /// `thrift::protocol::TMessageType` variants.
+ InvalidMessageType = 2,
+ /// Method name in a service reply does not match the name of the
+ /// receiving service method.
+ WrongMethodName = 3,
+ /// Received an out-of-order Thrift message.
+ BadSequenceId = 4,
+ /// Service reply is missing required fields.
+ MissingResult = 5,
+ /// Auto-generated code failed unexpectedly.
+ InternalError = 6,
+ /// Thrift protocol error. When possible use `Error::ProtocolError` with a
+ /// specific `ProtocolErrorKind` instead.
+ ProtocolError = 7,
+ /// *Unknown*. Included only for compatibility with existing Thrift implementations.
+ InvalidTransform = 8, // ??
+ /// Thrift endpoint requested, or is using, an unsupported encoding.
+ InvalidProtocol = 9, // ??
+ /// Thrift endpoint requested, or is using, an unsupported auto-generated client type.
+ UnsupportedClientType = 10, // ??
+}
+
+impl ApplicationError {
+ fn description(&self) -> &str {
+ match self.kind {
+ ApplicationErrorKind::Unknown => "service error",
+ ApplicationErrorKind::UnknownMethod => "unknown service method",
+ ApplicationErrorKind::InvalidMessageType => "wrong message type received",
+ ApplicationErrorKind::WrongMethodName => "unknown method reply received",
+ ApplicationErrorKind::BadSequenceId => "out of order sequence id",
+ ApplicationErrorKind::MissingResult => "missing method result",
+ ApplicationErrorKind::InternalError => "remote service threw exception",
+ ApplicationErrorKind::ProtocolError => "protocol error",
+ ApplicationErrorKind::InvalidTransform => "invalid transform",
+ ApplicationErrorKind::InvalidProtocol => "invalid protocol requested",
+ ApplicationErrorKind::UnsupportedClientType => "unsupported protocol client",
+ }
+ }
+}
+
+impl Display for ApplicationError {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ write!(f, "{}", self.description())
+ }
+}
+
+impl TryFrom<i32> for ApplicationErrorKind {
+ type Err = Error;
+ fn try_from(from: i32) -> Result<Self, Self::Err> {
+ match from {
+ 0 => Ok(ApplicationErrorKind::Unknown),
+ 1 => Ok(ApplicationErrorKind::UnknownMethod),
+ 2 => Ok(ApplicationErrorKind::InvalidMessageType),
+ 3 => Ok(ApplicationErrorKind::WrongMethodName),
+ 4 => Ok(ApplicationErrorKind::BadSequenceId),
+ 5 => Ok(ApplicationErrorKind::MissingResult),
+ 6 => Ok(ApplicationErrorKind::InternalError),
+ 7 => Ok(ApplicationErrorKind::ProtocolError),
+ 8 => Ok(ApplicationErrorKind::InvalidTransform),
+ 9 => Ok(ApplicationErrorKind::InvalidProtocol),
+ 10 => Ok(ApplicationErrorKind::UnsupportedClientType),
+ _ => {
+ Err(Error::Application(ApplicationError {
+ kind: ApplicationErrorKind::Unknown,
+ message: format!("cannot convert {} to ApplicationErrorKind", from),
+ }))
+ }
+ }
+ }
+}
diff --git a/lib/rs/src/lib.rs b/lib/rs/src/lib.rs
new file mode 100644
index 0000000..ad18721
--- /dev/null
+++ b/lib/rs/src/lib.rs
@@ -0,0 +1,87 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Rust runtime library for the Apache Thrift RPC system.
+//!
+//! This crate implements the components required to build a working
+//! Thrift server and client. It is divided into the following modules:
+//!
+//! 1. errors
+//! 2. protocol
+//! 3. transport
+//! 4. server
+//! 5. autogen
+//!
+//! The modules are layered as shown in the diagram below. The `generated`
+//! layer is generated by the Thrift compiler's Rust plugin. It uses the
+//! types and functions defined in this crate to serialize and deserialize
+//! messages and implement RPC. Users interact with these types and services
+//! by writing their own code on top.
+//!
+//! ```text
+//! +-----------+
+//! | user app |
+//! +-----------+
+//! | autogen'd | (uses errors, autogen)
+//! +-----------+
+//! | protocol |
+//! +-----------+
+//! | transport |
+//! +-----------+
+//! ```
+
+#![crate_type = "lib"]
+#![doc(test(attr(allow(unused_variables), deny(warnings))))]
+
+extern crate byteorder;
+extern crate integer_encoding;
+extern crate try_from;
+
+#[macro_use]
+extern crate log;
+
+// NOTE: this macro has to be defined before any modules. See:
+// https://danielkeep.github.io/quick-intro-to-macros.html#some-more-gotchas
+
+/// Assert that an expression returning a `Result` is a success. If it is,
+/// return the value contained in the result, i.e. `expr.unwrap()`.
+#[cfg(test)]
+macro_rules! assert_success {
+ ($e: expr) => {
+ {
+ let res = $e;
+ assert!(res.is_ok());
+ res.unwrap()
+ }
+ }
+}
+
+pub mod protocol;
+pub mod server;
+pub mod transport;
+
+mod errors;
+pub use errors::*;
+
+mod autogen;
+pub use autogen::*;
+
+/// Result type returned by all runtime library functions.
+///
+/// As is convention this is a typedef of `std::result::Result`
+/// with `E` defined as the `thrift::Error` type.
+pub type Result<T> = std::result::Result<T, self::Error>;
diff --git a/lib/rs/src/protocol/binary.rs b/lib/rs/src/protocol/binary.rs
new file mode 100644
index 0000000..f3c9ea2
--- /dev/null
+++ b/lib/rs/src/protocol/binary.rs
@@ -0,0 +1,817 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use byteorder::{BigEndian, ByteOrder, ReadBytesExt, WriteBytesExt};
+use std::cell::RefCell;
+use std::convert::From;
+use std::io::{Read, Write};
+use std::rc::Rc;
+use try_from::TryFrom;
+
+use ::{ProtocolError, ProtocolErrorKind};
+use ::transport::TTransport;
+use super::{TFieldIdentifier, TInputProtocol, TInputProtocolFactory, TListIdentifier,
+ TMapIdentifier, TMessageIdentifier, TMessageType};
+use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
+
+const BINARY_PROTOCOL_VERSION_1: u32 = 0x80010000;
+
+/// Read messages encoded in the Thrift simple binary encoding.
+///
+/// There are two available modes: `strict` and `non-strict`, where the
+/// `non-strict` version does not check for the protocol version in the
+/// received message header.
+///
+/// # Examples
+///
+/// Create and use a `TBinaryInputProtocol`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift::protocol::{TBinaryInputProtocol, TInputProtocol};
+/// use thrift::transport::{TTcpTransport, TTransport};
+///
+/// let mut transport = TTcpTransport::new();
+/// transport.open("localhost:9090").unwrap();
+/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+///
+/// let mut i_prot = TBinaryInputProtocol::new(transport, true);
+///
+/// let recvd_bool = i_prot.read_bool().unwrap();
+/// let recvd_string = i_prot.read_string().unwrap();
+/// ```
+pub struct TBinaryInputProtocol {
+ strict: bool,
+ transport: Rc<RefCell<Box<TTransport>>>,
+}
+
+impl TBinaryInputProtocol {
+ /// Create a `TBinaryInputProtocol` that reads bytes from `transport`.
+ ///
+ /// Set `strict` to `true` if all incoming messages contain the protocol
+ /// version number in the protocol header.
+ pub fn new(transport: Rc<RefCell<Box<TTransport>>>, strict: bool) -> TBinaryInputProtocol {
+ TBinaryInputProtocol {
+ strict: strict,
+ transport: transport,
+ }
+ }
+}
+
+impl TInputProtocol for TBinaryInputProtocol {
+ #[cfg_attr(feature = "cargo-clippy", allow(collapsible_if))]
+ fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> {
+ let mut first_bytes = vec![0; 4];
+ self.transport.borrow_mut().read_exact(&mut first_bytes[..])?;
+
+ // the thrift version header is intentionally negative
+ // so the first check we'll do is see if the sign bit is set
+ // and if so - assume it's the protocol-version header
+ if first_bytes[0] >= 8 {
+ // apparently we got a protocol-version header - check
+ // it, and if it matches, read the rest of the fields
+ if first_bytes[0..2] != [0x80, 0x01] {
+ Err(::Error::Protocol(ProtocolError {
+ kind: ProtocolErrorKind::BadVersion,
+ message: format!("received bad version: {:?}", &first_bytes[0..2]),
+ }))
+ } else {
+ let message_type: TMessageType = TryFrom::try_from(first_bytes[3])?;
+ let name = self.read_string()?;
+ let sequence_number = self.read_i32()?;
+ Ok(TMessageIdentifier::new(name, message_type, sequence_number))
+ }
+ } else {
+ // apparently we didn't get a protocol-version header,
+ // which happens if the sender is not using the strict protocol
+ if self.strict {
+ // we're in strict mode however, and that always
+ // requires the protocol-version header to be written first
+ Err(::Error::Protocol(ProtocolError {
+ kind: ProtocolErrorKind::BadVersion,
+ message: format!("received bad version: {:?}", &first_bytes[0..2]),
+ }))
+ } else {
+ // in the non-strict version the first message field
+ // is the message name. strings (byte arrays) are length-prefixed,
+ // so we've just read the length in the first 4 bytes
+ let name_size = BigEndian::read_i32(&first_bytes) as usize;
+ let mut name_buf: Vec<u8> = Vec::with_capacity(name_size);
+ self.transport.borrow_mut().read_exact(&mut name_buf)?;
+ let name = String::from_utf8(name_buf)?;
+
+ // read the rest of the fields
+ let message_type: TMessageType = self.read_byte().and_then(TryFrom::try_from)?;
+ let sequence_number = self.read_i32()?;
+ Ok(TMessageIdentifier::new(name, message_type, sequence_number))
+ }
+ }
+ }
+
+ fn read_message_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> {
+ Ok(None)
+ }
+
+ fn read_struct_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> {
+ let field_type_byte = self.read_byte()?;
+ let field_type = field_type_from_u8(field_type_byte)?;
+ let id = match field_type {
+ TType::Stop => Ok(0),
+ _ => self.read_i16(),
+ }?;
+ Ok(TFieldIdentifier::new::<Option<String>, String, i16>(None, field_type, id))
+ }
+
+ fn read_field_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn read_bytes(&mut self) -> ::Result<Vec<u8>> {
+ let num_bytes = self.transport.borrow_mut().read_i32::<BigEndian>()? as usize;
+ let mut buf = vec![0u8; num_bytes];
+ self.transport.borrow_mut().read_exact(&mut buf).map(|_| buf).map_err(From::from)
+ }
+
+ fn read_bool(&mut self) -> ::Result<bool> {
+ let b = self.read_i8()?;
+ match b {
+ 0 => Ok(false),
+ _ => Ok(true),
+ }
+ }
+
+ fn read_i8(&mut self) -> ::Result<i8> {
+ self.transport.borrow_mut().read_i8().map_err(From::from)
+ }
+
+ fn read_i16(&mut self) -> ::Result<i16> {
+ self.transport.borrow_mut().read_i16::<BigEndian>().map_err(From::from)
+ }
+
+ fn read_i32(&mut self) -> ::Result<i32> {
+ self.transport.borrow_mut().read_i32::<BigEndian>().map_err(From::from)
+ }
+
+ fn read_i64(&mut self) -> ::Result<i64> {
+ self.transport.borrow_mut().read_i64::<BigEndian>().map_err(From::from)
+ }
+
+ fn read_double(&mut self) -> ::Result<f64> {
+ self.transport.borrow_mut().read_f64::<BigEndian>().map_err(From::from)
+ }
+
+ fn read_string(&mut self) -> ::Result<String> {
+ let bytes = self.read_bytes()?;
+ String::from_utf8(bytes).map_err(From::from)
+ }
+
+ fn read_list_begin(&mut self) -> ::Result<TListIdentifier> {
+ let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
+ let size = self.read_i32()?;
+ Ok(TListIdentifier::new(element_type, size))
+ }
+
+ fn read_list_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> {
+ let element_type: TType = self.read_byte().and_then(field_type_from_u8)?;
+ let size = self.read_i32()?;
+ Ok(TSetIdentifier::new(element_type, size))
+ }
+
+ fn read_set_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> {
+ let key_type: TType = self.read_byte().and_then(field_type_from_u8)?;
+ let value_type: TType = self.read_byte().and_then(field_type_from_u8)?;
+ let size = self.read_i32()?;
+ Ok(TMapIdentifier::new(key_type, value_type, size))
+ }
+
+ fn read_map_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ // utility
+ //
+
+ fn read_byte(&mut self) -> ::Result<u8> {
+ self.transport.borrow_mut().read_u8().map_err(From::from)
+ }
+}
+
+/// Factory for creating instances of `TBinaryInputProtocol`.
+#[derive(Default)]
+pub struct TBinaryInputProtocolFactory;
+
+impl TBinaryInputProtocolFactory {
+ /// Create a `TBinaryInputProtocolFactory`.
+ pub fn new() -> TBinaryInputProtocolFactory {
+ TBinaryInputProtocolFactory {}
+ }
+}
+
+impl TInputProtocolFactory for TBinaryInputProtocolFactory {
+ fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TInputProtocol> {
+ Box::new(TBinaryInputProtocol::new(transport, true)) as Box<TInputProtocol>
+ }
+}
+
+/// Write messages using the Thrift simple binary encoding.
+///
+/// There are two available modes: `strict` and `non-strict`, where the
+/// `strict` version writes the protocol version number in the outgoing message
+/// header and the `non-strict` version does not.
+///
+/// # Examples
+///
+/// Create and use a `TBinaryOutputProtocol`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift::protocol::{TBinaryOutputProtocol, TOutputProtocol};
+/// use thrift::transport::{TTcpTransport, TTransport};
+///
+/// let mut transport = TTcpTransport::new();
+/// transport.open("localhost:9090").unwrap();
+/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+///
+/// let mut o_prot = TBinaryOutputProtocol::new(transport, true);
+///
+/// o_prot.write_bool(true).unwrap();
+/// o_prot.write_string("test_string").unwrap();
+/// ```
+pub struct TBinaryOutputProtocol {
+ strict: bool,
+ transport: Rc<RefCell<Box<TTransport>>>,
+}
+
+impl TBinaryOutputProtocol {
+ /// Create a `TBinaryOutputProtocol` that writes bytes to `transport`.
+ ///
+ /// Set `strict` to `true` if all outgoing messages should contain the
+ /// protocol version number in the protocol header.
+ pub fn new(transport: Rc<RefCell<Box<TTransport>>>, strict: bool) -> TBinaryOutputProtocol {
+ TBinaryOutputProtocol {
+ strict: strict,
+ transport: transport,
+ }
+ }
+
+ fn write_transport(&mut self, buf: &[u8]) -> ::Result<()> {
+ self.transport.borrow_mut().write(buf).map(|_| ()).map_err(From::from)
+ }
+}
+
+impl TOutputProtocol for TBinaryOutputProtocol {
+ fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> {
+ if self.strict {
+ let message_type: u8 = identifier.message_type.into();
+ let header = BINARY_PROTOCOL_VERSION_1 | (message_type as u32);
+ self.transport.borrow_mut().write_u32::<BigEndian>(header)?;
+ self.write_string(&identifier.name)?;
+ self.write_i32(identifier.sequence_number)
+ } else {
+ self.write_string(&identifier.name)?;
+ self.write_byte(identifier.message_type.into())?;
+ self.write_i32(identifier.sequence_number)
+ }
+ }
+
+ fn write_message_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn write_struct_begin(&mut self, _: &TStructIdentifier) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn write_struct_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> {
+ if identifier.id.is_none() && identifier.field_type != TType::Stop {
+ return Err(::Error::Protocol(ProtocolError {
+ kind: ProtocolErrorKind::Unknown,
+ message: format!("cannot write identifier {:?} without sequence number",
+ &identifier),
+ }));
+ }
+
+ self.write_byte(field_type_to_u8(identifier.field_type))?;
+ if let Some(id) = identifier.id {
+ self.write_i16(id)
+ } else {
+ Ok(())
+ }
+ }
+
+ fn write_field_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn write_field_stop(&mut self) -> ::Result<()> {
+ self.write_byte(field_type_to_u8(TType::Stop))
+ }
+
+ fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> {
+ self.write_i32(b.len() as i32)?;
+ self.write_transport(b)
+ }
+
+ fn write_bool(&mut self, b: bool) -> ::Result<()> {
+ if b {
+ self.write_i8(1)
+ } else {
+ self.write_i8(0)
+ }
+ }
+
+ fn write_i8(&mut self, i: i8) -> ::Result<()> {
+ self.transport.borrow_mut().write_i8(i).map_err(From::from)
+ }
+
+ fn write_i16(&mut self, i: i16) -> ::Result<()> {
+ self.transport.borrow_mut().write_i16::<BigEndian>(i).map_err(From::from)
+ }
+
+ fn write_i32(&mut self, i: i32) -> ::Result<()> {
+ self.transport.borrow_mut().write_i32::<BigEndian>(i).map_err(From::from)
+ }
+
+ fn write_i64(&mut self, i: i64) -> ::Result<()> {
+ self.transport.borrow_mut().write_i64::<BigEndian>(i).map_err(From::from)
+ }
+
+ fn write_double(&mut self, d: f64) -> ::Result<()> {
+ self.transport.borrow_mut().write_f64::<BigEndian>(d).map_err(From::from)
+ }
+
+ fn write_string(&mut self, s: &str) -> ::Result<()> {
+ self.write_bytes(s.as_bytes())
+ }
+
+ fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> {
+ self.write_byte(field_type_to_u8(identifier.element_type))?;
+ self.write_i32(identifier.size)
+ }
+
+ fn write_list_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> {
+ self.write_byte(field_type_to_u8(identifier.element_type))?;
+ self.write_i32(identifier.size)
+ }
+
+ fn write_set_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> {
+ let key_type = identifier.key_type
+ .expect("map identifier to write should contain key type");
+ self.write_byte(field_type_to_u8(key_type))?;
+ let val_type = identifier.value_type
+ .expect("map identifier to write should contain value type");
+ self.write_byte(field_type_to_u8(val_type))?;
+ self.write_i32(identifier.size)
+ }
+
+ fn write_map_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn flush(&mut self) -> ::Result<()> {
+ self.transport.borrow_mut().flush().map_err(From::from)
+ }
+
+ // utility
+ //
+
+ fn write_byte(&mut self, b: u8) -> ::Result<()> {
+ self.transport.borrow_mut().write_u8(b).map_err(From::from)
+ }
+}
+
+/// Factory for creating instances of `TBinaryOutputProtocol`.
+#[derive(Default)]
+pub struct TBinaryOutputProtocolFactory;
+
+impl TBinaryOutputProtocolFactory {
+ /// Create a `TBinaryOutputProtocolFactory`.
+ pub fn new() -> TBinaryOutputProtocolFactory {
+ TBinaryOutputProtocolFactory {}
+ }
+}
+
+impl TOutputProtocolFactory for TBinaryOutputProtocolFactory {
+ fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TOutputProtocol> {
+ Box::new(TBinaryOutputProtocol::new(transport, true)) as Box<TOutputProtocol>
+ }
+}
+
+fn field_type_to_u8(field_type: TType) -> u8 {
+ match field_type {
+ TType::Stop => 0x00,
+ TType::Void => 0x01,
+ TType::Bool => 0x02,
+ TType::I08 => 0x03, // equivalent to TType::Byte
+ TType::Double => 0x04,
+ TType::I16 => 0x06,
+ TType::I32 => 0x08,
+ TType::I64 => 0x0A,
+ TType::String | TType::Utf7 => 0x0B,
+ TType::Struct => 0x0C,
+ TType::Map => 0x0D,
+ TType::Set => 0x0E,
+ TType::List => 0x0F,
+ TType::Utf8 => 0x10,
+ TType::Utf16 => 0x11,
+ }
+}
+
+fn field_type_from_u8(b: u8) -> ::Result<TType> {
+ match b {
+ 0x00 => Ok(TType::Stop),
+ 0x01 => Ok(TType::Void),
+ 0x02 => Ok(TType::Bool),
+ 0x03 => Ok(TType::I08), // Equivalent to TType::Byte
+ 0x04 => Ok(TType::Double),
+ 0x06 => Ok(TType::I16),
+ 0x08 => Ok(TType::I32),
+ 0x0A => Ok(TType::I64),
+ 0x0B => Ok(TType::String), // technically, also a UTF7, but we'll treat it as string
+ 0x0C => Ok(TType::Struct),
+ 0x0D => Ok(TType::Map),
+ 0x0E => Ok(TType::Set),
+ 0x0F => Ok(TType::List),
+ 0x10 => Ok(TType::Utf8),
+ 0x11 => Ok(TType::Utf16),
+ unkn => {
+ Err(::Error::Protocol(ProtocolError {
+ kind: ProtocolErrorKind::InvalidData,
+ message: format!("cannot convert {} to TType", unkn),
+ }))
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use std::rc::Rc;
+ use std::cell::RefCell;
+
+ use ::protocol::{TFieldIdentifier, TMessageIdentifier, TMessageType, TInputProtocol,
+ TListIdentifier, TMapIdentifier, TOutputProtocol, TSetIdentifier,
+ TStructIdentifier, TType};
+ use ::transport::{TPassThruTransport, TTransport};
+ use ::transport::mem::TBufferTransport;
+
+ use super::*;
+
+ #[test]
+ fn must_write_message_call_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ let ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
+ assert!(o_prot.write_message_begin(&ident).is_ok());
+
+ let buf = trans.borrow().write_buffer_to_vec();
+
+ let expected: [u8; 16] = [0x80, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x74, 0x65,
+ 0x73, 0x74, 0x00, 0x00, 0x00, 0x01];
+
+ assert_eq!(&expected, buf.as_slice());
+ }
+
+
+ #[test]
+ fn must_write_message_reply_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ let ident = TMessageIdentifier::new("test", TMessageType::Reply, 10);
+ assert!(o_prot.write_message_begin(&ident).is_ok());
+
+ let buf = trans.borrow().write_buffer_to_vec();
+
+ let expected: [u8; 16] = [0x80, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x04, 0x74, 0x65,
+ 0x73, 0x74, 0x00, 0x00, 0x00, 0x0A];
+
+ assert_eq!(&expected, buf.as_slice());
+ }
+
+ #[test]
+ fn must_round_trip_strict_message_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let sent_ident = TMessageIdentifier::new("test", TMessageType::Call, 1);
+ assert!(o_prot.write_message_begin(&sent_ident).is_ok());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let received_ident = assert_success!(i_prot.read_message_begin());
+ assert_eq!(&received_ident, &sent_ident);
+ }
+
+ #[test]
+ fn must_write_message_end() {
+ assert_no_write(|o| o.write_message_end());
+ }
+
+ #[test]
+ fn must_write_struct_begin() {
+ assert_no_write(|o| o.write_struct_begin(&TStructIdentifier::new("foo")));
+ }
+
+ #[test]
+ fn must_write_struct_end() {
+ assert_no_write(|o| o.write_struct_end());
+ }
+
+ #[test]
+ fn must_write_field_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert!(o_prot.write_field_begin(&TFieldIdentifier::new("some_field", TType::String, 22))
+ .is_ok());
+
+ let expected: [u8; 3] = [0x0B, 0x00, 0x16];
+ let buf = trans.borrow().write_buffer_to_vec();
+ assert_eq!(&expected, buf.as_slice());
+ }
+
+ #[test]
+ fn must_round_trip_field_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let sent_field_ident = TFieldIdentifier::new("foo", TType::I64, 20);
+ assert!(o_prot.write_field_begin(&sent_field_ident).is_ok());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let expected_ident = TFieldIdentifier {
+ name: None,
+ field_type: TType::I64,
+ id: Some(20),
+ }; // no name
+ let received_ident = assert_success!(i_prot.read_field_begin());
+ assert_eq!(&received_ident, &expected_ident);
+ }
+
+ #[test]
+ fn must_write_stop_field() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert!(o_prot.write_field_stop().is_ok());
+
+ let expected: [u8; 1] = [0x00];
+ let buf = trans.borrow().write_buffer_to_vec();
+ assert_eq!(&expected, buf.as_slice());
+ }
+
+ #[test]
+ fn must_round_trip_field_stop() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ assert!(o_prot.write_field_stop().is_ok());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let expected_ident = TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: Some(0),
+ }; // we get id 0
+
+ let received_ident = assert_success!(i_prot.read_field_begin());
+ assert_eq!(&received_ident, &expected_ident);
+ }
+
+ #[test]
+ fn must_write_field_end() {
+ assert_no_write(|o| o.write_field_end());
+ }
+
+ #[test]
+ fn must_write_list_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert!(o_prot.write_list_begin(&TListIdentifier::new(TType::Bool, 5)).is_ok());
+
+ let expected: [u8; 5] = [0x02, 0x00, 0x00, 0x00, 0x05];
+ let buf = trans.borrow().write_buffer_to_vec();
+ assert_eq!(&expected, buf.as_slice());
+ }
+
+ #[test]
+ fn must_round_trip_list_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let ident = TListIdentifier::new(TType::List, 900);
+ assert!(o_prot.write_list_begin(&ident).is_ok());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let received_ident = assert_success!(i_prot.read_list_begin());
+ assert_eq!(&received_ident, &ident);
+ }
+
+ #[test]
+ fn must_write_list_end() {
+ assert_no_write(|o| o.write_list_end());
+ }
+
+ #[test]
+ fn must_write_set_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert!(o_prot.write_set_begin(&TSetIdentifier::new(TType::I16, 7)).is_ok());
+
+ let expected: [u8; 5] = [0x06, 0x00, 0x00, 0x00, 0x07];
+ let buf = trans.borrow().write_buffer_to_vec();
+ assert_eq!(&expected, buf.as_slice());
+ }
+
+ #[test]
+ fn must_round_trip_set_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let ident = TSetIdentifier::new(TType::I64, 2000);
+ assert!(o_prot.write_set_begin(&ident).is_ok());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let received_ident_result = i_prot.read_set_begin();
+ assert!(received_ident_result.is_ok());
+ assert_eq!(&received_ident_result.unwrap(), &ident);
+ }
+
+ #[test]
+ fn must_write_set_end() {
+ assert_no_write(|o| o.write_set_end());
+ }
+
+ #[test]
+ fn must_write_map_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert!(o_prot.write_map_begin(&TMapIdentifier::new(TType::I64, TType::Struct, 32))
+ .is_ok());
+
+ let expected: [u8; 6] = [0x0A, 0x0C, 0x00, 0x00, 0x00, 0x20];
+ let buf = trans.borrow().write_buffer_to_vec();
+ assert_eq!(&expected, buf.as_slice());
+ }
+
+ #[test]
+ fn must_round_trip_map_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let ident = TMapIdentifier::new(TType::Map, TType::Set, 100);
+ assert!(o_prot.write_map_begin(&ident).is_ok());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let received_ident = assert_success!(i_prot.read_map_begin());
+ assert_eq!(&received_ident, &ident);
+ }
+
+ #[test]
+ fn must_write_map_end() {
+ assert_no_write(|o| o.write_map_end());
+ }
+
+ #[test]
+ fn must_write_bool_true() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert!(o_prot.write_bool(true).is_ok());
+
+ let expected: [u8; 1] = [0x01];
+ let buf = trans.borrow().write_buffer_to_vec();
+ assert_eq!(&expected, buf.as_slice());
+ }
+
+ #[test]
+ fn must_write_bool_false() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert!(o_prot.write_bool(false).is_ok());
+
+ let expected: [u8; 1] = [0x00];
+ let buf = trans.borrow().write_buffer_to_vec();
+ assert_eq!(&expected, buf.as_slice());
+ }
+
+ #[test]
+ fn must_read_bool_true() {
+ let (trans, mut i_prot, _) = test_objects();
+
+ trans.borrow_mut().set_readable_bytes(&[0x01]);
+
+ let read_bool = assert_success!(i_prot.read_bool());
+ assert_eq!(read_bool, true);
+ }
+
+ #[test]
+ fn must_read_bool_false() {
+ let (trans, mut i_prot, _) = test_objects();
+
+ trans.borrow_mut().set_readable_bytes(&[0x00]);
+
+ let read_bool = assert_success!(i_prot.read_bool());
+ assert_eq!(read_bool, false);
+ }
+
+ #[test]
+ fn must_allow_any_non_zero_value_to_be_interpreted_as_bool_true() {
+ let (trans, mut i_prot, _) = test_objects();
+
+ trans.borrow_mut().set_readable_bytes(&[0xAC]);
+
+ let read_bool = assert_success!(i_prot.read_bool());
+ assert_eq!(read_bool, true);
+ }
+
+ #[test]
+ fn must_write_bytes() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ let bytes: [u8; 10] = [0x0A, 0xCC, 0xD1, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF];
+
+ assert!(o_prot.write_bytes(&bytes).is_ok());
+
+ let buf = trans.borrow().write_buffer_to_vec();
+ assert_eq!(&buf[0..4], [0x00, 0x00, 0x00, 0x0A]); // length
+ assert_eq!(&buf[4..], bytes); // actual bytes
+ }
+
+ #[test]
+ fn must_round_trip_bytes() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let bytes: [u8; 25] = [0x20, 0xFD, 0x18, 0x84, 0x99, 0x12, 0xAB, 0xBB, 0x45, 0xDF, 0x34,
+ 0xDC, 0x98, 0xA4, 0x6D, 0xF3, 0x99, 0xB4, 0xB7, 0xD4, 0x9C, 0xA5,
+ 0xB3, 0xC9, 0x88];
+
+ assert!(o_prot.write_bytes(&bytes).is_ok());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let received_bytes = assert_success!(i_prot.read_bytes());
+ assert_eq!(&received_bytes, &bytes);
+ }
+
+ fn test_objects
+ ()
+ -> (Rc<RefCell<Box<TBufferTransport>>>, TBinaryInputProtocol, TBinaryOutputProtocol)
+ {
+ let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity(40, 40))));
+
+ let inner: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() });
+ let inner = Rc::new(RefCell::new(inner));
+
+ let i_prot = TBinaryInputProtocol::new(inner.clone(), true);
+ let o_prot = TBinaryOutputProtocol::new(inner.clone(), true);
+
+ (mem, i_prot, o_prot)
+ }
+
+ fn assert_no_write<F: FnMut(&mut TBinaryOutputProtocol) -> ::Result<()>>(mut write_fn: F) {
+ let (trans, _, mut o_prot) = test_objects();
+ assert!(write_fn(&mut o_prot).is_ok());
+ assert_eq!(trans.borrow().write_buffer_as_ref().len(), 0);
+ }
+}
diff --git a/lib/rs/src/protocol/compact.rs b/lib/rs/src/protocol/compact.rs
new file mode 100644
index 0000000..96fa8ef
--- /dev/null
+++ b/lib/rs/src/protocol/compact.rs
@@ -0,0 +1,2085 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
+use integer_encoding::{VarIntReader, VarIntWriter};
+use std::cell::RefCell;
+use std::convert::From;
+use std::rc::Rc;
+use std::io::{Read, Write};
+use try_from::TryFrom;
+
+use ::transport::TTransport;
+use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType,
+ TInputProtocol, TInputProtocolFactory};
+use super::{TOutputProtocol, TOutputProtocolFactory, TSetIdentifier, TStructIdentifier, TType};
+
+const COMPACT_PROTOCOL_ID: u8 = 0x82;
+const COMPACT_VERSION: u8 = 0x01;
+const COMPACT_VERSION_MASK: u8 = 0x1F;
+
+/// Read messages encoded in the Thrift compact protocol.
+///
+/// # Examples
+///
+/// Create and use a `TCompactInputProtocol`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift::protocol::{TCompactInputProtocol, TInputProtocol};
+/// use thrift::transport::{TTcpTransport, TTransport};
+///
+/// let mut transport = TTcpTransport::new();
+/// transport.open("localhost:9090").unwrap();
+/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+///
+/// let mut i_prot = TCompactInputProtocol::new(transport);
+///
+/// let recvd_bool = i_prot.read_bool().unwrap();
+/// let recvd_string = i_prot.read_string().unwrap();
+/// ```
+pub struct TCompactInputProtocol {
+ // Identifier of the last field deserialized for a struct.
+ last_read_field_id: i16,
+ // Stack of the last read field ids (a new entry is added each time a nested struct is read).
+ read_field_id_stack: Vec<i16>,
+ // Boolean value for a field.
+ // Saved because boolean fields and their value are encoded in a single byte,
+ // and reading the field only occurs after the field id is read.
+ pending_read_bool_value: Option<bool>,
+ // Underlying transport used for byte-level operations.
+ transport: Rc<RefCell<Box<TTransport>>>,
+}
+
+impl TCompactInputProtocol {
+ /// Create a `TCompactInputProtocol` that reads bytes from `transport`.
+ pub fn new(transport: Rc<RefCell<Box<TTransport>>>) -> TCompactInputProtocol {
+ TCompactInputProtocol {
+ last_read_field_id: 0,
+ read_field_id_stack: Vec::new(),
+ pending_read_bool_value: None,
+ transport: transport,
+ }
+ }
+
+ fn read_list_set_begin(&mut self) -> ::Result<(TType, i32)> {
+ let header = self.read_byte()?;
+ let element_type = collection_u8_to_type(header & 0x0F)?;
+
+ let element_count;
+ let possible_element_count = (header & 0xF0) >> 4;
+ if possible_element_count != 15 {
+ // high bits set high if count and type encoded separately
+ element_count = possible_element_count as i32;
+ } else {
+ element_count = self.transport.borrow_mut().read_varint::<u32>()? as i32;
+ }
+
+ Ok((element_type, element_count))
+ }
+}
+
+impl TInputProtocol for TCompactInputProtocol {
+ fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> {
+ let compact_id = self.read_byte()?;
+ if compact_id != COMPACT_PROTOCOL_ID {
+ Err(::Error::Protocol(::ProtocolError {
+ kind: ::ProtocolErrorKind::BadVersion,
+ message: format!("invalid compact protocol header {:?}", compact_id),
+ }))
+ } else {
+ Ok(())
+ }?;
+
+ let type_and_byte = self.read_byte()?;
+ let received_version = type_and_byte & COMPACT_VERSION_MASK;
+ if received_version != COMPACT_VERSION {
+ Err(::Error::Protocol(::ProtocolError {
+ kind: ::ProtocolErrorKind::BadVersion,
+ message: format!("cannot process compact protocol version {:?}",
+ received_version),
+ }))
+ } else {
+ Ok(())
+ }?;
+
+ // NOTE: unsigned right shift will pad with 0s
+ let message_type: TMessageType = TMessageType::try_from(type_and_byte >> 5)?;
+ let sequence_number = self.read_i32()?;
+ let service_call_name = self.read_string()?;
+
+ self.last_read_field_id = 0;
+
+ Ok(TMessageIdentifier::new(service_call_name, message_type, sequence_number))
+ }
+
+ fn read_message_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> {
+ self.read_field_id_stack.push(self.last_read_field_id);
+ self.last_read_field_id = 0;
+ Ok(None)
+ }
+
+ fn read_struct_end(&mut self) -> ::Result<()> {
+ self.last_read_field_id = self.read_field_id_stack
+ .pop()
+ .expect("should have previous field ids");
+ Ok(())
+ }
+
+ fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> {
+ // we can read at least one byte, which is:
+ // - the type
+ // - the field delta and the type
+ let field_type = self.read_byte()?;
+ let field_delta = (field_type & 0xF0) >> 4;
+ let field_type = match field_type & 0x0F {
+ 0x01 => {
+ self.pending_read_bool_value = Some(true);
+ Ok(TType::Bool)
+ }
+ 0x02 => {
+ self.pending_read_bool_value = Some(false);
+ Ok(TType::Bool)
+ }
+ ttu8 => u8_to_type(ttu8),
+ }?;
+
+ match field_type {
+ TType::Stop => {
+ Ok(TFieldIdentifier::new::<Option<String>, String, Option<i16>>(None,
+ TType::Stop,
+ None))
+ }
+ _ => {
+ if field_delta != 0 {
+ self.last_read_field_id += field_delta as i16;
+ } else {
+ self.last_read_field_id = self.read_i16()?;
+ };
+
+ Ok(TFieldIdentifier {
+ name: None,
+ field_type: field_type,
+ id: Some(self.last_read_field_id),
+ })
+ }
+ }
+ }
+
+ fn read_field_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn read_bool(&mut self) -> ::Result<bool> {
+ match self.pending_read_bool_value.take() {
+ Some(b) => Ok(b),
+ None => {
+ let b = self.read_byte()?;
+ match b {
+ 0x01 => Ok(true),
+ 0x02 => Ok(false),
+ unkn => {
+ Err(::Error::Protocol(::ProtocolError {
+ kind: ::ProtocolErrorKind::InvalidData,
+ message: format!("cannot convert {} into bool", unkn),
+ }))
+ }
+ }
+ }
+ }
+ }
+
+ fn read_bytes(&mut self) -> ::Result<Vec<u8>> {
+ let len = self.transport.borrow_mut().read_varint::<u32>()?;
+ let mut buf = vec![0u8; len as usize];
+ self.transport.borrow_mut().read_exact(&mut buf).map_err(From::from).map(|_| buf)
+ }
+
+ fn read_i8(&mut self) -> ::Result<i8> {
+ self.read_byte().map(|i| i as i8)
+ }
+
+ fn read_i16(&mut self) -> ::Result<i16> {
+ self.transport.borrow_mut().read_varint::<i16>().map_err(From::from)
+ }
+
+ fn read_i32(&mut self) -> ::Result<i32> {
+ self.transport.borrow_mut().read_varint::<i32>().map_err(From::from)
+ }
+
+ fn read_i64(&mut self) -> ::Result<i64> {
+ self.transport.borrow_mut().read_varint::<i64>().map_err(From::from)
+ }
+
+ fn read_double(&mut self) -> ::Result<f64> {
+ self.transport.borrow_mut().read_f64::<BigEndian>().map_err(From::from)
+ }
+
+ fn read_string(&mut self) -> ::Result<String> {
+ let bytes = self.read_bytes()?;
+ String::from_utf8(bytes).map_err(From::from)
+ }
+
+ fn read_list_begin(&mut self) -> ::Result<TListIdentifier> {
+ let (element_type, element_count) = self.read_list_set_begin()?;
+ Ok(TListIdentifier::new(element_type, element_count))
+ }
+
+ fn read_list_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> {
+ let (element_type, element_count) = self.read_list_set_begin()?;
+ Ok(TSetIdentifier::new(element_type, element_count))
+ }
+
+ fn read_set_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> {
+ let element_count = self.transport.borrow_mut().read_varint::<u32>()? as i32;
+ if element_count == 0 {
+ Ok(TMapIdentifier::new(None, None, 0))
+ } else {
+ let type_header = self.read_byte()?;
+ let key_type = collection_u8_to_type((type_header & 0xF0) >> 4)?;
+ let val_type = collection_u8_to_type(type_header & 0x0F)?;
+ Ok(TMapIdentifier::new(key_type, val_type, element_count))
+ }
+ }
+
+ fn read_map_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ // utility
+ //
+
+ fn read_byte(&mut self) -> ::Result<u8> {
+ let mut buf = [0u8; 1];
+ self.transport.borrow_mut().read_exact(&mut buf).map_err(From::from).map(|_| buf[0])
+ }
+}
+
+/// Factory for creating instances of `TCompactInputProtocol`.
+#[derive(Default)]
+pub struct TCompactInputProtocolFactory;
+
+impl TCompactInputProtocolFactory {
+ /// Create a `TCompactInputProtocolFactory`.
+ pub fn new() -> TCompactInputProtocolFactory {
+ TCompactInputProtocolFactory {}
+ }
+}
+
+impl TInputProtocolFactory for TCompactInputProtocolFactory {
+ fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TInputProtocol> {
+ Box::new(TCompactInputProtocol::new(transport)) as Box<TInputProtocol>
+ }
+}
+
+/// Write messages using the Thrift compact protocol.
+///
+/// # Examples
+///
+/// Create and use a `TCompactOutputProtocol`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift::protocol::{TCompactOutputProtocol, TOutputProtocol};
+/// use thrift::transport::{TTcpTransport, TTransport};
+///
+/// let mut transport = TTcpTransport::new();
+/// transport.open("localhost:9090").unwrap();
+/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+///
+/// let mut o_prot = TCompactOutputProtocol::new(transport);
+///
+/// o_prot.write_bool(true).unwrap();
+/// o_prot.write_string("test_string").unwrap();
+/// ```
+pub struct TCompactOutputProtocol {
+ // Identifier of the last field serialized for a struct.
+ last_write_field_id: i16,
+ // Stack of the last written field ids (a new entry is added each time a nested struct is written).
+ write_field_id_stack: Vec<i16>,
+ // Field identifier of the boolean field to be written.
+ // Saved because boolean fields and their value are encoded in a single byte
+ pending_write_bool_field_identifier: Option<TFieldIdentifier>,
+ // Underlying transport used for byte-level operations.
+ transport: Rc<RefCell<Box<TTransport>>>,
+}
+
+impl TCompactOutputProtocol {
+ /// Create a `TCompactOutputProtocol` that writes bytes to `transport`.
+ pub fn new(transport: Rc<RefCell<Box<TTransport>>>) -> TCompactOutputProtocol {
+ TCompactOutputProtocol {
+ last_write_field_id: 0,
+ write_field_id_stack: Vec::new(),
+ pending_write_bool_field_identifier: None,
+ transport: transport,
+ }
+ }
+
+ // FIXME: field_type as unconstrained u8 is bad
+ fn write_field_header(&mut self, field_type: u8, field_id: i16) -> ::Result<()> {
+ let field_delta = field_id - self.last_write_field_id;
+ if field_delta > 0 && field_delta < 15 {
+ self.write_byte(((field_delta as u8) << 4) | field_type)?;
+ } else {
+ self.write_byte(field_type)?;
+ self.write_i16(field_id)?;
+ }
+ self.last_write_field_id = field_id;
+ Ok(())
+ }
+
+ fn write_list_set_begin(&mut self, element_type: TType, element_count: i32) -> ::Result<()> {
+ let elem_identifier = collection_type_to_u8(element_type);
+ if element_count <= 14 {
+ let header = (element_count as u8) << 4 | elem_identifier;
+ self.write_byte(header)
+ } else {
+ let header = 0xF0 | elem_identifier;
+ self.write_byte(header)?;
+ self.transport
+ .borrow_mut()
+ .write_varint(element_count as u32)
+ .map_err(From::from)
+ .map(|_| ())
+ }
+ }
+
+ fn assert_no_pending_bool_write(&self) {
+ if let Some(ref f) = self.pending_write_bool_field_identifier {
+ panic!("pending bool field {:?} not written", f)
+ }
+ }
+}
+
+impl TOutputProtocol for TCompactOutputProtocol {
+ fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> {
+ self.write_byte(COMPACT_PROTOCOL_ID)?;
+ self.write_byte((u8::from(identifier.message_type) << 5) | COMPACT_VERSION)?;
+ self.write_i32(identifier.sequence_number)?;
+ self.write_string(&identifier.name)?;
+ Ok(())
+ }
+
+ fn write_message_end(&mut self) -> ::Result<()> {
+ self.assert_no_pending_bool_write();
+ Ok(())
+ }
+
+ fn write_struct_begin(&mut self, _: &TStructIdentifier) -> ::Result<()> {
+ self.write_field_id_stack.push(self.last_write_field_id);
+ self.last_write_field_id = 0;
+ Ok(())
+ }
+
+ fn write_struct_end(&mut self) -> ::Result<()> {
+ self.assert_no_pending_bool_write();
+ self.last_write_field_id =
+ self.write_field_id_stack.pop().expect("should have previous field ids");
+ Ok(())
+ }
+
+ fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> {
+ match identifier.field_type {
+ TType::Bool => {
+ if self.pending_write_bool_field_identifier.is_some() {
+ panic!("should not have a pending bool while writing another bool with id: \
+ {:?}",
+ identifier)
+ }
+ self.pending_write_bool_field_identifier = Some(identifier.clone());
+ Ok(())
+ }
+ _ => {
+ let field_type = type_to_u8(identifier.field_type);
+ let field_id = identifier.id.expect("non-stop field should have field id");
+ self.write_field_header(field_type, field_id)
+ }
+ }
+ }
+
+ fn write_field_end(&mut self) -> ::Result<()> {
+ self.assert_no_pending_bool_write();
+ Ok(())
+ }
+
+ fn write_field_stop(&mut self) -> ::Result<()> {
+ self.assert_no_pending_bool_write();
+ self.write_byte(type_to_u8(TType::Stop))
+ }
+
+ fn write_bool(&mut self, b: bool) -> ::Result<()> {
+ match self.pending_write_bool_field_identifier.take() {
+ Some(pending) => {
+ let field_id = pending.id.expect("bool field should have a field id");
+ let field_type_as_u8 = if b { 0x01 } else { 0x02 };
+ self.write_field_header(field_type_as_u8, field_id)
+ }
+ None => {
+ if b {
+ self.write_byte(0x01)
+ } else {
+ self.write_byte(0x02)
+ }
+ }
+ }
+ }
+
+ fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> {
+ self.transport.borrow_mut().write_varint(b.len() as u32)?;
+ self.transport.borrow_mut().write_all(b).map_err(From::from)
+ }
+
+ fn write_i8(&mut self, i: i8) -> ::Result<()> {
+ self.write_byte(i as u8)
+ }
+
+ fn write_i16(&mut self, i: i16) -> ::Result<()> {
+ self.transport.borrow_mut().write_varint(i).map_err(From::from).map(|_| ())
+ }
+
+ fn write_i32(&mut self, i: i32) -> ::Result<()> {
+ self.transport.borrow_mut().write_varint(i).map_err(From::from).map(|_| ())
+ }
+
+ fn write_i64(&mut self, i: i64) -> ::Result<()> {
+ self.transport.borrow_mut().write_varint(i).map_err(From::from).map(|_| ())
+ }
+
+ fn write_double(&mut self, d: f64) -> ::Result<()> {
+ self.transport.borrow_mut().write_f64::<BigEndian>(d).map_err(From::from)
+ }
+
+ fn write_string(&mut self, s: &str) -> ::Result<()> {
+ self.write_bytes(s.as_bytes())
+ }
+
+ fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> {
+ self.write_list_set_begin(identifier.element_type, identifier.size)
+ }
+
+ fn write_list_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> {
+ self.write_list_set_begin(identifier.element_type, identifier.size)
+ }
+
+ fn write_set_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> {
+ if identifier.size == 0 {
+ self.write_byte(0)
+ } else {
+ self.transport.borrow_mut().write_varint(identifier.size as u32)?;
+
+ let key_type = identifier.key_type
+ .expect("map identifier to write should contain key type");
+ let key_type_byte = collection_type_to_u8(key_type) << 4;
+
+ let val_type = identifier.value_type
+ .expect("map identifier to write should contain value type");
+ let val_type_byte = collection_type_to_u8(val_type);
+
+ let map_type_header = key_type_byte | val_type_byte;
+ self.write_byte(map_type_header)
+ }
+ }
+
+ fn write_map_end(&mut self) -> ::Result<()> {
+ Ok(())
+ }
+
+ fn flush(&mut self) -> ::Result<()> {
+ self.transport.borrow_mut().flush().map_err(From::from)
+ }
+
+ // utility
+ //
+
+ fn write_byte(&mut self, b: u8) -> ::Result<()> {
+ self.transport.borrow_mut().write(&[b]).map_err(From::from).map(|_| ())
+ }
+}
+
+/// Factory for creating instances of `TCompactOutputProtocol`.
+#[derive(Default)]
+pub struct TCompactOutputProtocolFactory;
+
+impl TCompactOutputProtocolFactory {
+ /// Create a `TCompactOutputProtocolFactory`.
+ pub fn new() -> TCompactOutputProtocolFactory {
+ TCompactOutputProtocolFactory {}
+ }
+}
+
+impl TOutputProtocolFactory for TCompactOutputProtocolFactory {
+ fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TOutputProtocol> {
+ Box::new(TCompactOutputProtocol::new(transport)) as Box<TOutputProtocol>
+ }
+}
+
+fn collection_type_to_u8(field_type: TType) -> u8 {
+ match field_type {
+ TType::Bool => 0x01,
+ f => type_to_u8(f),
+ }
+}
+
+fn type_to_u8(field_type: TType) -> u8 {
+ match field_type {
+ TType::Stop => 0x00,
+ TType::I08 => 0x03, // equivalent to TType::Byte
+ TType::I16 => 0x04,
+ TType::I32 => 0x05,
+ TType::I64 => 0x06,
+ TType::Double => 0x07,
+ TType::String => 0x08,
+ TType::List => 0x09,
+ TType::Set => 0x0A,
+ TType::Map => 0x0B,
+ TType::Struct => 0x0C,
+ _ => panic!(format!("should not have attempted to convert {} to u8", field_type)),
+ }
+}
+
+fn collection_u8_to_type(b: u8) -> ::Result<TType> {
+ match b {
+ 0x01 => Ok(TType::Bool),
+ o => u8_to_type(o),
+ }
+}
+
+fn u8_to_type(b: u8) -> ::Result<TType> {
+ match b {
+ 0x00 => Ok(TType::Stop),
+ 0x03 => Ok(TType::I08), // equivalent to TType::Byte
+ 0x04 => Ok(TType::I16),
+ 0x05 => Ok(TType::I32),
+ 0x06 => Ok(TType::I64),
+ 0x07 => Ok(TType::Double),
+ 0x08 => Ok(TType::String),
+ 0x09 => Ok(TType::List),
+ 0x0A => Ok(TType::Set),
+ 0x0B => Ok(TType::Map),
+ 0x0C => Ok(TType::Struct),
+ unkn => {
+ Err(::Error::Protocol(::ProtocolError {
+ kind: ::ProtocolErrorKind::InvalidData,
+ message: format!("cannot convert {} into TType", unkn),
+ }))
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use std::rc::Rc;
+ use std::cell::RefCell;
+
+ use ::protocol::{TFieldIdentifier, TMessageIdentifier, TMessageType, TInputProtocol,
+ TListIdentifier, TMapIdentifier, TOutputProtocol, TSetIdentifier,
+ TStructIdentifier, TType};
+ use ::transport::{TPassThruTransport, TTransport};
+ use ::transport::mem::TBufferTransport;
+
+ use super::*;
+
+ #[test]
+ fn must_write_message_begin_0() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new("foo", TMessageType::Call, 431)));
+
+ let expected: [u8; 8] =
+ [0x82 /* protocol ID */, 0x21 /* message type | protocol version */, 0xDE,
+ 0x06 /* zig-zag varint sequence number */, 0x03 /* message-name length */,
+ 0x66, 0x6F, 0x6F /* "foo" */];
+
+ assert_eq!(trans.borrow().write_buffer_as_ref(), &expected);
+ }
+
+ #[test]
+ fn must_write_message_begin_1() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_message_begin(&TMessageIdentifier::new("bar", TMessageType::Reply, 991828)));
+
+ let expected: [u8; 9] =
+ [0x82 /* protocol ID */, 0x41 /* message type | protocol version */, 0xA8,
+ 0x89, 0x79 /* zig-zag varint sequence number */,
+ 0x03 /* message-name length */, 0x62, 0x61, 0x72 /* "bar" */];
+
+ assert_eq!(trans.borrow().write_buffer_as_ref(), &expected);
+ }
+
+ #[test]
+ fn must_round_trip_message_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let ident = TMessageIdentifier::new("service_call", TMessageType::Call, 1283948);
+
+ assert_success!(o_prot.write_message_begin(&ident));
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let res = assert_success!(i_prot.read_message_begin());
+ assert_eq!(&res, &ident);
+ }
+
+ #[test]
+ fn must_write_message_end() {
+ assert_no_write(|o| o.write_message_end());
+ }
+
+ // NOTE: structs and fields are tested together
+ //
+
+ #[test]
+ fn must_write_struct_with_delta_fields() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ // no bytes should be written however
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // write three fields with tiny field ids
+ // since they're small the field ids will be encoded as deltas
+
+ // since this is the first field (and it's zero) it gets the full varint write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I08, 0)));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it can be encoded as a delta
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I16, 4)));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it can be encoded as a delta
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::List, 9)));
+ assert_success!(o_prot.write_field_end());
+
+ // now, finish the struct off
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // get bytes written
+ let buf = trans.borrow_mut().write_buffer_to_vec();
+
+ let expected: [u8; 5] = [0x03 /* field type */, 0x00 /* first field id */,
+ 0x44 /* field delta (4) | field type */,
+ 0x59 /* field delta (5) | field type */,
+ 0x00 /* field stop */];
+
+ assert_eq!(&buf, &expected);
+ }
+
+ #[test]
+ fn must_round_trip_struct_with_delta_fields() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ // no bytes should be written however
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // write three fields with tiny field ids
+ // since they're small the field ids will be encoded as deltas
+
+ // since this is the first field (and it's zero) it gets the full varint write
+ let field_ident_1 = TFieldIdentifier::new("foo", TType::I08, 0);
+ assert_success!(o_prot.write_field_begin(&field_ident_1));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it can be encoded as a delta
+ let field_ident_2 = TFieldIdentifier::new("foo", TType::I16, 4);
+ assert_success!(o_prot.write_field_begin(&field_ident_2));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it can be encoded as a delta
+ let field_ident_3 = TFieldIdentifier::new("foo", TType::List, 9);
+ assert_success!(o_prot.write_field_begin(&field_ident_3));
+ assert_success!(o_prot.write_field_end());
+
+ // now, finish the struct off
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ // read the struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_1 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_1,
+ TFieldIdentifier { name: None, ..field_ident_1 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_2 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_2,
+ TFieldIdentifier { name: None, ..field_ident_2 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_3 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_3,
+ TFieldIdentifier { name: None, ..field_ident_3 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_4 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_4,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+
+ assert_success!(i_prot.read_struct_end());
+ }
+
+ #[test]
+ fn must_write_struct_with_non_zero_initial_field_and_delta_fields() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ // no bytes should be written however
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // write three fields with tiny field ids
+ // since they're small the field ids will be encoded as deltas
+
+ // gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I32, 1)));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it can be encoded as a delta
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Set, 2)));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it can be encoded as a delta
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::String, 6)));
+ assert_success!(o_prot.write_field_end());
+
+ // now, finish the struct off
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // get bytes written
+ let buf = trans.borrow_mut().write_buffer_to_vec();
+
+ let expected: [u8; 4] = [0x15 /* field delta (1) | field type */,
+ 0x1A /* field delta (1) | field type */,
+ 0x48 /* field delta (4) | field type */,
+ 0x00 /* field stop */];
+
+ assert_eq!(&buf, &expected);
+ }
+
+ #[test]
+ fn must_round_trip_struct_with_non_zero_initial_field_and_delta_fields() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ // no bytes should be written however
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // write three fields with tiny field ids
+ // since they're small the field ids will be encoded as deltas
+
+ // gets a delta write
+ let field_ident_1 = TFieldIdentifier::new("foo", TType::I32, 1);
+ assert_success!(o_prot.write_field_begin(&field_ident_1));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it can be encoded as a delta
+ let field_ident_2 = TFieldIdentifier::new("foo", TType::Set, 2);
+ assert_success!(o_prot.write_field_begin(&field_ident_2));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it can be encoded as a delta
+ let field_ident_3 = TFieldIdentifier::new("foo", TType::String, 6);
+ assert_success!(o_prot.write_field_begin(&field_ident_3));
+ assert_success!(o_prot.write_field_end());
+
+ // now, finish the struct off
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ // read the struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_1 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_1,
+ TFieldIdentifier { name: None, ..field_ident_1 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_2 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_2,
+ TFieldIdentifier { name: None, ..field_ident_2 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_3 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_3,
+ TFieldIdentifier { name: None, ..field_ident_3 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_4 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_4,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+
+ assert_success!(i_prot.read_struct_end());
+ }
+
+ #[test]
+ fn must_write_struct_with_long_fields() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ // no bytes should be written however
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // write three fields with field ids that cannot be encoded as deltas
+
+ // since this is the first field (and it's zero) it gets the full varint write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I32, 0)));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta is > 15 it is encoded as a zig-zag varint
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 16)));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta is > 15 it is encoded as a zig-zag varint
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Set, 99)));
+ assert_success!(o_prot.write_field_end());
+
+ // now, finish the struct off
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // get bytes written
+ let buf = trans.borrow_mut().write_buffer_to_vec();
+
+ let expected: [u8; 8] =
+ [0x05 /* field type */, 0x00 /* first field id */,
+ 0x06 /* field type */, 0x20 /* zig-zag varint field id */,
+ 0x0A /* field type */, 0xC6, 0x01 /* zig-zag varint field id */,
+ 0x00 /* field stop */];
+
+ assert_eq!(&buf, &expected);
+ }
+
+ #[test]
+ fn must_round_trip_struct_with_long_fields() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ // no bytes should be written however
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // write three fields with field ids that cannot be encoded as deltas
+
+ // since this is the first field (and it's zero) it gets the full varint write
+ let field_ident_1 = TFieldIdentifier::new("foo", TType::I32, 0);
+ assert_success!(o_prot.write_field_begin(&field_ident_1));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta is > 15 it is encoded as a zig-zag varint
+ let field_ident_2 = TFieldIdentifier::new("foo", TType::I64, 16);
+ assert_success!(o_prot.write_field_begin(&field_ident_2));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta is > 15 it is encoded as a zig-zag varint
+ let field_ident_3 = TFieldIdentifier::new("foo", TType::Set, 99);
+ assert_success!(o_prot.write_field_begin(&field_ident_3));
+ assert_success!(o_prot.write_field_end());
+
+ // now, finish the struct off
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ // read the struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_1 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_1,
+ TFieldIdentifier { name: None, ..field_ident_1 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_2 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_2,
+ TFieldIdentifier { name: None, ..field_ident_2 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_3 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_3,
+ TFieldIdentifier { name: None, ..field_ident_3 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_4 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_4,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+
+ assert_success!(i_prot.read_struct_end());
+ }
+
+ #[test]
+ fn must_write_struct_with_mix_of_long_and_delta_fields() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ // no bytes should be written however
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // write three fields with field ids that cannot be encoded as deltas
+
+ // since the delta is > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 1)));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I32, 9)));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta is > 15 it is encoded as a zig-zag varint
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Set, 1000)));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta is > 15 it is encoded as a zig-zag varint
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Set, 2001)));
+ assert_success!(o_prot.write_field_end());
+
+ // since this is only 3 up from the previous it is recorded as a delta
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Set, 2004)));
+ assert_success!(o_prot.write_field_end());
+
+ // now, finish the struct off
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // get bytes written
+ let buf = trans.borrow_mut().write_buffer_to_vec();
+
+ let expected: [u8; 10] =
+ [0x16 /* field delta (1) | field type */,
+ 0x85 /* field delta (8) | field type */, 0x0A /* field type */, 0xD0,
+ 0x0F /* zig-zag varint field id */, 0x0A /* field type */, 0xA2,
+ 0x1F /* zig-zag varint field id */,
+ 0x3A /* field delta (3) | field type */, 0x00 /* field stop */];
+
+ assert_eq!(&buf, &expected);
+ }
+
+ #[test]
+ fn must_round_trip_struct_with_mix_of_long_and_delta_fields() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ // no bytes should be written however
+ let struct_ident = TStructIdentifier::new("foo");
+ assert_success!(o_prot.write_struct_begin(&struct_ident));
+
+ // write three fields with field ids that cannot be encoded as deltas
+
+ // since the delta is > 0 and < 15 it gets a delta write
+ let field_ident_1 = TFieldIdentifier::new("foo", TType::I64, 1);
+ assert_success!(o_prot.write_field_begin(&field_ident_1));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it gets a delta write
+ let field_ident_2 = TFieldIdentifier::new("foo", TType::I32, 9);
+ assert_success!(o_prot.write_field_begin(&field_ident_2));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta is > 15 it is encoded as a zig-zag varint
+ let field_ident_3 = TFieldIdentifier::new("foo", TType::Set, 1000);
+ assert_success!(o_prot.write_field_begin(&field_ident_3));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta is > 15 it is encoded as a zig-zag varint
+ let field_ident_4 = TFieldIdentifier::new("foo", TType::Set, 2001);
+ assert_success!(o_prot.write_field_begin(&field_ident_4));
+ assert_success!(o_prot.write_field_end());
+
+ // since this is only 3 up from the previous it is recorded as a delta
+ let field_ident_5 = TFieldIdentifier::new("foo", TType::Set, 2004);
+ assert_success!(o_prot.write_field_begin(&field_ident_5));
+ assert_success!(o_prot.write_field_end());
+
+ // now, finish the struct off
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ // read the struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_1 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_1,
+ TFieldIdentifier { name: None, ..field_ident_1 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_2 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_2,
+ TFieldIdentifier { name: None, ..field_ident_2 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_3 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_3,
+ TFieldIdentifier { name: None, ..field_ident_3 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_4 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_4,
+ TFieldIdentifier { name: None, ..field_ident_4 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_5 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_5,
+ TFieldIdentifier { name: None, ..field_ident_5 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_6 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_6,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+
+ assert_success!(i_prot.read_struct_end());
+ }
+
+ #[test]
+ fn must_write_nested_structs_0() {
+ // last field of the containing struct is a delta
+ // first field of the the contained struct is a delta
+
+ let (trans, _, mut o_prot) = test_objects();
+
+ // start containing struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // containing struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 1)));
+ assert_success!(o_prot.write_field_end());
+
+ // containing struct
+ // since this delta > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I32, 9)));
+ assert_success!(o_prot.write_field_end());
+
+ // start contained struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // contained struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I08, 7)));
+ assert_success!(o_prot.write_field_end());
+
+ // contained struct
+ // since this delta > 15 it gets a full write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Double, 24)));
+ assert_success!(o_prot.write_field_end());
+
+ // end contained struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // end containing struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // get bytes written
+ let buf = trans.borrow_mut().write_buffer_to_vec();
+
+ let expected: [u8; 7] =
+ [0x16 /* field delta (1) | field type */,
+ 0x85 /* field delta (8) | field type */,
+ 0x73 /* field delta (7) | field type */, 0x07 /* field type */,
+ 0x30 /* zig-zag varint field id */, 0x00 /* field stop - contained */,
+ 0x00 /* field stop - containing */];
+
+ assert_eq!(&buf, &expected);
+ }
+
+ #[test]
+ fn must_round_trip_nested_structs_0() {
+ // last field of the containing struct is a delta
+ // first field of the the contained struct is a delta
+
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ // start containing struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // containing struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ let field_ident_1 = TFieldIdentifier::new("foo", TType::I64, 1);
+ assert_success!(o_prot.write_field_begin(&field_ident_1));
+ assert_success!(o_prot.write_field_end());
+
+ // containing struct
+ // since this delta > 0 and < 15 it gets a delta write
+ let field_ident_2 = TFieldIdentifier::new("foo", TType::I32, 9);
+ assert_success!(o_prot.write_field_begin(&field_ident_2));
+ assert_success!(o_prot.write_field_end());
+
+ // start contained struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // contained struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ let field_ident_3 = TFieldIdentifier::new("foo", TType::I08, 7);
+ assert_success!(o_prot.write_field_begin(&field_ident_3));
+ assert_success!(o_prot.write_field_end());
+
+ // contained struct
+ // since this delta > 15 it gets a full write
+ let field_ident_4 = TFieldIdentifier::new("foo", TType::Double, 24);
+ assert_success!(o_prot.write_field_begin(&field_ident_4));
+ assert_success!(o_prot.write_field_end());
+
+ // end contained struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // end containing struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ // read containing struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_1 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_1,
+ TFieldIdentifier { name: None, ..field_ident_1 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_2 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_2,
+ TFieldIdentifier { name: None, ..field_ident_2 });
+ assert_success!(i_prot.read_field_end());
+
+ // read contained struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_3 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_3,
+ TFieldIdentifier { name: None, ..field_ident_3 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_4 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_4,
+ TFieldIdentifier { name: None, ..field_ident_4 });
+ assert_success!(i_prot.read_field_end());
+
+ // end contained struct
+ let read_ident_6 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_6,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+ assert_success!(i_prot.read_struct_end());
+
+ // end containing struct
+ let read_ident_7 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_7,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+ assert_success!(i_prot.read_struct_end());
+ }
+
+ #[test]
+ fn must_write_nested_structs_1() {
+ // last field of the containing struct is a delta
+ // first field of the the contained struct is a full write
+
+ let (trans, _, mut o_prot) = test_objects();
+
+ // start containing struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // containing struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 1)));
+ assert_success!(o_prot.write_field_end());
+
+ // containing struct
+ // since this delta > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I32, 9)));
+ assert_success!(o_prot.write_field_end());
+
+ // start contained struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // contained struct
+ // since this delta > 15 it gets a full write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Double, 24)));
+ assert_success!(o_prot.write_field_end());
+
+ // contained struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I08, 27)));
+ assert_success!(o_prot.write_field_end());
+
+ // end contained struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // end containing struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // get bytes written
+ let buf = trans.borrow_mut().write_buffer_to_vec();
+
+ let expected: [u8; 7] =
+ [0x16 /* field delta (1) | field type */,
+ 0x85 /* field delta (8) | field type */, 0x07 /* field type */,
+ 0x30 /* zig-zag varint field id */,
+ 0x33 /* field delta (3) | field type */, 0x00 /* field stop - contained */,
+ 0x00 /* field stop - containing */];
+
+ assert_eq!(&buf, &expected);
+ }
+
+ #[test]
+ fn must_round_trip_nested_structs_1() {
+ // last field of the containing struct is a delta
+ // first field of the the contained struct is a full write
+
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ // start containing struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // containing struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ let field_ident_1 = TFieldIdentifier::new("foo", TType::I64, 1);
+ assert_success!(o_prot.write_field_begin(&field_ident_1));
+ assert_success!(o_prot.write_field_end());
+
+ // containing struct
+ // since this delta > 0 and < 15 it gets a delta write
+ let field_ident_2 = TFieldIdentifier::new("foo", TType::I32, 9);
+ assert_success!(o_prot.write_field_begin(&field_ident_2));
+ assert_success!(o_prot.write_field_end());
+
+ // start contained struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // contained struct
+ // since this delta > 15 it gets a full write
+ let field_ident_3 = TFieldIdentifier::new("foo", TType::Double, 24);
+ assert_success!(o_prot.write_field_begin(&field_ident_3));
+ assert_success!(o_prot.write_field_end());
+
+ // contained struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ let field_ident_4 = TFieldIdentifier::new("foo", TType::I08, 27);
+ assert_success!(o_prot.write_field_begin(&field_ident_4));
+ assert_success!(o_prot.write_field_end());
+
+ // end contained struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // end containing struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ // read containing struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_1 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_1,
+ TFieldIdentifier { name: None, ..field_ident_1 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_2 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_2,
+ TFieldIdentifier { name: None, ..field_ident_2 });
+ assert_success!(i_prot.read_field_end());
+
+ // read contained struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_3 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_3,
+ TFieldIdentifier { name: None, ..field_ident_3 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_4 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_4,
+ TFieldIdentifier { name: None, ..field_ident_4 });
+ assert_success!(i_prot.read_field_end());
+
+ // end contained struct
+ let read_ident_6 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_6,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+ assert_success!(i_prot.read_struct_end());
+
+ // end containing struct
+ let read_ident_7 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_7,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+ assert_success!(i_prot.read_struct_end());
+ }
+
+ #[test]
+ fn must_write_nested_structs_2() {
+ // last field of the containing struct is a full write
+ // first field of the the contained struct is a delta write
+
+ let (trans, _, mut o_prot) = test_objects();
+
+ // start containing struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // containing struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 1)));
+ assert_success!(o_prot.write_field_end());
+
+ // containing struct
+ // since this delta > 15 it gets a full write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::String, 21)));
+ assert_success!(o_prot.write_field_end());
+
+ // start contained struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // contained struct
+ // since this delta > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Double, 7)));
+ assert_success!(o_prot.write_field_end());
+
+ // contained struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I08, 10)));
+ assert_success!(o_prot.write_field_end());
+
+ // end contained struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // end containing struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // get bytes written
+ let buf = trans.borrow_mut().write_buffer_to_vec();
+
+ let expected: [u8; 7] =
+ [0x16 /* field delta (1) | field type */, 0x08 /* field type */,
+ 0x2A /* zig-zag varint field id */, 0x77 /* field delta(7) | field type */,
+ 0x33 /* field delta (3) | field type */, 0x00 /* field stop - contained */,
+ 0x00 /* field stop - containing */];
+
+ assert_eq!(&buf, &expected);
+ }
+
+ #[test]
+ fn must_round_trip_nested_structs_2() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ // start containing struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // containing struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ let field_ident_1 = TFieldIdentifier::new("foo", TType::I64, 1);
+ assert_success!(o_prot.write_field_begin(&field_ident_1));
+ assert_success!(o_prot.write_field_end());
+
+ // containing struct
+ // since this delta > 15 it gets a full write
+ let field_ident_2 = TFieldIdentifier::new("foo", TType::String, 21);
+ assert_success!(o_prot.write_field_begin(&field_ident_2));
+ assert_success!(o_prot.write_field_end());
+
+ // start contained struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // contained struct
+ // since this delta > 0 and < 15 it gets a delta write
+ let field_ident_3 = TFieldIdentifier::new("foo", TType::Double, 7);
+ assert_success!(o_prot.write_field_begin(&field_ident_3));
+ assert_success!(o_prot.write_field_end());
+
+ // contained struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ let field_ident_4 = TFieldIdentifier::new("foo", TType::I08, 10);
+ assert_success!(o_prot.write_field_begin(&field_ident_4));
+ assert_success!(o_prot.write_field_end());
+
+ // end contained struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // end containing struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ // read containing struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_1 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_1,
+ TFieldIdentifier { name: None, ..field_ident_1 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_2 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_2,
+ TFieldIdentifier { name: None, ..field_ident_2 });
+ assert_success!(i_prot.read_field_end());
+
+ // read contained struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_3 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_3,
+ TFieldIdentifier { name: None, ..field_ident_3 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_4 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_4,
+ TFieldIdentifier { name: None, ..field_ident_4 });
+ assert_success!(i_prot.read_field_end());
+
+ // end contained struct
+ let read_ident_6 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_6,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+ assert_success!(i_prot.read_struct_end());
+
+ // end containing struct
+ let read_ident_7 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_7,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+ assert_success!(i_prot.read_struct_end());
+ }
+
+ #[test]
+ fn must_write_nested_structs_3() {
+ // last field of the containing struct is a full write
+ // first field of the the contained struct is a full write
+
+ let (trans, _, mut o_prot) = test_objects();
+
+ // start containing struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // containing struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I64, 1)));
+ assert_success!(o_prot.write_field_end());
+
+ // containing struct
+ // since this delta > 15 it gets a full write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::String, 21)));
+ assert_success!(o_prot.write_field_end());
+
+ // start contained struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // contained struct
+ // since this delta > 15 it gets a full write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Double, 21)));
+ assert_success!(o_prot.write_field_end());
+
+ // contained struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::I08, 27)));
+ assert_success!(o_prot.write_field_end());
+
+ // end contained struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // end containing struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // get bytes written
+ let buf = trans.borrow_mut().write_buffer_to_vec();
+
+ let expected: [u8; 8] =
+ [0x16 /* field delta (1) | field type */, 0x08 /* field type */,
+ 0x2A /* zig-zag varint field id */, 0x07 /* field type */,
+ 0x2A /* zig-zag varint field id */,
+ 0x63 /* field delta (6) | field type */, 0x00 /* field stop - contained */,
+ 0x00 /* field stop - containing */];
+
+ assert_eq!(&buf, &expected);
+ }
+
+ #[test]
+ fn must_round_trip_nested_structs_3() {
+ // last field of the containing struct is a full write
+ // first field of the the contained struct is a full write
+
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ // start containing struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // containing struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ let field_ident_1 = TFieldIdentifier::new("foo", TType::I64, 1);
+ assert_success!(o_prot.write_field_begin(&field_ident_1));
+ assert_success!(o_prot.write_field_end());
+
+ // containing struct
+ // since this delta > 15 it gets a full write
+ let field_ident_2 = TFieldIdentifier::new("foo", TType::String, 21);
+ assert_success!(o_prot.write_field_begin(&field_ident_2));
+ assert_success!(o_prot.write_field_end());
+
+ // start contained struct
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // contained struct
+ // since this delta > 15 it gets a full write
+ let field_ident_3 = TFieldIdentifier::new("foo", TType::Double, 21);
+ assert_success!(o_prot.write_field_begin(&field_ident_3));
+ assert_success!(o_prot.write_field_end());
+
+ // contained struct
+ // since the delta is > 0 and < 15 it gets a delta write
+ let field_ident_4 = TFieldIdentifier::new("foo", TType::I08, 27);
+ assert_success!(o_prot.write_field_begin(&field_ident_4));
+ assert_success!(o_prot.write_field_end());
+
+ // end contained struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // end containing struct
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ // read containing struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_1 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_1,
+ TFieldIdentifier { name: None, ..field_ident_1 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_2 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_2,
+ TFieldIdentifier { name: None, ..field_ident_2 });
+ assert_success!(i_prot.read_field_end());
+
+ // read contained struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_3 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_3,
+ TFieldIdentifier { name: None, ..field_ident_3 });
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_4 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_4,
+ TFieldIdentifier { name: None, ..field_ident_4 });
+ assert_success!(i_prot.read_field_end());
+
+ // end contained struct
+ let read_ident_6 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_6,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+ assert_success!(i_prot.read_struct_end());
+
+ // end containing struct
+ let read_ident_7 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_7,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+ assert_success!(i_prot.read_struct_end());
+ }
+
+ #[test]
+ fn must_write_bool_field() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ // no bytes should be written however
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+
+ // write three fields with field ids that cannot be encoded as deltas
+
+ // since the delta is > 0 and < 16 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1)));
+ assert_success!(o_prot.write_bool(true));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it gets a delta write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 9)));
+ assert_success!(o_prot.write_bool(false));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 15 it gets a full write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 26)));
+ assert_success!(o_prot.write_bool(true));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 15 it gets a full write
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 45)));
+ assert_success!(o_prot.write_bool(false));
+ assert_success!(o_prot.write_field_end());
+
+ // now, finish the struct off
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ // get bytes written
+ let buf = trans.borrow_mut().write_buffer_to_vec();
+
+ let expected: [u8; 7] = [0x11 /* field delta (1) | true */,
+ 0x82 /* field delta (8) | false */, 0x01 /* true */,
+ 0x34 /* field id */, 0x02 /* false */,
+ 0x5A /* field id */, 0x00 /* stop field */];
+
+ assert_eq!(&buf, &expected);
+ }
+
+ #[test]
+ fn must_round_trip_bool_field() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ // no bytes should be written however
+ let struct_ident = TStructIdentifier::new("foo");
+ assert_success!(o_prot.write_struct_begin(&struct_ident));
+
+ // write two fields
+
+ // since the delta is > 0 and < 16 it gets a delta write
+ let field_ident_1 = TFieldIdentifier::new("foo", TType::Bool, 1);
+ assert_success!(o_prot.write_field_begin(&field_ident_1));
+ assert_success!(o_prot.write_bool(true));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 0 and < 15 it gets a delta write
+ let field_ident_2 = TFieldIdentifier::new("foo", TType::Bool, 9);
+ assert_success!(o_prot.write_field_begin(&field_ident_2));
+ assert_success!(o_prot.write_bool(false));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 15 it gets a full write
+ let field_ident_3 = TFieldIdentifier::new("foo", TType::Bool, 26);
+ assert_success!(o_prot.write_field_begin(&field_ident_3));
+ assert_success!(o_prot.write_bool(true));
+ assert_success!(o_prot.write_field_end());
+
+ // since this delta > 15 it gets a full write
+ let field_ident_4 = TFieldIdentifier::new("foo", TType::Bool, 45);
+ assert_success!(o_prot.write_field_begin(&field_ident_4));
+ assert_success!(o_prot.write_bool(false));
+ assert_success!(o_prot.write_field_end());
+
+ // now, finish the struct off
+ assert_success!(o_prot.write_field_stop());
+ assert_success!(o_prot.write_struct_end());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ // read the struct back
+ assert_success!(i_prot.read_struct_begin());
+
+ let read_ident_1 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_1,
+ TFieldIdentifier { name: None, ..field_ident_1 });
+ let read_value_1 = assert_success!(i_prot.read_bool());
+ assert_eq!(read_value_1, true);
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_2 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_2,
+ TFieldIdentifier { name: None, ..field_ident_2 });
+ let read_value_2 = assert_success!(i_prot.read_bool());
+ assert_eq!(read_value_2, false);
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_3 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_3,
+ TFieldIdentifier { name: None, ..field_ident_3 });
+ let read_value_3 = assert_success!(i_prot.read_bool());
+ assert_eq!(read_value_3, true);
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_4 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_4,
+ TFieldIdentifier { name: None, ..field_ident_4 });
+ let read_value_4 = assert_success!(i_prot.read_bool());
+ assert_eq!(read_value_4, false);
+ assert_success!(i_prot.read_field_end());
+
+ let read_ident_5 = assert_success!(i_prot.read_field_begin());
+ assert_eq!(read_ident_5,
+ TFieldIdentifier {
+ name: None,
+ field_type: TType::Stop,
+ id: None,
+ });
+
+ assert_success!(i_prot.read_struct_end());
+ }
+
+ #[test]
+ #[should_panic]
+ fn must_fail_if_write_field_end_without_writing_bool_value() {
+ let (_, _, mut o_prot) = test_objects();
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1)));
+ o_prot.write_field_end().unwrap();
+ }
+
+ #[test]
+ #[should_panic]
+ fn must_fail_if_write_stop_field_without_writing_bool_value() {
+ let (_, _, mut o_prot) = test_objects();
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1)));
+ o_prot.write_field_stop().unwrap();
+ }
+
+ #[test]
+ #[should_panic]
+ fn must_fail_if_write_struct_end_without_writing_bool_value() {
+ let (_, _, mut o_prot) = test_objects();
+ assert_success!(o_prot.write_struct_begin(&TStructIdentifier::new("foo")));
+ assert_success!(o_prot.write_field_begin(&TFieldIdentifier::new("foo", TType::Bool, 1)));
+ o_prot.write_struct_end().unwrap();
+ }
+
+ #[test]
+ #[should_panic]
+ fn must_fail_if_write_struct_end_without_any_fields() {
+ let (_, _, mut o_prot) = test_objects();
+ o_prot.write_struct_end().unwrap();
+ }
+
+ #[test]
+ fn must_write_field_end() {
+ assert_no_write(|o| o.write_field_end());
+ }
+
+ #[test]
+ fn must_write_small_sized_list_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_list_begin(&TListIdentifier::new(TType::I64, 4)));
+
+ let expected: [u8; 1] = [0x46 /* size | elem_type */];
+
+ assert_eq!(trans.borrow().write_buffer_as_ref(), &expected);
+ }
+
+ #[test]
+ fn must_round_trip_small_sized_list_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let ident = TListIdentifier::new(TType::I08, 10);
+
+ assert_success!(o_prot.write_list_begin(&ident));
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let res = assert_success!(i_prot.read_list_begin());
+ assert_eq!(&res, &ident);
+ }
+
+ #[test]
+ fn must_write_large_sized_list_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ let res = o_prot.write_list_begin(&TListIdentifier::new(TType::List, 9999));
+ assert!(res.is_ok());
+
+ let expected: [u8; 3] = [0xF9 /* 0xF0 | elem_type */, 0x8F,
+ 0x4E /* size as varint */];
+
+ assert_eq!(trans.borrow().write_buffer_as_ref(), &expected);
+ }
+
+ #[test]
+ fn must_round_trip_large_sized_list_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let ident = TListIdentifier::new(TType::Set, 47381);
+
+ assert_success!(o_prot.write_list_begin(&ident));
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let res = assert_success!(i_prot.read_list_begin());
+ assert_eq!(&res, &ident);
+ }
+
+ #[test]
+ fn must_write_list_end() {
+ assert_no_write(|o| o.write_list_end());
+ }
+
+ #[test]
+ fn must_write_small_sized_set_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_set_begin(&TSetIdentifier::new(TType::Struct, 2)));
+
+ let expected: [u8; 1] = [0x2C /* size | elem_type */];
+
+ assert_eq!(trans.borrow().write_buffer_as_ref(), &expected);
+ }
+
+ #[test]
+ fn must_round_trip_small_sized_set_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let ident = TSetIdentifier::new(TType::I16, 7);
+
+ assert_success!(o_prot.write_set_begin(&ident));
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let res = assert_success!(i_prot.read_set_begin());
+ assert_eq!(&res, &ident);
+ }
+
+ #[test]
+ fn must_write_large_sized_set_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_set_begin(&TSetIdentifier::new(TType::Double, 23891)));
+
+ let expected: [u8; 4] = [0xF7 /* 0xF0 | elem_type */, 0xD3, 0xBA,
+ 0x01 /* size as varint */];
+
+ assert_eq!(trans.borrow().write_buffer_as_ref(), &expected);
+ }
+
+ #[test]
+ fn must_round_trip_large_sized_set_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let ident = TSetIdentifier::new(TType::Map, 3928429);
+
+ assert_success!(o_prot.write_set_begin(&ident));
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let res = assert_success!(i_prot.read_set_begin());
+ assert_eq!(&res, &ident);
+ }
+
+ #[test]
+ fn must_write_set_end() {
+ assert_no_write(|o| o.write_set_end());
+ }
+
+ #[test]
+ fn must_write_zero_sized_map_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::String, TType::I32, 0)));
+
+ let expected: [u8; 1] = [0x00]; // since size is zero we don't write anything
+
+ assert_eq!(trans.borrow().write_buffer_as_ref(), &expected);
+ }
+
+ #[test]
+ fn must_read_zero_sized_map_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::Double, TType::I32, 0)));
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let res = assert_success!(i_prot.read_map_begin());
+ assert_eq!(&res,
+ &TMapIdentifier {
+ key_type: None,
+ value_type: None,
+ size: 0,
+ });
+ }
+
+ #[test]
+ fn must_write_map_begin() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::Double, TType::String, 238)));
+
+ let expected: [u8; 3] = [0xEE, 0x01 /* size as varint */,
+ 0x78 /* key type | val type */];
+
+ assert_eq!(trans.borrow().write_buffer_as_ref(), &expected);
+ }
+
+ #[test]
+ fn must_round_trip_map_begin() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let ident = TMapIdentifier::new(TType::Map, TType::List, 1928349);
+
+ assert_success!(o_prot.write_map_begin(&ident));
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ let res = assert_success!(i_prot.read_map_begin());
+ assert_eq!(&res, &ident);
+ }
+
+ #[test]
+ fn must_write_map_end() {
+ assert_no_write(|o| o.write_map_end());
+ }
+
+ #[test]
+ fn must_write_map_with_bool_key_and_value() {
+ let (trans, _, mut o_prot) = test_objects();
+
+ assert_success!(o_prot.write_map_begin(&TMapIdentifier::new(TType::Bool, TType::Bool, 1)));
+ assert_success!(o_prot.write_bool(true));
+ assert_success!(o_prot.write_bool(false));
+ assert_success!(o_prot.write_map_end());
+
+ let expected: [u8; 4] = [0x01 /* size as varint */,
+ 0x11 /* key type | val type */, 0x01 /* key: true */,
+ 0x02 /* val: false */];
+
+ assert_eq!(trans.borrow().write_buffer_as_ref(), &expected);
+ }
+
+ #[test]
+ fn must_round_trip_map_with_bool_value() {
+ let (trans, mut i_prot, mut o_prot) = test_objects();
+
+ let map_ident = TMapIdentifier::new(TType::Bool, TType::Bool, 2);
+ assert_success!(o_prot.write_map_begin(&map_ident));
+ assert_success!(o_prot.write_bool(true));
+ assert_success!(o_prot.write_bool(false));
+ assert_success!(o_prot.write_bool(false));
+ assert_success!(o_prot.write_bool(true));
+ assert_success!(o_prot.write_map_end());
+
+ trans.borrow_mut().copy_write_buffer_to_read_buffer();
+
+ // map header
+ let rcvd_ident = assert_success!(i_prot.read_map_begin());
+ assert_eq!(&rcvd_ident, &map_ident);
+ // key 1
+ let b = assert_success!(i_prot.read_bool());
+ assert_eq!(b, true);
+ // val 1
+ let b = assert_success!(i_prot.read_bool());
+ assert_eq!(b, false);
+ // key 2
+ let b = assert_success!(i_prot.read_bool());
+ assert_eq!(b, false);
+ // val 2
+ let b = assert_success!(i_prot.read_bool());
+ assert_eq!(b, true);
+ // map end
+ assert_success!(i_prot.read_map_end());
+ }
+
+ #[test]
+ fn must_read_map_end() {
+ let (_, mut i_prot, _) = test_objects();
+ assert!(i_prot.read_map_end().is_ok()); // will blow up if we try to read from empty buffer
+ }
+
+ fn test_objects
+ ()
+ -> (Rc<RefCell<Box<TBufferTransport>>>, TCompactInputProtocol, TCompactOutputProtocol)
+ {
+ let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity(80, 80))));
+
+ let inner: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() });
+ let inner = Rc::new(RefCell::new(inner));
+
+ let i_prot = TCompactInputProtocol::new(inner.clone());
+ let o_prot = TCompactOutputProtocol::new(inner.clone());
+
+ (mem, i_prot, o_prot)
+ }
+
+ fn assert_no_write<F: FnMut(&mut TCompactOutputProtocol) -> ::Result<()>>(mut write_fn: F) {
+ let (trans, _, mut o_prot) = test_objects();
+ assert!(write_fn(&mut o_prot).is_ok());
+ assert_eq!(trans.borrow().write_buffer_as_ref().len(), 0);
+ }
+}
diff --git a/lib/rs/src/protocol/mod.rs b/lib/rs/src/protocol/mod.rs
new file mode 100644
index 0000000..b230d63
--- /dev/null
+++ b/lib/rs/src/protocol/mod.rs
@@ -0,0 +1,709 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Types used to send and receive primitives between a Thrift client and server.
+//!
+//! # Examples
+//!
+//! Create and use a `TOutputProtocol`.
+//!
+//! ```no_run
+//! use std::cell::RefCell;
+//! use std::rc::Rc;
+//! use thrift::protocol::{TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol, TType};
+//! use thrift::transport::{TTcpTransport, TTransport};
+//!
+//! // create the I/O channel
+//! let mut transport = TTcpTransport::new();
+//! transport.open("127.0.0.1:9090").unwrap();
+//! let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+//!
+//! // create the protocol to encode types into bytes
+//! let mut o_prot = TBinaryOutputProtocol::new(transport.clone(), true);
+//!
+//! // write types
+//! o_prot.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap();
+//! o_prot.write_string("foo").unwrap();
+//! o_prot.write_field_end().unwrap();
+//! ```
+//!
+//! Create and use a `TInputProtocol`.
+//!
+//! ```no_run
+//! use std::cell::RefCell;
+//! use std::rc::Rc;
+//! use thrift::protocol::{TBinaryInputProtocol, TInputProtocol};
+//! use thrift::transport::{TTcpTransport, TTransport};
+//!
+//! // create the I/O channel
+//! let mut transport = TTcpTransport::new();
+//! transport.open("127.0.0.1:9090").unwrap();
+//! let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+//!
+//! // create the protocol to decode bytes into types
+//! let mut i_prot = TBinaryInputProtocol::new(transport.clone(), true);
+//!
+//! // read types from the wire
+//! let field_identifier = i_prot.read_field_begin().unwrap();
+//! let field_contents = i_prot.read_string().unwrap();
+//! let field_end = i_prot.read_field_end().unwrap();
+//! ```
+
+use std::cell::RefCell;
+use std::fmt;
+use std::fmt::{Display, Formatter};
+use std::convert::From;
+use std::rc::Rc;
+use try_from::TryFrom;
+
+use ::{ProtocolError, ProtocolErrorKind};
+use ::transport::TTransport;
+
+mod binary;
+mod compact;
+mod multiplexed;
+mod stored;
+
+pub use self::binary::{TBinaryInputProtocol, TBinaryInputProtocolFactory, TBinaryOutputProtocol,
+ TBinaryOutputProtocolFactory};
+pub use self::compact::{TCompactInputProtocol, TCompactInputProtocolFactory,
+ TCompactOutputProtocol, TCompactOutputProtocolFactory};
+pub use self::multiplexed::TMultiplexedOutputProtocol;
+pub use self::stored::TStoredInputProtocol;
+
+// Default maximum depth to which `TInputProtocol::skip` will skip a Thrift
+// field. A default is necessary because Thrift structs or collections may
+// contain nested structs and collections, which could result in indefinite
+// recursion.
+const MAXIMUM_SKIP_DEPTH: i8 = 64;
+
+/// Converts a stream of bytes into Thrift identifiers, primitives,
+/// containers, or structs.
+///
+/// This trait does not deal with higher-level Thrift concepts like structs or
+/// exceptions - only with primitives and message or container boundaries. Once
+/// bytes are read they are deserialized and an identifier (for example
+/// `TMessageIdentifier`) or a primitive is returned.
+///
+/// All methods return a `thrift::Result`. If an `Err` is returned the protocol
+/// instance and its underlying transport should be terminated.
+///
+/// # Examples
+///
+/// Create and use a `TInputProtocol`
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift::protocol::{TBinaryInputProtocol, TInputProtocol};
+/// use thrift::transport::{TTcpTransport, TTransport};
+///
+/// let mut transport = TTcpTransport::new();
+/// transport.open("127.0.0.1:9090").unwrap();
+/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+///
+/// let mut i_prot = TBinaryInputProtocol::new(transport.clone(), true);
+///
+/// let field_identifier = i_prot.read_field_begin().unwrap();
+/// let field_contents = i_prot.read_string().unwrap();
+/// let field_end = i_prot.read_field_end().unwrap();
+/// ```
+pub trait TInputProtocol {
+ /// Read the beginning of a Thrift message.
+ fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier>;
+ /// Read the end of a Thrift message.
+ fn read_message_end(&mut self) -> ::Result<()>;
+ /// Read the beginning of a Thrift struct.
+ fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>>;
+ /// Read the end of a Thrift struct.
+ fn read_struct_end(&mut self) -> ::Result<()>;
+ /// Read the beginning of a Thrift struct field.
+ fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier>;
+ /// Read the end of a Thrift struct field.
+ fn read_field_end(&mut self) -> ::Result<()>;
+ /// Read a bool.
+ fn read_bool(&mut self) -> ::Result<bool>;
+ /// Read a fixed-length byte array.
+ fn read_bytes(&mut self) -> ::Result<Vec<u8>>;
+ /// Read a word.
+ fn read_i8(&mut self) -> ::Result<i8>;
+ /// Read a 16-bit signed integer.
+ fn read_i16(&mut self) -> ::Result<i16>;
+ /// Read a 32-bit signed integer.
+ fn read_i32(&mut self) -> ::Result<i32>;
+ /// Read a 64-bit signed integer.
+ fn read_i64(&mut self) -> ::Result<i64>;
+ /// Read a 64-bit float.
+ fn read_double(&mut self) -> ::Result<f64>;
+ /// Read a fixed-length string (not null terminated).
+ fn read_string(&mut self) -> ::Result<String>;
+ /// Read the beginning of a list.
+ fn read_list_begin(&mut self) -> ::Result<TListIdentifier>;
+ /// Read the end of a list.
+ fn read_list_end(&mut self) -> ::Result<()>;
+ /// Read the beginning of a set.
+ fn read_set_begin(&mut self) -> ::Result<TSetIdentifier>;
+ /// Read the end of a set.
+ fn read_set_end(&mut self) -> ::Result<()>;
+ /// Read the beginning of a map.
+ fn read_map_begin(&mut self) -> ::Result<TMapIdentifier>;
+ /// Read the end of a map.
+ fn read_map_end(&mut self) -> ::Result<()>;
+ /// Skip a field with type `field_type` recursively until the default
+ /// maximum skip depth is reached.
+ fn skip(&mut self, field_type: TType) -> ::Result<()> {
+ self.skip_till_depth(field_type, MAXIMUM_SKIP_DEPTH)
+ }
+ /// Skip a field with type `field_type` recursively up to `depth` levels.
+ fn skip_till_depth(&mut self, field_type: TType, depth: i8) -> ::Result<()> {
+ if depth == 0 {
+ return Err(::Error::Protocol(ProtocolError {
+ kind: ProtocolErrorKind::DepthLimit,
+ message: format!("cannot parse past {:?}", field_type),
+ }));
+ }
+
+ match field_type {
+ TType::Bool => self.read_bool().map(|_| ()),
+ TType::I08 => self.read_i8().map(|_| ()),
+ TType::I16 => self.read_i16().map(|_| ()),
+ TType::I32 => self.read_i32().map(|_| ()),
+ TType::I64 => self.read_i64().map(|_| ()),
+ TType::Double => self.read_double().map(|_| ()),
+ TType::String => self.read_string().map(|_| ()),
+ TType::Struct => {
+ self.read_struct_begin()?;
+ loop {
+ let field_ident = self.read_field_begin()?;
+ if field_ident.field_type == TType::Stop {
+ break;
+ }
+ self.skip_till_depth(field_ident.field_type, depth - 1)?;
+ }
+ self.read_struct_end()
+ }
+ TType::List => {
+ let list_ident = self.read_list_begin()?;
+ for _ in 0..list_ident.size {
+ self.skip_till_depth(list_ident.element_type, depth - 1)?;
+ }
+ self.read_list_end()
+ }
+ TType::Set => {
+ let set_ident = self.read_set_begin()?;
+ for _ in 0..set_ident.size {
+ self.skip_till_depth(set_ident.element_type, depth - 1)?;
+ }
+ self.read_set_end()
+ }
+ TType::Map => {
+ let map_ident = self.read_map_begin()?;
+ for _ in 0..map_ident.size {
+ let key_type = map_ident.key_type
+ .expect("non-zero sized map should contain key type");
+ let val_type = map_ident.value_type
+ .expect("non-zero sized map should contain value type");
+ self.skip_till_depth(key_type, depth - 1)?;
+ self.skip_till_depth(val_type, depth - 1)?;
+ }
+ self.read_map_end()
+ }
+ u => {
+ Err(::Error::Protocol(ProtocolError {
+ kind: ProtocolErrorKind::Unknown,
+ message: format!("cannot skip field type {:?}", &u),
+ }))
+ }
+ }
+ }
+
+ // utility (DO NOT USE IN GENERATED CODE!!!!)
+ //
+
+ /// Read an unsigned byte.
+ ///
+ /// This method should **never** be used in generated code.
+ fn read_byte(&mut self) -> ::Result<u8>;
+}
+
+/// Converts Thrift identifiers, primitives, containers or structs into a
+/// stream of bytes.
+///
+/// This trait does not deal with higher-level Thrift concepts like structs or
+/// exceptions - only with primitives and message or container boundaries.
+/// Write methods take an identifier (for example, `TMessageIdentifier`) or a
+/// primitive. Any or all of the fields in an identifier may be omitted when
+/// writing to the transport. Write methods may even be noops. All of this is
+/// transparent to the caller; as long as a matching `TInputProtocol`
+/// implementation is used, received messages will be decoded correctly.
+///
+/// All methods return a `thrift::Result`. If an `Err` is returned the protocol
+/// instance and its underlying transport should be terminated.
+///
+/// # Examples
+///
+/// Create and use a `TOutputProtocol`
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift::protocol::{TBinaryOutputProtocol, TFieldIdentifier, TOutputProtocol, TType};
+/// use thrift::transport::{TTcpTransport, TTransport};
+///
+/// let mut transport = TTcpTransport::new();
+/// transport.open("127.0.0.1:9090").unwrap();
+/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+///
+/// let mut o_prot = TBinaryOutputProtocol::new(transport.clone(), true);
+///
+/// o_prot.write_field_begin(&TFieldIdentifier::new("string_thing", TType::String, 1)).unwrap();
+/// o_prot.write_string("foo").unwrap();
+/// o_prot.write_field_end().unwrap();
+/// ```
+pub trait TOutputProtocol {
+ /// Write the beginning of a Thrift message.
+ fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()>;
+ /// Write the end of a Thrift message.
+ fn write_message_end(&mut self) -> ::Result<()>;
+ /// Write the beginning of a Thrift struct.
+ fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> ::Result<()>;
+ /// Write the end of a Thrift struct.
+ fn write_struct_end(&mut self) -> ::Result<()>;
+ /// Write the beginning of a Thrift field.
+ fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()>;
+ /// Write the end of a Thrift field.
+ fn write_field_end(&mut self) -> ::Result<()>;
+ /// Write a STOP field indicating that all the fields in a struct have been
+ /// written.
+ fn write_field_stop(&mut self) -> ::Result<()>;
+ /// Write a bool.
+ fn write_bool(&mut self, b: bool) -> ::Result<()>;
+ /// Write a fixed-length byte array.
+ fn write_bytes(&mut self, b: &[u8]) -> ::Result<()>;
+ /// Write an 8-bit signed integer.
+ fn write_i8(&mut self, i: i8) -> ::Result<()>;
+ /// Write a 16-bit signed integer.
+ fn write_i16(&mut self, i: i16) -> ::Result<()>;
+ /// Write a 32-bit signed integer.
+ fn write_i32(&mut self, i: i32) -> ::Result<()>;
+ /// Write a 64-bit signed integer.
+ fn write_i64(&mut self, i: i64) -> ::Result<()>;
+ /// Write a 64-bit float.
+ fn write_double(&mut self, d: f64) -> ::Result<()>;
+ /// Write a fixed-length string.
+ fn write_string(&mut self, s: &str) -> ::Result<()>;
+ /// Write the beginning of a list.
+ fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()>;
+ /// Write the end of a list.
+ fn write_list_end(&mut self) -> ::Result<()>;
+ /// Write the beginning of a set.
+ fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()>;
+ /// Write the end of a set.
+ fn write_set_end(&mut self) -> ::Result<()>;
+ /// Write the beginning of a map.
+ fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()>;
+ /// Write the end of a map.
+ fn write_map_end(&mut self) -> ::Result<()>;
+ /// Flush buffered bytes to the underlying transport.
+ fn flush(&mut self) -> ::Result<()>;
+
+ // utility (DO NOT USE IN GENERATED CODE!!!!)
+ //
+
+ /// Write an unsigned byte.
+ ///
+ /// This method should **never** be used in generated code.
+ fn write_byte(&mut self, b: u8) -> ::Result<()>; // FIXME: REMOVE
+}
+
+/// Helper type used by servers to create `TInputProtocol` instances for
+/// accepted client connections.
+///
+/// # Examples
+///
+/// Create a `TInputProtocolFactory` and use it to create a `TInputProtocol`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift::protocol::{TBinaryInputProtocolFactory, TInputProtocolFactory};
+/// use thrift::transport::{TTcpTransport, TTransport};
+///
+/// let mut transport = TTcpTransport::new();
+/// transport.open("127.0.0.1:9090").unwrap();
+/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+///
+/// let mut i_proto_factory = TBinaryInputProtocolFactory::new();
+/// let i_prot = i_proto_factory.create(transport);
+/// ```
+pub trait TInputProtocolFactory {
+ /// Create a `TInputProtocol` that reads bytes from `transport`.
+ fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TInputProtocol>;
+}
+
+/// Helper type used by servers to create `TOutputProtocol` instances for
+/// accepted client connections.
+///
+/// # Examples
+///
+/// Create a `TOutputProtocolFactory` and use it to create a `TOutputProtocol`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift::protocol::{TBinaryOutputProtocolFactory, TOutputProtocolFactory};
+/// use thrift::transport::{TTcpTransport, TTransport};
+///
+/// let mut transport = TTcpTransport::new();
+/// transport.open("127.0.0.1:9090").unwrap();
+/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+///
+/// let mut o_proto_factory = TBinaryOutputProtocolFactory::new();
+/// let o_prot = o_proto_factory.create(transport);
+/// ```
+pub trait TOutputProtocolFactory {
+ /// Create a `TOutputProtocol` that writes bytes to `transport`.
+ fn create(&mut self, transport: Rc<RefCell<Box<TTransport>>>) -> Box<TOutputProtocol>;
+}
+
+/// Thrift message identifier.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct TMessageIdentifier {
+ /// Service call the message is associated with.
+ pub name: String,
+ /// Message type.
+ pub message_type: TMessageType,
+ /// Ordered sequence number identifying the message.
+ pub sequence_number: i32,
+}
+
+impl TMessageIdentifier {
+ /// Create a `TMessageIdentifier` for a Thrift service-call named `name`
+ /// with message type `message_type` and sequence number `sequence_number`.
+ pub fn new<S: Into<String>>(name: S,
+ message_type: TMessageType,
+ sequence_number: i32)
+ -> TMessageIdentifier {
+ TMessageIdentifier {
+ name: name.into(),
+ message_type: message_type,
+ sequence_number: sequence_number,
+ }
+ }
+}
+
+/// Thrift struct identifier.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct TStructIdentifier {
+ /// Name of the encoded Thrift struct.
+ pub name: String,
+}
+
+impl TStructIdentifier {
+ /// Create a `TStructIdentifier` for a struct named `name`.
+ pub fn new<S: Into<String>>(name: S) -> TStructIdentifier {
+ TStructIdentifier { name: name.into() }
+ }
+}
+
+/// Thrift field identifier.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct TFieldIdentifier {
+ /// Name of the Thrift field.
+ ///
+ /// `None` if it's not sent over the wire.
+ pub name: Option<String>,
+ /// Field type.
+ ///
+ /// This may be a primitive, container, or a struct.
+ pub field_type: TType,
+ /// Thrift field id.
+ ///
+ /// `None` only if `field_type` is `TType::Stop`.
+ pub id: Option<i16>,
+}
+
+impl TFieldIdentifier {
+ /// Create a `TFieldIdentifier` for a field named `name` with type
+ /// `field_type` and field id `id`.
+ ///
+ /// `id` should be `None` if `field_type` is `TType::Stop`.
+ pub fn new<N, S, I>(name: N, field_type: TType, id: I) -> TFieldIdentifier
+ where N: Into<Option<S>>,
+ S: Into<String>,
+ I: Into<Option<i16>>
+ {
+ TFieldIdentifier {
+ name: name.into().map(|n| n.into()),
+ field_type: field_type,
+ id: id.into(),
+ }
+ }
+}
+
+/// Thrift list identifier.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct TListIdentifier {
+ /// Type of the elements in the list.
+ pub element_type: TType,
+ /// Number of elements in the list.
+ pub size: i32,
+}
+
+impl TListIdentifier {
+ /// Create a `TListIdentifier` for a list with `size` elements of type
+ /// `element_type`.
+ pub fn new(element_type: TType, size: i32) -> TListIdentifier {
+ TListIdentifier {
+ element_type: element_type,
+ size: size,
+ }
+ }
+}
+
+/// Thrift set identifier.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct TSetIdentifier {
+ /// Type of the elements in the set.
+ pub element_type: TType,
+ /// Number of elements in the set.
+ pub size: i32,
+}
+
+impl TSetIdentifier {
+ /// Create a `TSetIdentifier` for a set with `size` elements of type
+ /// `element_type`.
+ pub fn new(element_type: TType, size: i32) -> TSetIdentifier {
+ TSetIdentifier {
+ element_type: element_type,
+ size: size,
+ }
+ }
+}
+
+/// Thrift map identifier.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct TMapIdentifier {
+ /// Map key type.
+ pub key_type: Option<TType>,
+ /// Map value type.
+ pub value_type: Option<TType>,
+ /// Number of entries in the map.
+ pub size: i32,
+}
+
+impl TMapIdentifier {
+ /// Create a `TMapIdentifier` for a map with `size` entries of type
+ /// `key_type -> value_type`.
+ pub fn new<K, V>(key_type: K, value_type: V, size: i32) -> TMapIdentifier
+ where K: Into<Option<TType>>,
+ V: Into<Option<TType>>
+ {
+ TMapIdentifier {
+ key_type: key_type.into(),
+ value_type: value_type.into(),
+ size: size,
+ }
+ }
+}
+
+/// Thrift message types.
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+pub enum TMessageType {
+ /// Service-call request.
+ Call,
+ /// Service-call response.
+ Reply,
+ /// Unexpected error in the remote service.
+ Exception,
+ /// One-way service-call request (no response is expected).
+ OneWay,
+}
+
+impl Display for TMessageType {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ match *self {
+ TMessageType::Call => write!(f, "Call"),
+ TMessageType::Reply => write!(f, "Reply"),
+ TMessageType::Exception => write!(f, "Exception"),
+ TMessageType::OneWay => write!(f, "OneWay"),
+ }
+ }
+}
+
+impl From<TMessageType> for u8 {
+ fn from(message_type: TMessageType) -> Self {
+ match message_type {
+ TMessageType::Call => 0x01,
+ TMessageType::Reply => 0x02,
+ TMessageType::Exception => 0x03,
+ TMessageType::OneWay => 0x04,
+ }
+ }
+}
+
+impl TryFrom<u8> for TMessageType {
+ type Err = ::Error;
+ fn try_from(b: u8) -> ::Result<Self> {
+ match b {
+ 0x01 => Ok(TMessageType::Call),
+ 0x02 => Ok(TMessageType::Reply),
+ 0x03 => Ok(TMessageType::Exception),
+ 0x04 => Ok(TMessageType::OneWay),
+ unkn => {
+ Err(::Error::Protocol(ProtocolError {
+ kind: ProtocolErrorKind::InvalidData,
+ message: format!("cannot convert {} to TMessageType", unkn),
+ }))
+ }
+ }
+ }
+}
+
+/// Thrift struct-field types.
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
+pub enum TType {
+ /// Indicates that there are no more serialized fields in this Thrift struct.
+ Stop,
+ /// Void (`()`) field.
+ Void,
+ /// Boolean.
+ Bool,
+ /// Signed 8-bit int.
+ I08,
+ /// Double-precision number.
+ Double,
+ /// Signed 16-bit int.
+ I16,
+ /// Signed 32-bit int.
+ I32,
+ /// Signed 64-bit int.
+ I64,
+ /// UTF-8 string.
+ String,
+ /// UTF-7 string. *Unsupported*.
+ Utf7,
+ /// Thrift struct.
+ Struct,
+ /// Map.
+ Map,
+ /// Set.
+ Set,
+ /// List.
+ List,
+ /// UTF-8 string.
+ Utf8,
+ /// UTF-16 string. *Unsupported*.
+ Utf16,
+}
+
+impl Display for TType {
+ fn fmt(&self, f: &mut Formatter) -> fmt::Result {
+ match *self {
+ TType::Stop => write!(f, "STOP"),
+ TType::Void => write!(f, "void"),
+ TType::Bool => write!(f, "bool"),
+ TType::I08 => write!(f, "i08"),
+ TType::Double => write!(f, "double"),
+ TType::I16 => write!(f, "i16"),
+ TType::I32 => write!(f, "i32"),
+ TType::I64 => write!(f, "i64"),
+ TType::String => write!(f, "string"),
+ TType::Utf7 => write!(f, "UTF7"),
+ TType::Struct => write!(f, "struct"),
+ TType::Map => write!(f, "map"),
+ TType::Set => write!(f, "set"),
+ TType::List => write!(f, "list"),
+ TType::Utf8 => write!(f, "UTF8"),
+ TType::Utf16 => write!(f, "UTF16"),
+ }
+ }
+}
+
+/// Compare the expected message sequence number `expected` with the received
+/// message sequence number `actual`.
+///
+/// Return `()` if `actual == expected`, `Err` otherwise.
+pub fn verify_expected_sequence_number(expected: i32, actual: i32) -> ::Result<()> {
+ if expected == actual {
+ Ok(())
+ } else {
+ Err(::Error::Application(::ApplicationError {
+ kind: ::ApplicationErrorKind::BadSequenceId,
+ message: format!("expected {} got {}", expected, actual),
+ }))
+ }
+}
+
+/// Compare the expected service-call name `expected` with the received
+/// service-call name `actual`.
+///
+/// Return `()` if `actual == expected`, `Err` otherwise.
+pub fn verify_expected_service_call(expected: &str, actual: &str) -> ::Result<()> {
+ if expected == actual {
+ Ok(())
+ } else {
+ Err(::Error::Application(::ApplicationError {
+ kind: ::ApplicationErrorKind::WrongMethodName,
+ message: format!("expected {} got {}", expected, actual),
+ }))
+ }
+}
+
+/// Compare the expected message type `expected` with the received message type
+/// `actual`.
+///
+/// Return `()` if `actual == expected`, `Err` otherwise.
+pub fn verify_expected_message_type(expected: TMessageType, actual: TMessageType) -> ::Result<()> {
+ if expected == actual {
+ Ok(())
+ } else {
+ Err(::Error::Application(::ApplicationError {
+ kind: ::ApplicationErrorKind::InvalidMessageType,
+ message: format!("expected {} got {}", expected, actual),
+ }))
+ }
+}
+
+/// Check if a required Thrift struct field exists.
+///
+/// Return `()` if it does, `Err` otherwise.
+pub fn verify_required_field_exists<T>(field_name: &str, field: &Option<T>) -> ::Result<()> {
+ match *field {
+ Some(_) => Ok(()),
+ None => {
+ Err(::Error::Protocol(::ProtocolError {
+ kind: ::ProtocolErrorKind::Unknown,
+ message: format!("missing required field {}", field_name),
+ }))
+ }
+ }
+}
+
+/// Extract the field id from a Thrift field identifier.
+///
+/// `field_ident` must *not* have `TFieldIdentifier.field_type` of type `TType::Stop`.
+///
+/// Return `TFieldIdentifier.id` if an id exists, `Err` otherwise.
+pub fn field_id(field_ident: &TFieldIdentifier) -> ::Result<i16> {
+ field_ident.id.ok_or_else(|| {
+ ::Error::Protocol(::ProtocolError {
+ kind: ::ProtocolErrorKind::Unknown,
+ message: format!("missing field in in {:?}", field_ident),
+ })
+ })
+}
diff --git a/lib/rs/src/protocol/multiplexed.rs b/lib/rs/src/protocol/multiplexed.rs
new file mode 100644
index 0000000..15fe608
--- /dev/null
+++ b/lib/rs/src/protocol/multiplexed.rs
@@ -0,0 +1,219 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifier, TMessageType,
+ TOutputProtocol, TSetIdentifier, TStructIdentifier};
+
+/// `TOutputProtocol` that prefixes the service name to all outgoing Thrift
+/// messages.
+///
+/// A `TMultiplexedOutputProtocol` should be used when multiple Thrift services
+/// send messages over a single I/O channel. By prefixing service identifiers
+/// to outgoing messages receivers are able to demux them and route them to the
+/// appropriate service processor. Rust receivers must use a `TMultiplexedProcessor`
+/// to process incoming messages, while other languages must use their
+/// corresponding multiplexed processor implementations.
+///
+/// For example, given a service `TestService` and a service call `test_call`,
+/// this implementation would identify messages as originating from
+/// `TestService:test_call`.
+///
+/// # Examples
+///
+/// Create and use a `TMultiplexedOutputProtocol`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift::protocol::{TMessageIdentifier, TMessageType, TOutputProtocol};
+/// use thrift::protocol::{TBinaryOutputProtocol, TMultiplexedOutputProtocol};
+/// use thrift::transport::{TTcpTransport, TTransport};
+///
+/// let mut transport = TTcpTransport::new();
+/// transport.open("localhost:9090").unwrap();
+/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+///
+/// let o_prot = TBinaryOutputProtocol::new(transport, true);
+/// let mut o_prot = TMultiplexedOutputProtocol::new("service_name", Box::new(o_prot));
+///
+/// let ident = TMessageIdentifier::new("svc_call", TMessageType::Call, 1);
+/// o_prot.write_message_begin(&ident).unwrap();
+/// ```
+pub struct TMultiplexedOutputProtocol {
+ service_name: String,
+ inner: Box<TOutputProtocol>,
+}
+
+impl TMultiplexedOutputProtocol {
+ /// Create a `TMultiplexedOutputProtocol` that identifies outgoing messages
+ /// as originating from a service named `service_name` and sends them over
+ /// the `wrapped` `TOutputProtocol`. Outgoing messages are encoded and sent
+ /// by `wrapped`, not by this instance.
+ pub fn new(service_name: &str, wrapped: Box<TOutputProtocol>) -> TMultiplexedOutputProtocol {
+ TMultiplexedOutputProtocol {
+ service_name: service_name.to_owned(),
+ inner: wrapped,
+ }
+ }
+}
+
+// FIXME: avoid passthrough methods
+impl TOutputProtocol for TMultiplexedOutputProtocol {
+ fn write_message_begin(&mut self, identifier: &TMessageIdentifier) -> ::Result<()> {
+ match identifier.message_type { // FIXME: is there a better way to override identifier here?
+ TMessageType::Call | TMessageType::OneWay => {
+ let identifier = TMessageIdentifier {
+ name: format!("{}:{}", self.service_name, identifier.name),
+ ..*identifier
+ };
+ self.inner.write_message_begin(&identifier)
+ }
+ _ => self.inner.write_message_begin(identifier),
+ }
+ }
+
+ fn write_message_end(&mut self) -> ::Result<()> {
+ self.inner.write_message_end()
+ }
+
+ fn write_struct_begin(&mut self, identifier: &TStructIdentifier) -> ::Result<()> {
+ self.inner.write_struct_begin(identifier)
+ }
+
+ fn write_struct_end(&mut self) -> ::Result<()> {
+ self.inner.write_struct_end()
+ }
+
+ fn write_field_begin(&mut self, identifier: &TFieldIdentifier) -> ::Result<()> {
+ self.inner.write_field_begin(identifier)
+ }
+
+ fn write_field_end(&mut self) -> ::Result<()> {
+ self.inner.write_field_end()
+ }
+
+ fn write_field_stop(&mut self) -> ::Result<()> {
+ self.inner.write_field_stop()
+ }
+
+ fn write_bytes(&mut self, b: &[u8]) -> ::Result<()> {
+ self.inner.write_bytes(b)
+ }
+
+ fn write_bool(&mut self, b: bool) -> ::Result<()> {
+ self.inner.write_bool(b)
+ }
+
+ fn write_i8(&mut self, i: i8) -> ::Result<()> {
+ self.inner.write_i8(i)
+ }
+
+ fn write_i16(&mut self, i: i16) -> ::Result<()> {
+ self.inner.write_i16(i)
+ }
+
+ fn write_i32(&mut self, i: i32) -> ::Result<()> {
+ self.inner.write_i32(i)
+ }
+
+ fn write_i64(&mut self, i: i64) -> ::Result<()> {
+ self.inner.write_i64(i)
+ }
+
+ fn write_double(&mut self, d: f64) -> ::Result<()> {
+ self.inner.write_double(d)
+ }
+
+ fn write_string(&mut self, s: &str) -> ::Result<()> {
+ self.inner.write_string(s)
+ }
+
+ fn write_list_begin(&mut self, identifier: &TListIdentifier) -> ::Result<()> {
+ self.inner.write_list_begin(identifier)
+ }
+
+ fn write_list_end(&mut self) -> ::Result<()> {
+ self.inner.write_list_end()
+ }
+
+ fn write_set_begin(&mut self, identifier: &TSetIdentifier) -> ::Result<()> {
+ self.inner.write_set_begin(identifier)
+ }
+
+ fn write_set_end(&mut self) -> ::Result<()> {
+ self.inner.write_set_end()
+ }
+
+ fn write_map_begin(&mut self, identifier: &TMapIdentifier) -> ::Result<()> {
+ self.inner.write_map_begin(identifier)
+ }
+
+ fn write_map_end(&mut self) -> ::Result<()> {
+ self.inner.write_map_end()
+ }
+
+ fn flush(&mut self) -> ::Result<()> {
+ self.inner.flush()
+ }
+
+ // utility
+ //
+
+ fn write_byte(&mut self, b: u8) -> ::Result<()> {
+ self.inner.write_byte(b)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use std::cell::RefCell;
+ use std::rc::Rc;
+
+ use ::protocol::{TBinaryOutputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol};
+ use ::transport::{TPassThruTransport, TTransport};
+ use ::transport::mem::TBufferTransport;
+
+ use super::*;
+
+ #[test]
+ fn must_write_message_begin_with_prefixed_service_name() {
+ let (trans, mut o_prot) = test_objects();
+
+ let ident = TMessageIdentifier::new("bar", TMessageType::Call, 2);
+ assert_success!(o_prot.write_message_begin(&ident));
+
+ let expected: [u8; 19] =
+ [0x80, 0x01 /* protocol identifier */, 0x00, 0x01 /* message type */, 0x00,
+ 0x00, 0x00, 0x07, 0x66, 0x6F, 0x6F /* "foo" */, 0x3A /* ":" */, 0x62, 0x61,
+ 0x72 /* "bar" */, 0x00, 0x00, 0x00, 0x02 /* sequence number */];
+
+ assert_eq!(&trans.borrow().write_buffer_to_vec(), &expected);
+ }
+
+ fn test_objects() -> (Rc<RefCell<Box<TBufferTransport>>>, TMultiplexedOutputProtocol) {
+ let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity(40, 40))));
+
+ let inner: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() });
+ let inner = Rc::new(RefCell::new(inner));
+
+ let o_prot = TBinaryOutputProtocol::new(inner.clone(), true);
+ let o_prot = TMultiplexedOutputProtocol::new("foo", Box::new(o_prot));
+
+ (mem, o_prot)
+ }
+}
diff --git a/lib/rs/src/protocol/stored.rs b/lib/rs/src/protocol/stored.rs
new file mode 100644
index 0000000..6826c00
--- /dev/null
+++ b/lib/rs/src/protocol/stored.rs
@@ -0,0 +1,191 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::convert::Into;
+
+use ::ProtocolErrorKind;
+use super::{TFieldIdentifier, TListIdentifier, TMapIdentifier, TMessageIdentifier, TInputProtocol,
+ TSetIdentifier, TStructIdentifier};
+
+/// `TInputProtocol` required to use a `TMultiplexedProcessor`.
+///
+/// A `TMultiplexedProcessor` reads incoming message identifiers to determine to
+/// which `TProcessor` requests should be forwarded. However, once read, those
+/// message identifier bytes are no longer on the wire. Since downstream
+/// processors expect to read message identifiers from the given input protocol
+/// we need some way of supplying a `TMessageIdentifier` with the service-name
+/// stripped. This implementation stores the received `TMessageIdentifier`
+/// (without the service name) and passes it to the wrapped `TInputProtocol`
+/// when `TInputProtocol::read_message_begin(...)` is called. It delegates all
+/// other calls directly to the wrapped `TInputProtocol`.
+///
+/// This type **should not** be used by application code.
+///
+/// # Examples
+///
+/// Create and use a `TStoredInputProtocol`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift;
+/// use thrift::protocol::{TInputProtocol, TMessageIdentifier, TMessageType, TOutputProtocol};
+/// use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TStoredInputProtocol};
+/// use thrift::server::TProcessor;
+/// use thrift::transport::{TTcpTransport, TTransport};
+///
+/// // sample processor
+/// struct ActualProcessor;
+/// impl TProcessor for ActualProcessor {
+/// fn process(
+/// &mut self,
+/// _: &mut TInputProtocol,
+/// _: &mut TOutputProtocol
+/// ) -> thrift::Result<()> {
+/// unimplemented!()
+/// }
+/// }
+/// let mut processor = ActualProcessor {};
+///
+/// // construct the shared transport
+/// let mut transport = TTcpTransport::new();
+/// transport.open("localhost:9090").unwrap();
+/// let transport = Rc::new(RefCell::new(Box::new(transport) as Box<TTransport>));
+///
+/// // construct the actual input and output protocols
+/// let mut i_prot = TBinaryInputProtocol::new(transport.clone(), true);
+/// let mut o_prot = TBinaryOutputProtocol::new(transport.clone(), true);
+///
+/// // message identifier received from remote and modified to remove the service name
+/// let new_msg_ident = TMessageIdentifier::new("service_call", TMessageType::Call, 1);
+///
+/// // construct the proxy input protocol
+/// let mut proxy_i_prot = TStoredInputProtocol::new(&mut i_prot, new_msg_ident);
+/// let res = processor.process(&mut proxy_i_prot, &mut o_prot);
+/// ```
+pub struct TStoredInputProtocol<'a> {
+ inner: &'a mut TInputProtocol,
+ message_ident: Option<TMessageIdentifier>,
+}
+
+impl<'a> TStoredInputProtocol<'a> {
+ /// Create a `TStoredInputProtocol` that delegates all calls other than
+ /// `TInputProtocol::read_message_begin(...)` to a `wrapped`
+ /// `TInputProtocol`. `message_ident` is the modified message identifier -
+ /// with service name stripped - that will be passed to
+ /// `wrapped.read_message_begin(...)`.
+ pub fn new(wrapped: &mut TInputProtocol,
+ message_ident: TMessageIdentifier)
+ -> TStoredInputProtocol {
+ TStoredInputProtocol {
+ inner: wrapped,
+ message_ident: message_ident.into(),
+ }
+ }
+}
+
+impl<'a> TInputProtocol for TStoredInputProtocol<'a> {
+ fn read_message_begin(&mut self) -> ::Result<TMessageIdentifier> {
+ self.message_ident.take().ok_or_else(|| {
+ ::errors::new_protocol_error(ProtocolErrorKind::Unknown,
+ "message identifier already read")
+ })
+ }
+
+ fn read_message_end(&mut self) -> ::Result<()> {
+ self.inner.read_message_end()
+ }
+
+ fn read_struct_begin(&mut self) -> ::Result<Option<TStructIdentifier>> {
+ self.inner.read_struct_begin()
+ }
+
+ fn read_struct_end(&mut self) -> ::Result<()> {
+ self.inner.read_struct_end()
+ }
+
+ fn read_field_begin(&mut self) -> ::Result<TFieldIdentifier> {
+ self.inner.read_field_begin()
+ }
+
+ fn read_field_end(&mut self) -> ::Result<()> {
+ self.inner.read_field_end()
+ }
+
+ fn read_bytes(&mut self) -> ::Result<Vec<u8>> {
+ self.inner.read_bytes()
+ }
+
+ fn read_bool(&mut self) -> ::Result<bool> {
+ self.inner.read_bool()
+ }
+
+ fn read_i8(&mut self) -> ::Result<i8> {
+ self.inner.read_i8()
+ }
+
+ fn read_i16(&mut self) -> ::Result<i16> {
+ self.inner.read_i16()
+ }
+
+ fn read_i32(&mut self) -> ::Result<i32> {
+ self.inner.read_i32()
+ }
+
+ fn read_i64(&mut self) -> ::Result<i64> {
+ self.inner.read_i64()
+ }
+
+ fn read_double(&mut self) -> ::Result<f64> {
+ self.inner.read_double()
+ }
+
+ fn read_string(&mut self) -> ::Result<String> {
+ self.inner.read_string()
+ }
+
+ fn read_list_begin(&mut self) -> ::Result<TListIdentifier> {
+ self.inner.read_list_begin()
+ }
+
+ fn read_list_end(&mut self) -> ::Result<()> {
+ self.inner.read_list_end()
+ }
+
+ fn read_set_begin(&mut self) -> ::Result<TSetIdentifier> {
+ self.inner.read_set_begin()
+ }
+
+ fn read_set_end(&mut self) -> ::Result<()> {
+ self.inner.read_set_end()
+ }
+
+ fn read_map_begin(&mut self) -> ::Result<TMapIdentifier> {
+ self.inner.read_map_begin()
+ }
+
+ fn read_map_end(&mut self) -> ::Result<()> {
+ self.inner.read_map_end()
+ }
+
+ // utility
+ //
+
+ fn read_byte(&mut self) -> ::Result<u8> {
+ self.inner.read_byte()
+ }
+}
diff --git a/lib/rs/src/server/mod.rs b/lib/rs/src/server/mod.rs
new file mode 100644
index 0000000..ceac18a
--- /dev/null
+++ b/lib/rs/src/server/mod.rs
@@ -0,0 +1,95 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Types required to implement a Thrift server.
+
+use ::protocol::{TInputProtocol, TOutputProtocol};
+
+mod simple;
+mod multiplexed;
+
+pub use self::simple::TSimpleServer;
+pub use self::multiplexed::TMultiplexedProcessor;
+
+/// Handles incoming Thrift messages and dispatches them to the user-defined
+/// handler functions.
+///
+/// An implementation is auto-generated for each Thrift service. When used by a
+/// server (for example, a `TSimpleServer`), it will demux incoming service
+/// calls and invoke the corresponding user-defined handler function.
+///
+/// # Examples
+///
+/// Create and start a server using the auto-generated `TProcessor` for
+/// a Thrift service `SimpleService`.
+///
+/// ```no_run
+/// use thrift;
+/// use thrift::protocol::{TInputProtocol, TOutputProtocol};
+/// use thrift::server::TProcessor;
+///
+/// //
+/// // auto-generated
+/// //
+///
+/// // processor for `SimpleService`
+/// struct SimpleServiceSyncProcessor;
+/// impl SimpleServiceSyncProcessor {
+/// fn new<H: SimpleServiceSyncHandler>(processor: H) -> SimpleServiceSyncProcessor {
+/// unimplemented!();
+/// }
+/// }
+///
+/// // `TProcessor` implementation for `SimpleService`
+/// impl TProcessor for SimpleServiceSyncProcessor {
+/// fn process(&mut self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> {
+/// unimplemented!();
+/// }
+/// }
+///
+/// // service functions for SimpleService
+/// trait SimpleServiceSyncHandler {
+/// fn service_call(&mut self) -> thrift::Result<()>;
+/// }
+///
+/// //
+/// // user-code follows
+/// //
+///
+/// // define a handler that will be invoked when `service_call` is received
+/// struct SimpleServiceHandlerImpl;
+/// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl {
+/// fn service_call(&mut self) -> thrift::Result<()> {
+/// unimplemented!();
+/// }
+/// }
+///
+/// // instantiate the processor
+/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {});
+///
+/// // at this point you can pass the processor to the server
+/// // let server = TSimpleServer::new(..., processor);
+/// ```
+pub trait TProcessor {
+ /// Process a Thrift service call.
+ ///
+ /// Reads arguments from `i`, executes the user's handler code, and writes
+ /// the response to `o`.
+ ///
+ /// Returns `()` if the handler was executed; `Err` otherwise.
+ fn process(&mut self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> ::Result<()>;
+}
diff --git a/lib/rs/src/server/multiplexed.rs b/lib/rs/src/server/multiplexed.rs
new file mode 100644
index 0000000..d2314a1
--- /dev/null
+++ b/lib/rs/src/server/multiplexed.rs
@@ -0,0 +1,92 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+use std::convert::Into;
+
+use ::{new_application_error, ApplicationErrorKind};
+use ::protocol::{TInputProtocol, TMessageIdentifier, TOutputProtocol, TStoredInputProtocol};
+
+use super::TProcessor;
+
+/// A `TProcessor` that can demux service calls to multiple underlying
+/// Thrift services.
+///
+/// Users register service-specific `TProcessor` instances with a
+/// `TMultiplexedProcessor`, and then register that processor with a server
+/// implementation. Following that, all incoming service calls are automatically
+/// routed to the service-specific `TProcessor`.
+///
+/// A `TMultiplexedProcessor` can only handle messages sent by a
+/// `TMultiplexedOutputProtocol`.
+pub struct TMultiplexedProcessor {
+ processors: HashMap<String, Box<TProcessor>>,
+}
+
+impl TMultiplexedProcessor {
+ /// Register a service-specific `processor` for the service named
+ /// `service_name`.
+ ///
+ /// Return `true` if this is the first registration for `service_name`.
+ ///
+ /// Return `false` if a mapping previously existed (the previous mapping is
+ /// *not* overwritten).
+ #[cfg_attr(feature = "cargo-clippy", allow(map_entry))]
+ pub fn register_processor<S: Into<String>>(&mut self,
+ service_name: S,
+ processor: Box<TProcessor>)
+ -> bool {
+ let name = service_name.into();
+ if self.processors.contains_key(&name) {
+ false
+ } else {
+ self.processors.insert(name, processor);
+ true
+ }
+ }
+}
+
+impl TProcessor for TMultiplexedProcessor {
+ fn process(&mut self,
+ i_prot: &mut TInputProtocol,
+ o_prot: &mut TOutputProtocol)
+ -> ::Result<()> {
+ let msg_ident = i_prot.read_message_begin()?;
+ let sep_index = msg_ident.name
+ .find(':')
+ .ok_or_else(|| {
+ new_application_error(ApplicationErrorKind::Unknown,
+ "no service separator found in incoming message")
+ })?;
+
+ let (svc_name, svc_call) = msg_ident.name.split_at(sep_index);
+
+ match self.processors.get_mut(svc_name) {
+ Some(ref mut processor) => {
+ let new_msg_ident = TMessageIdentifier::new(svc_call,
+ msg_ident.message_type,
+ msg_ident.sequence_number);
+ let mut proxy_i_prot = TStoredInputProtocol::new(i_prot, new_msg_ident);
+ processor.process(&mut proxy_i_prot, o_prot)
+ }
+ None => {
+ Err(new_application_error(ApplicationErrorKind::Unknown,
+ format!("no processor found for service {}", svc_name)))
+ }
+ }
+ }
+}
diff --git a/lib/rs/src/server/simple.rs b/lib/rs/src/server/simple.rs
new file mode 100644
index 0000000..89ed977
--- /dev/null
+++ b/lib/rs/src/server/simple.rs
@@ -0,0 +1,189 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::cell::RefCell;
+use std::net::{TcpListener, TcpStream};
+use std::rc::Rc;
+
+use ::{ApplicationError, ApplicationErrorKind};
+use ::protocol::{TInputProtocolFactory, TOutputProtocolFactory};
+use ::transport::{TTcpTransport, TTransport, TTransportFactory};
+
+use super::TProcessor;
+
+/// Single-threaded blocking Thrift socket server.
+///
+/// A `TSimpleServer` listens on a given address and services accepted
+/// connections *synchronously* and *sequentially* - i.e. in a blocking manner,
+/// one at a time - on the main thread. Each accepted connection has an input
+/// half and an output half, each of which uses a `TTransport` and `TProtocol`
+/// to translate messages to and from byes. Any combination of `TProtocol` and
+/// `TTransport` may be used.
+///
+/// # Examples
+///
+/// Creating and running a `TSimpleServer` using Thrift-compiler-generated
+/// service code.
+///
+/// ```no_run
+/// use thrift;
+/// use thrift::protocol::{TInputProtocolFactory, TOutputProtocolFactory};
+/// use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory};
+/// use thrift::protocol::{TInputProtocol, TOutputProtocol};
+/// use thrift::transport::{TBufferedTransportFactory, TTransportFactory};
+/// use thrift::server::{TProcessor, TSimpleServer};
+///
+/// //
+/// // auto-generated
+/// //
+///
+/// // processor for `SimpleService`
+/// struct SimpleServiceSyncProcessor;
+/// impl SimpleServiceSyncProcessor {
+/// fn new<H: SimpleServiceSyncHandler>(processor: H) -> SimpleServiceSyncProcessor {
+/// unimplemented!();
+/// }
+/// }
+///
+/// // `TProcessor` implementation for `SimpleService`
+/// impl TProcessor for SimpleServiceSyncProcessor {
+/// fn process(&mut self, i: &mut TInputProtocol, o: &mut TOutputProtocol) -> thrift::Result<()> {
+/// unimplemented!();
+/// }
+/// }
+///
+/// // service functions for SimpleService
+/// trait SimpleServiceSyncHandler {
+/// fn service_call(&mut self) -> thrift::Result<()>;
+/// }
+///
+/// //
+/// // user-code follows
+/// //
+///
+/// // define a handler that will be invoked when `service_call` is received
+/// struct SimpleServiceHandlerImpl;
+/// impl SimpleServiceSyncHandler for SimpleServiceHandlerImpl {
+/// fn service_call(&mut self) -> thrift::Result<()> {
+/// unimplemented!();
+/// }
+/// }
+///
+/// // instantiate the processor
+/// let processor = SimpleServiceSyncProcessor::new(SimpleServiceHandlerImpl {});
+///
+/// // instantiate the server
+/// let i_tr_fact: Box<TTransportFactory> = Box::new(TBufferedTransportFactory::new());
+/// let i_pr_fact: Box<TInputProtocolFactory> = Box::new(TBinaryInputProtocolFactory::new());
+/// let o_tr_fact: Box<TTransportFactory> = Box::new(TBufferedTransportFactory::new());
+/// let o_pr_fact: Box<TOutputProtocolFactory> = Box::new(TBinaryOutputProtocolFactory::new());
+///
+/// let mut server = TSimpleServer::new(
+/// i_tr_fact,
+/// i_pr_fact,
+/// o_tr_fact,
+/// o_pr_fact,
+/// processor
+/// );
+///
+/// // start listening for incoming connections
+/// match server.listen("127.0.0.1:8080") {
+/// Ok(_) => println!("listen completed"),
+/// Err(e) => println!("listen failed with error {:?}", e),
+/// }
+/// ```
+pub struct TSimpleServer<PR: TProcessor> {
+ i_trans_factory: Box<TTransportFactory>,
+ i_proto_factory: Box<TInputProtocolFactory>,
+ o_trans_factory: Box<TTransportFactory>,
+ o_proto_factory: Box<TOutputProtocolFactory>,
+ processor: PR,
+}
+
+impl<PR: TProcessor> TSimpleServer<PR> {
+ /// Create a `TSimpleServer`.
+ ///
+ /// Each accepted connection has an input and output half, each of which
+ /// requires a `TTransport` and `TProtocol`. `TSimpleServer` uses
+ /// `input_transport_factory` and `input_protocol_factory` to create
+ /// implementations for the input, and `output_transport_factory` and
+ /// `output_protocol_factory` to create implementations for the output.
+ pub fn new(input_transport_factory: Box<TTransportFactory>,
+ input_protocol_factory: Box<TInputProtocolFactory>,
+ output_transport_factory: Box<TTransportFactory>,
+ output_protocol_factory: Box<TOutputProtocolFactory>,
+ processor: PR)
+ -> TSimpleServer<PR> {
+ TSimpleServer {
+ i_trans_factory: input_transport_factory,
+ i_proto_factory: input_protocol_factory,
+ o_trans_factory: output_transport_factory,
+ o_proto_factory: output_protocol_factory,
+ processor: processor,
+ }
+ }
+
+ /// Listen for incoming connections on `listen_address`.
+ ///
+ /// `listen_address` should be in the form `host:port`,
+ /// for example: `127.0.0.1:8080`.
+ ///
+ /// Return `()` if successful.
+ ///
+ /// Return `Err` when the server cannot bind to `listen_address` or there
+ /// is an unrecoverable error.
+ pub fn listen(&mut self, listen_address: &str) -> ::Result<()> {
+ let listener = TcpListener::bind(listen_address)?;
+ for stream in listener.incoming() {
+ match stream {
+ Ok(s) => self.handle_incoming_connection(s),
+ Err(e) => warn!("failed to accept remote connection with error {:?}", e),
+ }
+ }
+
+ Err(::Error::Application(ApplicationError {
+ kind: ApplicationErrorKind::Unknown,
+ message: "aborted listen loop".into(),
+ }))
+ }
+
+ fn handle_incoming_connection(&mut self, stream: TcpStream) {
+ // create the shared tcp stream
+ let stream = TTcpTransport::with_stream(stream);
+ let stream: Box<TTransport> = Box::new(stream);
+ let stream = Rc::new(RefCell::new(stream));
+
+ // input protocol and transport
+ let i_tran = self.i_trans_factory.create(stream.clone());
+ let i_tran = Rc::new(RefCell::new(i_tran));
+ let mut i_prot = self.i_proto_factory.create(i_tran);
+
+ // output protocol and transport
+ let o_tran = self.o_trans_factory.create(stream.clone());
+ let o_tran = Rc::new(RefCell::new(o_tran));
+ let mut o_prot = self.o_proto_factory.create(o_tran);
+
+ // process loop
+ loop {
+ let r = self.processor.process(&mut *i_prot, &mut *o_prot);
+ if let Err(e) = r {
+ warn!("processor failed with error: {:?}", e);
+ break; // FIXME: close here
+ }
+ }
+ }
+}
diff --git a/lib/rs/src/transport/buffered.rs b/lib/rs/src/transport/buffered.rs
new file mode 100644
index 0000000..3f240d8
--- /dev/null
+++ b/lib/rs/src/transport/buffered.rs
@@ -0,0 +1,400 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::cell::RefCell;
+use std::cmp;
+use std::io;
+use std::io::{Read, Write};
+use std::rc::Rc;
+
+use super::{TTransport, TTransportFactory};
+
+/// Default capacity of the read buffer in bytes.
+const DEFAULT_RBUFFER_CAPACITY: usize = 4096;
+
+/// Default capacity of the write buffer in bytes..
+const DEFAULT_WBUFFER_CAPACITY: usize = 4096;
+
+/// Transport that communicates with endpoints using a byte stream.
+///
+/// A `TBufferedTransport` maintains a fixed-size internal write buffer. All
+/// writes are made to this buffer and are sent to the wrapped transport only
+/// when `TTransport::flush()` is called. On a flush a fixed-length header with a
+/// count of the buffered bytes is written, followed by the bytes themselves.
+///
+/// A `TBufferedTransport` also maintains a fixed-size internal read buffer.
+/// On a call to `TTransport::read(...)` one full message - both fixed-length
+/// header and bytes - is read from the wrapped transport and buffered.
+/// Subsequent read calls are serviced from the internal buffer until it is
+/// exhausted, at which point the next full message is read from the wrapped
+/// transport.
+///
+/// # Examples
+///
+/// Create and use a `TBufferedTransport`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use std::io::{Read, Write};
+/// use thrift::transport::{TBufferedTransport, TTcpTransport, TTransport};
+///
+/// let mut t = TTcpTransport::new();
+/// t.open("localhost:9090").unwrap();
+///
+/// let t = Rc::new(RefCell::new(Box::new(t) as Box<TTransport>));
+/// let mut t = TBufferedTransport::new(t);
+///
+/// // read
+/// t.read(&mut vec![0u8; 1]).unwrap();
+///
+/// // write
+/// t.write(&[0x00]).unwrap();
+/// t.flush().unwrap();
+/// ```
+pub struct TBufferedTransport {
+ rbuf: Box<[u8]>,
+ rpos: usize,
+ rcap: usize,
+ wbuf: Vec<u8>,
+ inner: Rc<RefCell<Box<TTransport>>>,
+}
+
+impl TBufferedTransport {
+ /// Create a `TBufferedTransport` with default-sized internal read and
+ /// write buffers that wraps an `inner` `TTransport`.
+ pub fn new(inner: Rc<RefCell<Box<TTransport>>>) -> TBufferedTransport {
+ TBufferedTransport::with_capacity(DEFAULT_RBUFFER_CAPACITY, DEFAULT_WBUFFER_CAPACITY, inner)
+ }
+
+ /// Create a `TBufferedTransport` with an internal read buffer of size
+ /// `read_buffer_capacity` and an internal write buffer of size
+ /// `write_buffer_capacity` that wraps an `inner` `TTransport`.
+ pub fn with_capacity(read_buffer_capacity: usize,
+ write_buffer_capacity: usize,
+ inner: Rc<RefCell<Box<TTransport>>>)
+ -> TBufferedTransport {
+ TBufferedTransport {
+ rbuf: vec![0; read_buffer_capacity].into_boxed_slice(),
+ rpos: 0,
+ rcap: 0,
+ wbuf: Vec::with_capacity(write_buffer_capacity),
+ inner: inner,
+ }
+ }
+
+ fn get_bytes(&mut self) -> io::Result<&[u8]> {
+ if self.rcap - self.rpos == 0 {
+ self.rpos = 0;
+ self.rcap = self.inner.borrow_mut().read(&mut self.rbuf)?;
+ }
+
+ Ok(&self.rbuf[self.rpos..self.rcap])
+ }
+
+ fn consume(&mut self, consumed: usize) {
+ // TODO: was a bug here += <-- test somehow
+ self.rpos = cmp::min(self.rcap, self.rpos + consumed);
+ }
+}
+
+impl Read for TBufferedTransport {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ let mut bytes_read = 0;
+
+ loop {
+ let nread = {
+ let avail_bytes = self.get_bytes()?;
+ let avail_space = buf.len() - bytes_read;
+ let nread = cmp::min(avail_space, avail_bytes.len());
+ buf[bytes_read..(bytes_read + nread)].copy_from_slice(&avail_bytes[..nread]);
+ nread
+ };
+
+ self.consume(nread);
+ bytes_read += nread;
+
+ if bytes_read == buf.len() || nread == 0 {
+ break;
+ }
+ }
+
+ Ok(bytes_read)
+ }
+}
+
+impl Write for TBufferedTransport {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ let avail_bytes = cmp::min(buf.len(), self.wbuf.capacity() - self.wbuf.len());
+ self.wbuf.extend_from_slice(&buf[..avail_bytes]);
+ assert!(self.wbuf.len() <= self.wbuf.capacity(),
+ "copy overflowed buffer");
+ Ok(avail_bytes)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ self.inner.borrow_mut().write_all(&self.wbuf)?;
+ self.inner.borrow_mut().flush()?;
+ self.wbuf.clear();
+ Ok(())
+ }
+}
+
+/// Factory for creating instances of `TBufferedTransport`
+#[derive(Default)]
+pub struct TBufferedTransportFactory;
+
+impl TBufferedTransportFactory {
+ /// Create a `TBufferedTransportFactory`.
+ pub fn new() -> TBufferedTransportFactory {
+ TBufferedTransportFactory {}
+ }
+}
+
+impl TTransportFactory for TBufferedTransportFactory {
+ fn create(&self, inner: Rc<RefCell<Box<TTransport>>>) -> Box<TTransport> {
+ Box::new(TBufferedTransport::new(inner)) as Box<TTransport>
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::cell::RefCell;
+ use std::io::{Read, Write};
+ use std::rc::Rc;
+
+ use super::*;
+ use ::transport::{TPassThruTransport, TTransport};
+ use ::transport::mem::TBufferTransport;
+
+ macro_rules! new_transports {
+ ($wbc:expr, $rbc:expr) => (
+ {
+ let mem = Rc::new(RefCell::new(Box::new(TBufferTransport::with_capacity($wbc, $rbc))));
+ let thru: Box<TTransport> = Box::new(TPassThruTransport { inner: mem.clone() });
+ let thru = Rc::new(RefCell::new(thru));
+ (mem, thru)
+ }
+ );
+ }
+
+ #[test]
+ fn must_return_zero_if_read_buffer_is_empty() {
+ let (_, thru) = new_transports!(10, 0);
+ let mut t = TBufferedTransport::with_capacity(10, 0, thru);
+
+ let mut b = vec![0; 10];
+ let read_result = t.read(&mut b);
+
+ assert_eq!(read_result.unwrap(), 0);
+ }
+
+ #[test]
+ fn must_return_zero_if_caller_reads_into_zero_capacity_buffer() {
+ let (_, thru) = new_transports!(10, 0);
+ let mut t = TBufferedTransport::with_capacity(10, 0, thru);
+
+ let read_result = t.read(&mut []);
+
+ assert_eq!(read_result.unwrap(), 0);
+ }
+
+ #[test]
+ fn must_return_zero_if_nothing_more_can_be_read() {
+ let (mem, thru) = new_transports!(4, 0);
+ let mut t = TBufferedTransport::with_capacity(4, 0, thru);
+
+ mem.borrow_mut().set_readable_bytes(&[0, 1, 2, 3]);
+
+ // read buffer is exactly the same size as bytes available
+ let mut buf = vec![0u8; 4];
+ let read_result = t.read(&mut buf);
+
+ // we've read exactly 4 bytes
+ assert_eq!(read_result.unwrap(), 4);
+ assert_eq!(&buf, &[0, 1, 2, 3]);
+
+ // try read again
+ let buf_again = vec![0u8; 4];
+ let read_result = t.read(&mut buf);
+
+ // this time, 0 bytes and we haven't changed the buffer
+ assert_eq!(read_result.unwrap(), 0);
+ assert_eq!(&buf_again, &[0, 0, 0, 0])
+ }
+
+ #[test]
+ fn must_fill_user_buffer_with_only_as_many_bytes_as_available() {
+ let (mem, thru) = new_transports!(4, 0);
+ let mut t = TBufferedTransport::with_capacity(4, 0, thru);
+
+ mem.borrow_mut().set_readable_bytes(&[0, 1, 2, 3]);
+
+ // read buffer is much larger than the bytes available
+ let mut buf = vec![0u8; 8];
+ let read_result = t.read(&mut buf);
+
+ // we've read exactly 4 bytes
+ assert_eq!(read_result.unwrap(), 4);
+ assert_eq!(&buf[..4], &[0, 1, 2, 3]);
+
+ // try read again
+ let read_result = t.read(&mut buf[4..]);
+
+ // this time, 0 bytes and we haven't changed the buffer
+ assert_eq!(read_result.unwrap(), 0);
+ assert_eq!(&buf, &[0, 1, 2, 3, 0, 0, 0, 0])
+ }
+
+ #[test]
+ fn must_read_successfully() {
+ // this test involves a few loops within the buffered transport
+ // itself where it has to drain the underlying transport in order
+ // to service a read
+
+ // we have a much smaller buffer than the
+ // underlying transport has bytes available
+ let (mem, thru) = new_transports!(10, 0);
+ let mut t = TBufferedTransport::with_capacity(2, 0, thru);
+
+ // fill the underlying transport's byte buffer
+ let mut readable_bytes = [0u8; 10];
+ for i in 0..10 {
+ readable_bytes[i] = i as u8;
+ }
+ mem.borrow_mut().set_readable_bytes(&readable_bytes);
+
+ // we ask to read into a buffer that's much larger
+ // than the one the buffered transport has; as a result
+ // it's going to have to keep asking the underlying
+ // transport for more bytes
+ let mut buf = [0u8; 8];
+ let read_result = t.read(&mut buf);
+
+ // we should have read 8 bytes
+ assert_eq!(read_result.unwrap(), 8);
+ assert_eq!(&buf, &[0, 1, 2, 3, 4, 5, 6, 7]);
+
+ // let's clear out the buffer and try read again
+ for i in 0..8 {
+ buf[i] = 0;
+ }
+ let read_result = t.read(&mut buf);
+
+ // this time we were only able to read 2 bytes
+ // (all that's remaining from the underlying transport)
+ // let's also check that the remaining bytes are untouched
+ assert_eq!(read_result.unwrap(), 2);
+ assert_eq!(&buf[0..2], &[8, 9]);
+ assert_eq!(&buf[2..], &[0, 0, 0, 0, 0, 0]);
+
+ // try read again (we should get 0)
+ // and all the existing bytes were untouched
+ let read_result = t.read(&mut buf);
+ assert_eq!(read_result.unwrap(), 0);
+ assert_eq!(&buf[0..2], &[8, 9]);
+ assert_eq!(&buf[2..], &[0, 0, 0, 0, 0, 0]);
+ }
+
+ #[test]
+ fn must_return_zero_if_nothing_can_be_written() {
+ let (_, thru) = new_transports!(0, 0);
+ let mut t = TBufferedTransport::with_capacity(0, 0, thru);
+
+ let b = vec![0; 10];
+ let r = t.write(&b);
+
+ assert_eq!(r.unwrap(), 0);
+ }
+
+ #[test]
+ fn must_return_zero_if_caller_calls_write_with_empty_buffer() {
+ let (mem, thru) = new_transports!(0, 10);
+ let mut t = TBufferedTransport::with_capacity(0, 10, thru);
+
+ let r = t.write(&[]);
+
+ assert_eq!(r.unwrap(), 0);
+ assert_eq!(mem.borrow_mut().write_buffer_as_ref(), &[]);
+ }
+
+ #[test]
+ fn must_return_zero_if_write_buffer_full() {
+ let (_, thru) = new_transports!(0, 0);
+ let mut t = TBufferedTransport::with_capacity(0, 4, thru);
+
+ let b = [0x00, 0x01, 0x02, 0x03];
+
+ // we've now filled the write buffer
+ let r = t.write(&b);
+ assert_eq!(r.unwrap(), 4);
+
+ // try write the same bytes again - nothing should be writable
+ let r = t.write(&b);
+ assert_eq!(r.unwrap(), 0);
+ }
+
+ #[test]
+ fn must_only_write_to_inner_transport_on_flush() {
+ let (mem, thru) = new_transports!(10, 10);
+ let mut t = TBufferedTransport::new(thru);
+
+ let b: [u8; 5] = [0, 1, 2, 3, 4];
+ assert_eq!(t.write(&b).unwrap(), 5);
+ assert_eq!(mem.borrow_mut().write_buffer_as_ref().len(), 0);
+
+ assert!(t.flush().is_ok());
+
+ {
+ let inner = mem.borrow_mut();
+ let underlying_buffer = inner.write_buffer_as_ref();
+ assert_eq!(b, underlying_buffer);
+ }
+ }
+
+ #[test]
+ fn must_write_successfully_after_flush() {
+ let (mem, thru) = new_transports!(0, 5);
+ let mut t = TBufferedTransport::with_capacity(0, 5, thru);
+
+ // write and flush
+ let b: [u8; 5] = [0, 1, 2, 3, 4];
+ assert_eq!(t.write(&b).unwrap(), 5);
+ assert!(t.flush().is_ok());
+
+ // check the flushed bytes
+ {
+ let inner = mem.borrow_mut();
+ let underlying_buffer = inner.write_buffer_as_ref();
+ assert_eq!(b, underlying_buffer);
+ }
+
+ // reset our underlying transport
+ mem.borrow_mut().empty_write_buffer();
+
+ // write and flush again
+ assert_eq!(t.write(&b).unwrap(), 5);
+ assert!(t.flush().is_ok());
+
+ // check the flushed bytes
+ {
+ let inner = mem.borrow_mut();
+ let underlying_buffer = inner.write_buffer_as_ref();
+ assert_eq!(b, underlying_buffer);
+ }
+ }
+}
diff --git a/lib/rs/src/transport/framed.rs b/lib/rs/src/transport/framed.rs
new file mode 100644
index 0000000..75c12f4
--- /dev/null
+++ b/lib/rs/src/transport/framed.rs
@@ -0,0 +1,187 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
+use std::cell::RefCell;
+use std::cmp;
+use std::io;
+use std::io::{ErrorKind, Read, Write};
+use std::rc::Rc;
+
+use super::{TTransport, TTransportFactory};
+
+/// Default capacity of the read buffer in bytes.
+const WRITE_BUFFER_CAPACITY: usize = 4096;
+
+/// Default capacity of the write buffer in bytes..
+const DEFAULT_WBUFFER_CAPACITY: usize = 4096;
+
+/// Transport that communicates with endpoints using framed messages.
+///
+/// A `TFramedTransport` maintains a fixed-size internal write buffer. All
+/// writes are made to this buffer and are sent to the wrapped transport only
+/// when `TTransport::flush()` is called. On a flush a fixed-length header with a
+/// count of the buffered bytes is written, followed by the bytes themselves.
+///
+/// A `TFramedTransport` also maintains a fixed-size internal read buffer.
+/// On a call to `TTransport::read(...)` one full message - both fixed-length
+/// header and bytes - is read from the wrapped transport and buffered.
+/// Subsequent read calls are serviced from the internal buffer until it is
+/// exhausted, at which point the next full message is read from the wrapped
+/// transport.
+///
+/// # Examples
+///
+/// Create and use a `TFramedTransport`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use std::io::{Read, Write};
+/// use thrift::transport::{TFramedTransport, TTcpTransport, TTransport};
+///
+/// let mut t = TTcpTransport::new();
+/// t.open("localhost:9090").unwrap();
+///
+/// let t = Rc::new(RefCell::new(Box::new(t) as Box<TTransport>));
+/// let mut t = TFramedTransport::new(t);
+///
+/// // read
+/// t.read(&mut vec![0u8; 1]).unwrap();
+///
+/// // write
+/// t.write(&[0x00]).unwrap();
+/// t.flush().unwrap();
+/// ```
+pub struct TFramedTransport {
+ rbuf: Box<[u8]>,
+ rpos: usize,
+ rcap: usize,
+ wbuf: Box<[u8]>,
+ wpos: usize,
+ inner: Rc<RefCell<Box<TTransport>>>,
+}
+
+impl TFramedTransport {
+ /// Create a `TFramedTransport` with default-sized internal read and
+ /// write buffers that wraps an `inner` `TTransport`.
+ pub fn new(inner: Rc<RefCell<Box<TTransport>>>) -> TFramedTransport {
+ TFramedTransport::with_capacity(WRITE_BUFFER_CAPACITY, DEFAULT_WBUFFER_CAPACITY, inner)
+ }
+
+ /// Create a `TFramedTransport` with an internal read buffer of size
+ /// `read_buffer_capacity` and an internal write buffer of size
+ /// `write_buffer_capacity` that wraps an `inner` `TTransport`.
+ pub fn with_capacity(read_buffer_capacity: usize,
+ write_buffer_capacity: usize,
+ inner: Rc<RefCell<Box<TTransport>>>)
+ -> TFramedTransport {
+ TFramedTransport {
+ rbuf: vec![0; read_buffer_capacity].into_boxed_slice(),
+ rpos: 0,
+ rcap: 0,
+ wbuf: vec![0; write_buffer_capacity].into_boxed_slice(),
+ wpos: 0,
+ inner: inner,
+ }
+ }
+}
+
+impl Read for TFramedTransport {
+ fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
+ if self.rcap - self.rpos == 0 {
+ let message_size = self.inner.borrow_mut().read_i32::<BigEndian>()? as usize;
+ if message_size > self.rbuf.len() {
+ return Err(io::Error::new(ErrorKind::Other,
+ format!("bytes to be read ({}) exceeds buffer \
+ capacity ({})",
+ message_size,
+ self.rbuf.len())));
+ }
+ self.inner.borrow_mut().read_exact(&mut self.rbuf[..message_size])?;
+ self.rpos = 0;
+ self.rcap = message_size as usize;
+ }
+
+ let nread = cmp::min(b.len(), self.rcap - self.rpos);
+ b[..nread].clone_from_slice(&self.rbuf[self.rpos..self.rpos + nread]);
+ self.rpos += nread;
+
+ Ok(nread)
+ }
+}
+
+impl Write for TFramedTransport {
+ fn write(&mut self, b: &[u8]) -> io::Result<usize> {
+ if b.len() > (self.wbuf.len() - self.wpos) {
+ return Err(io::Error::new(ErrorKind::Other,
+ format!("bytes to be written ({}) exceeds buffer \
+ capacity ({})",
+ b.len(),
+ self.wbuf.len() - self.wpos)));
+ }
+
+ let nwrite = b.len(); // always less than available write buffer capacity
+ self.wbuf[self.wpos..(self.wpos + nwrite)].clone_from_slice(b);
+ self.wpos += nwrite;
+ Ok(nwrite)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ let message_size = self.wpos;
+
+ if let 0 = message_size {
+ return Ok(());
+ } else {
+ self.inner.borrow_mut().write_i32::<BigEndian>(message_size as i32)?;
+ }
+
+ let mut byte_index = 0;
+ while byte_index < self.wpos {
+ let nwrite = self.inner.borrow_mut().write(&self.wbuf[byte_index..self.wpos])?;
+ byte_index = cmp::min(byte_index + nwrite, self.wpos);
+ }
+
+ self.wpos = 0;
+ self.inner.borrow_mut().flush()
+ }
+}
+
+/// Factory for creating instances of `TFramedTransport`.
+#[derive(Default)]
+pub struct TFramedTransportFactory;
+
+impl TFramedTransportFactory {
+ // Create a `TFramedTransportFactory`.
+ pub fn new() -> TFramedTransportFactory {
+ TFramedTransportFactory {}
+ }
+}
+
+impl TTransportFactory for TFramedTransportFactory {
+ fn create(&self, inner: Rc<RefCell<Box<TTransport>>>) -> Box<TTransport> {
+ Box::new(TFramedTransport::new(inner)) as Box<TTransport>
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ // use std::io::{Read, Write};
+ //
+ // use super::*;
+ // use ::transport::mem::TBufferTransport;
+}
diff --git a/lib/rs/src/transport/mem.rs b/lib/rs/src/transport/mem.rs
new file mode 100644
index 0000000..8ec2a98
--- /dev/null
+++ b/lib/rs/src/transport/mem.rs
@@ -0,0 +1,342 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::cmp;
+use std::io;
+
+/// Simple transport that contains both a fixed-length internal read buffer and
+/// a fixed-length internal write buffer.
+///
+/// On a `write` bytes are written to the internal write buffer. Writes are no
+/// longer accepted once this buffer is full. Callers must `empty_write_buffer()`
+/// before subsequent writes are accepted.
+///
+/// You can set readable bytes in the internal read buffer by filling it with
+/// `set_readable_bytes(...)`. Callers can then read until the buffer is
+/// depleted. No further reads are accepted until the internal read buffer is
+/// replenished again.
+pub struct TBufferTransport {
+ rbuf: Box<[u8]>,
+ rpos: usize,
+ ridx: usize,
+ rcap: usize,
+ wbuf: Box<[u8]>,
+ wpos: usize,
+ wcap: usize,
+}
+
+impl TBufferTransport {
+ /// Constructs a new, empty `TBufferTransport` with the given
+ /// read buffer capacity and write buffer capacity.
+ pub fn with_capacity(read_buffer_capacity: usize,
+ write_buffer_capacity: usize)
+ -> TBufferTransport {
+ TBufferTransport {
+ rbuf: vec![0; read_buffer_capacity].into_boxed_slice(),
+ ridx: 0,
+ rpos: 0,
+ rcap: read_buffer_capacity,
+ wbuf: vec![0; write_buffer_capacity].into_boxed_slice(),
+ wpos: 0,
+ wcap: write_buffer_capacity,
+ }
+ }
+
+ /// Return a slice containing the bytes held by the internal read buffer.
+ /// Returns an empty slice if no readable bytes are present.
+ pub fn read_buffer(&self) -> &[u8] {
+ &self.rbuf[..self.ridx]
+ }
+
+ // FIXME: do I really need this API call?
+ // FIXME: should this simply reset to the last set of readable bytes?
+ /// Reset the number of readable bytes to zero.
+ ///
+ /// Subsequent calls to `read` will return nothing.
+ pub fn empty_read_buffer(&mut self) {
+ self.rpos = 0;
+ self.ridx = 0;
+ }
+
+ /// Copy bytes from the source buffer `buf` into the internal read buffer,
+ /// overwriting any existing bytes. Returns the number of bytes copied,
+ /// which is `min(buf.len(), internal_read_buf.len())`.
+ pub fn set_readable_bytes(&mut self, buf: &[u8]) -> usize {
+ self.empty_read_buffer();
+ let max_bytes = cmp::min(self.rcap, buf.len());
+ self.rbuf[..max_bytes].clone_from_slice(&buf[..max_bytes]);
+ self.ridx = max_bytes;
+ max_bytes
+ }
+
+ /// Return a slice containing the bytes held by the internal write buffer.
+ /// Returns an empty slice if no bytes were written.
+ pub fn write_buffer_as_ref(&self) -> &[u8] {
+ &self.wbuf[..self.wpos]
+ }
+
+ /// Return a vector with a copy of the bytes held by the internal write buffer.
+ /// Returns an empty vector if no bytes were written.
+ pub fn write_buffer_to_vec(&self) -> Vec<u8> {
+ let mut buf = vec![0u8; self.wpos];
+ buf.copy_from_slice(&self.wbuf[..self.wpos]);
+ buf
+ }
+
+ /// Resets the internal write buffer, making it seem like no bytes were
+ /// written. Calling `write_buffer` after this returns an empty slice.
+ pub fn empty_write_buffer(&mut self) {
+ self.wpos = 0;
+ }
+
+ /// Overwrites the contents of the read buffer with the contents of the
+ /// write buffer. The write buffer is emptied after this operation.
+ pub fn copy_write_buffer_to_read_buffer(&mut self) {
+ let buf = {
+ let b = self.write_buffer_as_ref();
+ let mut b_ret = vec![0; b.len()];
+ b_ret.copy_from_slice(&b);
+ b_ret
+ };
+
+ let bytes_copied = self.set_readable_bytes(&buf);
+ assert_eq!(bytes_copied, buf.len());
+
+ self.empty_write_buffer();
+ }
+}
+
+impl io::Read for TBufferTransport {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ let nread = cmp::min(buf.len(), self.ridx - self.rpos);
+ buf[..nread].clone_from_slice(&self.rbuf[self.rpos..self.rpos + nread]);
+ self.rpos += nread;
+ Ok(nread)
+ }
+}
+
+impl io::Write for TBufferTransport {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ let nwrite = cmp::min(buf.len(), self.wcap - self.wpos);
+ self.wbuf[self.wpos..self.wpos + nwrite].clone_from_slice(&buf[..nwrite]);
+ self.wpos += nwrite;
+ Ok(nwrite)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ Ok(()) // nothing to do on flush
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::io::{Read, Write};
+
+ use super::TBufferTransport;
+
+ #[test]
+ fn must_empty_write_buffer() {
+ let mut t = TBufferTransport::with_capacity(0, 1);
+
+ let bytes_to_write: [u8; 1] = [0x01];
+ let result = t.write(&bytes_to_write);
+ assert_eq!(result.unwrap(), 1);
+ assert_eq!(&t.write_buffer_as_ref(), &bytes_to_write);
+
+ t.empty_write_buffer();
+ assert_eq!(t.write_buffer_as_ref().len(), 0);
+ }
+
+ #[test]
+ fn must_accept_writes_after_buffer_emptied() {
+ let mut t = TBufferTransport::with_capacity(0, 2);
+
+ let bytes_to_write: [u8; 2] = [0x01, 0x02];
+
+ // first write (all bytes written)
+ let result = t.write(&bytes_to_write);
+ assert_eq!(result.unwrap(), 2);
+ assert_eq!(&t.write_buffer_as_ref(), &bytes_to_write);
+
+ // try write again (nothing should be written)
+ let result = t.write(&bytes_to_write);
+ assert_eq!(result.unwrap(), 0);
+ assert_eq!(&t.write_buffer_as_ref(), &bytes_to_write); // still the same as before
+
+ // now reset the buffer
+ t.empty_write_buffer();
+ assert_eq!(t.write_buffer_as_ref().len(), 0);
+
+ // now try write again - the write should succeed
+ let result = t.write(&bytes_to_write);
+ assert_eq!(result.unwrap(), 2);
+ assert_eq!(&t.write_buffer_as_ref(), &bytes_to_write);
+ }
+
+ #[test]
+ fn must_accept_multiple_writes_until_buffer_is_full() {
+ let mut t = TBufferTransport::with_capacity(0, 10);
+
+ // first write (all bytes written)
+ let bytes_to_write_0: [u8; 2] = [0x01, 0x41];
+ let write_0_result = t.write(&bytes_to_write_0);
+ assert_eq!(write_0_result.unwrap(), 2);
+ assert_eq!(t.write_buffer_as_ref(), &bytes_to_write_0);
+
+ // second write (all bytes written, starting at index 2)
+ let bytes_to_write_1: [u8; 7] = [0x24, 0x41, 0x32, 0x33, 0x11, 0x98, 0xAF];
+ let write_1_result = t.write(&bytes_to_write_1);
+ assert_eq!(write_1_result.unwrap(), 7);
+ assert_eq!(&t.write_buffer_as_ref()[2..], &bytes_to_write_1);
+
+ // third write (only 1 byte written - that's all we have space for)
+ let bytes_to_write_2: [u8; 3] = [0xBF, 0xDA, 0x98];
+ let write_2_result = t.write(&bytes_to_write_2);
+ assert_eq!(write_2_result.unwrap(), 1);
+ assert_eq!(&t.write_buffer_as_ref()[9..], &bytes_to_write_2[0..1]); // how does this syntax work?!
+
+ // fourth write (no writes are accepted)
+ let bytes_to_write_3: [u8; 3] = [0xBF, 0xAA, 0xFD];
+ let write_3_result = t.write(&bytes_to_write_3);
+ assert_eq!(write_3_result.unwrap(), 0);
+
+ // check the full write buffer
+ let mut expected: Vec<u8> = Vec::with_capacity(10);
+ expected.extend_from_slice(&bytes_to_write_0);
+ expected.extend_from_slice(&bytes_to_write_1);
+ expected.extend_from_slice(&bytes_to_write_2[0..1]);
+ assert_eq!(t.write_buffer_as_ref(), &expected[..]);
+ }
+
+ #[test]
+ fn must_empty_read_buffer() {
+ let mut t = TBufferTransport::with_capacity(1, 0);
+
+ let bytes_to_read: [u8; 1] = [0x01];
+ let result = t.set_readable_bytes(&bytes_to_read);
+ assert_eq!(result, 1);
+ assert_eq!(&t.read_buffer(), &bytes_to_read);
+
+ t.empty_read_buffer();
+ assert_eq!(t.read_buffer().len(), 0);
+ }
+
+ #[test]
+ fn must_allow_readable_bytes_to_be_set_after_read_buffer_emptied() {
+ let mut t = TBufferTransport::with_capacity(1, 0);
+
+ let bytes_to_read_0: [u8; 1] = [0x01];
+ let result = t.set_readable_bytes(&bytes_to_read_0);
+ assert_eq!(result, 1);
+ assert_eq!(&t.read_buffer(), &bytes_to_read_0);
+
+ t.empty_read_buffer();
+ assert_eq!(t.read_buffer().len(), 0);
+
+ let bytes_to_read_1: [u8; 1] = [0x02];
+ let result = t.set_readable_bytes(&bytes_to_read_1);
+ assert_eq!(result, 1);
+ assert_eq!(&t.read_buffer(), &bytes_to_read_1);
+ }
+
+ #[test]
+ fn must_accept_multiple_reads_until_all_bytes_read() {
+ let mut t = TBufferTransport::with_capacity(10, 0);
+
+ let readable_bytes: [u8; 10] = [0xFF, 0xEE, 0xDD, 0xCC, 0xBB, 0x00, 0x1A, 0x2B, 0x3C, 0x4D];
+
+ // check that we're able to set the bytes to be read
+ let result = t.set_readable_bytes(&readable_bytes);
+ assert_eq!(result, 10);
+ assert_eq!(&t.read_buffer(), &readable_bytes);
+
+ // first read
+ let mut read_buf_0 = vec![0; 5];
+ let read_result = t.read(&mut read_buf_0);
+ assert_eq!(read_result.unwrap(), 5);
+ assert_eq!(read_buf_0.as_slice(), &(readable_bytes[0..5]));
+
+ // second read
+ let mut read_buf_1 = vec![0; 4];
+ let read_result = t.read(&mut read_buf_1);
+ assert_eq!(read_result.unwrap(), 4);
+ assert_eq!(read_buf_1.as_slice(), &(readable_bytes[5..9]));
+
+ // third read (only 1 byte remains to be read)
+ let mut read_buf_2 = vec![0; 3];
+ let read_result = t.read(&mut read_buf_2);
+ assert_eq!(read_result.unwrap(), 1);
+ read_buf_2.truncate(1); // FIXME: does the caller have to do this?
+ assert_eq!(read_buf_2.as_slice(), &(readable_bytes[9..]));
+
+ // fourth read (nothing should be readable)
+ let mut read_buf_3 = vec![0; 10];
+ let read_result = t.read(&mut read_buf_3);
+ assert_eq!(read_result.unwrap(), 0);
+ read_buf_3.truncate(0);
+
+ // check that all the bytes we received match the original (again!)
+ let mut bytes_read = Vec::with_capacity(10);
+ bytes_read.extend_from_slice(&read_buf_0);
+ bytes_read.extend_from_slice(&read_buf_1);
+ bytes_read.extend_from_slice(&read_buf_2);
+ bytes_read.extend_from_slice(&read_buf_3);
+ assert_eq!(&bytes_read, &readable_bytes);
+ }
+
+ #[test]
+ fn must_allow_reads_to_succeed_after_read_buffer_replenished() {
+ let mut t = TBufferTransport::with_capacity(3, 0);
+
+ let readable_bytes_0: [u8; 3] = [0x02, 0xAB, 0x33];
+
+ // check that we're able to set the bytes to be read
+ let result = t.set_readable_bytes(&readable_bytes_0);
+ assert_eq!(result, 3);
+ assert_eq!(&t.read_buffer(), &readable_bytes_0);
+
+ let mut read_buf = vec![0; 4];
+
+ // drain the read buffer
+ let read_result = t.read(&mut read_buf);
+ assert_eq!(read_result.unwrap(), 3);
+ assert_eq!(t.read_buffer(), &read_buf[0..3]);
+
+ // check that a subsequent read fails
+ let read_result = t.read(&mut read_buf);
+ assert_eq!(read_result.unwrap(), 0);
+
+ // we don't modify the read buffer on failure
+ let mut expected_bytes = Vec::with_capacity(4);
+ expected_bytes.extend_from_slice(&readable_bytes_0);
+ expected_bytes.push(0x00);
+ assert_eq!(&read_buf, &expected_bytes);
+
+ // replenish the read buffer again
+ let readable_bytes_1: [u8; 2] = [0x91, 0xAA];
+
+ // check that we're able to set the bytes to be read
+ let result = t.set_readable_bytes(&readable_bytes_1);
+ assert_eq!(result, 2);
+ assert_eq!(&t.read_buffer(), &readable_bytes_1);
+
+ // read again
+ let read_result = t.read(&mut read_buf);
+ assert_eq!(read_result.unwrap(), 2);
+ assert_eq!(t.read_buffer(), &read_buf[0..2]);
+ }
+}
diff --git a/lib/rs/src/transport/mod.rs b/lib/rs/src/transport/mod.rs
new file mode 100644
index 0000000..bbabd66
--- /dev/null
+++ b/lib/rs/src/transport/mod.rs
@@ -0,0 +1,51 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Types required to send and receive bytes over an I/O channel.
+//!
+//! The core type is the `TTransport` trait, through which a `TProtocol` can
+//! send and receive primitives over the wire. While `TProtocol` instances deal
+//! with primitive types, `TTransport` instances understand only bytes.
+
+use std::cell::RefCell;
+use std::io;
+use std::rc::Rc;
+
+mod buffered;
+mod framed;
+mod passthru;
+mod socket;
+
+#[cfg(test)]
+pub mod mem;
+
+pub use self::buffered::{TBufferedTransport, TBufferedTransportFactory};
+pub use self::framed::{TFramedTransport, TFramedTransportFactory};
+pub use self::passthru::TPassThruTransport;
+pub use self::socket::TTcpTransport;
+
+/// Identifies an I/O channel that can be used to send and receive bytes.
+pub trait TTransport: io::Read + io::Write {}
+impl<I: io::Read + io::Write> TTransport for I {}
+
+/// Helper type used by servers to create `TTransport` instances for accepted
+/// client connections.
+pub trait TTransportFactory {
+ /// Create a `TTransport` that wraps an `inner` transport, thus creating
+ /// a transport stack.
+ fn create(&self, inner: Rc<RefCell<Box<TTransport>>>) -> Box<TTransport>;
+}
diff --git a/lib/rs/src/transport/passthru.rs b/lib/rs/src/transport/passthru.rs
new file mode 100644
index 0000000..60dc3a6
--- /dev/null
+++ b/lib/rs/src/transport/passthru.rs
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::cell::RefCell;
+use std::rc::Rc;
+use std::io;
+use std::io::{Read, Write};
+
+use super::TTransport;
+
+/// Proxy that wraps an inner `TTransport` and delegates all calls to it.
+///
+/// Unlike other `TTransport` wrappers, `TPassThruTransport` is generic with
+/// regards to the wrapped transport. This allows callers to use methods
+/// specific to the type being wrapped instead of being constrained to methods
+/// on the `TTransport` trait.
+///
+/// # Examples
+///
+/// Create and use a `TPassThruTransport`.
+///
+/// ```no_run
+/// use std::cell::RefCell;
+/// use std::rc::Rc;
+/// use thrift::transport::{TPassThruTransport, TTcpTransport};
+///
+/// let t = TTcpTransport::new();
+/// let t = TPassThruTransport::new(Rc::new(RefCell::new(Box::new(t))));
+///
+/// // since the type parameter is maintained, we are able
+/// // to use functions specific to `TTcpTransport`
+/// t.inner.borrow_mut().open("localhost:9090").unwrap();
+/// ```
+pub struct TPassThruTransport<I: TTransport> {
+ pub inner: Rc<RefCell<Box<I>>>,
+}
+
+impl<I: TTransport> TPassThruTransport<I> {
+ /// Create a `TPassThruTransport` that wraps an `inner` TTransport.
+ pub fn new(inner: Rc<RefCell<Box<I>>>) -> TPassThruTransport<I> {
+ TPassThruTransport { inner: inner }
+ }
+}
+
+impl<I: TTransport> Read for TPassThruTransport<I> {
+ fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
+ self.inner.borrow_mut().read(buf)
+ }
+}
+
+impl<I: TTransport> Write for TPassThruTransport<I> {
+ fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
+ self.inner.borrow_mut().write(buf)
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ self.inner.borrow_mut().flush()
+ }
+}
diff --git a/lib/rs/src/transport/socket.rs b/lib/rs/src/transport/socket.rs
new file mode 100644
index 0000000..9f2b8ba
--- /dev/null
+++ b/lib/rs/src/transport/socket.rs
@@ -0,0 +1,141 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::convert::From;
+use std::io;
+use std::io::{ErrorKind, Read, Write};
+use std::net::{Shutdown, TcpStream};
+use std::ops::Drop;
+
+use ::{TransportError, TransportErrorKind};
+
+/// Communicate with a Thrift service over a TCP socket.
+///
+/// # Examples
+///
+/// Create a `TTcpTransport`.
+///
+/// ```no_run
+/// use std::io::{Read, Write};
+/// use thrift::transport::TTcpTransport;
+///
+/// let mut t = TTcpTransport::new();
+/// t.open("localhost:9090").unwrap();
+///
+/// let mut buf = vec![0u8; 4];
+/// t.read(&mut buf).unwrap();
+/// t.write(&vec![0, 1, 2]).unwrap();
+/// ```
+///
+/// Create a `TTcpTransport` by wrapping an existing `TcpStream`.
+///
+/// ```no_run
+/// use std::io::{Read, Write};
+/// use std::net::TcpStream;
+/// use thrift::transport::TTcpTransport;
+///
+/// let stream = TcpStream::connect("127.0.0.1:9189").unwrap();
+/// let mut t = TTcpTransport::with_stream(stream);
+///
+/// // no need to call t.open() since we've already connected above
+///
+/// let mut buf = vec![0u8; 4];
+/// t.read(&mut buf).unwrap();
+/// t.write(&vec![0, 1, 2]).unwrap();
+/// ```
+#[derive(Default)]
+pub struct TTcpTransport {
+ stream: Option<TcpStream>,
+}
+
+impl TTcpTransport {
+ /// Create an uninitialized `TTcpTransport`.
+ ///
+ /// The returned instance must be opened using `TTcpTransport::open(...)`
+ /// before it can be used.
+ pub fn new() -> TTcpTransport {
+ TTcpTransport { stream: None }
+ }
+
+ /// Create a `TTcpTransport` that wraps an existing `TcpStream`.
+ ///
+ /// The passed-in stream is assumed to have been opened before being wrapped
+ /// by the created `TTcpTransport` instance.
+ pub fn with_stream(stream: TcpStream) -> TTcpTransport {
+ TTcpTransport { stream: Some(stream) }
+ }
+
+ /// Connect to `remote_address`, which should have the form `host:port`.
+ pub fn open(&mut self, remote_address: &str) -> ::Result<()> {
+ if self.stream.is_some() {
+ Err(::Error::Transport(TransportError::new(TransportErrorKind::AlreadyOpen,
+ "transport previously opened")))
+ } else {
+ match TcpStream::connect(&remote_address) {
+ Ok(s) => {
+ self.stream = Some(s);
+ Ok(())
+ }
+ Err(e) => Err(From::from(e)),
+ }
+ }
+ }
+
+ /// Shutdown this transport.
+ ///
+ /// Both send and receive halves are closed, and this instance can no
+ /// longer be used to communicate with another endpoint.
+ pub fn close(&mut self) -> ::Result<()> {
+ self.if_set(|s| s.shutdown(Shutdown::Both)).map_err(From::from)
+ }
+
+ fn if_set<F, T>(&mut self, mut stream_operation: F) -> io::Result<T>
+ where F: FnMut(&mut TcpStream) -> io::Result<T>
+ {
+
+ if let Some(ref mut s) = self.stream {
+ stream_operation(s)
+ } else {
+ Err(io::Error::new(ErrorKind::NotConnected, "tcp endpoint not connected"))
+ }
+ }
+}
+
+impl Read for TTcpTransport {
+ fn read(&mut self, b: &mut [u8]) -> io::Result<usize> {
+ self.if_set(|s| s.read(b))
+ }
+}
+
+impl Write for TTcpTransport {
+ fn write(&mut self, b: &[u8]) -> io::Result<usize> {
+ self.if_set(|s| s.write(b))
+ }
+
+ fn flush(&mut self) -> io::Result<()> {
+ self.if_set(|s| s.flush())
+ }
+}
+
+// Do I have to implement the Drop trait? TcpStream closes the socket on drop.
+impl Drop for TTcpTransport {
+ fn drop(&mut self) {
+ if let Err(e) = self.close() {
+ warn!("error while closing socket transport: {:?}", e)
+ }
+ }
+}
diff --git a/lib/rs/test/Cargo.toml b/lib/rs/test/Cargo.toml
new file mode 100644
index 0000000..8655a76
--- /dev/null
+++ b/lib/rs/test/Cargo.toml
@@ -0,0 +1,15 @@
+[package]
+name = "kitchen-sink"
+version = "0.1.0"
+license = "Apache-2.0"
+authors = ["Apache Thrift Developers <dev@thrift.apache.org>"]
+publish = false
+
+[dependencies]
+clap = "2.18.0"
+ordered-float = "0.3.0"
+try_from = "0.2.0"
+
+[dependencies.thrift]
+path = "../"
+
diff --git a/lib/rs/test/Makefile.am b/lib/rs/test/Makefile.am
new file mode 100644
index 0000000..8896940
--- /dev/null
+++ b/lib/rs/test/Makefile.am
@@ -0,0 +1,49 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+THRIFT = $(top_builddir)/compiler/cpp/thrift
+
+stubs: thrifts/Base_One.thrift thrifts/Base_Two.thrift thrifts/Midlayer.thrift thrifts/Ultimate.thrift $(THRIFT)
+ $(THRIFT) -I ./thrifts -out src --gen rs thrifts/Base_One.thrift
+ $(THRIFT) -I ./thrifts -out src --gen rs thrifts/Base_Two.thrift
+ $(THRIFT) -I ./thrifts -out src --gen rs thrifts/Midlayer.thrift
+ $(THRIFT) -I ./thrifts -out src --gen rs thrifts/Ultimate.thrift
+
+check: stubs
+ $(CARGO) build
+ $(CARGO) test
+ [ -d bin ] || mkdir bin
+ cp target/debug/kitchen_sink_server bin/kitchen_sink_server
+ cp target/debug/kitchen_sink_client bin/kitchen_sink_client
+
+clean-local:
+ $(CARGO) clean
+ -$(RM) Cargo.lock
+ -$(RM) src/base_one.rs
+ -$(RM) src/base_two.rs
+ -$(RM) src/midlayer.rs
+ -$(RM) src/ultimate.rs
+ -$(RM) -r bin
+
+EXTRA_DIST = \
+ Cargo.toml \
+ src/lib.rs \
+ src/bin/kitchen_sink_server.rs \
+ src/bin/kitchen_sink_client.rs
+
diff --git a/lib/rs/test/src/bin/kitchen_sink_client.rs b/lib/rs/test/src/bin/kitchen_sink_client.rs
new file mode 100644
index 0000000..27171be
--- /dev/null
+++ b/lib/rs/test/src/bin/kitchen_sink_client.rs
@@ -0,0 +1,142 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate clap;
+
+extern crate kitchen_sink;
+extern crate thrift;
+
+use std::cell::RefCell;
+use std::rc::Rc;
+
+use kitchen_sink::base_two::{TNapkinServiceSyncClient, TRamenServiceSyncClient};
+use kitchen_sink::midlayer::{MealServiceSyncClient, TMealServiceSyncClient};
+use kitchen_sink::ultimate::{FullMealServiceSyncClient, TFullMealServiceSyncClient};
+use thrift::transport::{TFramedTransport, TTcpTransport, TTransport};
+use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TCompactInputProtocol,
+ TCompactOutputProtocol, TInputProtocol, TOutputProtocol};
+
+fn main() {
+ match run() {
+ Ok(()) => println!("kitchen sink client completed successfully"),
+ Err(e) => {
+ println!("kitchen sink client failed with error {:?}", e);
+ std::process::exit(1);
+ }
+ }
+}
+
+fn run() -> thrift::Result<()> {
+ let matches = clap_app!(rust_kitchen_sink_client =>
+ (version: "0.1.0")
+ (author: "Apache Thrift Developers <dev@thrift.apache.org>")
+ (about: "Thrift Rust kitchen sink client")
+ (@arg host: --host +takes_value "Host on which the Thrift test server is located")
+ (@arg port: --port +takes_value "Port on which the Thrift test server is listening")
+ (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")")
+ (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\")")
+ ).get_matches();
+
+ let host = matches.value_of("host").unwrap_or("127.0.0.1");
+ let port = value_t!(matches, "port", u16).unwrap_or(9090);
+ let protocol = matches.value_of("protocol").unwrap_or("compact");
+ let service = matches.value_of("service").unwrap_or("part");
+
+ let t = open_tcp_transport(host, port)?;
+ let t = Rc::new(RefCell::new(Box::new(TFramedTransport::new(t)) as Box<TTransport>));
+
+ let (i_prot, o_prot): (Box<TInputProtocol>, Box<TOutputProtocol>) = match protocol {
+ "binary" => {
+ (Box::new(TBinaryInputProtocol::new(t.clone(), true)),
+ Box::new(TBinaryOutputProtocol::new(t.clone(), true)))
+ }
+ "compact" => {
+ (Box::new(TCompactInputProtocol::new(t.clone())),
+ Box::new(TCompactOutputProtocol::new(t.clone())))
+ }
+ unmatched => return Err(format!("unsupported protocol {}", unmatched).into()),
+ };
+
+ run_client(service, i_prot, o_prot)
+}
+
+fn run_client(service: &str,
+ i_prot: Box<TInputProtocol>,
+ o_prot: Box<TOutputProtocol>)
+ -> thrift::Result<()> {
+ match service {
+ "full" => run_full_meal_service(i_prot, o_prot),
+ "part" => run_meal_service(i_prot, o_prot),
+ _ => Err(thrift::Error::from(format!("unknown service type {}", service))),
+ }
+}
+
+fn open_tcp_transport(host: &str, port: u16) -> thrift::Result<Rc<RefCell<Box<TTransport>>>> {
+ let mut t = TTcpTransport::new();
+ match t.open(&format!("{}:{}", host, port)) {
+ Ok(()) => Ok(Rc::new(RefCell::new(Box::new(t) as Box<TTransport>))),
+ Err(e) => Err(e),
+ }
+}
+
+fn run_meal_service(i_prot: Box<TInputProtocol>,
+ o_prot: Box<TOutputProtocol>)
+ -> thrift::Result<()> {
+ let mut client = MealServiceSyncClient::new(i_prot, o_prot);
+
+ // client.full_meal(); // <-- IMPORTANT: if you uncomment this, compilation *should* fail
+ // this is because the MealService struct does not contain the appropriate service marker
+
+ // only the following three calls work
+ execute_call("part", "ramen", || client.ramen(50))?;
+ execute_call("part", "meal", || client.meal())?;
+ execute_call("part", "napkin", || client.napkin())?;
+
+ Ok(())
+}
+
+fn run_full_meal_service(i_prot: Box<TInputProtocol>,
+ o_prot: Box<TOutputProtocol>)
+ -> thrift::Result<()> {
+ let mut client = FullMealServiceSyncClient::new(i_prot, o_prot);
+
+ execute_call("full", "ramen", || client.ramen(100))?;
+ execute_call("full", "meal", || client.meal())?;
+ execute_call("full", "napkin", || client.napkin())?;
+ execute_call("full", "full meal", || client.full_meal())?;
+
+ Ok(())
+}
+
+fn execute_call<F, R>(service_type: &str, call_name: &str, mut f: F) -> thrift::Result<()>
+ where F: FnMut() -> thrift::Result<R>
+{
+ let res = f();
+
+ match res {
+ Ok(_) => println!("{}: completed {} call", service_type, call_name),
+ Err(ref e) => {
+ println!("{}: failed {} call with error {:?}",
+ service_type,
+ call_name,
+ e)
+ }
+ }
+
+ res.map(|_| ())
+}
diff --git a/lib/rs/test/src/bin/kitchen_sink_server.rs b/lib/rs/test/src/bin/kitchen_sink_server.rs
new file mode 100644
index 0000000..4ce4fa3
--- /dev/null
+++ b/lib/rs/test/src/bin/kitchen_sink_server.rs
@@ -0,0 +1,225 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate clap;
+
+extern crate kitchen_sink;
+extern crate thrift;
+
+use kitchen_sink::base_one::Noodle;
+use kitchen_sink::base_two::{Napkin, Ramen, NapkinServiceSyncHandler, RamenServiceSyncHandler};
+use kitchen_sink::midlayer::{Dessert, Meal, MealServiceSyncHandler, MealServiceSyncProcessor};
+use kitchen_sink::ultimate::{Drink, FullMeal, FullMealAndDrinks,
+ FullMealAndDrinksServiceSyncProcessor, FullMealServiceSyncHandler};
+use kitchen_sink::ultimate::FullMealAndDrinksServiceSyncHandler;
+use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory,
+ TCompactInputProtocolFactory, TCompactOutputProtocolFactory,
+ TInputProtocolFactory, TOutputProtocolFactory};
+use thrift::transport::{TFramedTransportFactory, TTransportFactory};
+use thrift::server::TSimpleServer;
+
+fn main() {
+ match run() {
+ Ok(()) => println!("kitchen sink server completed successfully"),
+ Err(e) => {
+ println!("kitchen sink server failed with error {:?}", e);
+ std::process::exit(1);
+ }
+ }
+}
+
+fn run() -> thrift::Result<()> {
+
+ let matches = clap_app!(rust_kitchen_sink_server =>
+ (version: "0.1.0")
+ (author: "Apache Thrift Developers <dev@thrift.apache.org>")
+ (about: "Thrift Rust kitchen sink test server")
+ (@arg port: --port +takes_value "port on which the test server listens")
+ (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")")
+ (@arg service: --service +takes_value "Service type to contact (\"part\", \"full\")")
+ ).get_matches();
+
+ let port = value_t!(matches, "port", u16).unwrap_or(9090);
+ let protocol = matches.value_of("protocol").unwrap_or("compact");
+ let service = matches.value_of("service").unwrap_or("part");
+ let listen_address = format!("127.0.0.1:{}", port);
+
+ println!("binding to {}", listen_address);
+
+ let (i_transport_factory, o_transport_factory): (Box<TTransportFactory>,
+ Box<TTransportFactory>) =
+ (Box::new(TFramedTransportFactory {}), Box::new(TFramedTransportFactory {}));
+
+ let (i_protocol_factory, o_protocol_factory): (Box<TInputProtocolFactory>,
+ Box<TOutputProtocolFactory>) =
+ match &*protocol {
+ "binary" => {
+ (Box::new(TBinaryInputProtocolFactory::new()),
+ Box::new(TBinaryOutputProtocolFactory::new()))
+ }
+ "compact" => {
+ (Box::new(TCompactInputProtocolFactory::new()),
+ Box::new(TCompactOutputProtocolFactory::new()))
+ }
+ unknown => {
+ return Err(format!("unsupported transport type {}", unknown).into());
+ }
+ };
+
+ // FIXME: should processor be boxed as well?
+ //
+ // [sigh] I hate Rust generics implementation
+ //
+ // I would have preferred to build a server here, return it, and then do
+ // the common listen-and-handle stuff, but since the server doesn't have a
+ // common type (because each match arm instantiates a server with a
+ // different processor) this isn't possible.
+ //
+ // Since what I'm doing is uncommon I'm just going to duplicate the code
+ match &*service {
+ "part" => {
+ run_meal_server(&listen_address,
+ i_transport_factory,
+ i_protocol_factory,
+ o_transport_factory,
+ o_protocol_factory)
+ }
+ "full" => {
+ run_full_meal_server(&listen_address,
+ i_transport_factory,
+ i_protocol_factory,
+ o_transport_factory,
+ o_protocol_factory)
+ }
+ unknown => Err(format!("unsupported service type {}", unknown).into()),
+ }
+}
+
+fn run_meal_server(listen_address: &str,
+ i_transport_factory: Box<TTransportFactory>,
+ i_protocol_factory: Box<TInputProtocolFactory>,
+ o_transport_factory: Box<TTransportFactory>,
+ o_protocol_factory: Box<TOutputProtocolFactory>)
+ -> thrift::Result<()> {
+ let processor = MealServiceSyncProcessor::new(PartHandler {});
+ let mut server = TSimpleServer::new(i_transport_factory,
+ i_protocol_factory,
+ o_transport_factory,
+ o_protocol_factory,
+ processor);
+
+ server.listen(listen_address)
+}
+
+fn run_full_meal_server(listen_address: &str,
+ i_transport_factory: Box<TTransportFactory>,
+ i_protocol_factory: Box<TInputProtocolFactory>,
+ o_transport_factory: Box<TTransportFactory>,
+ o_protocol_factory: Box<TOutputProtocolFactory>)
+ -> thrift::Result<()> {
+ let processor = FullMealAndDrinksServiceSyncProcessor::new(FullHandler {});
+ let mut server = TSimpleServer::new(i_transport_factory,
+ i_protocol_factory,
+ o_transport_factory,
+ o_protocol_factory,
+ processor);
+
+ server.listen(listen_address)
+}
+
+struct PartHandler;
+
+impl MealServiceSyncHandler for PartHandler {
+ fn handle_meal(&mut self) -> thrift::Result<Meal> {
+ println!("part: handling meal call");
+ Ok(meal())
+ }
+}
+
+impl RamenServiceSyncHandler for PartHandler {
+ fn handle_ramen(&mut self, _: i32) -> thrift::Result<Ramen> {
+ println!("part: handling ramen call");
+ Ok(ramen())
+ }
+}
+
+impl NapkinServiceSyncHandler for PartHandler {
+ fn handle_napkin(&mut self) -> thrift::Result<Napkin> {
+ println!("part: handling napkin call");
+ Ok(napkin())
+ }
+}
+
+// full service
+//
+
+struct FullHandler;
+
+impl FullMealAndDrinksServiceSyncHandler for FullHandler {
+ fn handle_full_meal_and_drinks(&mut self) -> thrift::Result<FullMealAndDrinks> {
+ Ok(FullMealAndDrinks::new(full_meal(), Drink::WHISKEY))
+ }
+}
+
+impl FullMealServiceSyncHandler for FullHandler {
+ fn handle_full_meal(&mut self) -> thrift::Result<FullMeal> {
+ println!("full: handling full meal call");
+ Ok(full_meal())
+ }
+}
+
+impl MealServiceSyncHandler for FullHandler {
+ fn handle_meal(&mut self) -> thrift::Result<Meal> {
+ println!("full: handling meal call");
+ Ok(meal())
+ }
+}
+
+impl RamenServiceSyncHandler for FullHandler {
+ fn handle_ramen(&mut self, _: i32) -> thrift::Result<Ramen> {
+ println!("full: handling ramen call");
+ Ok(ramen())
+ }
+}
+
+impl NapkinServiceSyncHandler for FullHandler {
+ fn handle_napkin(&mut self) -> thrift::Result<Napkin> {
+ println!("full: handling napkin call");
+ Ok(napkin())
+ }
+}
+
+fn full_meal() -> FullMeal {
+ FullMeal::new(meal(), Dessert::Port("Graham's Tawny".to_owned()))
+}
+
+fn meal() -> Meal {
+ Meal::new(noodle(), ramen())
+}
+
+fn noodle() -> Noodle {
+ Noodle::new("spelt".to_owned(), 100)
+}
+
+fn ramen() -> Ramen {
+ Ramen::new("Mr Ramen".to_owned(), 72)
+}
+
+fn napkin() -> Napkin {
+ Napkin {}
+}
diff --git a/lib/rs/test/src/lib.rs b/lib/rs/test/src/lib.rs
new file mode 100644
index 0000000..8a7ccd0
--- /dev/null
+++ b/lib/rs/test/src/lib.rs
@@ -0,0 +1,53 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern crate ordered_float;
+extern crate thrift;
+extern crate try_from;
+
+pub mod base_one;
+pub mod base_two;
+pub mod midlayer;
+pub mod ultimate;
+
+#[cfg(test)]
+mod tests {
+
+ use std::default::Default;
+
+ use super::*;
+
+ #[test]
+ fn must_be_able_to_use_constructor() {
+ let _ = midlayer::Meal::new(Some(base_one::Noodle::default()), None);
+ }
+
+ #[test]
+ fn must_be_able_to_use_constructor_with_no_fields() {
+ let _ = midlayer::Meal::new(None, None);
+ }
+
+ #[test]
+ fn must_be_able_to_use_constructor_without_option_wrap() {
+ let _ = midlayer::Meal::new(base_one::Noodle::default(), None);
+ }
+
+ #[test]
+ fn must_be_able_to_use_defaults() {
+ let _ = midlayer::Meal { noodle: Some(base_one::Noodle::default()), ..Default::default() };
+ }
+}
diff --git a/lib/rs/test/thrifts/Base_One.thrift b/lib/rs/test/thrifts/Base_One.thrift
new file mode 100644
index 0000000..ceb1207
--- /dev/null
+++ b/lib/rs/test/thrifts/Base_One.thrift
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ * Contains some contributions under the Thrift Software License.
+ * Please see doc/old-thrift-license.txt in the Thrift distribution for
+ * details.
+ */
+
+typedef i64 Temperature
+
+typedef i8 Size
+
+typedef string Location
+
+const i32 BoilingPoint = 100
+
+const list<Temperature> Temperatures = [10, 11, 22, 33]
+
+const double MealsPerDay = 2.5;
+
+struct Noodle {
+ 1: string flourType
+ 2: Temperature cookTemp
+}
+
+struct Spaghetti {
+ 1: optional list<Noodle> noodles
+}
+
+const Noodle SpeltNoodle = { "flourType": "spelt", "cookTemp": 110 }
+
+struct MeasuringSpoon {
+ 1: Size size
+}
+
+struct Recipe {
+ 1: string recipeName
+ 2: string cuisine
+ 3: i8 page
+}
+
+union CookingTools {
+ 1: set<MeasuringSpoon> measuringSpoons
+ 2: map<Size, Location> measuringCups,
+ 3: list<Recipe> recipes
+}
+
diff --git a/lib/rs/test/thrifts/Base_Two.thrift b/lib/rs/test/thrifts/Base_Two.thrift
new file mode 100644
index 0000000..b4b4ea1
--- /dev/null
+++ b/lib/rs/test/thrifts/Base_Two.thrift
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ * Contains some contributions under the Thrift Software License.
+ * Please see doc/old-thrift-license.txt in the Thrift distribution for
+ * details.
+ */
+
+const i32 WaterWeight = 200
+
+struct Ramen {
+ 1: optional string ramenType
+ 2: required i32 noodleCount
+}
+
+struct Napkin {
+ // empty
+}
+
+service NapkinService {
+ Napkin napkin()
+}
+
+service RamenService extends NapkinService {
+ Ramen ramen(1: i32 requestedNoodleCount)
+}
+
+/* const struct CookedRamen = { "bar": 10 } */
+
diff --git a/lib/rs/test/thrifts/Midlayer.thrift b/lib/rs/test/thrifts/Midlayer.thrift
new file mode 100644
index 0000000..cf1157c
--- /dev/null
+++ b/lib/rs/test/thrifts/Midlayer.thrift
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ * Contains some contributions under the Thrift Software License.
+ * Please see doc/old-thrift-license.txt in the Thrift distribution for
+ * details.
+ */
+
+include "Base_One.thrift"
+include "Base_Two.thrift"
+
+const i32 WaterBoilingPoint = Base_One.BoilingPoint
+
+const map<string, Base_One.Temperature> TemperatureNames = { "freezing": 0, "boiling": 100 }
+
+const map<set<i32>, map<list<string>, string>> MyConstNestedMap = {
+ [0, 1, 2, 3]: { ["foo"]: "bar" },
+ [20]: { ["nut", "ton"] : "bar" },
+ [30, 40]: { ["bouncy", "tinkly"]: "castle" }
+}
+
+const list<list<i32>> MyConstNestedList = [
+ [0, 1, 2],
+ [3, 4, 5],
+ [6, 7, 8]
+]
+
+const set<set<i32>> MyConstNestedSet = [
+ [0, 1, 2],
+ [3, 4, 5],
+ [6, 7, 8]
+]
+
+struct Meal {
+ 1: Base_One.Noodle noodle
+ 2: Base_Two.Ramen ramen
+}
+
+union Dessert {
+ 1: string port
+ 2: string iceWine
+}
+
+service MealService extends Base_Two.RamenService {
+ Meal meal()
+}
+
diff --git a/lib/rs/test/thrifts/Ultimate.thrift b/lib/rs/test/thrifts/Ultimate.thrift
new file mode 100644
index 0000000..8154d91
--- /dev/null
+++ b/lib/rs/test/thrifts/Ultimate.thrift
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ * Contains some contributions under the Thrift Software License.
+ * Please see doc/old-thrift-license.txt in the Thrift distribution for
+ * details.
+ */
+
+include "Midlayer.thrift"
+
+enum Drink {
+ WATER,
+ WHISKEY,
+ WINE,
+}
+
+struct FullMeal {
+ 1: required Midlayer.Meal meal
+ 2: required Midlayer.Dessert dessert
+}
+
+struct FullMealAndDrinks {
+ 1: required FullMeal fullMeal
+ 2: optional Drink drink
+}
+
+service FullMealService extends Midlayer.MealService {
+ FullMeal fullMeal()
+}
+
+service FullMealAndDrinksService extends FullMealService {
+ FullMealAndDrinks fullMealAndDrinks()
+}
+
diff --git a/test/Makefile.am b/test/Makefile.am
index 51da3ba..01fab4f 100755
--- a/test/Makefile.am
+++ b/test/Makefile.am
@@ -91,6 +91,11 @@
PRECROSS_TARGET += precross-lua
endif
+if WITH_RS
+SUBDIRS += rs
+PRECROSS_TARGET += precross-rs
+endif
+
#
# generate html for ThriftTest.thrift
#
@@ -117,6 +122,7 @@
py.twisted \
py.tornado \
rb \
+ rs \
threads \
AnnotationTest.thrift \
BrokenConstants.thrift \
diff --git a/test/rs/Cargo.toml b/test/rs/Cargo.toml
new file mode 100644
index 0000000..8167390
--- /dev/null
+++ b/test/rs/Cargo.toml
@@ -0,0 +1,15 @@
+[package]
+name = "thrift-test"
+version = "0.1.0"
+license = "Apache-2.0"
+authors = ["Apache Thrift Developers <dev@thrift.apache.org>"]
+publish = false
+
+[dependencies]
+clap = "2.18.0"
+ordered-float = "0.3.0"
+try_from = "0.2.0"
+
+[dependencies.thrift]
+path = "../../lib/rs"
+
diff --git a/test/rs/Makefile.am b/test/rs/Makefile.am
new file mode 100644
index 0000000..1a409b8
--- /dev/null
+++ b/test/rs/Makefile.am
@@ -0,0 +1,42 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+THRIFT = $(top_builddir)/compiler/cpp/thrift
+
+stubs: ../ThriftTest.thrift
+ $(THRIFT) -I ./thrifts -out src --gen rs ../ThriftTest.thrift
+
+precross: stubs
+ $(CARGO) build
+ [ -d bin ] || mkdir bin
+ cp target/debug/test_server bin/test_server
+ cp target/debug/test_client bin/test_client
+
+clean-local:
+ $(CARGO) clean
+ -$(RM) Cargo.lock
+ -$(RM) src/thrift_test.rs
+ -$(RM) -r bin
+
+EXTRA_DIST = \
+ Cargo.toml \
+ src/lib.rs \
+ src/bin/test_server.rs \
+ src/bin/test_client.rs
+
diff --git a/test/rs/src/bin/test_client.rs b/test/rs/src/bin/test_client.rs
new file mode 100644
index 0000000..a2ea832
--- /dev/null
+++ b/test/rs/src/bin/test_client.rs
@@ -0,0 +1,500 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate clap;
+extern crate ordered_float;
+extern crate thrift;
+extern crate thrift_test; // huh. I have to do this to use my lib
+
+use ordered_float::OrderedFloat;
+use std::cell::RefCell;
+use std::collections::{BTreeMap, BTreeSet};
+use std::fmt::Debug;
+use std::rc::Rc;
+
+use thrift::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TCompactInputProtocol,
+ TCompactOutputProtocol, TInputProtocol, TOutputProtocol};
+use thrift::transport::{TBufferedTransport, TFramedTransport, TTcpTransport, TTransport};
+use thrift_test::*;
+
+fn main() {
+ match run() {
+ Ok(()) => println!("cross-test client succeeded"),
+ Err(e) => {
+ println!("cross-test client failed with error {:?}", e);
+ std::process::exit(1);
+ }
+ }
+}
+
+fn run() -> thrift::Result<()> {
+ // unsupported options:
+ // --domain-socket
+ // --named-pipe
+ // --anon-pipes
+ // --ssl
+ // --threads
+ let matches = clap_app!(rust_test_client =>
+ (version: "1.0")
+ (author: "Apache Thrift Developers <dev@thrift.apache.org>")
+ (about: "Rust Thrift test client")
+ (@arg host: --host +takes_value "Host on which the Thrift test server is located")
+ (@arg port: --port +takes_value "Port on which the Thrift test server is listening")
+ (@arg transport: --transport +takes_value "Thrift transport implementation to use (\"buffered\", \"framed\")")
+ (@arg protocol: --protocol +takes_value "Thrift protocol implementation to use (\"binary\", \"compact\")")
+ (@arg testloops: -n --testloops +takes_value "Number of times to run tests")
+ ).get_matches();
+
+ let host = matches.value_of("host").unwrap_or("127.0.0.1");
+ let port = value_t!(matches, "port", u16).unwrap_or(9090);
+ let testloops = value_t!(matches, "testloops", u8).unwrap_or(1);
+ let transport = matches.value_of("transport").unwrap_or("buffered");
+ let protocol = matches.value_of("protocol").unwrap_or("binary");
+
+ let t = open_tcp_transport(host, port)?;
+
+ let t: Box<TTransport> = match transport {
+ "buffered" => Box::new(TBufferedTransport::new(t)),
+ "framed" => Box::new(TFramedTransport::new(t)),
+ unmatched => return Err(format!("unsupported transport {}", unmatched).into()),
+ };
+ let t = Rc::new(RefCell::new(t));
+
+ let (i_prot, o_prot): (Box<TInputProtocol>, Box<TOutputProtocol>) = match protocol {
+ "binary" => {
+ (Box::new(TBinaryInputProtocol::new(t.clone(), true)),
+ Box::new(TBinaryOutputProtocol::new(t.clone(), true)))
+ }
+ "compact" => {
+ (Box::new(TCompactInputProtocol::new(t.clone())),
+ Box::new(TCompactOutputProtocol::new(t.clone())))
+ }
+ unmatched => return Err(format!("unsupported protocol {}", unmatched).into()),
+ };
+
+ println!("connecting to {}:{} with {}+{} stack",
+ host,
+ port,
+ protocol,
+ transport);
+
+ let mut client = ThriftTestSyncClient::new(i_prot, o_prot);
+
+ for _ in 0..testloops {
+ make_thrift_calls(&mut client)?
+ }
+
+ Ok(())
+}
+
+// FIXME: expose "open" through the client interface so I don't have to early open the transport
+fn open_tcp_transport(host: &str, port: u16) -> thrift::Result<Rc<RefCell<Box<TTransport>>>> {
+ let mut t = TTcpTransport::new();
+ match t.open(&format!("{}:{}", host, port)) {
+ Ok(()) => Ok(Rc::new(RefCell::new(Box::new(t) as Box<TTransport>))),
+ Err(e) => Err(e),
+ }
+}
+
+fn make_thrift_calls(client: &mut ThriftTestSyncClient) -> Result<(), thrift::Error> {
+ println!("testVoid");
+ client.test_void()?;
+
+ println!("testString");
+ verify_expected_result(client.test_string("thing".to_owned()), "thing".to_owned())?;
+
+ println!("testBool");
+ verify_expected_result(client.test_bool(true), true)?;
+
+ println!("testBool");
+ verify_expected_result(client.test_bool(false), false)?;
+
+ println!("testByte");
+ verify_expected_result(client.test_byte(42), 42)?;
+
+ println!("testi32");
+ verify_expected_result(client.test_i32(1159348374), 1159348374)?;
+
+ println!("testi64");
+ // try!(verify_expected_result(client.test_i64(-8651829879438294565), -8651829879438294565));
+ verify_expected_result(client.test_i64(i64::min_value()), i64::min_value())?;
+
+ println!("testDouble");
+ verify_expected_result(client.test_double(OrderedFloat::from(42.42)),
+ OrderedFloat::from(42.42))?;
+
+ println!("testTypedef");
+ {
+ let u_snd: UserId = 2348;
+ let u_cmp: UserId = 2348;
+ verify_expected_result(client.test_typedef(u_snd), u_cmp)?;
+ }
+
+ println!("testEnum");
+ {
+ verify_expected_result(client.test_enum(Numberz::TWO), Numberz::TWO)?;
+ }
+
+ println!("testBinary");
+ {
+ let b_snd = vec![0x77, 0x30, 0x30, 0x74, 0x21, 0x20, 0x52, 0x75, 0x73, 0x74];
+ let b_cmp = vec![0x77, 0x30, 0x30, 0x74, 0x21, 0x20, 0x52, 0x75, 0x73, 0x74];
+ verify_expected_result(client.test_binary(b_snd), b_cmp)?;
+ }
+
+ println!("testStruct");
+ {
+ let x_snd = Xtruct {
+ string_thing: Some("foo".to_owned()),
+ byte_thing: Some(12),
+ i32_thing: Some(219129),
+ i64_thing: Some(12938492818),
+ };
+ let x_cmp = Xtruct {
+ string_thing: Some("foo".to_owned()),
+ byte_thing: Some(12),
+ i32_thing: Some(219129),
+ i64_thing: Some(12938492818),
+ };
+ verify_expected_result(client.test_struct(x_snd), x_cmp)?;
+ }
+
+ // Xtruct again, with optional values
+ // FIXME: apparently the erlang thrift server does not like opt-in-req-out parameters that are undefined. Joy.
+ // {
+ // let x_snd = Xtruct { string_thing: Some("foo".to_owned()), byte_thing: None, i32_thing: None, i64_thing: Some(12938492818) };
+ // let x_cmp = Xtruct { string_thing: Some("foo".to_owned()), byte_thing: Some(0), i32_thing: Some(0), i64_thing: Some(12938492818) }; // the C++ server is responding correctly
+ // try!(verify_expected_result(client.test_struct(x_snd), x_cmp));
+ // }
+ //
+
+
+ println!("testNest"); // (FIXME: try Xtruct2 with optional values)
+ {
+ let x_snd = Xtruct2 {
+ byte_thing: Some(32),
+ struct_thing: Some(Xtruct {
+ string_thing: Some("foo".to_owned()),
+ byte_thing: Some(1),
+ i32_thing: Some(324382098),
+ i64_thing: Some(12938492818),
+ }),
+ i32_thing: Some(293481098),
+ };
+ let x_cmp = Xtruct2 {
+ byte_thing: Some(32),
+ struct_thing: Some(Xtruct {
+ string_thing: Some("foo".to_owned()),
+ byte_thing: Some(1),
+ i32_thing: Some(324382098),
+ i64_thing: Some(12938492818),
+ }),
+ i32_thing: Some(293481098),
+ };
+ verify_expected_result(client.test_nest(x_snd), x_cmp)?;
+ }
+
+ println!("testList");
+ {
+ let mut v_snd: Vec<i32> = Vec::new();
+ v_snd.push(29384);
+ v_snd.push(238);
+ v_snd.push(32498);
+
+ let mut v_cmp: Vec<i32> = Vec::new();
+ v_cmp.push(29384);
+ v_cmp.push(238);
+ v_cmp.push(32498);
+
+ verify_expected_result(client.test_list(v_snd), v_cmp)?;
+ }
+
+ println!("testSet");
+ {
+ let mut s_snd: BTreeSet<i32> = BTreeSet::new();
+ s_snd.insert(293481);
+ s_snd.insert(23);
+ s_snd.insert(3234);
+
+ let mut s_cmp: BTreeSet<i32> = BTreeSet::new();
+ s_cmp.insert(293481);
+ s_cmp.insert(23);
+ s_cmp.insert(3234);
+
+ verify_expected_result(client.test_set(s_snd), s_cmp)?;
+ }
+
+ println!("testMap");
+ {
+ let mut m_snd: BTreeMap<i32, i32> = BTreeMap::new();
+ m_snd.insert(2, 4);
+ m_snd.insert(4, 6);
+ m_snd.insert(8, 7);
+
+ let mut m_cmp: BTreeMap<i32, i32> = BTreeMap::new();
+ m_cmp.insert(2, 4);
+ m_cmp.insert(4, 6);
+ m_cmp.insert(8, 7);
+
+ verify_expected_result(client.test_map(m_snd), m_cmp)?;
+ }
+
+ println!("testStringMap");
+ {
+ let mut m_snd: BTreeMap<String, String> = BTreeMap::new();
+ m_snd.insert("2".to_owned(), "4_string".to_owned());
+ m_snd.insert("4".to_owned(), "6_string".to_owned());
+ m_snd.insert("8".to_owned(), "7_string".to_owned());
+
+ let mut m_rcv: BTreeMap<String, String> = BTreeMap::new();
+ m_rcv.insert("2".to_owned(), "4_string".to_owned());
+ m_rcv.insert("4".to_owned(), "6_string".to_owned());
+ m_rcv.insert("8".to_owned(), "7_string".to_owned());
+
+ verify_expected_result(client.test_string_map(m_snd), m_rcv)?;
+ }
+
+ // nested map
+ // expect : {-4 => {-4 => -4, -3 => -3, -2 => -2, -1 => -1, }, 4 => {1 => 1, 2 => 2, 3 => 3, 4 => 4, }, }
+ println!("testMapMap");
+ {
+ let mut m_cmp_nested_0: BTreeMap<i32, i32> = BTreeMap::new();
+ for i in (-4 as i32)..0 {
+ m_cmp_nested_0.insert(i, i);
+ }
+ let mut m_cmp_nested_1: BTreeMap<i32, i32> = BTreeMap::new();
+ for i in 1..5 {
+ m_cmp_nested_1.insert(i, i);
+ }
+
+ let mut m_cmp: BTreeMap<i32, BTreeMap<i32, i32>> = BTreeMap::new();
+ m_cmp.insert(-4, m_cmp_nested_0);
+ m_cmp.insert(4, m_cmp_nested_1);
+
+ verify_expected_result(client.test_map_map(42), m_cmp)?;
+ }
+
+ println!("testMulti");
+ {
+ let mut m_snd: BTreeMap<i16, String> = BTreeMap::new();
+ m_snd.insert(1298, "fizz".to_owned());
+ m_snd.insert(-148, "buzz".to_owned());
+
+ let s_cmp = Xtruct {
+ string_thing: Some("Hello2".to_owned()),
+ byte_thing: Some(1),
+ i32_thing: Some(-123948),
+ i64_thing: Some(-19234123981),
+ };
+
+ verify_expected_result(client.test_multi(1,
+ -123948,
+ -19234123981,
+ m_snd,
+ Numberz::EIGHT,
+ 81),
+ s_cmp)?;
+ }
+
+ // Insanity
+ // returns:
+ // { 1 => { 2 => argument,
+ // 3 => argument,
+ // },
+ // 2 => { 6 => <empty Insanity struct>, },
+ // }
+ {
+ let mut arg_map_usermap: BTreeMap<Numberz, i64> = BTreeMap::new();
+ arg_map_usermap.insert(Numberz::ONE, 4289);
+ arg_map_usermap.insert(Numberz::EIGHT, 19);
+
+ let mut arg_vec_xtructs: Vec<Xtruct> = Vec::new();
+ arg_vec_xtructs.push(Xtruct {
+ string_thing: Some("foo".to_owned()),
+ byte_thing: Some(8),
+ i32_thing: Some(29),
+ i64_thing: Some(92384),
+ });
+ arg_vec_xtructs.push(Xtruct {
+ string_thing: Some("bar".to_owned()),
+ byte_thing: Some(28),
+ i32_thing: Some(2),
+ i64_thing: Some(-1281),
+ });
+ arg_vec_xtructs.push(Xtruct {
+ string_thing: Some("baz".to_owned()),
+ byte_thing: Some(0),
+ i32_thing: Some(3948539),
+ i64_thing: Some(-12938492),
+ });
+
+ let mut s_cmp_nested_1: BTreeMap<Numberz, Insanity> = BTreeMap::new();
+ let insanity = Insanity {
+ user_map: Some(arg_map_usermap),
+ xtructs: Some(arg_vec_xtructs),
+ };
+ s_cmp_nested_1.insert(Numberz::TWO, insanity.clone());
+ s_cmp_nested_1.insert(Numberz::THREE, insanity.clone());
+
+ let mut s_cmp_nested_2: BTreeMap<Numberz, Insanity> = BTreeMap::new();
+ let empty_insanity = Insanity {
+ user_map: Some(BTreeMap::new()),
+ xtructs: Some(Vec::new()),
+ };
+ s_cmp_nested_2.insert(Numberz::SIX, empty_insanity);
+
+ let mut s_cmp: BTreeMap<UserId, BTreeMap<Numberz, Insanity>> = BTreeMap::new();
+ s_cmp.insert(1 as UserId, s_cmp_nested_1);
+ s_cmp.insert(2 as UserId, s_cmp_nested_2);
+
+ verify_expected_result(client.test_insanity(insanity.clone()), s_cmp)?;
+ }
+
+ println!("testException - remote throws Xception");
+ {
+ let r = client.test_exception("Xception".to_owned());
+ let x = match r {
+ Err(thrift::Error::User(ref e)) => {
+ match e.downcast_ref::<Xception>() {
+ Some(x) => Ok(x),
+ None => Err(thrift::Error::User("did not get expected Xception struct".into())),
+ }
+ }
+ _ => Err(thrift::Error::User("did not get exception".into())),
+ }?;
+
+ let x_cmp = Xception {
+ error_code: Some(1001),
+ message: Some("Xception".to_owned()),
+ };
+
+ verify_expected_result(Ok(x), &x_cmp)?;
+ }
+
+ println!("testException - remote throws TApplicationException");
+ {
+ let r = client.test_exception("TException".to_owned());
+ match r {
+ Err(thrift::Error::Application(ref e)) => {
+ println!("received an {:?}", e);
+ Ok(())
+ }
+ _ => Err(thrift::Error::User("did not get exception".into())),
+ }?;
+ }
+
+ println!("testException - remote succeeds");
+ {
+ let r = client.test_exception("foo".to_owned());
+ match r {
+ Ok(_) => Ok(()),
+ _ => Err(thrift::Error::User("received an exception".into())),
+ }?;
+ }
+
+ println!("testMultiException - remote throws Xception");
+ {
+ let r = client.test_multi_exception("Xception".to_owned(), "ignored".to_owned());
+ let x = match r {
+ Err(thrift::Error::User(ref e)) => {
+ match e.downcast_ref::<Xception>() {
+ Some(x) => Ok(x),
+ None => Err(thrift::Error::User("did not get expected Xception struct".into())),
+ }
+ }
+ _ => Err(thrift::Error::User("did not get exception".into())),
+ }?;
+
+ let x_cmp = Xception {
+ error_code: Some(1001),
+ message: Some("This is an Xception".to_owned()),
+ };
+
+ verify_expected_result(Ok(x), &x_cmp)?;
+ }
+
+ println!("testMultiException - remote throws Xception2");
+ {
+ let r = client.test_multi_exception("Xception2".to_owned(), "ignored".to_owned());
+ let x = match r {
+ Err(thrift::Error::User(ref e)) => {
+ match e.downcast_ref::<Xception2>() {
+ Some(x) => Ok(x),
+ None => Err(thrift::Error::User("did not get expected Xception struct".into())),
+ }
+ }
+ _ => Err(thrift::Error::User("did not get exception".into())),
+ }?;
+
+ let x_cmp = Xception2 {
+ error_code: Some(2002),
+ struct_thing: Some(Xtruct {
+ string_thing: Some("This is an Xception2".to_owned()),
+ byte_thing: Some(0), /* since this is an OPT_IN_REQ_OUT field the sender sets a default */
+ i32_thing: Some(0), /* since this is an OPT_IN_REQ_OUT field the sender sets a default */
+ i64_thing: Some(0), /* since this is an OPT_IN_REQ_OUT field the sender sets a default */
+ }),
+ };
+
+ verify_expected_result(Ok(x), &x_cmp)?;
+ }
+
+ println!("testMultiException - remote succeeds");
+ {
+ let r = client.test_multi_exception("haha".to_owned(), "RETURNED".to_owned());
+ let x = match r {
+ Err(e) => {
+ Err(thrift::Error::User(format!("received an unexpected exception {:?}", e).into()))
+ }
+ _ => r,
+ }?;
+
+ let x_cmp = Xtruct {
+ string_thing: Some("RETURNED".to_owned()),
+ byte_thing: Some(0), // since this is an OPT_IN_REQ_OUT field the sender sets a default
+ i32_thing: Some(0), // since this is an OPT_IN_REQ_OUT field the sender sets a default
+ i64_thing: Some(0), // since this is an OPT_IN_REQ_OUT field the sender sets a default
+ };
+
+ verify_expected_result(Ok(x), x_cmp)?;
+ }
+
+ println!("testOneWay - remote sleeps for 1 second");
+ {
+ client.test_oneway(1)?;
+ }
+
+ // final test to verify that the connection is still writable after the one-way call
+ client.test_void()
+}
+
+fn verify_expected_result<T: Debug + PartialEq + Sized>(actual: Result<T, thrift::Error>,
+ expected: T)
+ -> Result<(), thrift::Error> {
+ match actual {
+ Ok(v) => {
+ if v == expected {
+ Ok(())
+ } else {
+ Err(thrift::Error::User(format!("expected {:?} but got {:?}", &expected, &v)
+ .into()))
+ }
+ }
+ Err(e) => Err(e),
+ }
+}
diff --git a/test/rs/src/bin/test_server.rs b/test/rs/src/bin/test_server.rs
new file mode 100644
index 0000000..613cd55
--- /dev/null
+++ b/test/rs/src/bin/test_server.rs
@@ -0,0 +1,337 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate clap;
+extern crate ordered_float;
+extern crate thrift;
+extern crate thrift_test;
+
+use ordered_float::OrderedFloat;
+use std::collections::{BTreeMap, BTreeSet};
+use std::thread;
+use std::time::Duration;
+
+use thrift::protocol::{TBinaryInputProtocolFactory, TBinaryOutputProtocolFactory,
+ TCompactInputProtocolFactory, TCompactOutputProtocolFactory,
+ TInputProtocolFactory, TOutputProtocolFactory};
+use thrift::server::TSimpleServer;
+use thrift::transport::{TBufferedTransportFactory, TFramedTransportFactory, TTransportFactory};
+use thrift_test::*;
+
+fn main() {
+ match run() {
+ Ok(()) => println!("cross-test server succeeded"),
+ Err(e) => {
+ println!("cross-test server failed with error {:?}", e);
+ std::process::exit(1);
+ }
+ }
+}
+
+fn run() -> thrift::Result<()> {
+
+ // unsupported options:
+ // --domain-socket
+ // --named-pipe
+ // --ssl
+ // --workers
+ let matches = clap_app!(rust_test_client =>
+ (version: "1.0")
+ (author: "Apache Thrift Developers <dev@thrift.apache.org>")
+ (about: "Rust Thrift test server")
+ (@arg port: --port +takes_value "port on which the test server listens")
+ (@arg transport: --transport +takes_value "transport implementation to use (\"buffered\", \"framed\")")
+ (@arg protocol: --protocol +takes_value "protocol implementation to use (\"binary\", \"compact\")")
+ (@arg server_type: --server_type +takes_value "type of server instantiated (\"simple\", \"thread-pool\", \"threaded\", \"non-blocking\")")
+ ).get_matches();
+
+ let port = value_t!(matches, "port", u16).unwrap_or(9090);
+ let transport = matches.value_of("transport").unwrap_or("buffered");
+ let protocol = matches.value_of("protocol").unwrap_or("binary");
+ let server_type = matches.value_of("server_type").unwrap_or("simple");
+ let listen_address = format!("127.0.0.1:{}", port);
+
+ println!("binding to {}", listen_address);
+
+ let (i_transport_factory, o_transport_factory): (Box<TTransportFactory>,
+ Box<TTransportFactory>) = match &*transport {
+ "buffered" => {
+ (Box::new(TBufferedTransportFactory::new()), Box::new(TBufferedTransportFactory::new()))
+ }
+ "framed" => {
+ (Box::new(TFramedTransportFactory::new()), Box::new(TFramedTransportFactory::new()))
+ }
+ unknown => {
+ return Err(format!("unsupported transport type {}", unknown).into());
+ }
+ };
+
+ let (i_protocol_factory, o_protocol_factory): (Box<TInputProtocolFactory>,
+ Box<TOutputProtocolFactory>) =
+ match &*protocol {
+ "binary" => {
+ (Box::new(TBinaryInputProtocolFactory::new()),
+ Box::new(TBinaryOutputProtocolFactory::new()))
+ }
+ "compact" => {
+ (Box::new(TCompactInputProtocolFactory::new()),
+ Box::new(TCompactOutputProtocolFactory::new()))
+ }
+ unknown => {
+ return Err(format!("unsupported transport type {}", unknown).into());
+ }
+ };
+
+ let processor = ThriftTestSyncProcessor::new(ThriftTestSyncHandlerImpl {});
+
+ let mut server = match &*server_type {
+ "simple" => {
+ TSimpleServer::new(i_transport_factory,
+ i_protocol_factory,
+ o_transport_factory,
+ o_protocol_factory,
+ processor)
+ }
+ unknown => {
+ return Err(format!("unsupported server type {}", unknown).into());
+ }
+ };
+
+ server.listen(&listen_address)
+}
+
+struct ThriftTestSyncHandlerImpl;
+impl ThriftTestSyncHandler for ThriftTestSyncHandlerImpl {
+ fn handle_test_void(&mut self) -> thrift::Result<()> {
+ println!("testVoid()");
+ Ok(())
+ }
+
+ fn handle_test_string(&mut self, thing: String) -> thrift::Result<String> {
+ println!("testString({})", &thing);
+ Ok(thing)
+ }
+
+ fn handle_test_bool(&mut self, thing: bool) -> thrift::Result<bool> {
+ println!("testBool({})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_byte(&mut self, thing: i8) -> thrift::Result<i8> {
+ println!("testByte({})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_i32(&mut self, thing: i32) -> thrift::Result<i32> {
+ println!("testi32({})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_i64(&mut self, thing: i64) -> thrift::Result<i64> {
+ println!("testi64({})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_double(&mut self,
+ thing: OrderedFloat<f64>)
+ -> thrift::Result<OrderedFloat<f64>> {
+ println!("testDouble({})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_binary(&mut self, thing: Vec<u8>) -> thrift::Result<Vec<u8>> {
+ println!("testBinary({:?})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_struct(&mut self, thing: Xtruct) -> thrift::Result<Xtruct> {
+ println!("testStruct({:?})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_nest(&mut self, thing: Xtruct2) -> thrift::Result<Xtruct2> {
+ println!("testNest({:?})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_map(&mut self, thing: BTreeMap<i32, i32>) -> thrift::Result<BTreeMap<i32, i32>> {
+ println!("testMap({:?})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_string_map(&mut self,
+ thing: BTreeMap<String, String>)
+ -> thrift::Result<BTreeMap<String, String>> {
+ println!("testStringMap({:?})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_set(&mut self, thing: BTreeSet<i32>) -> thrift::Result<BTreeSet<i32>> {
+ println!("testSet({:?})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_list(&mut self, thing: Vec<i32>) -> thrift::Result<Vec<i32>> {
+ println!("testList({:?})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_enum(&mut self, thing: Numberz) -> thrift::Result<Numberz> {
+ println!("testEnum({:?})", thing);
+ Ok(thing)
+ }
+
+ fn handle_test_typedef(&mut self, thing: UserId) -> thrift::Result<UserId> {
+ println!("testTypedef({})", thing);
+ Ok(thing)
+ }
+
+ /// @return map<i32,map<i32,i32>> - returns a dictionary with these values:
+ /// {-4 => {-4 => -4, -3 => -3, -2 => -2, -1 => -1, }, 4 => {1 => 1, 2 => 2, 3 => 3, 4 => 4, }, }
+ fn handle_test_map_map(&mut self,
+ hello: i32)
+ -> thrift::Result<BTreeMap<i32, BTreeMap<i32, i32>>> {
+ println!("testMapMap({})", hello);
+
+ let mut inner_map_0: BTreeMap<i32, i32> = BTreeMap::new();
+ for i in -4..(0 as i32) {
+ inner_map_0.insert(i, i);
+ }
+
+ let mut inner_map_1: BTreeMap<i32, i32> = BTreeMap::new();
+ for i in 1..5 {
+ inner_map_1.insert(i, i);
+ }
+
+ let mut ret_map: BTreeMap<i32, BTreeMap<i32, i32>> = BTreeMap::new();
+ ret_map.insert(-4, inner_map_0);
+ ret_map.insert(4, inner_map_1);
+
+ Ok(ret_map)
+ }
+
+ /// Creates a the returned map with these values and prints it out:
+ /// { 1 => { 2 => argument,
+ /// 3 => argument,
+ /// },
+ /// 2 => { 6 => <empty Insanity struct>, },
+ /// }
+ /// return map<UserId, map<Numberz,Insanity>> - a map with the above values
+ fn handle_test_insanity(&mut self,
+ argument: Insanity)
+ -> thrift::Result<BTreeMap<UserId, BTreeMap<Numberz, Insanity>>> {
+ println!("testInsanity({:?})", argument);
+ let mut map_0: BTreeMap<Numberz, Insanity> = BTreeMap::new();
+ map_0.insert(Numberz::TWO, argument.clone());
+ map_0.insert(Numberz::THREE, argument.clone());
+
+ let mut map_1: BTreeMap<Numberz, Insanity> = BTreeMap::new();
+ let insanity = Insanity {
+ user_map: None,
+ xtructs: None,
+ };
+ map_1.insert(Numberz::SIX, insanity);
+
+ let mut ret: BTreeMap<UserId, BTreeMap<Numberz, Insanity>> = BTreeMap::new();
+ ret.insert(1, map_0);
+ ret.insert(2, map_1);
+
+ Ok(ret)
+ }
+
+ /// returns an Xtruct with string_thing = "Hello2", byte_thing = arg0, i32_thing = arg1 and i64_thing = arg2
+ fn handle_test_multi(&mut self,
+ arg0: i8,
+ arg1: i32,
+ arg2: i64,
+ _: BTreeMap<i16, String>,
+ _: Numberz,
+ _: UserId)
+ -> thrift::Result<Xtruct> {
+ let x_ret = Xtruct {
+ string_thing: Some("Hello2".to_owned()),
+ byte_thing: Some(arg0),
+ i32_thing: Some(arg1),
+ i64_thing: Some(arg2),
+ };
+
+ Ok(x_ret)
+ }
+
+ /// if arg == "Xception" throw Xception with errorCode = 1001 and message = arg
+ /// else if arg == "TException" throw TException
+ /// else do not throw anything
+ fn handle_test_exception(&mut self, arg: String) -> thrift::Result<()> {
+ println!("testException({})", arg);
+
+ match &*arg {
+ "Xception" => {
+ Err((Xception {
+ error_code: Some(1001),
+ message: Some(arg),
+ })
+ .into())
+ }
+ "TException" => Err("this is a random error".into()),
+ _ => Ok(()),
+ }
+ }
+
+ /// if arg0 == "Xception" throw Xception with errorCode = 1001 and message = "This is an Xception"
+ /// else if arg0 == "Xception2" throw Xception2 with errorCode = 2002 and struct_thing.string_thing = "This is an Xception2"
+ // else do not throw anything and return Xtruct with string_thing = arg1
+ fn handle_test_multi_exception(&mut self,
+ arg0: String,
+ arg1: String)
+ -> thrift::Result<Xtruct> {
+ match &*arg0 {
+ "Xception" => {
+ Err((Xception {
+ error_code: Some(1001),
+ message: Some("This is an Xception".to_owned()),
+ })
+ .into())
+ }
+ "Xception2" => {
+ Err((Xception2 {
+ error_code: Some(2002),
+ struct_thing: Some(Xtruct {
+ string_thing: Some("This is an Xception2".to_owned()),
+ byte_thing: None,
+ i32_thing: None,
+ i64_thing: None,
+ }),
+ })
+ .into())
+ }
+ _ => {
+ Ok(Xtruct {
+ string_thing: Some(arg1),
+ byte_thing: None,
+ i32_thing: None,
+ i64_thing: None,
+ })
+ }
+ }
+ }
+
+ fn handle_test_oneway(&mut self, seconds_to_sleep: i32) -> thrift::Result<()> {
+ thread::sleep(Duration::from_secs(seconds_to_sleep as u64));
+ Ok(())
+ }
+}
diff --git a/test/rs/src/lib.rs b/test/rs/src/lib.rs
new file mode 100644
index 0000000..479bf90
--- /dev/null
+++ b/test/rs/src/lib.rs
@@ -0,0 +1,23 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern crate ordered_float;
+extern crate thrift;
+extern crate try_from;
+
+mod thrift_test;
+pub use thrift_test::*;
diff --git a/test/tests.json b/test/tests.json
index b101bfd..09d4c89 100644
--- a/test/tests.json
+++ b/test/tests.json
@@ -574,5 +574,31 @@
]
},
"workdir": "lua"
+ },
+ {
+ "name": "rs",
+ "server": {
+ "command": [
+ "test_server"
+ ]
+ },
+ "client": {
+ "timeout": 6,
+ "command": [
+ "test_client"
+ ]
+ },
+ "transports": [
+ "buffered",
+ "framed"
+ ],
+ "sockets": [
+ "ip"
+ ],
+ "protocols": [
+ "binary",
+ "compact"
+ ],
+ "workdir": "rs/bin"
}
]
diff --git a/tutorial/Makefile.am b/tutorial/Makefile.am
index efa314a..d8ad09c 100755
--- a/tutorial/Makefile.am
+++ b/tutorial/Makefile.am
@@ -74,6 +74,10 @@
SUBDIRS += dart
endif
+if WITH_RS
+SUBDIRS += rs
+endif
+
#
# generate html for ThriftTest.thrift
#
diff --git a/tutorial/rs/Cargo.toml b/tutorial/rs/Cargo.toml
new file mode 100644
index 0000000..9075db7
--- /dev/null
+++ b/tutorial/rs/Cargo.toml
@@ -0,0 +1,16 @@
+[package]
+name = "thrift-tutorial"
+version = "0.1.0"
+license = "Apache-2.0"
+authors = ["Apache Thrift Developers <dev@thrift.apache.org>"]
+exclude = ["Makefile*", "shared.rs", "tutorial.rs"]
+publish = false
+
+[dependencies]
+clap = "2.18.0"
+ordered-float = "0.3.0"
+try_from = "0.2.0"
+
+[dependencies.thrift]
+path = "../../lib/rs"
+
diff --git a/tutorial/rs/Makefile.am b/tutorial/rs/Makefile.am
new file mode 100644
index 0000000..666331e
--- /dev/null
+++ b/tutorial/rs/Makefile.am
@@ -0,0 +1,52 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+THRIFT = $(top_builddir)/compiler/cpp/thrift
+
+gen-rs/tutorial.rs gen-rs/shared.rs: $(top_srcdir)/tutorial/tutorial.thrift
+ $(THRIFT) -out src --gen rs -r $<
+
+all-local: gen-rs/tutorial.rs
+ $(CARGO) build
+ [ -d bin ] || mkdir bin
+ cp target/debug/tutorial_server bin/tutorial_server
+ cp target/debug/tutorial_client bin/tutorial_client
+
+check: all
+
+tutorialserver: all
+ bin/tutorial_server
+
+tutorialclient: all
+ bin/tutorial_client
+
+clean-local:
+ $(CARGO) clean
+ -$(RM) Cargo.lock
+ -$(RM) src/shared.rs
+ -$(RM) src/tutorial.rs
+ -$(RM) -r bin
+
+EXTRA_DIST = \
+ Cargo.toml \
+ src/lib.rs \
+ src/bin/tutorial_server.rs \
+ src/bin/tutorial_client.rs \
+ README.md
+
diff --git a/tutorial/rs/README.md b/tutorial/rs/README.md
new file mode 100644
index 0000000..4d0d7c8
--- /dev/null
+++ b/tutorial/rs/README.md
@@ -0,0 +1,330 @@
+# Rust Language Bindings for Thrift
+
+## Getting Started
+
+1. Get the [Thrift compiler](https://thrift.apache.org).
+
+2. Add the following crates to your `Cargo.toml`.
+
+```toml
+thrift = "x.y.z" # x.y.z is the version of the thrift compiler
+ordered_float = "0.3.0"
+try_from = "0.2.0"
+```
+
+3. Add the same crates to your `lib.rs` or `main.rs`.
+
+```rust
+extern crate ordered_float;
+extern crate thrift;
+extern crate try_from;
+```
+
+4. Generate Rust sources for your IDL (for example, `Tutorial.thrift`).
+
+```shell
+thrift -out my_rust_program/src --gen rs -r Tutorial.thrift
+```
+
+5. Use the generated source in your code.
+
+```rust
+// add extern crates here, or in your lib.rs
+extern crate ordered_float;
+extern crate thrift;
+extern crate try_from;
+
+// generated Rust module
+mod tutorial;
+
+use std::cell::RefCell;
+use std::rc::Rc;
+use thrift::protocol::{TInputProtocol, TOutputProtocol};
+use thrift::protocol::{TCompactInputProtocol, TCompactOutputProtocol};
+use thrift::transport::{TFramedTransport, TTcpTransport, TTransport};
+use tutorial::{CalculatorSyncClient, TCalculatorSyncClient};
+use tutorial::{Operation, Work};
+
+fn main() {
+ match run() {
+ Ok(()) => println!("client ran successfully"),
+ Err(e) => {
+ println!("client failed with {:?}", e);
+ std::process::exit(1);
+ }
+ }
+}
+
+fn run() -> thrift::Result<()> {
+ //
+ // build client
+ //
+
+ println!("connect to server on 127.0.0.1:9090");
+ let mut t = TTcpTransport::new();
+ let t = match t.open("127.0.0.1:9090") {
+ Ok(()) => t,
+ Err(e) => {
+ return Err(
+ format!("failed to connect with {:?}", e).into()
+ );
+ }
+ };
+
+ let t = Rc::new(RefCell::new(
+ Box::new(t) as Box<TTransport>
+ ));
+ let t = Rc::new(RefCell::new(
+ Box::new(TFramedTransport::new(t)) as Box<TTransport>
+ ));
+
+ let i_prot: Box<TInputProtocol> = Box::new(
+ TCompactInputProtocol::new(t.clone())
+ );
+ let o_prot: Box<TOutputProtocol> = Box::new(
+ TCompactOutputProtocol::new(t.clone())
+ );
+
+ let client = CalculatorSyncClient::new(i_prot, o_prot);
+
+ //
+ // alright! - let's make some calls
+ //
+
+ // two-way, void return
+ client.ping()?;
+
+ // two-way with some return
+ let res = client.calculate(
+ 72,
+ Work::new(7, 8, Operation::MULTIPLY, None)
+ )?;
+ println!("multiplied 7 and 8, got {}", res);
+
+ // two-way and returns a Thrift-defined exception
+ let res = client.calculate(
+ 77,
+ Work::new(2, 0, Operation::DIVIDE, None)
+ );
+ match res {
+ Ok(v) => panic!("shouldn't have succeeded with result {}", v),
+ Err(e) => println!("divide by zero failed with {:?}", e),
+ }
+
+ // one-way
+ client.zip()?;
+
+ // done!
+ Ok(())
+}
+```
+
+## Code Generation
+
+### Thrift Files and Generated Modules
+
+The Thrift code generator takes each Thrift file and generates a Rust module
+with the same name snake-cased. For example, running the compiler on
+`ThriftTest.thrift` creates `thrift_test.rs`. To use these generated files add
+`mod ...` and `use ...` declarations to your `lib.rs` or `main.rs` - one for
+each generated file.
+
+### Results and Errors
+
+The Thrift runtime library defines a `thrift::Result` and a `thrift::Error` type,
+both of which are used throught the runtime library and in all generated code.
+Conversions are defined from `std::io::Error`, `str` and `String` into
+`thrift::Error`.
+
+### Thrift Type and their Rust Equivalents
+
+Thrift defines a number of types, each of which is translated into its Rust
+equivalent by the code generator.
+
+* Primitives (bool, i8, i16, i32, i64, double, string, binary)
+* Typedefs
+* Enums
+* Containers
+* Structs
+* Unions
+* Exceptions
+* Services
+* Constants (primitives, containers, structs)
+
+In addition, unless otherwise noted, thrift includes are translated into
+`use ...` statements in the generated code, and all declarations, parameters,
+traits and types in the generated code are namespaced appropriately.
+
+The following subsections cover each type and their generated Rust equivalent.
+
+### Primitives
+
+Thrift primitives have straightforward Rust equivalents.
+
+* bool: `bool`
+* i8: `i8`
+* i16: `i16`
+* i32: `i32`
+* i64: `i64`
+* double: `OrderedFloat<f64>`
+* string: `String`
+* binary: `Vec<u8>`
+
+### Typedefs
+
+A typedef is translated to a `pub type` declaration.
+
+```thrift
+typedef i64 UserId
+
+typedef map<string, Bonk> MapType
+```
+```rust
+pub type UserId = 164;
+
+pub type MapType = BTreeMap<String, Bonk>;
+```
+
+### Enums
+
+A Thrift enum is represented as a Rust enum, and each variant is transcribed 1:1.
+
+```thrift
+enum Numberz
+{
+ ONE = 1,
+ TWO,
+ THREE,
+ FIVE = 5,
+ SIX,
+ EIGHT = 8
+}
+```
+
+```rust
+#[derive(Copy, Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
+pub enum Numberz {
+ ONE = 1,
+ TWO = 2,
+ THREE = 3,
+ FIVE = 5,
+ SIX = 6,
+ EIGHT = 8,
+}
+
+impl TryFrom<i32> for Numberz {
+ // ...
+}
+
+```
+
+### Containers
+
+Thrift has three container types: list, set and map. They are translated into
+Rust `Vec`, `BTreeSet` and `BTreeMap` respectively. Any Thrift type (this
+includes structs, enums and typedefs) can be a list/set element or a map
+key/value.
+
+#### List
+
+```thrift
+list <i32> numbers
+```
+
+```rust
+numbers: Vec<i32>
+```
+
+#### Set
+
+```thrift
+set <i32> numbers
+```
+
+```rust
+numbers: BTreeSet<i32>
+```
+
+#### Map
+
+```thrift
+map <string, i32> numbers
+```
+
+```rust
+numbers: BTreeMap<String, i32>
+```
+
+### Structs
+
+A Thrift struct is represented as a Rust struct, and each field transcribed 1:1.
+
+```thrift
+struct CrazyNesting {
+ 1: string string_field,
+ 2: optional set<Insanity> set_field,
+ 3: required list<
+ map<set<i32>, map<i32,set<list<map<Insanity,string>>>>>
+ >
+ 4: binary binary_field
+}
+```
+```rust
+#[derive(Clone, Debug, Eq, Ord, PartialEq, PartialOrd)]
+pub struct CrazyNesting {
+ pub string_field: Option<String>,
+ pub set_field: Option<BTreeSet<Insanity>>,
+ pub list_field: Vec<
+ BTreeMap<
+ BTreeSet<i32>,
+ BTreeMap<i32, BTreeSet<Vec<BTreeMap<Insanity, String>>>>
+ >
+ >,
+ pub binary_field: Option<Vec<u8>>,
+}
+
+impl CrazyNesting {
+ pub fn read_from_in_protocol(i_prot: &mut TInputProtocol)
+ ->
+ thrift::Result<CrazyNesting> {
+ // ...
+ }
+ pub fn write_to_out_protocol(&self, o_prot: &mut TOutputProtocol)
+ ->
+ thrift::Result<()> {
+ // ...
+ }
+}
+
+```
+##### Optionality
+
+Thrift has 3 "optionality" types:
+
+1. Required
+2. Optional
+3. Default
+
+The Rust code generator encodes *Required* fields as the bare type itself, while
+*Optional* and *Default* fields are encoded as `Option<TypeName>`.
+
+```thrift
+struct Foo {
+ 1: required string bar // 1. required
+ 2: optional string baz // 2. optional
+ 3: string qux // 3. default
+}
+```
+
+```rust
+pub struct Foo {
+ bar: String, // 1. required
+ baz: Option<String>, // 2. optional
+ qux: Option<String>, // 3. default
+}
+```
+
+## Known Issues
+
+* Struct constants are not supported
+* Map, list and set constants require a const holder struct
\ No newline at end of file
diff --git a/tutorial/rs/src/bin/tutorial_client.rs b/tutorial/rs/src/bin/tutorial_client.rs
new file mode 100644
index 0000000..2b0d4f9
--- /dev/null
+++ b/tutorial/rs/src/bin/tutorial_client.rs
@@ -0,0 +1,136 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate clap;
+
+extern crate thrift;
+extern crate thrift_tutorial;
+
+use std::cell::RefCell;
+use std::rc::Rc;
+
+use thrift::protocol::{TInputProtocol, TOutputProtocol};
+use thrift::protocol::{TCompactInputProtocol, TCompactOutputProtocol};
+use thrift::transport::{TFramedTransport, TTcpTransport, TTransport};
+
+use thrift_tutorial::shared::TSharedServiceSyncClient;
+use thrift_tutorial::tutorial::{CalculatorSyncClient, TCalculatorSyncClient, Operation, Work};
+
+fn main() {
+ match run() {
+ Ok(()) => println!("tutorial client ran successfully"),
+ Err(e) => {
+ println!("tutorial client failed with error {:?}", e);
+ std::process::exit(1);
+ }
+ }
+}
+
+fn run() -> thrift::Result<()> {
+ let options = clap_app!(rust_tutorial_client =>
+ (version: "0.1.0")
+ (author: "Apache Thrift Developers <dev@thrift.apache.org>")
+ (about: "Thrift Rust tutorial client")
+ (@arg host: --host +takes_value "host on which the tutorial server listens")
+ (@arg port: --port +takes_value "port on which the tutorial server listens")
+ );
+ let matches = options.get_matches();
+
+ // get any passed-in args or the defaults
+ let host = matches.value_of("host").unwrap_or("127.0.0.1");
+ let port = value_t!(matches, "port", u16).unwrap_or(9090);
+
+ // build our client and connect to the host:port
+ let mut client = new_client(host, port)?;
+
+ // alright!
+ // let's start making some calls
+
+ // let's start with a ping; the server should respond
+ println!("ping!");
+ client.ping()?;
+
+ // simple add
+ println!("add");
+ let res = client.add(1, 2)?;
+ println!("added 1, 2 and got {}", res);
+
+ let logid = 32;
+
+ // let's do...a multiply!
+ let res = client.calculate(logid, Work::new(7, 8, Operation::MULTIPLY, None))?;
+ println!("multiplied 7 and 8 and got {}", res);
+
+ // let's get the log for it
+ let res = client.get_struct(32)?;
+ println!("got log {:?} for operation {}", res, logid);
+
+ // ok - let's be bad :(
+ // do a divide by 0
+ // logid doesn't matter; won't be recorded
+ let res = client.calculate(77, Work::new(2, 0, Operation::DIVIDE, "we bad".to_owned()));
+
+ // we should have gotten an exception back
+ match res {
+ Ok(v) => panic!("should not have succeeded with result {}", v),
+ Err(e) => println!("divide by zero failed with error {:?}", e),
+ }
+
+ // let's do a one-way call
+ println!("zip");
+ client.zip()?;
+
+ // and then close out with a final ping
+ println!("ping!");
+ client.ping()?;
+
+ Ok(())
+}
+
+fn new_client(host: &str, port: u16) -> thrift::Result<CalculatorSyncClient> {
+ let mut t = TTcpTransport::new();
+
+ // open the underlying TCP stream
+ println!("connecting to tutorial server on {}:{}", host, port);
+ let t = match t.open(&format!("{}:{}", host, port)) {
+ Ok(()) => t,
+ Err(e) => {
+ return Err(format!("failed to open tcp stream to {}:{} error:{:?}",
+ host,
+ port,
+ e)
+ .into());
+ }
+ };
+
+ // refcounted because it's shared by both input and output transports
+ let t = Rc::new(RefCell::new(Box::new(t) as Box<TTransport>));
+
+ // wrap a raw socket (slow) with a buffered transport of some kind
+ let t = Box::new(TFramedTransport::new(t)) as Box<TTransport>;
+
+ // refcounted again because it's shared by both input and output protocols
+ let t = Rc::new(RefCell::new(t));
+
+ // now create the protocol implementations
+ let i_prot = Box::new(TCompactInputProtocol::new(t.clone())) as Box<TInputProtocol>;
+ let o_prot = Box::new(TCompactOutputProtocol::new(t.clone())) as Box<TOutputProtocol>;
+
+ // we're done!
+ Ok(CalculatorSyncClient::new(i_prot, o_prot))
+}
diff --git a/tutorial/rs/src/bin/tutorial_server.rs b/tutorial/rs/src/bin/tutorial_server.rs
new file mode 100644
index 0000000..9cc1866
--- /dev/null
+++ b/tutorial/rs/src/bin/tutorial_server.rs
@@ -0,0 +1,168 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#[macro_use]
+extern crate clap;
+
+extern crate thrift;
+extern crate thrift_tutorial;
+
+use std::collections::HashMap;
+use std::convert::{From, Into};
+use std::default::Default;
+
+use thrift::protocol::{TInputProtocolFactory, TOutputProtocolFactory};
+use thrift::protocol::{TCompactInputProtocolFactory, TCompactOutputProtocolFactory};
+use thrift::server::TSimpleServer;
+
+use thrift::transport::{TFramedTransportFactory, TTransportFactory};
+use thrift_tutorial::shared::{SharedServiceSyncHandler, SharedStruct};
+use thrift_tutorial::tutorial::{CalculatorSyncHandler, CalculatorSyncProcessor};
+use thrift_tutorial::tutorial::{InvalidOperation, Operation, Work};
+
+fn main() {
+ match run() {
+ Ok(()) => println!("tutorial server ran successfully"),
+ Err(e) => {
+ println!("tutorial server failed with error {:?}", e);
+ std::process::exit(1);
+ }
+ }
+}
+
+fn run() -> thrift::Result<()> {
+ let options = clap_app!(rust_tutorial_server =>
+ (version: "0.1.0")
+ (author: "Apache Thrift Developers <dev@thrift.apache.org>")
+ (about: "Thrift Rust tutorial server")
+ (@arg port: --port +takes_value "port on which the tutorial server listens")
+ );
+ let matches = options.get_matches();
+
+ let port = value_t!(matches, "port", u16).unwrap_or(9090);
+ let listen_address = format!("127.0.0.1:{}", port);
+
+ println!("binding to {}", listen_address);
+
+ let i_tran_fact: Box<TTransportFactory> = Box::new(TFramedTransportFactory::new());
+ let i_prot_fact: Box<TInputProtocolFactory> = Box::new(TCompactInputProtocolFactory::new());
+
+ let o_tran_fact: Box<TTransportFactory> = Box::new(TFramedTransportFactory::new());
+ let o_prot_fact: Box<TOutputProtocolFactory> = Box::new(TCompactOutputProtocolFactory::new());
+
+ // demux incoming messages
+ let processor = CalculatorSyncProcessor::new(CalculatorServer { ..Default::default() });
+
+ // create the server and start listening
+ let mut server = TSimpleServer::new(i_tran_fact,
+ i_prot_fact,
+ o_tran_fact,
+ o_prot_fact,
+ processor);
+
+ server.listen(&listen_address)
+}
+
+/// Handles incoming Calculator service calls.
+struct CalculatorServer {
+ log: HashMap<i32, SharedStruct>,
+}
+
+impl Default for CalculatorServer {
+ fn default() -> CalculatorServer {
+ CalculatorServer { log: HashMap::new() }
+ }
+}
+
+// since Calculator extends SharedService we have to implement the
+// handler for both traits.
+//
+
+// SharedService handler
+impl SharedServiceSyncHandler for CalculatorServer {
+ fn handle_get_struct(&mut self, key: i32) -> thrift::Result<SharedStruct> {
+ self.log
+ .get(&key)
+ .cloned()
+ .ok_or_else(|| format!("could not find log for key {}", key).into())
+ }
+}
+
+// Calculator handler
+impl CalculatorSyncHandler for CalculatorServer {
+ fn handle_ping(&mut self) -> thrift::Result<()> {
+ println!("pong!");
+ Ok(())
+ }
+
+ fn handle_add(&mut self, num1: i32, num2: i32) -> thrift::Result<i32> {
+ println!("handling add: n1:{} n2:{}", num1, num2);
+ Ok(num1 + num2)
+ }
+
+ fn handle_calculate(&mut self, logid: i32, w: Work) -> thrift::Result<i32> {
+ println!("handling calculate: l:{}, w:{:?}", logid, w);
+
+ let res = if let Some(ref op) = w.op {
+ if w.num1.is_none() || w.num2.is_none() {
+ Err(InvalidOperation {
+ what_op: Some(*op as i32),
+ why: Some("no operands specified".to_owned()),
+ })
+ } else {
+ // so that I don't have to call unwrap() multiple times below
+ let num1 = w.num1.as_ref().expect("operands checked");
+ let num2 = w.num2.as_ref().expect("operands checked");
+
+ match *op {
+ Operation::ADD => Ok(num1 + num2),
+ Operation::SUBTRACT => Ok(num1 - num2),
+ Operation::MULTIPLY => Ok(num1 * num2),
+ Operation::DIVIDE => {
+ if *num2 == 0 {
+ Err(InvalidOperation {
+ what_op: Some(*op as i32),
+ why: Some("divide by 0".to_owned()),
+ })
+ } else {
+ Ok(num1 / num2)
+ }
+ }
+ }
+ }
+ } else {
+ Err(InvalidOperation::new(None, "no operation specified".to_owned()))
+ };
+
+ // if the operation was successful log it
+ if let Ok(ref v) = res {
+ self.log.insert(logid, SharedStruct::new(logid, format!("{}", v)));
+ }
+
+ // the try! macro automatically maps errors
+ // but, since we aren't using that here we have to map errors manually
+ //
+ // exception structs defined in the IDL have an auto-generated
+ // impl of From::from
+ res.map_err(From::from)
+ }
+
+ fn handle_zip(&mut self) -> thrift::Result<()> {
+ println!("handling zip");
+ Ok(())
+ }
+}
diff --git a/tutorial/rs/src/lib.rs b/tutorial/rs/src/lib.rs
new file mode 100644
index 0000000..40007e5
--- /dev/null
+++ b/tutorial/rs/src/lib.rs
@@ -0,0 +1,23 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern crate ordered_float;
+extern crate thrift;
+extern crate try_from;
+
+pub mod shared;
+pub mod tutorial;