THRIFT-5208: fix OCaml struct and exn raising/handling codegen
Client: ocaml
Patch: Yawar Amin
diff --git a/compiler/cpp/tests/CMakeLists.txt b/compiler/cpp/tests/CMakeLists.txt
index b8b2777..d9c5209 100644
--- a/compiler/cpp/tests/CMakeLists.txt
+++ b/compiler/cpp/tests/CMakeLists.txt
@@ -80,7 +80,7 @@
     ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/generate/validator_parser.h
     ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/parse/t_typedef.cc
     ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/parse/parse.cc
-    ${THRIFT_COMPILER_SOURCE_DIR}/thrift/version.h
+    ${THRIFT_COMPILER_SOURCE_DIR}/src/thrift/version.h
 )
 
 # This macro adds an option THRIFT_COMPILER_${NAME}
@@ -128,7 +128,7 @@
 THRIFT_ADD_COMPILER(json    "Enable compiler for JSON" OFF)
 THRIFT_ADD_COMPILER(lua     "Enable compiler for Lua" OFF)
 THRIFT_ADD_COMPILER(netstd  "Enable compiler for .NET Standard" ON)
-THRIFT_ADD_COMPILER(ocaml   "Enable compiler for OCaml" OFF)
+THRIFT_ADD_COMPILER(ocaml   "Enable compiler for OCaml" ON)
 THRIFT_ADD_COMPILER(perl    "Enable compiler for Perl" OFF)
 THRIFT_ADD_COMPILER(php     "Enable compiler for PHP" OFF)
 THRIFT_ADD_COMPILER(py      "Enable compiler for Python 2.0" OFF)
diff --git a/compiler/cpp/tests/README.md b/compiler/cpp/tests/README.md
index 91c0625..e45e298 100644
--- a/compiler/cpp/tests/README.md
+++ b/compiler/cpp/tests/README.md
@@ -26,7 +26,7 @@
 ## How to add your tests
 
 - Open **CMakeLists.txt**
-- Set **On** to call of **THRIFT_ADD_COMPILER** for your language
+- Set call of `THRIFT_ADD_COMPILER` for your language to `ON`
 
 ``` cmake 
 THRIFT_ADD_COMPILER(netstd "Enable compiler for .NET Standard" ON)
@@ -85,4 +85,4 @@
 cmake ..
 cmake --build .
 ctest -C Debug -V
-```
\ No newline at end of file
+```
diff --git a/compiler/cpp/tests/ocaml/README.md b/compiler/cpp/tests/ocaml/README.md
new file mode 100644
index 0000000..e79a887
--- /dev/null
+++ b/compiler/cpp/tests/ocaml/README.md
@@ -0,0 +1,16 @@
+## Testing approach
+
+1. Programmatically construct parsed instances of Thrift IDLs using internal
+   types
+2. Generate the OCaml output using the OCaml generator
+3. Capture the generated output in `ostringstream`
+4. Query and compare the outputs in the strings to stored snapshots in the
+   `snapshot_*.cc` files
+
+Run tests in `../tests` directory:
+
+      # Only on changing build definition:
+      cmake -DCMAKE_PREFIX_PATH=/usr/local/opt/bison -DCMAKE_CXX_STANDARD=11 .
+
+      # On each iteration:
+      rm -rf gen-ocaml; cmake --build . && ctest --output-on-failure
diff --git a/compiler/cpp/tests/ocaml/snapshot_exception_types_i.cc b/compiler/cpp/tests/ocaml/snapshot_exception_types_i.cc
new file mode 100644
index 0000000..a7d908f
--- /dev/null
+++ b/compiler/cpp/tests/ocaml/snapshot_exception_types_i.cc
@@ -0,0 +1,27 @@
+// Licensed to the Apache Software Foundation(ASF) under one
+// or more contributor license agreements.See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+const char* snapshot(R"""(
+open Thrift
+class serverError :
+object ('a)
+  method copy : 'a
+  method write : Protocol.t -> unit
+end
+exception ServerError of serverError
+val read_serverError : Protocol.t -> serverError
+)""");
diff --git a/compiler/cpp/tests/ocaml/snapshot_service_handle_ex.cc b/compiler/cpp/tests/ocaml/snapshot_service_handle_ex.cc
new file mode 100644
index 0000000..f20d698
--- /dev/null
+++ b/compiler/cpp/tests/ocaml/snapshot_service_handle_ex.cc
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation(ASF) under one
+// or more contributor license agreements.See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+const char* snapshot(R"""(
+open Thrift
+open Service_types
+
+(* HELPER FUNCTIONS AND STRUCTURES *)
+
+class ping_args =
+object (self)
+  method copy =
+      let _new = Oo.copy self in
+    _new
+  method write (oprot : Protocol.t) =
+    oprot#writeStructBegin "ping_args";
+    oprot#writeFieldStop;
+    oprot#writeStructEnd
+end
+let rec read_ping_args (iprot : Protocol.t) =
+  let _str2 = new ping_args in
+    ignore(iprot#readStructBegin);
+    (try while true do
+        let (_,_t3,_id4) = iprot#readFieldBegin in
+        if _t3 = Protocol.T_STOP then
+          raise Break
+        else ();
+        (match _id4 with 
+          | _ -> iprot#skip _t3);
+        iprot#readFieldEnd;
+      done; ()
+    with Break -> ());
+    iprot#readStructEnd;
+    _str2
+
+class ping_result =
+object (self)
+  val mutable _serverError : Errors_types.serverError option = None
+  method get_serverError = _serverError
+  method grab_serverError = match _serverError with None->raise (Field_empty "ping_result.serverError") | Some _x5 -> _x5
+  method set_serverError _x5 = _serverError <- Some _x5
+  method unset_serverError = _serverError <- None
+  method reset_serverError = _serverError <- None
+
+  method copy =
+      let _new = Oo.copy self in
+      if _serverError <> None then
+        _new#set_serverError self#grab_serverError#copy;
+    _new
+  method write (oprot : Protocol.t) =
+    oprot#writeStructBegin "ping_result";
+    (match _serverError with None -> () | Some _v -> 
+      oprot#writeFieldBegin("serverError",Protocol.T_STRUCT,1);
+      _v#write(oprot);
+      oprot#writeFieldEnd
+    );
+    oprot#writeFieldStop;
+    oprot#writeStructEnd
+end
+let rec read_ping_result (iprot : Protocol.t) =
+  let _str8 = new ping_result in
+    ignore(iprot#readStructBegin);
+    (try while true do
+        let (_,_t9,_id10) = iprot#readFieldBegin in
+        if _t9 = Protocol.T_STOP then
+          raise Break
+        else ();
+        (match _id10 with 
+          | 1 -> (if _t9 = Protocol.T_STRUCT then
+              _str8#set_serverError (Errors_types.read_serverError iprot)
+            else
+              iprot#skip _t9)
+          | _ -> iprot#skip _t9);
+        iprot#readFieldEnd;
+      done; ()
+    with Break -> ());
+    iprot#readStructEnd;
+    _str8
+
+class virtual iface =
+object (self)
+  method virtual ping : unit
+end
+
+class client (iprot : Protocol.t) (oprot : Protocol.t) =
+object (self)
+  val mutable seqid = 0
+  method ping  = 
+    self#send_ping;
+    self#recv_ping
+  method private send_ping  = 
+    oprot#writeMessageBegin ("ping", Protocol.CALL, seqid);
+    let args = new ping_args in
+      args#write oprot;
+      oprot#writeMessageEnd;
+      oprot#getTransport#flush
+  method private recv_ping  =
+    let (fname, mtype, rseqid) = iprot#readMessageBegin in
+      (if mtype = Protocol.EXCEPTION then
+        let x = Application_Exn.read iprot in
+          (iprot#readMessageEnd;           raise (Application_Exn.E x))
+      else ());
+      let result = read_ping_result iprot in
+        iprot#readMessageEnd;
+        (match result#get_serverError with None -> () | Some _v ->
+          raise (Errors_types.ServerError _v));
+        ()
+end
+
+class processor (handler : iface) =
+object (self)
+  inherit Processor.t
+
+  val processMap = Hashtbl.create 1
+  method process iprot oprot =
+    let (name, typ, seqid)  = iprot#readMessageBegin in
+      if Hashtbl.mem processMap name then
+        (Hashtbl.find processMap name) (seqid, iprot, oprot)
+      else (
+        iprot#skip(Protocol.T_STRUCT);
+        iprot#readMessageEnd;
+        let x = Application_Exn.create Application_Exn.UNKNOWN_METHOD ("Unknown function "^name) in
+          oprot#writeMessageBegin(name, Protocol.EXCEPTION, seqid);
+          x#write oprot;
+          oprot#writeMessageEnd;
+          oprot#getTransport#flush
+      );
+      true
+  method private process_ping (seqid, iprot, oprot) =
+    let _ = read_ping_args iprot in
+      iprot#readMessageEnd;
+      let result = new ping_result in
+        (try
+          (handler#ping);
+        with
+          | Errors_types.ServerError serverError -> 
+              result#set_serverError serverError
+        );
+        oprot#writeMessageBegin ("ping", Protocol.REPLY, seqid);
+        result#write oprot;
+        oprot#writeMessageEnd;
+        oprot#getTransport#flush
+  initializer
+    Hashtbl.add processMap "ping" self#process_ping;
+end
+
+)""");
diff --git a/compiler/cpp/tests/ocaml/snapshot_typedefs.cc b/compiler/cpp/tests/ocaml/snapshot_typedefs.cc
new file mode 100644
index 0000000..473b2e8
--- /dev/null
+++ b/compiler/cpp/tests/ocaml/snapshot_typedefs.cc
@@ -0,0 +1,22 @@
+// Licensed to the Apache Software Foundation(ASF) under one
+// or more contributor license agreements.See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+const char* snapshot(R"""(
+open Thrift
+type decimal = string
+
+)""");
diff --git a/compiler/cpp/tests/ocaml/t_ocaml_generator_tests.cc b/compiler/cpp/tests/ocaml/t_ocaml_generator_tests.cc
new file mode 100644
index 0000000..ea788fc
--- /dev/null
+++ b/compiler/cpp/tests/ocaml/t_ocaml_generator_tests.cc
@@ -0,0 +1,111 @@
+// Licensed to the Apache Software Foundation(ASF) under one
+// or more contributor license agreements.See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include <list>
+#include <memory>
+#include "../catch/catch.hpp"
+#include <thrift/generate/t_ocaml_generator.cc>
+
+using defs_t = std::list<t_type*>;
+
+/** This subclass is meant to allow accessing the Thrift generated OCaml outputs
+    and keep the tests stable across Thrift versions (as much as possible), but
+    otherwise is identical to the standard OCaml generator. */
+class t_test_ocaml_generator : public t_ocaml_generator {
+public:
+    t_test_ocaml_generator(t_program* program) : t_ocaml_generator(program, {}, "") {}
+
+    /** Override and turn off comment generation which contains a version number
+        to make tests version-independent. */
+    std::string ocaml_autogen_comment() override { return ""; }
+
+    // Allow inspecting the generated code.
+
+    string types() { return f_types_.str(); }
+    string consts() { return f_consts_.str(); }
+    string service() { return f_service_.str(); }
+    string types_i() { return f_types_i_.str(); }
+    string service_i() { return f_service_i_.str(); }
+};
+
+/** Helper to add a list of definitions to a Thrift 'program' (i.e.
+    representation of the IDL) and generate the OCaml outputs. */
+void gen_program(t_generator& gen, t_program& program, defs_t defs) {
+    for (auto def : defs) {
+        if (def->is_typedef()) program.add_typedef(static_cast<t_typedef*>(def));
+        else if (def->is_enum()) program.add_enum(static_cast<t_enum*>(def));
+        else if (def->is_struct()) program.add_struct(static_cast<t_struct*>(def));
+        else if (def->is_xception()) program.add_xception(static_cast<t_struct*>(def));
+        else if (def->is_service()) program.add_service(static_cast<t_service*>(def));
+    }
+
+    gen.generate_program();
+}
+
+TEST_CASE( "t_ocaml_generator - typedefs", "[functional]" )
+{
+    t_program program("Typedefs.thrift", "Typedefs");
+    t_base_type ty_string("string", t_base_type::TYPE_STRING);
+    t_typedef tydef_decimal(&program, &ty_string, "Decimal");
+    t_test_ocaml_generator gen(&program);
+
+    gen_program(gen, program, defs_t {
+        &tydef_decimal
+    });
+
+    #include "snapshot_typedefs.cc"
+    REQUIRE( snapshot == gen.types() );
+}
+
+TEST_CASE( "t_ocaml_generator - handle exception from different module", "[functional]" )
+{
+    t_program errors_thrift("Errors.thrift", "Errors");
+    t_struct server_error(&errors_thrift, "ServerError");
+    server_error.set_xception(true);
+
+    t_test_ocaml_generator errors_gen(&errors_thrift);
+    gen_program(errors_gen, errors_thrift, defs_t {
+        &server_error
+    });
+
+    {
+        #include "snapshot_exception_types_i.cc"
+        REQUIRE( snapshot == errors_gen.types_i() );
+    }
+
+    t_program service_thrift("Service.thrift", "Service");
+    t_service service(&service_thrift);
+    service.set_name("Service");
+    t_base_type ret_type("void", t_base_type::TYPE_VOID);
+    t_struct args(&service_thrift, "ping_args");
+    t_struct throws(&service_thrift, "ping_throws");
+    t_field ex_server_error(&server_error, "serverError", 1);
+    throws.append(&ex_server_error);
+    t_function ping(&ret_type, "ping", &args, &throws);
+    service.add_function(&ping);
+
+    t_test_ocaml_generator service_gen(&service_thrift);
+
+    gen_program(service_gen, service_thrift, defs_t {
+        &service
+    });
+
+    {
+        #include "snapshot_service_handle_ex.cc"
+        REQUIRE( snapshot == service_gen.service() );
+    }
+}