THRIFT-5208: fix OCaml struct and exn raising/handling codegen
Client: ocaml
Patch: Yawar Amin
diff --git a/compiler/cpp/src/thrift/generate/t_ocaml_generator.cc b/compiler/cpp/src/thrift/generate/t_ocaml_generator.cc
index 5e86de4..747adcb 100644
--- a/compiler/cpp/src/thrift/generate/t_ocaml_generator.cc
+++ b/compiler/cpp/src/thrift/generate/t_ocaml_generator.cc
@@ -61,12 +61,9 @@
     out_dir_base_ = "gen-ocaml";
   }
 
-  /**
-   * Init and close methods
-   */
+  ~t_ocaml_generator() override;
 
   void init_generator() override;
-  void close_generator() override;
 
   /**
    * Program-level generation functions
@@ -147,16 +144,20 @@
    * Helper rendering functions
    */
 
-  std::string ocaml_autogen_comment();
+  /** Need to disable codegen comment for unit tests to be version-agnostic */
+  virtual std::string ocaml_autogen_comment();
+
   std::string ocaml_imports();
   std::string type_name(t_type* ttype);
+  std::string exception_ctor(t_type* ttype);
   std::string function_signature(t_function* tfunction, std::string prefix = "");
   std::string function_type(t_function* tfunc, bool method = false, bool options = false);
   std::string argument_list(t_struct* tstruct);
   std::string type_to_enum(t_type* ttype);
   std::string render_ocaml_type(t_type* type);
 
-private:
+// Need access to output file streams for testing.
+protected:
   /**
    * File streams
    */
@@ -216,9 +217,6 @@
   // Generate constants
   vector<t_const*> consts = program_->get_consts();
   generate_consts(consts);
-
-  // Close the generator
-  close_generator();
 }
 
 /**
@@ -262,12 +260,12 @@
   return "open Thrift";
 }
 
-/**
- * Closes the type files
- */
-void t_ocaml_generator::close_generator() {
-  // Close types file
+t_ocaml_generator::~t_ocaml_generator() {
+  f_consts_.close();
   f_types_.close();
+  f_types_i_.close();
+  f_service_.close();
+  f_service_i_.close();
 }
 
 /**
@@ -914,10 +912,6 @@
   generate_service_interface(tservice);
   generate_service_client(tservice);
   generate_service_server(tservice);
-
-  // Close service file
-  f_service_.close();
-  f_service_i_.close();
 }
 
 /**
@@ -1108,7 +1102,7 @@
       for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
         f_service_ << indent() << "(match result#get_" << (*x_iter)->get_name()
                    << " with None -> () | Some _v ->" << endl;
-        indent(f_service_) << "  raise (" << capitalize(type_name((*x_iter)->get_type()))
+        indent(f_service_) << "  raise (" << capitalize(exception_ctor((*x_iter)->get_type()))
                            << " _v));" << endl;
       }
 
@@ -1270,7 +1264,7 @@
     indent(f_service_) << "with" << endl;
     indent_up();
     for (x_iter = xceptions.begin(); x_iter != xceptions.end(); ++x_iter) {
-      f_service_ << indent() << "| " << capitalize(type_name((*x_iter)->get_type())) << " "
+      f_service_ << indent() << "| " << capitalize(exception_ctor((*x_iter)->get_type())) << " "
                  << (*x_iter)->get_name() << " -> " << endl;
       indent_up();
       indent_up();
@@ -1665,7 +1659,7 @@
   }
 
   string name = ttype->get_name();
-  if (ttype->is_service() || ttype->is_xception()) {
+  if (ttype->is_service()) {
     name = capitalize(name);
   } else {
     name = decapitalize(name);
@@ -1673,6 +1667,18 @@
   return prefix + name;
 }
 
+string t_ocaml_generator::exception_ctor(t_type* ttype) {
+  string prefix = "";
+  t_program* program = ttype->get_program();
+  if (program != nullptr && program != program_) {
+    if (!ttype->is_service()) {
+      prefix = capitalize(program->get_name()) + "_types.";
+    }
+  }
+
+  return prefix + capitalize(ttype->get_name());
+}
+
 /**
  * Converts the parse type to a Protocol.t_type enum
  */
diff --git a/compiler/cpp/tests/CMakeLists.txt b/compiler/cpp/tests/CMakeLists.txt
index b8b2777..d9c5209 100644
--- a/compiler/cpp/tests/CMakeLists.txt
+++ b/compiler/cpp/tests/CMakeLists.txt
@@ -80,7 +80,7 @@
     ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/validator_parser.h
     ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/parse/t_typedef.cc
     ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/parse/parse.cc
-    ${THRIFT_COMPILER_SOURCE_DIR}/thrift/version.h
+    ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/version.h
 )
 
 # This macro adds an option THRIFT_COMPILER_${NAME}
@@ -128,7 +128,7 @@
 THRIFT_ADD_COMPILER(json    "Enable compiler for JSON" OFF)
 THRIFT_ADD_COMPILER(lua     "Enable compiler for Lua" OFF)
 THRIFT_ADD_COMPILER(netstd  "Enable compiler for .NET Standard" ON)
-THRIFT_ADD_COMPILER(ocaml   "Enable compiler for OCaml" OFF)
+THRIFT_ADD_COMPILER(ocaml   "Enable compiler for OCaml" ON)
 THRIFT_ADD_COMPILER(perl    "Enable compiler for Perl" OFF)
 THRIFT_ADD_COMPILER(php     "Enable compiler for PHP" OFF)
 THRIFT_ADD_COMPILER(py      "Enable compiler for Python 2.0" OFF)
diff --git a/compiler/cpp/tests/README.md b/compiler/cpp/tests/README.md
index 91c0625..e45e298 100644
--- a/compiler/cpp/tests/README.md
+++ b/compiler/cpp/tests/README.md
@@ -26,7 +26,7 @@
 ## How to add your tests
 
 - Open **CMakeLists.txt**
-- Set **On** to call of **THRIFT_ADD_COMPILER** for your language
+- Set call of `THRIFT_ADD_COMPILER` for your language to `ON`
 
 ``` cmake 
 THRIFT_ADD_COMPILER(netstd "Enable compiler for .NET Standard" ON)
@@ -85,4 +85,4 @@
 cmake ..
 cmake --build .
 ctest -C Debug -V
-```
\ No newline at end of file
+```
diff --git a/compiler/cpp/tests/ocaml/README.md b/compiler/cpp/tests/ocaml/README.md
new file mode 100644
index 0000000..e79a887
--- /dev/null
+++ b/compiler/cpp/tests/ocaml/README.md
@@ -0,0 +1,16 @@
+## Testing approach
+
+1. Programmatically construct parsed instances of Thrift IDLs using internal
+   types
+2. Generate the OCaml output using the OCaml generator
+3. Capture the generated output in `ostringstream`
+4. Query and compare the outputs in the strings to stored snapshots in the
+   `snapshot_*.cc` files
+
+Run tests in `../tests` directory:
+
+      # Only on changing build definition:
+      cmake -DCMAKE_PREFIX_PATH=/usr/local/opt/bison -DCMAKE_CXX_STANDARD=11 .
+
+      # On each iteration:
+      rm -rf gen-ocaml; cmake --build . && ctest --output-on-failure
diff --git a/compiler/cpp/tests/ocaml/snapshot_exception_types_i.cc b/compiler/cpp/tests/ocaml/snapshot_exception_types_i.cc
new file mode 100644
index 0000000..a7d908f
--- /dev/null
+++ b/compiler/cpp/tests/ocaml/snapshot_exception_types_i.cc
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation(ASF) under one
+// or more contributor license agreements.See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+const char* snapshot(R"""(
+open Thrift
+class serverError :
+object ('a)
+  method copy : 'a
+  method write : Protocol.t -> unit
+end
+exception ServerError of serverError
+val read_serverError : Protocol.t -> serverError
+)""");
diff --git a/compiler/cpp/tests/ocaml/snapshot_service_handle_ex.cc b/compiler/cpp/tests/ocaml/snapshot_service_handle_ex.cc
new file mode 100644
index 0000000..f20d698
--- /dev/null
+++ b/compiler/cpp/tests/ocaml/snapshot_service_handle_ex.cc
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation(ASF) under one
+// or more contributor license agreements.See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+const char* snapshot(R"""(
+open Thrift
+open Service_types
+
+(* HELPER FUNCTIONS AND STRUCTURES *)
+
+class ping_args =
+object (self)
+  method copy =
+      let _new = Oo.copy self in
+    _new
+  method write (oprot : Protocol.t) =
+    oprot#writeStructBegin "ping_args";
+    oprot#writeFieldStop;
+    oprot#writeStructEnd
+end
+let rec read_ping_args (iprot : Protocol.t) =
+  let _str2 = new ping_args in
+    ignore(iprot#readStructBegin);
+    (try while true do
+        let (_,_t3,_id4) = iprot#readFieldBegin in
+        if _t3 = Protocol.T_STOP then
+          raise Break
+        else ();
+        (match _id4 with 
+          | _ -> iprot#skip _t3);
+        iprot#readFieldEnd;
+      done; ()
+    with Break -> ());
+    iprot#readStructEnd;
+    _str2
+
+class ping_result =
+object (self)
+  val mutable _serverError : Errors_types.serverError option = None
+  method get_serverError = _serverError
+  method grab_serverError = match _serverError with None->raise (Field_empty "ping_result.serverError") | Some _x5 -> _x5
+  method set_serverError _x5 = _serverError <- Some _x5
+  method unset_serverError = _serverError <- None
+  method reset_serverError = _serverError <- None
+
+  method copy =
+      let _new = Oo.copy self in
+      if _serverError <> None then
+        _new#set_serverError self#grab_serverError#copy;
+    _new
+  method write (oprot : Protocol.t) =
+    oprot#writeStructBegin "ping_result";
+    (match _serverError with None -> () | Some _v -> 
+      oprot#writeFieldBegin("serverError",Protocol.T_STRUCT,1);
+      _v#write(oprot);
+      oprot#writeFieldEnd
+    );
+    oprot#writeFieldStop;
+    oprot#writeStructEnd
+end
+let rec read_ping_result (iprot : Protocol.t) =
+  let _str8 = new ping_result in
+    ignore(iprot#readStructBegin);
+    (try while true do
+        let (_,_t9,_id10) = iprot#readFieldBegin in
+        if _t9 = Protocol.T_STOP then
+          raise Break
+        else ();
+        (match _id10 with 
+          | 1 -> (if _t9 = Protocol.T_STRUCT then
+              _str8#set_serverError (Errors_types.read_serverError iprot)
+            else
+              iprot#skip _t9)
+          | _ -> iprot#skip _t9);
+        iprot#readFieldEnd;
+      done; ()
+    with Break -> ());
+    iprot#readStructEnd;
+    _str8
+
+class virtual iface =
+object (self)
+  method virtual ping : unit
+end
+
+class client (iprot : Protocol.t) (oprot : Protocol.t) =
+object (self)
+  val mutable seqid = 0
+  method ping  = 
+    self#send_ping;
+    self#recv_ping
+  method private send_ping  = 
+    oprot#writeMessageBegin ("ping", Protocol.CALL, seqid);
+    let args = new ping_args in
+      args#write oprot;
+      oprot#writeMessageEnd;
+      oprot#getTransport#flush
+  method private recv_ping  =
+    let (fname, mtype, rseqid) = iprot#readMessageBegin in
+      (if mtype = Protocol.EXCEPTION then
+        let x = Application_Exn.read iprot in
+          (iprot#readMessageEnd;           raise (Application_Exn.E x))
+      else ());
+      let result = read_ping_result iprot in
+        iprot#readMessageEnd;
+        (match result#get_serverError with None -> () | Some _v ->
+          raise (Errors_types.ServerError _v));
+        ()
+end
+
+class processor (handler : iface) =
+object (self)
+  inherit Processor.t
+
+  val processMap = Hashtbl.create 1
+  method process iprot oprot =
+    let (name, typ, seqid)  = iprot#readMessageBegin in
+      if Hashtbl.mem processMap name then
+        (Hashtbl.find processMap name) (seqid, iprot, oprot)
+      else (
+        iprot#skip(Protocol.T_STRUCT);
+        iprot#readMessageEnd;
+        let x = Application_Exn.create Application_Exn.UNKNOWN_METHOD ("Unknown function "^name) in
+          oprot#writeMessageBegin(name, Protocol.EXCEPTION, seqid);
+          x#write oprot;
+          oprot#writeMessageEnd;
+          oprot#getTransport#flush
+      );
+      true
+  method private process_ping (seqid, iprot, oprot) =
+    let _ = read_ping_args iprot in
+      iprot#readMessageEnd;
+      let result = new ping_result in
+        (try
+          (handler#ping);
+        with
+          | Errors_types.ServerError serverError -> 
+              result#set_serverError serverError
+        );
+        oprot#writeMessageBegin ("ping", Protocol.REPLY, seqid);
+        result#write oprot;
+        oprot#writeMessageEnd;
+        oprot#getTransport#flush
+  initializer
+    Hashtbl.add processMap "ping" self#process_ping;
+end
+
+)""");
diff --git a/compiler/cpp/tests/ocaml/snapshot_typedefs.cc b/compiler/cpp/tests/ocaml/snapshot_typedefs.cc
new file mode 100644
index 0000000..473b2e8
--- /dev/null
+++ b/compiler/cpp/tests/ocaml/snapshot_typedefs.cc
@@ -0,0 +1,22 @@
+// Licensed to the Apache Software Foundation(ASF) under one
+// or more contributor license agreements.See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+const char* snapshot(R"""(
+open Thrift
+type decimal = string
+
+)""");
diff --git a/compiler/cpp/tests/ocaml/t_ocaml_generator_tests.cc b/compiler/cpp/tests/ocaml/t_ocaml_generator_tests.cc
new file mode 100644
index 0000000..ea788fc
--- /dev/null
+++ b/compiler/cpp/tests/ocaml/t_ocaml_generator_tests.cc
@@ -0,0 +1,111 @@
+// Licensed to the Apache Software Foundation(ASF) under one
+// or more contributor license agreements.See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <list>
+#include <memory>
+#include "../catch/catch.hpp"
+#include <thrift/generate/t_ocaml_generator.cc>
+
+using defs_t = std::list<t_type*>;
+
+/** This subclass is meant to allow accessing the Thrift generated OCaml outputs
+    and keep the tests stable across Thrift versions (as much as possible), but
+    otherwise is identical to the standard OCaml generator. */
+class t_test_ocaml_generator : public t_ocaml_generator {
+public:
+    t_test_ocaml_generator(t_program* program) : t_ocaml_generator(program, {}, "") {}
+
+    /** Override and turn off comment generation which contains a version number
+        to make tests version-independent. */
+    std::string ocaml_autogen_comment() override { return ""; }
+
+    // Allow inspecting the generated code.
+
+    string types() { return f_types_.str(); }
+    string consts() { return f_consts_.str(); }
+    string service() { return f_service_.str(); }
+    string types_i() { return f_types_i_.str(); }
+    string service_i() { return f_service_i_.str(); }
+};
+
+/** Helper to add a list of definitions to a Thrift 'program' (i.e.
+    representation of the IDL) and generate the OCaml outputs. */
+void gen_program(t_generator& gen, t_program& program, defs_t defs) {
+    for (auto def : defs) {
+        if (def->is_typedef()) program.add_typedef(static_cast<t_typedef*>(def));
+        else if (def->is_enum()) program.add_enum(static_cast<t_enum*>(def));
+        else if (def->is_struct()) program.add_struct(static_cast<t_struct*>(def));
+        else if (def->is_xception()) program.add_xception(static_cast<t_struct*>(def));
+        else if (def->is_service()) program.add_service(static_cast<t_service*>(def));
+    }
+
+    gen.generate_program();
+}
+
+TEST_CASE( "t_ocaml_generator - typedefs", "[functional]" )
+{
+    t_program program("Typedefs.thrift", "Typedefs");
+    t_base_type ty_string("string", t_base_type::TYPE_STRING);
+    t_typedef tydef_decimal(&program, &ty_string, "Decimal");
+    t_test_ocaml_generator gen(&program);
+
+    gen_program(gen, program, defs_t {
+        &tydef_decimal
+    });
+
+    #include "snapshot_typedefs.cc"
+    REQUIRE( snapshot == gen.types() );
+}
+
+TEST_CASE( "t_ocaml_generator - handle exception from different module", "[functional]" )
+{
+    t_program errors_thrift("Errors.thrift", "Errors");
+    t_struct server_error(&errors_thrift, "ServerError");
+    server_error.set_xception(true);
+
+    t_test_ocaml_generator errors_gen(&errors_thrift);
+    gen_program(errors_gen, errors_thrift, defs_t {
+        &server_error
+    });
+
+    {
+        #include "snapshot_exception_types_i.cc"
+        REQUIRE( snapshot == errors_gen.types_i() );
+    }
+
+    t_program service_thrift("Service.thrift", "Service");
+    t_service service(&service_thrift);
+    service.set_name("Service");
+    t_base_type ret_type("void", t_base_type::TYPE_VOID);
+    t_struct args(&service_thrift, "ping_args");
+    t_struct throws(&service_thrift, "ping_throws");
+    t_field ex_server_error(&server_error, "serverError", 1);
+    throws.append(&ex_server_error);
+    t_function ping(&ret_type, "ping", &args, &throws);
+    service.add_function(&ping);
+
+    t_test_ocaml_generator service_gen(&service_thrift);
+
+    gen_program(service_gen, service_thrift, defs_t {
+        &service
+    });
+
+    {
+        #include "snapshot_service_handle_ex.cc"
+        REQUIRE( snapshot == service_gen.service() );
+    }
+}