Thrift: OCaml library and generator

Summary: Added (minimal) library and code generator for OCaml.
Reviewed by: mcslee
Test plan: Test client and server (included).
Revert plan: yes


git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665163 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/ocaml/src/Makefile b/lib/ocaml/src/Makefile
new file mode 100644
index 0000000..0b989ce
--- /dev/null
+++ b/lib/ocaml/src/Makefile
@@ -0,0 +1,6 @@
+SOURCES = Thrift.ml TBinaryProtocol.ml TSocket.ml TChannelTransport.ml TServer.ml TSimpleServer.ml
+RESULT = thrift
+LIBS = unix
+all: native-code-library byte-code-library top
+OCAMLMAKEFILE = ../OCamlMakefile
+include $(OCAMLMAKEFILE)
diff --git a/lib/ocaml/src/TBinaryProtocol.ml b/lib/ocaml/src/TBinaryProtocol.ml
new file mode 100644
index 0000000..44433d6
--- /dev/null
+++ b/lib/ocaml/src/TBinaryProtocol.ml
@@ -0,0 +1,145 @@
+open Thrift
+
+module P = Protocol
+
+let get_byte i b = 255 land (i lsr (8*b))
+let get_byte64 i b = 255 land (Int64.to_int (Int64.shift_right i (8*b)))
+
+
+let tv = P.t_type_to_i
+let vt = P.t_type_of_i
+
+
+let comp_int b n = 
+  let s = ref 0 in
+  let sb = Sys.word_size - 8*n in
+    for i=0 to (n-1) do
+      s:=!s lor ((int_of_char b.[i]) lsl (8*(n-1-i)))
+    done;
+    s:=(!s lsl sb) asr sb;
+    !s
+
+let comp_int64 b n =
+  let s = ref 0L in
+    for i=0 to (n-1) do
+      s:=Int64.logor !s (Int64.shift_left (Int64.of_int (int_of_char b.[i])) (8*(n-1-i)))
+    done;
+    !s
+
+class t trans =
+object (self)
+  inherit P.t trans
+  val ibyte = String.create 8
+  method writeBool b = 
+    ibyte.[0] <- char_of_int (if b then 1 else 0);
+    trans#write ibyte 0 1
+  method writeByte i =
+    ibyte.[0] <- char_of_int (get_byte i 0);
+    trans#write ibyte 0 1
+  method writeI16 i =
+    let gb = get_byte i in
+      ibyte.[1] <- char_of_int (gb 0);
+      ibyte.[0] <- char_of_int (gb 1);
+      trans#write ibyte 0 2
+  method writeI32 i =
+    let gb = get_byte i in
+      for i=0 to 3 do
+        ibyte.[3-i] <- char_of_int (gb i)
+      done;
+      trans#write ibyte 0 4
+  method writeI64 i=
+    let gb = get_byte64 i in
+      for i=0 to 7 do
+        ibyte.[7-i] <- char_of_int (gb i)
+      done;
+      trans#write ibyte 0 8
+  method writeDouble d =
+    self#writeI64 (Int64.bits_of_float d)
+  method writeString s=
+    let n = String.length s in
+      self#writeI32(n);
+      trans#write s 0 n
+  method writeBinary a = self#writeString a
+  method writeMessageBegin (n,t,s) =
+    self#writeString n;
+    self#writeByte (P.message_type_to_i t);
+    self#writeI32 s
+  method writeMessageEnd = ()
+  method writeStructBegin s = ()
+  method writeStructEnd = ()
+  method writeFieldBegin (n,t,i) =
+    self#writeByte (tv t);
+    self#writeI16 i
+  method writeFieldEnd = ()
+  method writeFieldStop =
+    self#writeByte (tv (Protocol.T_STOP))
+  method writeMapBegin (k,v,s) =
+    self#writeByte (tv k);
+    self#writeByte (tv v);
+    self#writeI32 s
+  method writeMapEnd = ()
+  method writeListBegin (t,s) =
+    self#writeByte (tv t);
+    self#writeI32 s
+  method writeListEnd = ()
+  method writeSetBegin (t,s) =
+    self#writeByte (tv t);
+    self#writeI32 s
+  method writeSetEnd = ()
+  method readByte = 
+    ignore (trans#readAll ibyte 0 1);
+    (comp_int ibyte 1)
+  method readI16 =
+    ignore (trans#readAll ibyte 0 2);
+    comp_int ibyte 2
+  method readI32 =
+    ignore (trans#readAll ibyte 0 4);
+    comp_int ibyte 4
+  method readI64 =
+    ignore (trans#readAll ibyte 0 8);
+    comp_int64 ibyte 8
+  method readDouble =
+    Int64.float_of_bits (self#readI64)
+  method readBool =
+    self#readByte = 1
+  method readString =
+    let sz = self#readI32 in
+    let buf = String.create sz in
+      ignore (trans#readAll buf 0 sz);
+      buf
+  method readBinary = self#readString
+  method readMessageBegin =
+    let s = self#readString in
+    let mt = P.message_type_of_i (self#readByte) in
+      (s,mt, self#readI32)
+  method readMessageEnd = ()
+  method readStructBegin =
+    ""
+  method readStructEnd = ()
+  method readFieldBegin =
+    let t = (vt (self#readByte)) 
+    in
+      if t != P.T_STOP then
+        ("",t,self#readI16)
+      else ("",t,0);
+  method readFieldEnd = ()
+  method readMapBegin =
+    let kt = vt (self#readByte) in
+    let vt = vt (self#readByte) in
+      (kt,vt, self#readI32)
+  method readMapEnd = ()
+  method readListBegin =
+    let t = vt (self#readByte) in
+    (t,self#readI32)
+  method readListEnd = ()
+  method readSetBegin =
+    let t = vt (self#readByte) in
+    (t, self#readI32);
+  method readSetEnd = ()
+end
+
+class factory =
+object
+  inherit P.factory
+  method getProtocol tr = new t tr
+end
diff --git a/lib/ocaml/src/TChannelTransport.ml b/lib/ocaml/src/TChannelTransport.ml
new file mode 100644
index 0000000..89ae352
--- /dev/null
+++ b/lib/ocaml/src/TChannelTransport.ml
@@ -0,0 +1,16 @@
+open Thrift
+module T = Transport
+
+class t (i,o) =
+object (self)
+  inherit Transport.t
+  method isOpen = true
+  method opn = ()
+  method close = ()
+  method read buf off len = 
+    try 
+      really_input i buf off len; len
+    with _ -> T.raise_TTransportExn ("TChannelTransport: Could not read "^(string_of_int len)) T.UNKNOWN
+  method write buf off len = output o buf off len
+  method flush = flush o
+end
diff --git a/lib/ocaml/src/TServer.ml b/lib/ocaml/src/TServer.ml
new file mode 100644
index 0000000..d8509ff
--- /dev/null
+++ b/lib/ocaml/src/TServer.ml
@@ -0,0 +1,30 @@
+open Thrift
+
+class virtual t
+    (pf : Processor.factory) 
+    (st : Transport.server_t)
+    (itf : Transport.factory)
+    (otf : Transport.factory)
+    (ipf : Protocol.factory)
+    (opf : Protocol.factory)=
+object
+  val processorFactory = pf
+  val serverTransport = st
+  val inputTransportFactory = itf
+  val outputTransportFactory = otf
+  val inputProtocolFactory = ipf
+  val outputProtocolFactory = opf
+  method virtual serve : unit
+end;;
+
+
+let run_basic_server proc port =
+  Unix.establish_server (fun inp -> fun out ->
+                           let trans = new TChannelTransport.t (inp,out) in
+                           let proto = new TBinaryProtocol.t (trans :> Transport.t) in
+                             try
+                               while proc#process proto proto do () done;
+                               ()
+                             with e -> ()) (Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1",port))
+
+
diff --git a/lib/ocaml/src/TSimpleServer.ml b/lib/ocaml/src/TSimpleServer.ml
new file mode 100644
index 0000000..1a85809
--- /dev/null
+++ b/lib/ocaml/src/TSimpleServer.ml
@@ -0,0 +1,24 @@
+open Thrift
+module S = TServer
+
+class t pf st itf otf ipf opf =
+object
+  inherit S.t pf st itf otf ipf opf
+  method serve =
+    try
+      st#listen;
+      let c = st#accept in
+      let proc = pf#getProcessor c in
+      let itrans = itf#getTransport c in
+      let otrans = try
+          otf#getTransport c
+        with e -> itrans#close; raise e
+      in
+      let inp = ipf#getProtocol itrans in
+      let op = opf#getProtocol otrans in
+        try
+          while (proc#process inp op) do () done;
+          itrans#close; otrans#close
+        with e -> itrans#close; otrans#close; raise e
+    with _ -> ()
+end
diff --git a/lib/ocaml/src/TSocket.ml b/lib/ocaml/src/TSocket.ml
new file mode 100644
index 0000000..c02f1eb
--- /dev/null
+++ b/lib/ocaml/src/TSocket.ml
@@ -0,0 +1,32 @@
+open Thrift
+
+module T = Transport
+
+class t host port=
+object (self)
+  inherit T.t
+  val mutable chans = None
+  method isOpen = chans != None
+  method opn = 
+    try
+      chans <- Some(Unix.open_connection (Unix.ADDR_INET ((Unix.inet_addr_of_string host),port)))
+    with _ -> 
+      T.raise_TTransportExn 
+        ("Could not connect to "^host^":"^(string_of_int port)) 
+        T.NOT_OPEN
+  method close = match chans with None -> () | Some(inc,_) -> (Unix.shutdown_connection inc; chans <- None)
+  method read buf off len = match chans with
+      None -> T.raise_TTransportExn "Socket not open" T.NOT_OPEN
+    | Some(i,o) -> 
+        try 
+          really_input i buf off len; len
+        with _ -> T.raise_TTransportExn ("TSocket: Could not read "^(string_of_int len)^" from "^host^":"^(string_of_int port)) T.UNKNOWN
+  method write buf off len = match chans with 
+      None -> T.raise_TTransportExn "Socket not open" T.NOT_OPEN
+    | Some(i,o) -> output o buf off len
+  method flush = match chans with
+      None -> T.raise_TTransportExn "Socket not open" T.NOT_OPEN
+    | Some(i,o) -> flush o
+end
+        
+    
diff --git a/lib/ocaml/src/Thrift.ml b/lib/ocaml/src/Thrift.ml
new file mode 100644
index 0000000..224febb
--- /dev/null
+++ b/lib/ocaml/src/Thrift.ml
@@ -0,0 +1,357 @@
+exception Break;;
+exception Thrift_error;;
+exception Field_empty of string;;
+
+class t_exn = 
+object
+  val mutable message = ""
+  method get_message = message
+  method set_message s = message <- s
+end;;
+
+exception TExn of t_exn;;
+
+
+
+
+module Transport =
+struct
+  type exn_type = 
+      | UNKNOWN
+      | NOT_OPEN
+      | ALREADY_OPEN
+      | TIMED_OUT
+      | END_OF_FILE;;
+
+  class exn =
+  object
+    inherit t_exn
+    val mutable typ = UNKNOWN
+      method get_type = typ
+      method set_type t = typ <- t
+  end
+  exception TTransportExn of exn
+  let raise_TTransportExn message typ =
+    let e = new exn in
+      e#set_message message;
+      e#set_type typ;
+      raise (TTransportExn e)
+
+  class virtual t =
+  object (self)
+    method virtual isOpen : bool
+    method virtual opn : unit
+    method virtual close : unit
+    method virtual read : string -> int -> int -> int
+    method readAll buf off len =
+      let got = ref 0 in
+      let ret = ref 0 in
+        while !got < len do
+          ret := self#read buf (off+(!got)) (len - (!got));
+          if !ret <= 0 then
+            let e = new exn in
+              e#set_message "Cannot read. Remote side has closed.";
+              raise (TTransportExn e)
+          else ();
+          got := !got + !ret
+        done;
+        !got
+    method virtual write : string -> int -> int -> unit
+    method virtual flush : unit
+  end
+
+  class factory =
+  object
+    method getTransport (t : t) = t
+  end
+
+  class virtual server_t =
+  object (self)
+    method virtual listen : unit
+    method accept = self#acceptImpl
+    method virtual close : unit
+    method virtual acceptImpl : t
+  end
+        
+end;;
+
+
+
+module Protocol =
+struct
+  type t_type =   
+      | T_STOP     
+      | T_VOID     
+      | T_BOOL
+      | T_BYTE
+      | T_I08 
+      | T_I16 
+      | T_I32 
+      | T_U64 
+      | T_I64 
+      | T_DOUBLE 
+      | T_STRING 
+      | T_UTF7   
+      | T_STRUCT    
+      | T_MAP       
+      | T_SET       
+      | T_LIST      
+      | T_UTF8      
+      | T_UTF16
+
+  let t_type_to_i = function
+      T_STOP       -> 0
+    | T_VOID       -> 1
+    | T_BOOL       -> 2
+    | T_BYTE       -> 3
+    | T_I08        -> 3
+    | T_I16        -> 6
+    | T_I32        -> 8
+    | T_U64        -> 9
+    | T_I64        -> 10
+    | T_DOUBLE     -> 4
+    | T_STRING     -> 11
+    | T_UTF7       -> 11
+    | T_STRUCT     -> 12
+    | T_MAP        -> 13
+    | T_SET        -> 14
+    | T_LIST       -> 15
+    | T_UTF8       -> 16
+    | T_UTF16      -> 17
+        
+  let t_type_of_i = function
+      0 -> T_STOP      
+    | 1 -> T_VOID      
+    | 2 -> T_BOOL
+    | 3 ->  T_BYTE
+    | 6-> T_I16       
+    | 8 -> T_I32      
+    | 9 -> T_U64      
+    | 10 -> T_I64     
+    | 4 -> T_DOUBLE   
+    | 11 -> T_STRING
+    | 12 -> T_STRUCT
+    | 13 -> T_MAP   
+    | 14 -> T_SET   
+    | 15 -> T_LIST  
+    | 16 -> T_UTF8  
+    | 17 -> T_UTF16
+    | _ -> raise Thrift_error 
+
+  type message_type =
+    | CALL
+    | REPLY
+    | EXCEPTION
+
+  let message_type_to_i = function
+    | CALL -> 1
+    | REPLY -> 2
+    | EXCEPTION -> 3
+
+  let message_type_of_i = function 
+    | 1 -> CALL
+    | 2 -> REPLY
+    | 3 -> EXCEPTION
+    | _ -> raise Thrift_error
+
+  class virtual t (trans: Transport.t) =
+  object (self)
+    val mutable trans_ = trans
+    method getTransport = trans_
+      (* writing methods *)
+    method virtual writeMessageBegin : string * message_type * int -> unit
+    method virtual writeMessageEnd : unit
+    method virtual writeStructBegin : string -> unit
+    method virtual writeStructEnd : unit
+    method virtual writeFieldBegin : string * t_type * int -> unit
+    method virtual writeFieldEnd : unit
+    method virtual writeFieldStop : unit
+    method virtual writeMapBegin : t_type * t_type * int -> unit
+    method virtual writeMapEnd : unit
+    method virtual writeListBegin : t_type * int -> unit
+    method virtual writeListEnd : unit
+    method virtual writeSetBegin : t_type * int -> unit
+    method virtual writeSetEnd : unit
+    method virtual writeBool : bool -> unit
+    method virtual writeByte : int -> unit
+    method virtual writeI16 : int -> unit
+    method virtual writeI32 : int -> unit
+    method virtual writeI64 : Int64.t -> unit
+    method virtual writeDouble : float -> unit
+    method virtual writeString : string -> unit
+    method virtual writeBinary : string -> unit
+      (* reading methods *)
+    method virtual readMessageBegin : string * message_type * int
+    method virtual readMessageEnd : unit
+    method virtual readStructBegin : string
+    method virtual readStructEnd : unit
+    method virtual readFieldBegin : string * t_type * int
+    method virtual readFieldEnd : unit
+    method virtual readMapBegin : t_type * t_type * int
+    method virtual readMapEnd : unit
+    method virtual readListBegin : t_type * int
+    method virtual readListEnd : unit
+    method virtual readSetBegin : t_type * int
+    method virtual readSetEnd : unit
+    method virtual readBool : bool
+    method virtual readByte : int
+    method virtual readI16 : int
+    method virtual readI32: int
+    method virtual readI64 : Int64.t
+    method virtual readDouble : float
+    method virtual readString : string
+    method virtual readBinary : string
+        (* skippage *)
+    method skip typ = 
+      match typ with
+        | T_STOP -> ()
+        | T_VOID -> ()
+        | T_BOOL -> ignore self#readBool
+        | T_BYTE
+        | T_I08 -> ignore self#readByte
+        | T_I16 -> ignore self#readI16
+        | T_I32 -> ignore self#readI32
+        | T_U64
+        | T_I64 -> ignore self#readI64 
+        | T_DOUBLE -> ignore self#readDouble
+        | T_STRING -> ignore self#readString
+        | T_UTF7 -> ()
+        | T_STRUCT -> ignore ((ignore self#readStructBegin);
+                              (try
+                                   while true do
+                                     let (_,t,_) = self#readFieldBegin in
+                                       if t = T_STOP then
+                                         raise Break
+                                       else 
+                                         (self#skip t;
+                                          self#readFieldEnd)
+                                   done
+                               with Break -> ());
+                              self#readStructEnd)
+        | T_MAP -> ignore (let (k,v,s) = self#readMapBegin in
+                             for i=0 to s do
+                               self#skip k;
+                               self#skip v;
+                             done;
+                             self#readMapEnd)
+        | T_SET -> ignore (let (t,s) = self#readSetBegin in
+                             for i=0 to s do
+                               self#skip t
+                             done;
+                             self#readSetEnd)
+        | T_LIST -> ignore (let (t,s) = self#readListBegin in
+                              for i=0 to s do
+                                self#skip t
+                              done;
+                              self#readListEnd)
+        | T_UTF8 -> ()
+        | T_UTF16 -> ()
+  end
+
+  class virtual factory =
+  object
+    method virtual getProtocol : Transport.t -> t
+  end
+           
+end;;   
+
+
+module Processor =
+struct
+  class virtual t =
+  object
+    method virtual process : Protocol.t -> Protocol.t -> bool
+  end;;
+  
+  class factory (processor : t) =
+  object
+    val processor_ = processor 
+    method getProcessor (trans : Transport.t) = processor_
+  end;;
+end
+
+
+
+module Application_Exn =
+struct
+  type typ=
+      | UNKNOWN
+      | UNKNOWN_METHOD
+      | INVALID_MESSAGE_TYPE
+      | WRONG_METHOD_NAME
+      | BAD_SEQUENCE_ID
+      | MISSING_RESULT
+
+  let typ_of_i = function
+      0 -> UNKNOWN
+    | 1 -> UNKNOWN_METHOD
+    | 2 -> INVALID_MESSAGE_TYPE
+    | 3 -> WRONG_METHOD_NAME
+    | 4 -> BAD_SEQUENCE_ID
+    | 5 -> MISSING_RESULT
+    | _ -> raise Thrift_error;;
+  let typ_to_i = function
+    | UNKNOWN -> 0
+    | UNKNOWN_METHOD -> 1
+    | INVALID_MESSAGE_TYPE -> 2
+    | WRONG_METHOD_NAME -> 3
+    | BAD_SEQUENCE_ID -> 4
+    | MISSING_RESULT -> 5
+
+  class t =
+  object (self)
+    inherit t_exn
+    val mutable typ = UNKNOWN
+    method get_type = typ
+    method set_type t = typ <- t
+    method write (oprot : Protocol.t) =
+      oprot#writeStructBegin "TApplicationExeception";
+      if self#get_message != "" then
+        (oprot#writeFieldBegin ("message",Protocol.T_STRING, 1);
+         oprot#writeString self#get_message;
+         oprot#writeFieldEnd)
+      else ();
+      oprot#writeFieldBegin ("type",Protocol.T_I32,2);
+      oprot#writeI32 (typ_to_i typ);
+      oprot#writeFieldEnd;
+      oprot#writeFieldStop;
+      oprot#writeStructEnd
+  end;;
+  
+  let create typ msg =
+    let e = new t in
+      e#set_type typ;
+    e#set_message msg;
+    e
+      
+  let read (iprot : Protocol.t) =
+    let msg = ref "" in
+    let typ = ref 0 in
+      iprot#readStructBegin;
+      (try 
+           while true do
+             let (name,ft,id) =iprot#readFieldBegin in
+               if ft = Protocol.T_STOP then
+                 raise Break
+               else ();
+               (match id with
+             | 1 -> (if ft = Protocol.T_STRING then
+                         msg := (iprot#readString)
+                     else
+                         iprot#skip ft)
+             | 2 -> (if ft = Protocol.T_I32 then
+                         typ := iprot#readI32
+                     else
+                         iprot#skip ft)
+             | _ -> iprot#skip ft);
+               iprot#readFieldEnd
+      done
+       with Break -> ());
+      iprot#readStructEnd;
+      let e = new t in
+        e#set_type (typ_of_i !typ);
+        e#set_message !msg;
+        e;;
+  
+  exception E of t
+end;;