THRIFT-3580 THeader for Haskell
Client: hs

This closes #820
This closes #1423
diff --git a/lib/hs/src/Thrift/Protocol/Binary.hs b/lib/hs/src/Thrift/Protocol/Binary.hs
index 2d35305..7b0acd9 100644
--- a/lib/hs/src/Thrift/Protocol/Binary.hs
+++ b/lib/hs/src/Thrift/Protocol/Binary.hs
@@ -25,6 +25,8 @@
 module Thrift.Protocol.Binary
     ( module Thrift.Protocol
     , BinaryProtocol(..)
+    , versionMask
+    , version1
     ) where
 
 import Control.Exception ( throw )
@@ -35,6 +37,7 @@
 import Data.Int
 import Data.Monoid
 import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
+import Data.Word
 
 import Thrift.Protocol
 import Thrift.Transport
@@ -47,37 +50,55 @@
 import qualified Data.HashMap.Strict as Map
 import qualified Data.Text.Lazy as LT
 
-data BinaryProtocol a = BinaryProtocol a
+versionMask :: Int32
+versionMask = fromIntegral (0xffff0000 :: Word32)
+
+version1 :: Int32
+version1 = fromIntegral (0x80010000 :: Word32)
+
+data BinaryProtocol a = Transport a => BinaryProtocol a
+
+getTransport :: Transport t => BinaryProtocol t -> t
+getTransport (BinaryProtocol t) = t
 
 -- NOTE: Reading and Writing functions rely on Builders and Data.Binary to
 -- encode and decode data.  Data.Binary assumes that the binary values it is
 -- encoding to and decoding from are in BIG ENDIAN format, and converts the
 -- endianness as necessary to match the local machine.
-instance Protocol BinaryProtocol where
-    getTransport (BinaryProtocol t) = t
+instance Transport t => Protocol (BinaryProtocol t) where
+    readByte p = tReadAll (getTransport p) 1
+    -- flushTransport p = tFlush (getTransport p)
+    writeMessage p (n, t, s) f = do
+      tWrite (getTransport p) messageBegin
+      f
+      tFlush $ getTransport p
+      where
+        messageBegin = toLazyByteString $
+          buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
+          buildBinaryValue (TString $ encodeUtf8 n) <>
+          buildBinaryValue (TI32 s)
 
-    writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
-        buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
-        buildBinaryValue (TString $ encodeUtf8 n) <>
-        buildBinaryValue (TI32 s)
+    readMessage p = (readMessageBegin p >>=)
+      where
+        readMessageBegin p = runParser p $ do
+          TI32 ver <- parseBinaryValue T_I32
+          if ver .&. versionMask /= version1
+            then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
+            else do
+              TString s <- parseBinaryValue T_STRING
+              TI32 sz <- parseBinaryValue T_I32
+              return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
 
-    readMessageBegin p = runParser p $ do
-      TI32 ver <- parseBinaryValue T_I32
-      if ver .&. versionMask /= version1
-        then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
-        else do
-          TString s <- parseBinaryValue T_STRING
-          TI32 sz <- parseBinaryValue T_I32
-          return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
+    writeVal p = tWrite (getTransport p) . toLazyByteString . buildBinaryValue
+    readVal p = runParser p . parseBinaryValue
 
+instance Transport t => StatelessProtocol (BinaryProtocol t) where
     serializeVal _ = toLazyByteString . buildBinaryValue
     deserializeVal _ ty bs =
       case LP.eitherResult $ LP.parse (parseBinaryValue ty) bs of
         Left s -> error s
         Right val -> val
 
-    readVal p = runParser p . parseBinaryValue
-
 -- | Writing Functions
 buildBinaryValue :: ThriftVal -> Builder
 buildBinaryValue (TStruct fields) = buildBinaryStruct fields <> buildType T_STOP
diff --git a/lib/hs/src/Thrift/Protocol/Compact.hs b/lib/hs/src/Thrift/Protocol/Compact.hs
index 07113df..f23970a 100644
--- a/lib/hs/src/Thrift/Protocol/Compact.hs
+++ b/lib/hs/src/Thrift/Protocol/Compact.hs
@@ -25,10 +25,11 @@
 module Thrift.Protocol.Compact
     ( module Thrift.Protocol
     , CompactProtocol(..)
+    , parseVarint
+    , buildVarint
     ) where
 
 import Control.Applicative
-import Control.Exception ( throw )
 import Control.Monad
 import Data.Attoparsec.ByteString as P
 import Data.Attoparsec.ByteString.Lazy as LP
@@ -40,7 +41,7 @@
 import Data.Word
 import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
 
-import Thrift.Protocol hiding (versionMask)
+import Thrift.Protocol
 import Thrift.Transport
 import Thrift.Types
 
@@ -64,38 +65,47 @@
 typeShiftAmount :: Int
 typeShiftAmount = 5
 
+getTransport :: Transport t => CompactProtocol t -> t
+getTransport (CompactProtocol t) = t
 
-instance Protocol CompactProtocol where
-    getTransport (CompactProtocol t) = t
+instance Transport t => Protocol (CompactProtocol t) where
+    readByte p = tReadAll (getTransport p) 1
+    writeMessage p (n, t, s) f = do
+      tWrite (getTransport p) messageBegin
+      f
+      tFlush $ getTransport p
+      where
+        messageBegin = toLazyByteString $
+          B.word8 protocolID <>
+          B.word8 ((version .&. versionMask) .|.
+                  (((fromIntegral $ fromEnum t) `shiftL`
+                  typeShiftAmount) .&. typeMask)) <>
+          buildVarint (i32ToZigZag s) <>
+          buildCompactValue (TString $ encodeUtf8 n)
 
-    writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
-      B.word8 protocolID <>
-      B.word8 ((version .&. versionMask) .|.
-              (((fromIntegral $ fromEnum t) `shiftL`
-                typeShiftAmount) .&. typeMask)) <>
-      buildVarint (i32ToZigZag s) <>
-      buildCompactValue (TString $ encodeUtf8 n)
-    
-    readMessageBegin p = runParser p $ do
-      pid <- fromIntegral <$> P.anyWord8
-      when (pid /= protocolID) $ error "Bad Protocol ID"
-      w <- fromIntegral <$> P.anyWord8
-      let ver = w .&. versionMask 
-      when (ver /= version) $ error "Bad Protocol version"
-      let typ = (w `shiftR` typeShiftAmount) .&. typeBits
-      seqId <- parseVarint zigZagToI32
-      TString name <- parseCompactValue T_STRING
-      return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
+    readMessage p f = readMessageBegin >>= f
+      where
+        readMessageBegin = runParser p $ do
+          pid <- fromIntegral <$> P.anyWord8
+          when (pid /= protocolID) $ error "Bad Protocol ID"
+          w <- fromIntegral <$> P.anyWord8
+          let ver = w .&. versionMask
+          when (ver /= version) $ error "Bad Protocol version"
+          let typ = (w `shiftR` typeShiftAmount) .&. typeBits
+          seqId <- parseVarint zigZagToI32
+          TString name <- parseCompactValue T_STRING
+          return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
 
+    writeVal p = tWrite (getTransport p) . toLazyByteString . buildCompactValue
+    readVal p ty = runParser p $ parseCompactValue ty
+
+instance Transport t => StatelessProtocol (CompactProtocol t) where
     serializeVal _ = toLazyByteString . buildCompactValue
     deserializeVal _ ty bs =
       case LP.eitherResult $ LP.parse (parseCompactValue ty) bs of
         Left s -> error s
         Right val -> val
 
-    readVal p ty = runParser p $ parseCompactValue ty
-
-
 -- | Writing Functions
 buildCompactValue :: ThriftVal -> Builder
 buildCompactValue (TStruct fields) = buildCompactStruct fields
@@ -283,7 +293,7 @@
   TSet{} -> 0x0A
   TMap{} -> 0x0B
   TStruct{} -> 0x0C
-  
+
 typeFrom :: Word8 -> ThriftType
 typeFrom w = case w of
   0x01 -> T_BOOL
diff --git a/lib/hs/src/Thrift/Protocol/Header.hs b/lib/hs/src/Thrift/Protocol/Header.hs
new file mode 100644
index 0000000..5f42db4
--- /dev/null
+++ b/lib/hs/src/Thrift/Protocol/Header.hs
@@ -0,0 +1,141 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+--
+--   http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing,
+-- software distributed under the License is distributed on an
+-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+-- KIND, either express or implied. See the License for the
+-- specific language governing permissions and limitations
+-- under the License.
+--
+
+
+module Thrift.Protocol.Header
+    ( module Thrift.Protocol
+    , HeaderProtocol(..)
+    , getProtocolType
+    , setProtocolType
+    , getHeaders
+    , getWriteHeaders
+    , setHeader
+    , setHeaders
+    , createHeaderProtocol
+    , createHeaderProtocol1
+    ) where
+
+import Thrift.Protocol
+import Thrift.Protocol.Binary
+import Thrift.Protocol.JSON
+import Thrift.Protocol.Compact
+import Thrift.Transport
+import Thrift.Transport.Header
+import Data.IORef
+import qualified Data.Map as Map
+
+data ProtocolWrap = forall a. (Protocol a) => ProtocolWrap(a)
+
+instance Protocol ProtocolWrap where
+  readByte (ProtocolWrap p) = readByte p
+  readVal (ProtocolWrap p) = readVal p
+  readMessage (ProtocolWrap p) = readMessage p
+  writeVal (ProtocolWrap p) = writeVal p
+  writeMessage (ProtocolWrap p) = writeMessage p
+
+data HeaderProtocol i o = (Transport i, Transport o) => HeaderProtocol {
+    trans :: HeaderTransport i o,
+    wrappedProto :: IORef ProtocolWrap
+  }
+
+createProtocolWrap :: Transport t => ProtocolType -> t -> ProtocolWrap
+createProtocolWrap typ t =
+  case typ of
+    TBinary -> ProtocolWrap $ BinaryProtocol t
+    TCompact -> ProtocolWrap $ CompactProtocol t
+    TJSON -> ProtocolWrap $ JSONProtocol t
+
+createHeaderProtocol :: (Transport i, Transport o) => i -> o -> IO(HeaderProtocol i o)
+createHeaderProtocol i o = do
+  t <- openHeaderTransport i o
+  pid <- readIORef $ protocolType t
+  proto <- newIORef $ createProtocolWrap pid t
+  return $ HeaderProtocol { trans = t, wrappedProto = proto }
+
+createHeaderProtocol1 :: Transport t => t -> IO(HeaderProtocol t t)
+createHeaderProtocol1 t = createHeaderProtocol t t
+
+resetProtocol :: (Transport i, Transport o) => HeaderProtocol i o -> IO ()
+resetProtocol p = do
+  pid <- readIORef $ protocolType $ trans p
+  writeIORef (wrappedProto p) $ createProtocolWrap pid $ trans p
+
+getWrapped = readIORef . wrappedProto
+
+setTransport :: (Transport i, Transport o) => HeaderProtocol i o -> HeaderTransport i o -> HeaderProtocol i o
+setTransport p t = p { trans = t }
+
+updateTransport :: (Transport i, Transport o) => HeaderProtocol i o -> (HeaderTransport i o -> HeaderTransport i o)-> HeaderProtocol i o
+updateTransport p f = setTransport p (f $ trans p)
+
+type Headers = Map.Map String String
+
+-- TODO: we want to set headers without recreating client...
+setHeader :: (Transport i, Transport o) => HeaderProtocol i o -> String -> String -> HeaderProtocol i o
+setHeader p k v = updateTransport p $ \t -> t { writeHeaders = Map.insert k v $ writeHeaders t }
+
+setHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers -> HeaderProtocol i o
+setHeaders p h = updateTransport p $ \t -> t { writeHeaders = h }
+
+-- TODO: make it public once we have first transform implementation for Haskell
+setTransforms :: (Transport i, Transport o) => HeaderProtocol i o -> [TransformType] -> HeaderProtocol i o
+setTransforms p trs = updateTransport p $ \t -> t { writeTransforms = trs }
+
+setTransform :: (Transport i, Transport o) => HeaderProtocol i o -> TransformType -> HeaderProtocol i o
+setTransform p tr = updateTransport p $ \t -> t { writeTransforms = tr:(writeTransforms t) }
+
+getWriteHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers
+getWriteHeaders = writeHeaders . trans
+
+getHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> IO [(String, String)]
+getHeaders = readIORef . headers . trans
+
+getProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> IO ProtocolType
+getProtocolType p = readIORef $ protocolType $ trans p
+
+setProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> ProtocolType -> IO ()
+setProtocolType p typ = do
+  typ0 <- getProtocolType p
+  if typ == typ0
+    then return ()
+    else do
+      tSetProtocol (trans p) typ
+      resetProtocol p
+
+instance (Transport i, Transport o) => Protocol (HeaderProtocol i o) where
+  readByte p = tReadAll (trans p) 1
+
+  readVal p tp = do
+    proto <- getWrapped p
+    readVal proto tp
+
+  readMessage p f = do
+    tResetProtocol (trans p)
+    resetProtocol p
+    proto <- getWrapped p
+    readMessage proto f
+
+  writeVal p v = do
+    proto <- getWrapped p
+    writeVal proto v
+
+  writeMessage p x f = do
+    proto <- getWrapped p
+    writeMessage proto x f
+
diff --git a/lib/hs/src/Thrift/Protocol/JSON.hs b/lib/hs/src/Thrift/Protocol/JSON.hs
index 7f619e8..839eddc 100644
--- a/lib/hs/src/Thrift/Protocol/JSON.hs
+++ b/lib/hs/src/Thrift/Protocol/JSON.hs
@@ -29,12 +29,12 @@
     ) where
 
 import Control.Applicative
+import Control.Exception (bracket)
 import Control.Monad
 import Data.Attoparsec.ByteString as P
 import Data.Attoparsec.ByteString.Char8 as PC
 import Data.Attoparsec.ByteString.Lazy as LP
 import Data.ByteString.Base64.Lazy as B64C
-import Data.ByteString.Base64 as B64
 import Data.ByteString.Lazy.Builder as B
 import Data.ByteString.Internal (c2w, w2c)
 import Data.Functor
@@ -58,38 +58,48 @@
 -- encoded as a JSON 'ByteString'
 data JSONProtocol t = JSONProtocol t
                       -- ^ Construct a 'JSONProtocol' with a 'Transport'
+getTransport :: Transport t => JSONProtocol t -> t
+getTransport (JSONProtocol t) = t
 
-instance Protocol JSONProtocol where
-    getTransport (JSONProtocol t) = t
+instance Transport t => Protocol (JSONProtocol t) where
+    readByte p = tReadAll (getTransport p) 1
 
-    writeMessageBegin (JSONProtocol t) (s, ty, sq) = tWrite t $ toLazyByteString $
-      B.char8 '[' <> buildShowable (1 :: Int32) <>
-      B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <>
-      B.char8 ',' <> buildShowable (fromEnum ty) <>
-      B.char8 ',' <> buildShowable sq <>
-      B.char8 ','
-    writeMessageEnd (JSONProtocol t) = tWrite t "]"
-    readMessageBegin p = runParser p $ skipSpace *> do
-      _ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal)
-      bs <- lexeme (PC.char8 ',') *> lexeme escapedString
-      case decodeUtf8' bs of
-        Left _ -> fail "readMessage: invalid text encoding"
-        Right str -> do
-          ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal))
-          seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal)
-          _ <- PC.char8 ','
-          return (str, ty, seqNum)
-    readMessageEnd p = void $ runParser p (PC.char8 ']')
+    writeMessage (JSONProtocol t) (s, ty, sq) = bracket readMessageBegin readMessageEnd . const
+      where
+        readMessageBegin = tWrite t $ toLazyByteString $
+          B.char8 '[' <> buildShowable (1 :: Int32) <>
+          B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <>
+          B.char8 ',' <> buildShowable (fromEnum ty) <>
+          B.char8 ',' <> buildShowable sq <>
+          B.char8 ','
+        readMessageEnd _ = do
+          tWrite t "]"
+          tFlush t
 
+    readMessage p = bracket readMessageBegin readMessageEnd
+      where
+        readMessageBegin = runParser p $ skipSpace *> do
+          _ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal)
+          bs <- lexeme (PC.char8 ',') *> lexeme escapedString
+          case decodeUtf8' bs of
+            Left _ -> fail "readMessage: invalid text encoding"
+            Right str -> do
+              ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal))
+              seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal)
+              _ <- PC.char8 ','
+              return (str, ty, seqNum)
+        readMessageEnd _ = void $ runParser p (PC.char8 ']')
+
+    writeVal p = tWrite (getTransport p) . toLazyByteString . buildJSONValue
+    readVal p ty = runParser p $ skipSpace *> parseJSONValue ty
+
+instance Transport t => StatelessProtocol (JSONProtocol t) where
     serializeVal _ = toLazyByteString . buildJSONValue
     deserializeVal _ ty bs =
       case LP.eitherResult $ LP.parse (parseJSONValue ty) bs of
         Left s -> error s
         Right val -> val
 
-    readVal p ty = runParser p $ skipSpace *> parseJSONValue ty
-
-
 -- Writing Functions
 
 buildJSONValue :: ThriftVal -> Builder