Thrift: OCaml library and generator
Summary: Added (minimal) library and code generator for OCaml.
Reviewed by: mcslee
Test plan: Test client and server (included).
Revert plan: yes
git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@665163 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/lib/ocaml/src/Makefile b/lib/ocaml/src/Makefile
new file mode 100644
index 0000000..0b989ce
--- /dev/null
+++ b/lib/ocaml/src/Makefile
@@ -0,0 +1,6 @@
+SOURCES = Thrift.ml TBinaryProtocol.ml TSocket.ml TChannelTransport.ml TServer.ml TSimpleServer.ml
+RESULT = thrift
+LIBS = unix
+all: native-code-library byte-code-library top
+OCAMLMAKEFILE = ../OCamlMakefile
+include $(OCAMLMAKEFILE)
diff --git a/lib/ocaml/src/TBinaryProtocol.ml b/lib/ocaml/src/TBinaryProtocol.ml
new file mode 100644
index 0000000..44433d6
--- /dev/null
+++ b/lib/ocaml/src/TBinaryProtocol.ml
@@ -0,0 +1,145 @@
+open Thrift
+
+module P = Protocol
+
+let get_byte i b = 255 land (i lsr (8*b))
+let get_byte64 i b = 255 land (Int64.to_int (Int64.shift_right i (8*b)))
+
+
+let tv = P.t_type_to_i
+let vt = P.t_type_of_i
+
+
+let comp_int b n =
+ let s = ref 0 in
+ let sb = Sys.word_size - 8*n in
+ for i=0 to (n-1) do
+ s:=!s lor ((int_of_char b.[i]) lsl (8*(n-1-i)))
+ done;
+ s:=(!s lsl sb) asr sb;
+ !s
+
+let comp_int64 b n =
+ let s = ref 0L in
+ for i=0 to (n-1) do
+ s:=Int64.logor !s (Int64.shift_left (Int64.of_int (int_of_char b.[i])) (8*(n-1-i)))
+ done;
+ !s
+
+class t trans =
+object (self)
+ inherit P.t trans
+ val ibyte = String.create 8
+ method writeBool b =
+ ibyte.[0] <- char_of_int (if b then 1 else 0);
+ trans#write ibyte 0 1
+ method writeByte i =
+ ibyte.[0] <- char_of_int (get_byte i 0);
+ trans#write ibyte 0 1
+ method writeI16 i =
+ let gb = get_byte i in
+ ibyte.[1] <- char_of_int (gb 0);
+ ibyte.[0] <- char_of_int (gb 1);
+ trans#write ibyte 0 2
+ method writeI32 i =
+ let gb = get_byte i in
+ for i=0 to 3 do
+ ibyte.[3-i] <- char_of_int (gb i)
+ done;
+ trans#write ibyte 0 4
+ method writeI64 i=
+ let gb = get_byte64 i in
+ for i=0 to 7 do
+ ibyte.[7-i] <- char_of_int (gb i)
+ done;
+ trans#write ibyte 0 8
+ method writeDouble d =
+ self#writeI64 (Int64.bits_of_float d)
+ method writeString s=
+ let n = String.length s in
+ self#writeI32(n);
+ trans#write s 0 n
+ method writeBinary a = self#writeString a
+ method writeMessageBegin (n,t,s) =
+ self#writeString n;
+ self#writeByte (P.message_type_to_i t);
+ self#writeI32 s
+ method writeMessageEnd = ()
+ method writeStructBegin s = ()
+ method writeStructEnd = ()
+ method writeFieldBegin (n,t,i) =
+ self#writeByte (tv t);
+ self#writeI16 i
+ method writeFieldEnd = ()
+ method writeFieldStop =
+ self#writeByte (tv (Protocol.T_STOP))
+ method writeMapBegin (k,v,s) =
+ self#writeByte (tv k);
+ self#writeByte (tv v);
+ self#writeI32 s
+ method writeMapEnd = ()
+ method writeListBegin (t,s) =
+ self#writeByte (tv t);
+ self#writeI32 s
+ method writeListEnd = ()
+ method writeSetBegin (t,s) =
+ self#writeByte (tv t);
+ self#writeI32 s
+ method writeSetEnd = ()
+ method readByte =
+ ignore (trans#readAll ibyte 0 1);
+ (comp_int ibyte 1)
+ method readI16 =
+ ignore (trans#readAll ibyte 0 2);
+ comp_int ibyte 2
+ method readI32 =
+ ignore (trans#readAll ibyte 0 4);
+ comp_int ibyte 4
+ method readI64 =
+ ignore (trans#readAll ibyte 0 8);
+ comp_int64 ibyte 8
+ method readDouble =
+ Int64.float_of_bits (self#readI64)
+ method readBool =
+ self#readByte = 1
+ method readString =
+ let sz = self#readI32 in
+ let buf = String.create sz in
+ ignore (trans#readAll buf 0 sz);
+ buf
+ method readBinary = self#readString
+ method readMessageBegin =
+ let s = self#readString in
+ let mt = P.message_type_of_i (self#readByte) in
+ (s,mt, self#readI32)
+ method readMessageEnd = ()
+ method readStructBegin =
+ ""
+ method readStructEnd = ()
+ method readFieldBegin =
+ let t = (vt (self#readByte))
+ in
+ if t != P.T_STOP then
+ ("",t,self#readI16)
+ else ("",t,0);
+ method readFieldEnd = ()
+ method readMapBegin =
+ let kt = vt (self#readByte) in
+ let vt = vt (self#readByte) in
+ (kt,vt, self#readI32)
+ method readMapEnd = ()
+ method readListBegin =
+ let t = vt (self#readByte) in
+ (t,self#readI32)
+ method readListEnd = ()
+ method readSetBegin =
+ let t = vt (self#readByte) in
+ (t, self#readI32);
+ method readSetEnd = ()
+end
+
+class factory =
+object
+ inherit P.factory
+ method getProtocol tr = new t tr
+end
diff --git a/lib/ocaml/src/TChannelTransport.ml b/lib/ocaml/src/TChannelTransport.ml
new file mode 100644
index 0000000..89ae352
--- /dev/null
+++ b/lib/ocaml/src/TChannelTransport.ml
@@ -0,0 +1,16 @@
+open Thrift
+module T = Transport
+
+class t (i,o) =
+object (self)
+ inherit Transport.t
+ method isOpen = true
+ method opn = ()
+ method close = ()
+ method read buf off len =
+ try
+ really_input i buf off len; len
+ with _ -> T.raise_TTransportExn ("TChannelTransport: Could not read "^(string_of_int len)) T.UNKNOWN
+ method write buf off len = output o buf off len
+ method flush = flush o
+end
diff --git a/lib/ocaml/src/TServer.ml b/lib/ocaml/src/TServer.ml
new file mode 100644
index 0000000..d8509ff
--- /dev/null
+++ b/lib/ocaml/src/TServer.ml
@@ -0,0 +1,30 @@
+open Thrift
+
+class virtual t
+ (pf : Processor.factory)
+ (st : Transport.server_t)
+ (itf : Transport.factory)
+ (otf : Transport.factory)
+ (ipf : Protocol.factory)
+ (opf : Protocol.factory)=
+object
+ val processorFactory = pf
+ val serverTransport = st
+ val inputTransportFactory = itf
+ val outputTransportFactory = otf
+ val inputProtocolFactory = ipf
+ val outputProtocolFactory = opf
+ method virtual serve : unit
+end;;
+
+
+let run_basic_server proc port =
+ Unix.establish_server (fun inp -> fun out ->
+ let trans = new TChannelTransport.t (inp,out) in
+ let proto = new TBinaryProtocol.t (trans :> Transport.t) in
+ try
+ while proc#process proto proto do () done;
+ ()
+ with e -> ()) (Unix.ADDR_INET (Unix.inet_addr_of_string "127.0.0.1",port))
+
+
diff --git a/lib/ocaml/src/TSimpleServer.ml b/lib/ocaml/src/TSimpleServer.ml
new file mode 100644
index 0000000..1a85809
--- /dev/null
+++ b/lib/ocaml/src/TSimpleServer.ml
@@ -0,0 +1,24 @@
+open Thrift
+module S = TServer
+
+class t pf st itf otf ipf opf =
+object
+ inherit S.t pf st itf otf ipf opf
+ method serve =
+ try
+ st#listen;
+ let c = st#accept in
+ let proc = pf#getProcessor c in
+ let itrans = itf#getTransport c in
+ let otrans = try
+ otf#getTransport c
+ with e -> itrans#close; raise e
+ in
+ let inp = ipf#getProtocol itrans in
+ let op = opf#getProtocol otrans in
+ try
+ while (proc#process inp op) do () done;
+ itrans#close; otrans#close
+ with e -> itrans#close; otrans#close; raise e
+ with _ -> ()
+end
diff --git a/lib/ocaml/src/TSocket.ml b/lib/ocaml/src/TSocket.ml
new file mode 100644
index 0000000..c02f1eb
--- /dev/null
+++ b/lib/ocaml/src/TSocket.ml
@@ -0,0 +1,32 @@
+open Thrift
+
+module T = Transport
+
+class t host port=
+object (self)
+ inherit T.t
+ val mutable chans = None
+ method isOpen = chans != None
+ method opn =
+ try
+ chans <- Some(Unix.open_connection (Unix.ADDR_INET ((Unix.inet_addr_of_string host),port)))
+ with _ ->
+ T.raise_TTransportExn
+ ("Could not connect to "^host^":"^(string_of_int port))
+ T.NOT_OPEN
+ method close = match chans with None -> () | Some(inc,_) -> (Unix.shutdown_connection inc; chans <- None)
+ method read buf off len = match chans with
+ None -> T.raise_TTransportExn "Socket not open" T.NOT_OPEN
+ | Some(i,o) ->
+ try
+ really_input i buf off len; len
+ with _ -> T.raise_TTransportExn ("TSocket: Could not read "^(string_of_int len)^" from "^host^":"^(string_of_int port)) T.UNKNOWN
+ method write buf off len = match chans with
+ None -> T.raise_TTransportExn "Socket not open" T.NOT_OPEN
+ | Some(i,o) -> output o buf off len
+ method flush = match chans with
+ None -> T.raise_TTransportExn "Socket not open" T.NOT_OPEN
+ | Some(i,o) -> flush o
+end
+
+
diff --git a/lib/ocaml/src/Thrift.ml b/lib/ocaml/src/Thrift.ml
new file mode 100644
index 0000000..224febb
--- /dev/null
+++ b/lib/ocaml/src/Thrift.ml
@@ -0,0 +1,357 @@
+exception Break;;
+exception Thrift_error;;
+exception Field_empty of string;;
+
+class t_exn =
+object
+ val mutable message = ""
+ method get_message = message
+ method set_message s = message <- s
+end;;
+
+exception TExn of t_exn;;
+
+
+
+
+module Transport =
+struct
+ type exn_type =
+ | UNKNOWN
+ | NOT_OPEN
+ | ALREADY_OPEN
+ | TIMED_OUT
+ | END_OF_FILE;;
+
+ class exn =
+ object
+ inherit t_exn
+ val mutable typ = UNKNOWN
+ method get_type = typ
+ method set_type t = typ <- t
+ end
+ exception TTransportExn of exn
+ let raise_TTransportExn message typ =
+ let e = new exn in
+ e#set_message message;
+ e#set_type typ;
+ raise (TTransportExn e)
+
+ class virtual t =
+ object (self)
+ method virtual isOpen : bool
+ method virtual opn : unit
+ method virtual close : unit
+ method virtual read : string -> int -> int -> int
+ method readAll buf off len =
+ let got = ref 0 in
+ let ret = ref 0 in
+ while !got < len do
+ ret := self#read buf (off+(!got)) (len - (!got));
+ if !ret <= 0 then
+ let e = new exn in
+ e#set_message "Cannot read. Remote side has closed.";
+ raise (TTransportExn e)
+ else ();
+ got := !got + !ret
+ done;
+ !got
+ method virtual write : string -> int -> int -> unit
+ method virtual flush : unit
+ end
+
+ class factory =
+ object
+ method getTransport (t : t) = t
+ end
+
+ class virtual server_t =
+ object (self)
+ method virtual listen : unit
+ method accept = self#acceptImpl
+ method virtual close : unit
+ method virtual acceptImpl : t
+ end
+
+end;;
+
+
+
+module Protocol =
+struct
+ type t_type =
+ | T_STOP
+ | T_VOID
+ | T_BOOL
+ | T_BYTE
+ | T_I08
+ | T_I16
+ | T_I32
+ | T_U64
+ | T_I64
+ | T_DOUBLE
+ | T_STRING
+ | T_UTF7
+ | T_STRUCT
+ | T_MAP
+ | T_SET
+ | T_LIST
+ | T_UTF8
+ | T_UTF16
+
+ let t_type_to_i = function
+ T_STOP -> 0
+ | T_VOID -> 1
+ | T_BOOL -> 2
+ | T_BYTE -> 3
+ | T_I08 -> 3
+ | T_I16 -> 6
+ | T_I32 -> 8
+ | T_U64 -> 9
+ | T_I64 -> 10
+ | T_DOUBLE -> 4
+ | T_STRING -> 11
+ | T_UTF7 -> 11
+ | T_STRUCT -> 12
+ | T_MAP -> 13
+ | T_SET -> 14
+ | T_LIST -> 15
+ | T_UTF8 -> 16
+ | T_UTF16 -> 17
+
+ let t_type_of_i = function
+ 0 -> T_STOP
+ | 1 -> T_VOID
+ | 2 -> T_BOOL
+ | 3 -> T_BYTE
+ | 6-> T_I16
+ | 8 -> T_I32
+ | 9 -> T_U64
+ | 10 -> T_I64
+ | 4 -> T_DOUBLE
+ | 11 -> T_STRING
+ | 12 -> T_STRUCT
+ | 13 -> T_MAP
+ | 14 -> T_SET
+ | 15 -> T_LIST
+ | 16 -> T_UTF8
+ | 17 -> T_UTF16
+ | _ -> raise Thrift_error
+
+ type message_type =
+ | CALL
+ | REPLY
+ | EXCEPTION
+
+ let message_type_to_i = function
+ | CALL -> 1
+ | REPLY -> 2
+ | EXCEPTION -> 3
+
+ let message_type_of_i = function
+ | 1 -> CALL
+ | 2 -> REPLY
+ | 3 -> EXCEPTION
+ | _ -> raise Thrift_error
+
+ class virtual t (trans: Transport.t) =
+ object (self)
+ val mutable trans_ = trans
+ method getTransport = trans_
+ (* writing methods *)
+ method virtual writeMessageBegin : string * message_type * int -> unit
+ method virtual writeMessageEnd : unit
+ method virtual writeStructBegin : string -> unit
+ method virtual writeStructEnd : unit
+ method virtual writeFieldBegin : string * t_type * int -> unit
+ method virtual writeFieldEnd : unit
+ method virtual writeFieldStop : unit
+ method virtual writeMapBegin : t_type * t_type * int -> unit
+ method virtual writeMapEnd : unit
+ method virtual writeListBegin : t_type * int -> unit
+ method virtual writeListEnd : unit
+ method virtual writeSetBegin : t_type * int -> unit
+ method virtual writeSetEnd : unit
+ method virtual writeBool : bool -> unit
+ method virtual writeByte : int -> unit
+ method virtual writeI16 : int -> unit
+ method virtual writeI32 : int -> unit
+ method virtual writeI64 : Int64.t -> unit
+ method virtual writeDouble : float -> unit
+ method virtual writeString : string -> unit
+ method virtual writeBinary : string -> unit
+ (* reading methods *)
+ method virtual readMessageBegin : string * message_type * int
+ method virtual readMessageEnd : unit
+ method virtual readStructBegin : string
+ method virtual readStructEnd : unit
+ method virtual readFieldBegin : string * t_type * int
+ method virtual readFieldEnd : unit
+ method virtual readMapBegin : t_type * t_type * int
+ method virtual readMapEnd : unit
+ method virtual readListBegin : t_type * int
+ method virtual readListEnd : unit
+ method virtual readSetBegin : t_type * int
+ method virtual readSetEnd : unit
+ method virtual readBool : bool
+ method virtual readByte : int
+ method virtual readI16 : int
+ method virtual readI32: int
+ method virtual readI64 : Int64.t
+ method virtual readDouble : float
+ method virtual readString : string
+ method virtual readBinary : string
+ (* skippage *)
+ method skip typ =
+ match typ with
+ | T_STOP -> ()
+ | T_VOID -> ()
+ | T_BOOL -> ignore self#readBool
+ | T_BYTE
+ | T_I08 -> ignore self#readByte
+ | T_I16 -> ignore self#readI16
+ | T_I32 -> ignore self#readI32
+ | T_U64
+ | T_I64 -> ignore self#readI64
+ | T_DOUBLE -> ignore self#readDouble
+ | T_STRING -> ignore self#readString
+ | T_UTF7 -> ()
+ | T_STRUCT -> ignore ((ignore self#readStructBegin);
+ (try
+ while true do
+ let (_,t,_) = self#readFieldBegin in
+ if t = T_STOP then
+ raise Break
+ else
+ (self#skip t;
+ self#readFieldEnd)
+ done
+ with Break -> ());
+ self#readStructEnd)
+ | T_MAP -> ignore (let (k,v,s) = self#readMapBegin in
+ for i=0 to s do
+ self#skip k;
+ self#skip v;
+ done;
+ self#readMapEnd)
+ | T_SET -> ignore (let (t,s) = self#readSetBegin in
+ for i=0 to s do
+ self#skip t
+ done;
+ self#readSetEnd)
+ | T_LIST -> ignore (let (t,s) = self#readListBegin in
+ for i=0 to s do
+ self#skip t
+ done;
+ self#readListEnd)
+ | T_UTF8 -> ()
+ | T_UTF16 -> ()
+ end
+
+ class virtual factory =
+ object
+ method virtual getProtocol : Transport.t -> t
+ end
+
+end;;
+
+
+module Processor =
+struct
+ class virtual t =
+ object
+ method virtual process : Protocol.t -> Protocol.t -> bool
+ end;;
+
+ class factory (processor : t) =
+ object
+ val processor_ = processor
+ method getProcessor (trans : Transport.t) = processor_
+ end;;
+end
+
+
+
+module Application_Exn =
+struct
+ type typ=
+ | UNKNOWN
+ | UNKNOWN_METHOD
+ | INVALID_MESSAGE_TYPE
+ | WRONG_METHOD_NAME
+ | BAD_SEQUENCE_ID
+ | MISSING_RESULT
+
+ let typ_of_i = function
+ 0 -> UNKNOWN
+ | 1 -> UNKNOWN_METHOD
+ | 2 -> INVALID_MESSAGE_TYPE
+ | 3 -> WRONG_METHOD_NAME
+ | 4 -> BAD_SEQUENCE_ID
+ | 5 -> MISSING_RESULT
+ | _ -> raise Thrift_error;;
+ let typ_to_i = function
+ | UNKNOWN -> 0
+ | UNKNOWN_METHOD -> 1
+ | INVALID_MESSAGE_TYPE -> 2
+ | WRONG_METHOD_NAME -> 3
+ | BAD_SEQUENCE_ID -> 4
+ | MISSING_RESULT -> 5
+
+ class t =
+ object (self)
+ inherit t_exn
+ val mutable typ = UNKNOWN
+ method get_type = typ
+ method set_type t = typ <- t
+ method write (oprot : Protocol.t) =
+ oprot#writeStructBegin "TApplicationExeception";
+ if self#get_message != "" then
+ (oprot#writeFieldBegin ("message",Protocol.T_STRING, 1);
+ oprot#writeString self#get_message;
+ oprot#writeFieldEnd)
+ else ();
+ oprot#writeFieldBegin ("type",Protocol.T_I32,2);
+ oprot#writeI32 (typ_to_i typ);
+ oprot#writeFieldEnd;
+ oprot#writeFieldStop;
+ oprot#writeStructEnd
+ end;;
+
+ let create typ msg =
+ let e = new t in
+ e#set_type typ;
+ e#set_message msg;
+ e
+
+ let read (iprot : Protocol.t) =
+ let msg = ref "" in
+ let typ = ref 0 in
+ iprot#readStructBegin;
+ (try
+ while true do
+ let (name,ft,id) =iprot#readFieldBegin in
+ if ft = Protocol.T_STOP then
+ raise Break
+ else ();
+ (match id with
+ | 1 -> (if ft = Protocol.T_STRING then
+ msg := (iprot#readString)
+ else
+ iprot#skip ft)
+ | 2 -> (if ft = Protocol.T_I32 then
+ typ := iprot#readI32
+ else
+ iprot#skip ft)
+ | _ -> iprot#skip ft);
+ iprot#readFieldEnd
+ done
+ with Break -> ());
+ iprot#readStructEnd;
+ let e = new t in
+ e#set_type (typ_of_i !typ);
+ e#set_message !msg;
+ e;;
+
+ exception E of t
+end;;