THRIFT-906. hs: Improve type mappings

This patch fixes the type mappings to be more sane. It *will* break existing code, but the breakages should be well worth it.

Patch: Christian Lavoie

git-svn-id: https://svn.apache.org/repos/asf/incubator/thrift/trunk@999700 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/compiler/cpp/src/generate/t_hs_generator.cc b/compiler/cpp/src/generate/t_hs_generator.cc
index 7b178df..47d4faa 100644
--- a/compiler/cpp/src/generate/t_hs_generator.cc
+++ b/compiler/cpp/src/generate/t_hs_generator.cc
@@ -232,7 +232,7 @@
     result += "\n";
   }
 
-  result += "import Thrift\nimport Data.Typeable ( Typeable )\nimport Control.Exception\nimport qualified Data.Map as Map\nimport qualified Data.Set as Set\nimport Data.Int;\nimport Prelude ((==), String, Eq, Show, Ord, Maybe(..), (&&), (||), return, IO, Enum, fromEnum, toEnum, Bool(..), (++), ($), Double, (-))";
+  result += "import Thrift\nimport Data.Typeable ( Typeable )\nimport Control.Exception\nimport qualified Data.Map as Map\nimport qualified Data.Set as Set\nimport Data.ByteString.Lazy\nimport Data.Int\nimport Data.Word\nimport Prelude ((==), String, Eq, Show, Ord, Maybe(..), (&&), (||), return, IO, Enum, fromInteger, toInteger, fromEnum, toEnum, Bool(..), (++), ($), Double, (-), length)";
   return result;
 }
 
@@ -337,18 +337,31 @@
   if (type->is_base_type()) {
     t_base_type::t_base tbase = ((t_base_type*)type)->get_base();
     switch (tbase) {
+
     case t_base_type::TYPE_STRING:
       out << '"' << get_escaped_string(value) << '"';
       break;
+
     case t_base_type::TYPE_BOOL:
       out << (value->get_integer() > 0 ? "True" : "False");
       break;
+
     case t_base_type::TYPE_BYTE:
-    case t_base_type::TYPE_I16:
-    case t_base_type::TYPE_I32:
-    case t_base_type::TYPE_I64:
-      out << value->get_integer();
+      out << "(" << value->get_integer() << " :: Word8)";
       break;
+
+    case t_base_type::TYPE_I16:
+      out << "(" << value->get_integer() << " :: Int16)";
+      break;
+
+    case t_base_type::TYPE_I32:
+      out << "(" << value->get_integer() << " :: Int32)";
+      break;
+
+    case t_base_type::TYPE_I64:
+      out << "(" << value->get_integer() << " :: Int64)";
+      break;
+
     case t_base_type::TYPE_DOUBLE:
       if (value->get_type() == t_const_value::CV_INTEGER) {
         out << value->get_integer();
@@ -1116,7 +1129,7 @@
       throw "compiler error: cannot serialize void field in a struct";
       break;
     case t_base_type::TYPE_STRING:
-      out << "readString";
+      out << (((t_base_type*)type)->is_binary() ? "readBinary" : "readString");
       break;
     case t_base_type::TYPE_BOOL:
       out << "readBool";
@@ -1234,7 +1247,7 @@
           "compiler error: cannot serialize void field in a struct: " + name;
         break;
       case t_base_type::TYPE_STRING:
-        out << "writeString oprot " << name;
+        out << (((t_base_type*)type)->is_binary() ? "writeBinary" : "writeString") << " oprot " << name;
         break;
       case t_base_type::TYPE_BOOL:
         out << "writeBool oprot " << name;
@@ -1300,7 +1313,7 @@
     string v = tmp("_viter");
     out << "(let {f [] = return (); f ("<<v<<":t) = do {";
     generate_serialize_list_element(out, (t_list*)ttype, v);
-    out << ";f t}} in do {writeListBegin oprot ("<< type_to_enum(((t_list*)ttype)->get_elem_type())<<",length " << prefix << "); f " << prefix << ";writeListEnd oprot})";
+    out << ";f t}} in do {writeListBegin oprot ("<< type_to_enum(((t_list*)ttype)->get_elem_type())<<",fromInteger $ toInteger $ Prelude.length " << prefix << "); f " << prefix << ";writeListEnd oprot})";
   }
 
 }
@@ -1434,15 +1447,15 @@
     case t_base_type::TYPE_VOID:
       return "()";
     case t_base_type::TYPE_STRING:
-      return "String";
+      return (((t_base_type*)type)->is_binary() ? "ByteString" : "String");
     case t_base_type::TYPE_BOOL:
       return "Bool";
     case t_base_type::TYPE_BYTE:
-      return "Int";
+      return "Word8";
     case t_base_type::TYPE_I16:
-      return "Int";
+      return "Int16";
     case t_base_type::TYPE_I32:
-      return "Int";
+      return "Int32";
     case t_base_type::TYPE_I64:
       return "Int64";
     case t_base_type::TYPE_DOUBLE:
diff --git a/lib/hs/src/Thrift.hs b/lib/hs/src/Thrift.hs
index 182df3f..71957c4 100644
--- a/lib/hs/src/Thrift.hs
+++ b/lib/hs/src/Thrift.hs
@@ -82,7 +82,7 @@
         writeFieldEnd pt
 
     writeFieldBegin pt ("type", T_I32, 2);
-    writeI32 pt (fromEnum (ae_type ae))
+    writeI32 pt (fromIntegral $ fromEnum (ae_type ae))
     writeFieldEnd pt
     writeFieldStop pt
     writeStructEnd pt
@@ -107,7 +107,7 @@
                                   readAppExnFields pt rec
                  2 -> if ft == T_I32 then
                           do i <- readI32 pt
-                             readAppExnFields pt rec{ae_type = (toEnum  i)}
+                             readAppExnFields pt rec{ae_type = (toEnum $ fromIntegral i)}
                           else do skip pt ft
                                   readAppExnFields pt rec
                  _ -> do skip pt ft
diff --git a/lib/hs/src/Thrift/Protocol.hs b/lib/hs/src/Thrift/Protocol.hs
index c7c2d69..b34e806 100644
--- a/lib/hs/src/Thrift/Protocol.hs
+++ b/lib/hs/src/Thrift/Protocol.hs
@@ -29,9 +29,10 @@
 
 import Control.Monad ( replicateM_, unless )
 import Control.Exception
-
-import Data.Typeable ( Typeable )
 import Data.Int
+import Data.Typeable ( Typeable )
+import Data.Word
+import Data.ByteString.Lazy
 
 import Thrift.Transport
 
@@ -102,53 +103,53 @@
 class Protocol a where
     getTransport :: Transport t => a t -> t
 
-    writeMessageBegin :: Transport t => a t -> (String, MessageType, Int) -> IO ()
+    writeMessageBegin :: Transport t => a t -> (String, MessageType, Int32) -> IO ()
     writeMessageEnd   :: Transport t => a t -> IO ()
 
     writeStructBegin :: Transport t => a t -> String -> IO ()
     writeStructEnd   :: Transport t => a t -> IO ()
-    writeFieldBegin  :: Transport t => a t -> (String, ThriftType, Int) -> IO ()
+    writeFieldBegin  :: Transport t => a t -> (String, ThriftType, Int16) -> IO ()
     writeFieldEnd    :: Transport t => a t -> IO ()
     writeFieldStop   :: Transport t => a t -> IO ()
-    writeMapBegin    :: Transport t => a t -> (ThriftType, ThriftType, Int) -> IO ()
+    writeMapBegin    :: Transport t => a t -> (ThriftType, ThriftType, Int32) -> IO ()
     writeMapEnd      :: Transport t => a t -> IO ()
-    writeListBegin   :: Transport t => a t -> (ThriftType, Int) -> IO ()
+    writeListBegin   :: Transport t => a t -> (ThriftType, Int32) -> IO ()
     writeListEnd     :: Transport t => a t -> IO ()
-    writeSetBegin    :: Transport t => a t -> (ThriftType, Int) -> IO ()
+    writeSetBegin    :: Transport t => a t -> (ThriftType, Int32) -> IO ()
     writeSetEnd      :: Transport t => a t -> IO ()
 
     writeBool   :: Transport t => a t -> Bool -> IO ()
-    writeByte   :: Transport t => a t -> Int -> IO ()
-    writeI16    :: Transport t => a t -> Int -> IO ()
-    writeI32    :: Transport t => a t -> Int -> IO ()
+    writeByte   :: Transport t => a t -> Word8 -> IO ()
+    writeI16    :: Transport t => a t -> Int16 -> IO ()
+    writeI32    :: Transport t => a t -> Int32 -> IO ()
     writeI64    :: Transport t => a t -> Int64 -> IO ()
     writeDouble :: Transport t => a t -> Double -> IO ()
     writeString :: Transport t => a t -> String -> IO ()
-    writeBinary :: Transport t => a t -> String -> IO ()
+    writeBinary :: Transport t => a t -> ByteString -> IO ()
 
 
-    readMessageBegin :: Transport t => a t -> IO (String, MessageType, Int)
+    readMessageBegin :: Transport t => a t -> IO (String, MessageType, Int32)
     readMessageEnd   :: Transport t => a t -> IO ()
 
     readStructBegin :: Transport t => a t -> IO String
     readStructEnd   :: Transport t => a t -> IO ()
-    readFieldBegin  :: Transport t => a t -> IO (String, ThriftType, Int)
+    readFieldBegin  :: Transport t => a t -> IO (String, ThriftType, Int16)
     readFieldEnd    :: Transport t => a t -> IO ()
-    readMapBegin    :: Transport t => a t -> IO (ThriftType, ThriftType, Int)
+    readMapBegin    :: Transport t => a t -> IO (ThriftType, ThriftType, Int32)
     readMapEnd      :: Transport t => a t -> IO ()
-    readListBegin   :: Transport t => a t -> IO (ThriftType, Int)
+    readListBegin   :: Transport t => a t -> IO (ThriftType, Int32)
     readListEnd     :: Transport t => a t -> IO ()
-    readSetBegin    :: Transport t => a t -> IO (ThriftType, Int)
+    readSetBegin    :: Transport t => a t -> IO (ThriftType, Int32)
     readSetEnd      :: Transport t => a t -> IO ()
 
     readBool   :: Transport t => a t -> IO Bool
-    readByte   :: Transport t => a t -> IO Int
-    readI16    :: Transport t => a t -> IO Int
-    readI32    :: Transport t => a t -> IO Int
+    readByte   :: Transport t => a t -> IO Word8
+    readI16    :: Transport t => a t -> IO Int16
+    readI32    :: Transport t => a t -> IO Int32
     readI64    :: Transport t => a t -> IO Int64
     readDouble :: Transport t => a t -> IO Double
     readString :: Transport t => a t -> IO String
-    readBinary :: Transport t => a t -> IO String
+    readBinary :: Transport t => a t -> IO ByteString
 
 
 skip :: (Protocol p, Transport t) => p t -> ThriftType -> IO ()
@@ -165,13 +166,13 @@
                      skipFields p
                      readStructEnd p
 skip p T_MAP = do (k, v, s) <- readMapBegin p
-                  replicateM_ s (skip p k >> skip p v)
+                  replicateM_ (fromIntegral s) (skip p k >> skip p v)
                   readMapEnd p
 skip p T_SET = do (t, n) <- readSetBegin p
-                  replicateM_ n (skip p t)
+                  replicateM_ (fromIntegral n) (skip p t)
                   readSetEnd p
 skip p T_LIST = do (t, n) <- readListBegin p
-                   replicateM_ n (skip p t)
+                   replicateM_ (fromIntegral n) (skip p t)
                    readListEnd p
 
 
diff --git a/lib/hs/src/Thrift/Protocol/Binary.hs b/lib/hs/src/Thrift/Protocol/Binary.hs
index 308ab48..cd95965 100644
--- a/lib/hs/src/Thrift/Protocol/Binary.hs
+++ b/lib/hs/src/Thrift/Protocol/Binary.hs
@@ -30,6 +30,7 @@
 import Data.Bits
 import Data.Int
 import Data.List ( foldl' )
+import Data.Word
 
 import GHC.Exts
 import GHC.Word
@@ -37,12 +38,13 @@
 import Thrift.Protocol
 import Thrift.Transport
 
-import qualified Data.ByteString.Lazy.Char8 as LBS
+import qualified Data.ByteString.Lazy.Char8 as LBSChar8
+import qualified Data.ByteString.Lazy as LBS
 
-version_mask :: Int
+version_mask :: Int32
 version_mask = 0xffff0000
 
-version_1 :: Int
+version_1 :: Int32
 version_1    = 0x80010000
 
 data BinaryProtocol a = Transport a => BinaryProtocol a
@@ -52,7 +54,7 @@
     getTransport (BinaryProtocol t) = t
 
     writeMessageBegin p (n, t, s) = do
-        writeI32 p (version_1 .|. (fromEnum t))
+        writeI32 p (version_1 .|. (fromIntegral $ fromEnum t))
         writeString p n
         writeI32 p s
     writeMessageEnd _ = return ()
@@ -69,14 +71,14 @@
     writeSetBegin p (t, n) = writeType p t >> writeI32 p n
     writeSetEnd _ = return ()
 
-    writeBool p b = tWrite (getTransport p) $ LBS.singleton $ toEnum $ if b then 1 else 0
+    writeBool p b = tWrite (getTransport p) $ LBSChar8.singleton $ toEnum $ if b then 1 else 0
     writeByte p b = tWrite (getTransport p) (getBytes b 1)
     writeI16 p b = tWrite (getTransport p) (getBytes b 2)
     writeI32 p b = tWrite (getTransport p) (getBytes b 4)
     writeI64 p b = tWrite (getTransport p) (getBytes b 8)
     writeDouble p d = writeI64 p (fromIntegral $ floatBits d)
-    writeString p s = writeI32 p (length s) >> tWrite (getTransport p) (LBS.pack s)
-    writeBinary = writeString
+    writeString p s = writeI32 p (fromIntegral $ length s) >> tWrite (getTransport p) (LBSChar8.pack s)
+    writeBinary p s = writeI32 p (fromIntegral $ LBS.length s) >> tWrite (getTransport p) s
 
     readMessageBegin p = do
         ver <- readI32 p
@@ -85,7 +87,7 @@
             else do
               s <- readString p
               sz <- readI32 p
-              return (s, toEnum $ ver .&. 0xFF, sz)
+              return (s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
     readMessageEnd _ = return ()
     readStructBegin _ = return ""
     readStructEnd _ = return ()
@@ -125,29 +127,32 @@
         return $ floatOfBits $ fromIntegral bs
     readString p = do
         i <- readI32 p
-        LBS.unpack `liftM` tReadAll (getTransport p) i
-
-    readBinary = readString
+        LBSChar8.unpack `liftM` tReadAll (getTransport p) (fromIntegral i)
+    readBinary p = do
+        i <- readI32 p
+        tReadAll (getTransport p) (fromIntegral i)
 
 
 -- | Write a type as a byte
 writeType :: (Protocol p, Transport t) => p t -> ThriftType -> IO ()
-writeType p t = writeByte p (fromEnum t)
+writeType p t = writeByte p (fromIntegral $ fromEnum t)
 
 -- | Read a byte as though it were a ThriftType
 readType :: (Protocol p, Transport t) => p t -> IO ThriftType
-readType p = toEnum `fmap` readByte p
+readType p = do
+    b <- readByte p
+    return $ toEnum $ fromIntegral b
 
-composeBytes :: (Bits b) => LBS.ByteString -> b
-composeBytes = (foldl' fn 0) . (map (fromIntegral . fromEnum)) . LBS.unpack
+composeBytes :: (Bits b) => LBSChar8.ByteString -> b
+composeBytes = (foldl' fn 0) . (map (fromIntegral . fromEnum)) . LBSChar8.unpack
     where fn acc b = (acc `shiftL` 8) .|. b
 
 getByte :: Bits a => a -> Int -> a
 getByte i n = 255 .&. (i `shiftR` (8 * n))
 
-getBytes :: (Bits a, Integral a) => a -> Int -> LBS.ByteString
-getBytes _ 0 = LBS.empty
-getBytes i n = (toEnum $ fromIntegral $ getByte i (n-1)) `LBS.cons` (getBytes i (n-1))
+getBytes :: (Bits a, Integral a) => a -> Int -> LBSChar8.ByteString
+getBytes _ 0 = LBSChar8.empty
+getBytes i n = (toEnum $ fromIntegral $ getByte i (n-1)) `LBSChar8.cons` (getBytes i (n-1))
 
 floatBits :: Double -> Word64
 floatBits (D# d#) = W64# (unsafeCoerce# d#)