THRIFT-3580 THeader for Haskell
Client: hs
This closes #820
This closes #1423
diff --git a/lib/hs/src/Thrift/Protocol/Binary.hs b/lib/hs/src/Thrift/Protocol/Binary.hs
index 2d35305..7b0acd9 100644
--- a/lib/hs/src/Thrift/Protocol/Binary.hs
+++ b/lib/hs/src/Thrift/Protocol/Binary.hs
@@ -25,6 +25,8 @@
module Thrift.Protocol.Binary
( module Thrift.Protocol
, BinaryProtocol(..)
+ , versionMask
+ , version1
) where
import Control.Exception ( throw )
@@ -35,6 +37,7 @@
import Data.Int
import Data.Monoid
import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
+import Data.Word
import Thrift.Protocol
import Thrift.Transport
@@ -47,37 +50,55 @@
import qualified Data.HashMap.Strict as Map
import qualified Data.Text.Lazy as LT
-data BinaryProtocol a = BinaryProtocol a
+versionMask :: Int32
+versionMask = fromIntegral (0xffff0000 :: Word32)
+
+version1 :: Int32
+version1 = fromIntegral (0x80010000 :: Word32)
+
+data BinaryProtocol a = Transport a => BinaryProtocol a
+
+getTransport :: Transport t => BinaryProtocol t -> t
+getTransport (BinaryProtocol t) = t
-- NOTE: Reading and Writing functions rely on Builders and Data.Binary to
-- encode and decode data. Data.Binary assumes that the binary values it is
-- encoding to and decoding from are in BIG ENDIAN format, and converts the
-- endianness as necessary to match the local machine.
-instance Protocol BinaryProtocol where
- getTransport (BinaryProtocol t) = t
+instance Transport t => Protocol (BinaryProtocol t) where
+ readByte p = tReadAll (getTransport p) 1
+ -- flushTransport p = tFlush (getTransport p)
+ writeMessage p (n, t, s) f = do
+ tWrite (getTransport p) messageBegin
+ f
+ tFlush $ getTransport p
+ where
+ messageBegin = toLazyByteString $
+ buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
+ buildBinaryValue (TString $ encodeUtf8 n) <>
+ buildBinaryValue (TI32 s)
- writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
- buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
- buildBinaryValue (TString $ encodeUtf8 n) <>
- buildBinaryValue (TI32 s)
+ readMessage p = (readMessageBegin p >>=)
+ where
+ readMessageBegin p = runParser p $ do
+ TI32 ver <- parseBinaryValue T_I32
+ if ver .&. versionMask /= version1
+ then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
+ else do
+ TString s <- parseBinaryValue T_STRING
+ TI32 sz <- parseBinaryValue T_I32
+ return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
- readMessageBegin p = runParser p $ do
- TI32 ver <- parseBinaryValue T_I32
- if ver .&. versionMask /= version1
- then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
- else do
- TString s <- parseBinaryValue T_STRING
- TI32 sz <- parseBinaryValue T_I32
- return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
+ writeVal p = tWrite (getTransport p) . toLazyByteString . buildBinaryValue
+ readVal p = runParser p . parseBinaryValue
+instance Transport t => StatelessProtocol (BinaryProtocol t) where
serializeVal _ = toLazyByteString . buildBinaryValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseBinaryValue ty) bs of
Left s -> error s
Right val -> val
- readVal p = runParser p . parseBinaryValue
-
-- | Writing Functions
buildBinaryValue :: ThriftVal -> Builder
buildBinaryValue (TStruct fields) = buildBinaryStruct fields <> buildType T_STOP
diff --git a/lib/hs/src/Thrift/Protocol/Compact.hs b/lib/hs/src/Thrift/Protocol/Compact.hs
index 07113df..f23970a 100644
--- a/lib/hs/src/Thrift/Protocol/Compact.hs
+++ b/lib/hs/src/Thrift/Protocol/Compact.hs
@@ -25,10 +25,11 @@
module Thrift.Protocol.Compact
( module Thrift.Protocol
, CompactProtocol(..)
+ , parseVarint
+ , buildVarint
) where
import Control.Applicative
-import Control.Exception ( throw )
import Control.Monad
import Data.Attoparsec.ByteString as P
import Data.Attoparsec.ByteString.Lazy as LP
@@ -40,7 +41,7 @@
import Data.Word
import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
-import Thrift.Protocol hiding (versionMask)
+import Thrift.Protocol
import Thrift.Transport
import Thrift.Types
@@ -64,38 +65,47 @@
typeShiftAmount :: Int
typeShiftAmount = 5
+getTransport :: Transport t => CompactProtocol t -> t
+getTransport (CompactProtocol t) = t
-instance Protocol CompactProtocol where
- getTransport (CompactProtocol t) = t
+instance Transport t => Protocol (CompactProtocol t) where
+ readByte p = tReadAll (getTransport p) 1
+ writeMessage p (n, t, s) f = do
+ tWrite (getTransport p) messageBegin
+ f
+ tFlush $ getTransport p
+ where
+ messageBegin = toLazyByteString $
+ B.word8 protocolID <>
+ B.word8 ((version .&. versionMask) .|.
+ (((fromIntegral $ fromEnum t) `shiftL`
+ typeShiftAmount) .&. typeMask)) <>
+ buildVarint (i32ToZigZag s) <>
+ buildCompactValue (TString $ encodeUtf8 n)
- writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
- B.word8 protocolID <>
- B.word8 ((version .&. versionMask) .|.
- (((fromIntegral $ fromEnum t) `shiftL`
- typeShiftAmount) .&. typeMask)) <>
- buildVarint (i32ToZigZag s) <>
- buildCompactValue (TString $ encodeUtf8 n)
-
- readMessageBegin p = runParser p $ do
- pid <- fromIntegral <$> P.anyWord8
- when (pid /= protocolID) $ error "Bad Protocol ID"
- w <- fromIntegral <$> P.anyWord8
- let ver = w .&. versionMask
- when (ver /= version) $ error "Bad Protocol version"
- let typ = (w `shiftR` typeShiftAmount) .&. typeBits
- seqId <- parseVarint zigZagToI32
- TString name <- parseCompactValue T_STRING
- return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
+ readMessage p f = readMessageBegin >>= f
+ where
+ readMessageBegin = runParser p $ do
+ pid <- fromIntegral <$> P.anyWord8
+ when (pid /= protocolID) $ error "Bad Protocol ID"
+ w <- fromIntegral <$> P.anyWord8
+ let ver = w .&. versionMask
+ when (ver /= version) $ error "Bad Protocol version"
+ let typ = (w `shiftR` typeShiftAmount) .&. typeBits
+ seqId <- parseVarint zigZagToI32
+ TString name <- parseCompactValue T_STRING
+ return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
+ writeVal p = tWrite (getTransport p) . toLazyByteString . buildCompactValue
+ readVal p ty = runParser p $ parseCompactValue ty
+
+instance Transport t => StatelessProtocol (CompactProtocol t) where
serializeVal _ = toLazyByteString . buildCompactValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseCompactValue ty) bs of
Left s -> error s
Right val -> val
- readVal p ty = runParser p $ parseCompactValue ty
-
-
-- | Writing Functions
buildCompactValue :: ThriftVal -> Builder
buildCompactValue (TStruct fields) = buildCompactStruct fields
@@ -283,7 +293,7 @@
TSet{} -> 0x0A
TMap{} -> 0x0B
TStruct{} -> 0x0C
-
+
typeFrom :: Word8 -> ThriftType
typeFrom w = case w of
0x01 -> T_BOOL
diff --git a/lib/hs/src/Thrift/Protocol/Header.hs b/lib/hs/src/Thrift/Protocol/Header.hs
new file mode 100644
index 0000000..5f42db4
--- /dev/null
+++ b/lib/hs/src/Thrift/Protocol/Header.hs
@@ -0,0 +1,141 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing,
+-- software distributed under the License is distributed on an
+-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+-- KIND, either express or implied. See the License for the
+-- specific language governing permissions and limitations
+-- under the License.
+--
+
+
+module Thrift.Protocol.Header
+ ( module Thrift.Protocol
+ , HeaderProtocol(..)
+ , getProtocolType
+ , setProtocolType
+ , getHeaders
+ , getWriteHeaders
+ , setHeader
+ , setHeaders
+ , createHeaderProtocol
+ , createHeaderProtocol1
+ ) where
+
+import Thrift.Protocol
+import Thrift.Protocol.Binary
+import Thrift.Protocol.JSON
+import Thrift.Protocol.Compact
+import Thrift.Transport
+import Thrift.Transport.Header
+import Data.IORef
+import qualified Data.Map as Map
+
+data ProtocolWrap = forall a. (Protocol a) => ProtocolWrap(a)
+
+instance Protocol ProtocolWrap where
+ readByte (ProtocolWrap p) = readByte p
+ readVal (ProtocolWrap p) = readVal p
+ readMessage (ProtocolWrap p) = readMessage p
+ writeVal (ProtocolWrap p) = writeVal p
+ writeMessage (ProtocolWrap p) = writeMessage p
+
+data HeaderProtocol i o = (Transport i, Transport o) => HeaderProtocol {
+ trans :: HeaderTransport i o,
+ wrappedProto :: IORef ProtocolWrap
+ }
+
+createProtocolWrap :: Transport t => ProtocolType -> t -> ProtocolWrap
+createProtocolWrap typ t =
+ case typ of
+ TBinary -> ProtocolWrap $ BinaryProtocol t
+ TCompact -> ProtocolWrap $ CompactProtocol t
+ TJSON -> ProtocolWrap $ JSONProtocol t
+
+createHeaderProtocol :: (Transport i, Transport o) => i -> o -> IO(HeaderProtocol i o)
+createHeaderProtocol i o = do
+ t <- openHeaderTransport i o
+ pid <- readIORef $ protocolType t
+ proto <- newIORef $ createProtocolWrap pid t
+ return $ HeaderProtocol { trans = t, wrappedProto = proto }
+
+createHeaderProtocol1 :: Transport t => t -> IO(HeaderProtocol t t)
+createHeaderProtocol1 t = createHeaderProtocol t t
+
+resetProtocol :: (Transport i, Transport o) => HeaderProtocol i o -> IO ()
+resetProtocol p = do
+ pid <- readIORef $ protocolType $ trans p
+ writeIORef (wrappedProto p) $ createProtocolWrap pid $ trans p
+
+getWrapped = readIORef . wrappedProto
+
+setTransport :: (Transport i, Transport o) => HeaderProtocol i o -> HeaderTransport i o -> HeaderProtocol i o
+setTransport p t = p { trans = t }
+
+updateTransport :: (Transport i, Transport o) => HeaderProtocol i o -> (HeaderTransport i o -> HeaderTransport i o)-> HeaderProtocol i o
+updateTransport p f = setTransport p (f $ trans p)
+
+type Headers = Map.Map String String
+
+-- TODO: we want to set headers without recreating client...
+setHeader :: (Transport i, Transport o) => HeaderProtocol i o -> String -> String -> HeaderProtocol i o
+setHeader p k v = updateTransport p $ \t -> t { writeHeaders = Map.insert k v $ writeHeaders t }
+
+setHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers -> HeaderProtocol i o
+setHeaders p h = updateTransport p $ \t -> t { writeHeaders = h }
+
+-- TODO: make it public once we have first transform implementation for Haskell
+setTransforms :: (Transport i, Transport o) => HeaderProtocol i o -> [TransformType] -> HeaderProtocol i o
+setTransforms p trs = updateTransport p $ \t -> t { writeTransforms = trs }
+
+setTransform :: (Transport i, Transport o) => HeaderProtocol i o -> TransformType -> HeaderProtocol i o
+setTransform p tr = updateTransport p $ \t -> t { writeTransforms = tr:(writeTransforms t) }
+
+getWriteHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers
+getWriteHeaders = writeHeaders . trans
+
+getHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> IO [(String, String)]
+getHeaders = readIORef . headers . trans
+
+getProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> IO ProtocolType
+getProtocolType p = readIORef $ protocolType $ trans p
+
+setProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> ProtocolType -> IO ()
+setProtocolType p typ = do
+ typ0 <- getProtocolType p
+ if typ == typ0
+ then return ()
+ else do
+ tSetProtocol (trans p) typ
+ resetProtocol p
+
+instance (Transport i, Transport o) => Protocol (HeaderProtocol i o) where
+ readByte p = tReadAll (trans p) 1
+
+ readVal p tp = do
+ proto <- getWrapped p
+ readVal proto tp
+
+ readMessage p f = do
+ tResetProtocol (trans p)
+ resetProtocol p
+ proto <- getWrapped p
+ readMessage proto f
+
+ writeVal p v = do
+ proto <- getWrapped p
+ writeVal proto v
+
+ writeMessage p x f = do
+ proto <- getWrapped p
+ writeMessage proto x f
+
diff --git a/lib/hs/src/Thrift/Protocol/JSON.hs b/lib/hs/src/Thrift/Protocol/JSON.hs
index 7f619e8..839eddc 100644
--- a/lib/hs/src/Thrift/Protocol/JSON.hs
+++ b/lib/hs/src/Thrift/Protocol/JSON.hs
@@ -29,12 +29,12 @@
) where
import Control.Applicative
+import Control.Exception (bracket)
import Control.Monad
import Data.Attoparsec.ByteString as P
import Data.Attoparsec.ByteString.Char8 as PC
import Data.Attoparsec.ByteString.Lazy as LP
import Data.ByteString.Base64.Lazy as B64C
-import Data.ByteString.Base64 as B64
import Data.ByteString.Lazy.Builder as B
import Data.ByteString.Internal (c2w, w2c)
import Data.Functor
@@ -58,38 +58,48 @@
-- encoded as a JSON 'ByteString'
data JSONProtocol t = JSONProtocol t
-- ^ Construct a 'JSONProtocol' with a 'Transport'
+getTransport :: Transport t => JSONProtocol t -> t
+getTransport (JSONProtocol t) = t
-instance Protocol JSONProtocol where
- getTransport (JSONProtocol t) = t
+instance Transport t => Protocol (JSONProtocol t) where
+ readByte p = tReadAll (getTransport p) 1
- writeMessageBegin (JSONProtocol t) (s, ty, sq) = tWrite t $ toLazyByteString $
- B.char8 '[' <> buildShowable (1 :: Int32) <>
- B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <>
- B.char8 ',' <> buildShowable (fromEnum ty) <>
- B.char8 ',' <> buildShowable sq <>
- B.char8 ','
- writeMessageEnd (JSONProtocol t) = tWrite t "]"
- readMessageBegin p = runParser p $ skipSpace *> do
- _ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal)
- bs <- lexeme (PC.char8 ',') *> lexeme escapedString
- case decodeUtf8' bs of
- Left _ -> fail "readMessage: invalid text encoding"
- Right str -> do
- ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal))
- seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal)
- _ <- PC.char8 ','
- return (str, ty, seqNum)
- readMessageEnd p = void $ runParser p (PC.char8 ']')
+ writeMessage (JSONProtocol t) (s, ty, sq) = bracket readMessageBegin readMessageEnd . const
+ where
+ readMessageBegin = tWrite t $ toLazyByteString $
+ B.char8 '[' <> buildShowable (1 :: Int32) <>
+ B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <>
+ B.char8 ',' <> buildShowable (fromEnum ty) <>
+ B.char8 ',' <> buildShowable sq <>
+ B.char8 ','
+ readMessageEnd _ = do
+ tWrite t "]"
+ tFlush t
+ readMessage p = bracket readMessageBegin readMessageEnd
+ where
+ readMessageBegin = runParser p $ skipSpace *> do
+ _ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal)
+ bs <- lexeme (PC.char8 ',') *> lexeme escapedString
+ case decodeUtf8' bs of
+ Left _ -> fail "readMessage: invalid text encoding"
+ Right str -> do
+ ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal))
+ seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal)
+ _ <- PC.char8 ','
+ return (str, ty, seqNum)
+ readMessageEnd _ = void $ runParser p (PC.char8 ']')
+
+ writeVal p = tWrite (getTransport p) . toLazyByteString . buildJSONValue
+ readVal p ty = runParser p $ skipSpace *> parseJSONValue ty
+
+instance Transport t => StatelessProtocol (JSONProtocol t) where
serializeVal _ = toLazyByteString . buildJSONValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseJSONValue ty) bs of
Left s -> error s
Right val -> val
- readVal p ty = runParser p $ skipSpace *> parseJSONValue ty
-
-
-- Writing Functions
buildJSONValue :: ThriftVal -> Builder