THRIFT-3580 THeader for Haskell
Client: hs
This closes #820
This closes #1423
diff --git a/lib/hs/src/Thrift.hs b/lib/hs/src/Thrift.hs
index 58a304b..6580209 100644
--- a/lib/hs/src/Thrift.hs
+++ b/lib/hs/src/Thrift.hs
@@ -90,13 +90,13 @@
deriving ( Show, Typeable )
instance Exception AppExn
-writeAppExn :: (Protocol p, Transport t) => p t -> AppExn -> IO ()
+writeAppExn :: Protocol p => p -> AppExn -> IO ()
writeAppExn pt ae = writeVal pt $ TStruct $ Map.fromList
[ (1, ("message", TString $ encodeUtf8 $ pack $ ae_message ae))
, (2, ("type", TI32 $ fromIntegral $ fromEnum (ae_type ae)))
]
-readAppExn :: (Protocol p, Transport t) => p t -> IO AppExn
+readAppExn :: Protocol p => p -> IO AppExn
readAppExn pt = do
let typemap = Map.fromList [(1,("message",T_STRING)),(2,("type",T_I32))]
TStruct fields <- readVal pt $ T_STRUCT typemap
diff --git a/lib/hs/src/Thrift/Protocol.hs b/lib/hs/src/Thrift/Protocol.hs
index ed779a2..67a9175 100644
--- a/lib/hs/src/Thrift/Protocol.hs
+++ b/lib/hs/src/Thrift/Protocol.hs
@@ -22,12 +22,11 @@
module Thrift.Protocol
( Protocol(..)
+ , StatelessProtocol(..)
, ProtocolExn(..)
, ProtocolExnType(..)
, getTypeOf
, runParser
- , versionMask
- , version1
, bsToDouble
, bsToDoubleLE
) where
@@ -35,7 +34,6 @@
import Control.Exception
import Data.Attoparsec.ByteString
import Data.Bits
-import Data.ByteString.Lazy (ByteString, toStrict)
import Data.ByteString.Unsafe
import Data.Functor ((<$>))
import Data.Int
@@ -44,37 +42,26 @@
import Data.Typeable (Typeable)
import Data.Word
import Foreign.Ptr (castPtr)
-import Foreign.Storable (Storable, peek, poke)
+import Foreign.Storable (peek, poke)
import System.IO.Unsafe
import qualified Data.ByteString as BS
import qualified Data.HashMap.Strict as Map
+import qualified Data.ByteString.Lazy as LBS
-import Thrift.Types
import Thrift.Transport
-
-versionMask :: Int32
-versionMask = fromIntegral (0xffff0000 :: Word32)
-
-version1 :: Int32
-version1 = fromIntegral (0x80010000 :: Word32)
+import Thrift.Types
class Protocol a where
- getTransport :: Transport t => a t -> t
+ readByte :: a -> IO LBS.ByteString
+ readVal :: a -> ThriftType -> IO ThriftVal
+ readMessage :: a -> ((Text, MessageType, Int32) -> IO b) -> IO b
- writeMessageBegin :: Transport t => a t -> (Text, MessageType, Int32) -> IO ()
- writeMessageEnd :: Transport t => a t -> IO ()
- writeMessageEnd _ = return ()
-
- readMessageBegin :: Transport t => a t -> IO (Text, MessageType, Int32)
- readMessageEnd :: Transport t => a t -> IO ()
- readMessageEnd _ = return ()
+ writeVal :: a -> ThriftVal -> IO ()
+ writeMessage :: a -> (Text, MessageType, Int32) -> IO () -> IO ()
- serializeVal :: Transport t => a t -> ThriftVal -> ByteString
- deserializeVal :: Transport t => a t -> ThriftType -> ByteString -> ThriftVal
-
- writeVal :: Transport t => a t -> ThriftVal -> IO ()
- writeVal p = tWrite (getTransport p) . serializeVal p
- readVal :: Transport t => a t -> ThriftType -> IO ThriftVal
+class Protocol a => StatelessProtocol a where
+ serializeVal :: a -> ThriftVal -> LBS.ByteString
+ deserializeVal :: a -> ThriftType -> LBS.ByteString -> ThriftVal
data ProtocolExnType
= PE_UNKNOWN
@@ -105,10 +92,10 @@
TBinary{} -> T_BINARY
TDouble{} -> T_DOUBLE
-runParser :: (Protocol p, Transport t, Show a) => p t -> Parser a -> IO a
+runParser :: (Protocol p, Show a) => p -> Parser a -> IO a
runParser prot p = refill >>= getResult . parse p
where
- refill = handle handleEOF $ toStrict <$> tReadAll (getTransport prot) 1
+ refill = handle handleEOF $ LBS.toStrict <$> readByte prot
getResult (Done _ a) = return a
getResult (Partial k) = refill >>= getResult . k
getResult f = throw $ ProtocolExn PE_INVALID_DATA (show f)
diff --git a/lib/hs/src/Thrift/Protocol/Binary.hs b/lib/hs/src/Thrift/Protocol/Binary.hs
index 2d35305..7b0acd9 100644
--- a/lib/hs/src/Thrift/Protocol/Binary.hs
+++ b/lib/hs/src/Thrift/Protocol/Binary.hs
@@ -25,6 +25,8 @@
module Thrift.Protocol.Binary
( module Thrift.Protocol
, BinaryProtocol(..)
+ , versionMask
+ , version1
) where
import Control.Exception ( throw )
@@ -35,6 +37,7 @@
import Data.Int
import Data.Monoid
import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
+import Data.Word
import Thrift.Protocol
import Thrift.Transport
@@ -47,37 +50,55 @@
import qualified Data.HashMap.Strict as Map
import qualified Data.Text.Lazy as LT
-data BinaryProtocol a = BinaryProtocol a
+versionMask :: Int32
+versionMask = fromIntegral (0xffff0000 :: Word32)
+
+version1 :: Int32
+version1 = fromIntegral (0x80010000 :: Word32)
+
+data BinaryProtocol a = Transport a => BinaryProtocol a
+
+getTransport :: Transport t => BinaryProtocol t -> t
+getTransport (BinaryProtocol t) = t
-- NOTE: Reading and Writing functions rely on Builders and Data.Binary to
-- encode and decode data. Data.Binary assumes that the binary values it is
-- encoding to and decoding from are in BIG ENDIAN format, and converts the
-- endianness as necessary to match the local machine.
-instance Protocol BinaryProtocol where
- getTransport (BinaryProtocol t) = t
+instance Transport t => Protocol (BinaryProtocol t) where
+ readByte p = tReadAll (getTransport p) 1
+ -- flushTransport p = tFlush (getTransport p)
+ writeMessage p (n, t, s) f = do
+ tWrite (getTransport p) messageBegin
+ f
+ tFlush $ getTransport p
+ where
+ messageBegin = toLazyByteString $
+ buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
+ buildBinaryValue (TString $ encodeUtf8 n) <>
+ buildBinaryValue (TI32 s)
- writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
- buildBinaryValue (TI32 (version1 .|. fromIntegral (fromEnum t))) <>
- buildBinaryValue (TString $ encodeUtf8 n) <>
- buildBinaryValue (TI32 s)
+ readMessage p = (readMessageBegin p >>=)
+ where
+ readMessageBegin p = runParser p $ do
+ TI32 ver <- parseBinaryValue T_I32
+ if ver .&. versionMask /= version1
+ then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
+ else do
+ TString s <- parseBinaryValue T_STRING
+ TI32 sz <- parseBinaryValue T_I32
+ return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
- readMessageBegin p = runParser p $ do
- TI32 ver <- parseBinaryValue T_I32
- if ver .&. versionMask /= version1
- then throw $ ProtocolExn PE_BAD_VERSION "Missing version identifier"
- else do
- TString s <- parseBinaryValue T_STRING
- TI32 sz <- parseBinaryValue T_I32
- return (decodeUtf8 s, toEnum $ fromIntegral $ ver .&. 0xFF, sz)
+ writeVal p = tWrite (getTransport p) . toLazyByteString . buildBinaryValue
+ readVal p = runParser p . parseBinaryValue
+instance Transport t => StatelessProtocol (BinaryProtocol t) where
serializeVal _ = toLazyByteString . buildBinaryValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseBinaryValue ty) bs of
Left s -> error s
Right val -> val
- readVal p = runParser p . parseBinaryValue
-
-- | Writing Functions
buildBinaryValue :: ThriftVal -> Builder
buildBinaryValue (TStruct fields) = buildBinaryStruct fields <> buildType T_STOP
diff --git a/lib/hs/src/Thrift/Protocol/Compact.hs b/lib/hs/src/Thrift/Protocol/Compact.hs
index 07113df..f23970a 100644
--- a/lib/hs/src/Thrift/Protocol/Compact.hs
+++ b/lib/hs/src/Thrift/Protocol/Compact.hs
@@ -25,10 +25,11 @@
module Thrift.Protocol.Compact
( module Thrift.Protocol
, CompactProtocol(..)
+ , parseVarint
+ , buildVarint
) where
import Control.Applicative
-import Control.Exception ( throw )
import Control.Monad
import Data.Attoparsec.ByteString as P
import Data.Attoparsec.ByteString.Lazy as LP
@@ -40,7 +41,7 @@
import Data.Word
import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 )
-import Thrift.Protocol hiding (versionMask)
+import Thrift.Protocol
import Thrift.Transport
import Thrift.Types
@@ -64,38 +65,47 @@
typeShiftAmount :: Int
typeShiftAmount = 5
+getTransport :: Transport t => CompactProtocol t -> t
+getTransport (CompactProtocol t) = t
-instance Protocol CompactProtocol where
- getTransport (CompactProtocol t) = t
+instance Transport t => Protocol (CompactProtocol t) where
+ readByte p = tReadAll (getTransport p) 1
+ writeMessage p (n, t, s) f = do
+ tWrite (getTransport p) messageBegin
+ f
+ tFlush $ getTransport p
+ where
+ messageBegin = toLazyByteString $
+ B.word8 protocolID <>
+ B.word8 ((version .&. versionMask) .|.
+ (((fromIntegral $ fromEnum t) `shiftL`
+ typeShiftAmount) .&. typeMask)) <>
+ buildVarint (i32ToZigZag s) <>
+ buildCompactValue (TString $ encodeUtf8 n)
- writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $
- B.word8 protocolID <>
- B.word8 ((version .&. versionMask) .|.
- (((fromIntegral $ fromEnum t) `shiftL`
- typeShiftAmount) .&. typeMask)) <>
- buildVarint (i32ToZigZag s) <>
- buildCompactValue (TString $ encodeUtf8 n)
-
- readMessageBegin p = runParser p $ do
- pid <- fromIntegral <$> P.anyWord8
- when (pid /= protocolID) $ error "Bad Protocol ID"
- w <- fromIntegral <$> P.anyWord8
- let ver = w .&. versionMask
- when (ver /= version) $ error "Bad Protocol version"
- let typ = (w `shiftR` typeShiftAmount) .&. typeBits
- seqId <- parseVarint zigZagToI32
- TString name <- parseCompactValue T_STRING
- return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
+ readMessage p f = readMessageBegin >>= f
+ where
+ readMessageBegin = runParser p $ do
+ pid <- fromIntegral <$> P.anyWord8
+ when (pid /= protocolID) $ error "Bad Protocol ID"
+ w <- fromIntegral <$> P.anyWord8
+ let ver = w .&. versionMask
+ when (ver /= version) $ error "Bad Protocol version"
+ let typ = (w `shiftR` typeShiftAmount) .&. typeBits
+ seqId <- parseVarint zigZagToI32
+ TString name <- parseCompactValue T_STRING
+ return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId)
+ writeVal p = tWrite (getTransport p) . toLazyByteString . buildCompactValue
+ readVal p ty = runParser p $ parseCompactValue ty
+
+instance Transport t => StatelessProtocol (CompactProtocol t) where
serializeVal _ = toLazyByteString . buildCompactValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseCompactValue ty) bs of
Left s -> error s
Right val -> val
- readVal p ty = runParser p $ parseCompactValue ty
-
-
-- | Writing Functions
buildCompactValue :: ThriftVal -> Builder
buildCompactValue (TStruct fields) = buildCompactStruct fields
@@ -283,7 +293,7 @@
TSet{} -> 0x0A
TMap{} -> 0x0B
TStruct{} -> 0x0C
-
+
typeFrom :: Word8 -> ThriftType
typeFrom w = case w of
0x01 -> T_BOOL
diff --git a/lib/hs/src/Thrift/Protocol/Header.hs b/lib/hs/src/Thrift/Protocol/Header.hs
new file mode 100644
index 0000000..5f42db4
--- /dev/null
+++ b/lib/hs/src/Thrift/Protocol/Header.hs
@@ -0,0 +1,141 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing,
+-- software distributed under the License is distributed on an
+-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+-- KIND, either express or implied. See the License for the
+-- specific language governing permissions and limitations
+-- under the License.
+--
+
+
+module Thrift.Protocol.Header
+ ( module Thrift.Protocol
+ , HeaderProtocol(..)
+ , getProtocolType
+ , setProtocolType
+ , getHeaders
+ , getWriteHeaders
+ , setHeader
+ , setHeaders
+ , createHeaderProtocol
+ , createHeaderProtocol1
+ ) where
+
+import Thrift.Protocol
+import Thrift.Protocol.Binary
+import Thrift.Protocol.JSON
+import Thrift.Protocol.Compact
+import Thrift.Transport
+import Thrift.Transport.Header
+import Data.IORef
+import qualified Data.Map as Map
+
+data ProtocolWrap = forall a. (Protocol a) => ProtocolWrap(a)
+
+instance Protocol ProtocolWrap where
+ readByte (ProtocolWrap p) = readByte p
+ readVal (ProtocolWrap p) = readVal p
+ readMessage (ProtocolWrap p) = readMessage p
+ writeVal (ProtocolWrap p) = writeVal p
+ writeMessage (ProtocolWrap p) = writeMessage p
+
+data HeaderProtocol i o = (Transport i, Transport o) => HeaderProtocol {
+ trans :: HeaderTransport i o,
+ wrappedProto :: IORef ProtocolWrap
+ }
+
+createProtocolWrap :: Transport t => ProtocolType -> t -> ProtocolWrap
+createProtocolWrap typ t =
+ case typ of
+ TBinary -> ProtocolWrap $ BinaryProtocol t
+ TCompact -> ProtocolWrap $ CompactProtocol t
+ TJSON -> ProtocolWrap $ JSONProtocol t
+
+createHeaderProtocol :: (Transport i, Transport o) => i -> o -> IO(HeaderProtocol i o)
+createHeaderProtocol i o = do
+ t <- openHeaderTransport i o
+ pid <- readIORef $ protocolType t
+ proto <- newIORef $ createProtocolWrap pid t
+ return $ HeaderProtocol { trans = t, wrappedProto = proto }
+
+createHeaderProtocol1 :: Transport t => t -> IO(HeaderProtocol t t)
+createHeaderProtocol1 t = createHeaderProtocol t t
+
+resetProtocol :: (Transport i, Transport o) => HeaderProtocol i o -> IO ()
+resetProtocol p = do
+ pid <- readIORef $ protocolType $ trans p
+ writeIORef (wrappedProto p) $ createProtocolWrap pid $ trans p
+
+getWrapped = readIORef . wrappedProto
+
+setTransport :: (Transport i, Transport o) => HeaderProtocol i o -> HeaderTransport i o -> HeaderProtocol i o
+setTransport p t = p { trans = t }
+
+updateTransport :: (Transport i, Transport o) => HeaderProtocol i o -> (HeaderTransport i o -> HeaderTransport i o)-> HeaderProtocol i o
+updateTransport p f = setTransport p (f $ trans p)
+
+type Headers = Map.Map String String
+
+-- TODO: we want to set headers without recreating client...
+setHeader :: (Transport i, Transport o) => HeaderProtocol i o -> String -> String -> HeaderProtocol i o
+setHeader p k v = updateTransport p $ \t -> t { writeHeaders = Map.insert k v $ writeHeaders t }
+
+setHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers -> HeaderProtocol i o
+setHeaders p h = updateTransport p $ \t -> t { writeHeaders = h }
+
+-- TODO: make it public once we have first transform implementation for Haskell
+setTransforms :: (Transport i, Transport o) => HeaderProtocol i o -> [TransformType] -> HeaderProtocol i o
+setTransforms p trs = updateTransport p $ \t -> t { writeTransforms = trs }
+
+setTransform :: (Transport i, Transport o) => HeaderProtocol i o -> TransformType -> HeaderProtocol i o
+setTransform p tr = updateTransport p $ \t -> t { writeTransforms = tr:(writeTransforms t) }
+
+getWriteHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> Headers
+getWriteHeaders = writeHeaders . trans
+
+getHeaders :: (Transport i, Transport o) => HeaderProtocol i o -> IO [(String, String)]
+getHeaders = readIORef . headers . trans
+
+getProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> IO ProtocolType
+getProtocolType p = readIORef $ protocolType $ trans p
+
+setProtocolType :: (Transport i, Transport o) => HeaderProtocol i o -> ProtocolType -> IO ()
+setProtocolType p typ = do
+ typ0 <- getProtocolType p
+ if typ == typ0
+ then return ()
+ else do
+ tSetProtocol (trans p) typ
+ resetProtocol p
+
+instance (Transport i, Transport o) => Protocol (HeaderProtocol i o) where
+ readByte p = tReadAll (trans p) 1
+
+ readVal p tp = do
+ proto <- getWrapped p
+ readVal proto tp
+
+ readMessage p f = do
+ tResetProtocol (trans p)
+ resetProtocol p
+ proto <- getWrapped p
+ readMessage proto f
+
+ writeVal p v = do
+ proto <- getWrapped p
+ writeVal proto v
+
+ writeMessage p x f = do
+ proto <- getWrapped p
+ writeMessage proto x f
+
diff --git a/lib/hs/src/Thrift/Protocol/JSON.hs b/lib/hs/src/Thrift/Protocol/JSON.hs
index 7f619e8..839eddc 100644
--- a/lib/hs/src/Thrift/Protocol/JSON.hs
+++ b/lib/hs/src/Thrift/Protocol/JSON.hs
@@ -29,12 +29,12 @@
) where
import Control.Applicative
+import Control.Exception (bracket)
import Control.Monad
import Data.Attoparsec.ByteString as P
import Data.Attoparsec.ByteString.Char8 as PC
import Data.Attoparsec.ByteString.Lazy as LP
import Data.ByteString.Base64.Lazy as B64C
-import Data.ByteString.Base64 as B64
import Data.ByteString.Lazy.Builder as B
import Data.ByteString.Internal (c2w, w2c)
import Data.Functor
@@ -58,38 +58,48 @@
-- encoded as a JSON 'ByteString'
data JSONProtocol t = JSONProtocol t
-- ^ Construct a 'JSONProtocol' with a 'Transport'
+getTransport :: Transport t => JSONProtocol t -> t
+getTransport (JSONProtocol t) = t
-instance Protocol JSONProtocol where
- getTransport (JSONProtocol t) = t
+instance Transport t => Protocol (JSONProtocol t) where
+ readByte p = tReadAll (getTransport p) 1
- writeMessageBegin (JSONProtocol t) (s, ty, sq) = tWrite t $ toLazyByteString $
- B.char8 '[' <> buildShowable (1 :: Int32) <>
- B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <>
- B.char8 ',' <> buildShowable (fromEnum ty) <>
- B.char8 ',' <> buildShowable sq <>
- B.char8 ','
- writeMessageEnd (JSONProtocol t) = tWrite t "]"
- readMessageBegin p = runParser p $ skipSpace *> do
- _ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal)
- bs <- lexeme (PC.char8 ',') *> lexeme escapedString
- case decodeUtf8' bs of
- Left _ -> fail "readMessage: invalid text encoding"
- Right str -> do
- ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal))
- seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal)
- _ <- PC.char8 ','
- return (str, ty, seqNum)
- readMessageEnd p = void $ runParser p (PC.char8 ']')
+ writeMessage (JSONProtocol t) (s, ty, sq) = bracket readMessageBegin readMessageEnd . const
+ where
+ readMessageBegin = tWrite t $ toLazyByteString $
+ B.char8 '[' <> buildShowable (1 :: Int32) <>
+ B.string8 ",\"" <> escape (encodeUtf8 s) <> B.char8 '\"' <>
+ B.char8 ',' <> buildShowable (fromEnum ty) <>
+ B.char8 ',' <> buildShowable sq <>
+ B.char8 ','
+ readMessageEnd _ = do
+ tWrite t "]"
+ tFlush t
+ readMessage p = bracket readMessageBegin readMessageEnd
+ where
+ readMessageBegin = runParser p $ skipSpace *> do
+ _ver :: Int32 <- lexeme (PC.char8 '[') *> lexeme (signed decimal)
+ bs <- lexeme (PC.char8 ',') *> lexeme escapedString
+ case decodeUtf8' bs of
+ Left _ -> fail "readMessage: invalid text encoding"
+ Right str -> do
+ ty <- toEnum <$> (lexeme (PC.char8 ',') *> lexeme (signed decimal))
+ seqNum <- lexeme (PC.char8 ',') *> lexeme (signed decimal)
+ _ <- PC.char8 ','
+ return (str, ty, seqNum)
+ readMessageEnd _ = void $ runParser p (PC.char8 ']')
+
+ writeVal p = tWrite (getTransport p) . toLazyByteString . buildJSONValue
+ readVal p ty = runParser p $ skipSpace *> parseJSONValue ty
+
+instance Transport t => StatelessProtocol (JSONProtocol t) where
serializeVal _ = toLazyByteString . buildJSONValue
deserializeVal _ ty bs =
case LP.eitherResult $ LP.parse (parseJSONValue ty) bs of
Left s -> error s
Right val -> val
- readVal p ty = runParser p $ skipSpace *> parseJSONValue ty
-
-
-- Writing Functions
buildJSONValue :: ThriftVal -> Builder
diff --git a/lib/hs/src/Thrift/Server.hs b/lib/hs/src/Thrift/Server.hs
index ed74ceb..543f338 100644
--- a/lib/hs/src/Thrift/Server.hs
+++ b/lib/hs/src/Thrift/Server.hs
@@ -38,10 +38,10 @@
-- | A threaded sever that is capable of using any Transport or Protocol
-- instances.
-runThreadedServer :: (Transport t, Protocol i, Protocol o)
- => (Socket -> IO (i t, o t))
+runThreadedServer :: (Protocol i, Protocol o)
+ => (Socket -> IO (i, o))
-> h
- -> (h -> (i t, o t) -> IO Bool)
+ -> (h -> (i, o) -> IO Bool)
-> PortID
-> IO a
runThreadedServer accepter hand proc_ port = do
diff --git a/lib/hs/src/Thrift/Transport/Handle.hs b/lib/hs/src/Thrift/Transport/Handle.hs
index b7d16e4..ff6295b 100644
--- a/lib/hs/src/Thrift/Transport/Handle.hs
+++ b/lib/hs/src/Thrift/Transport/Handle.hs
@@ -44,7 +44,13 @@
instance Transport Handle where
tIsOpen = hIsOpen
tClose = hClose
- tRead h n = LBS.hGet h n `Control.Exception.catch` handleEOF mempty
+ tRead h n = read `Control.Exception.catch` handleEOF mempty
+ where
+ read = do
+ hLookAhead h
+ LBS.hGetNonBlocking h n
+ tReadAll _ 0 = return mempty
+ tReadAll h n = LBS.hGet h n `Control.Exception.catch` throwTransportExn
tPeek h = (Just . c2w <$> hLookAhead h) `Control.Exception.catch` handleEOF Nothing
tWrite = LBS.hPut
tFlush = hFlush
@@ -61,8 +67,12 @@
instance HandleSource (HostName, PortID) where
hOpen = uncurry connectTo
+throwTransportExn :: IOError -> IO a
+throwTransportExn e = if isEOFError e
+ then throw $ TransportExn "Cannot read. Remote side has closed." TE_UNKNOWN
+ else throw $ TransportExn "Handle tReadAll: Could not read" TE_UNKNOWN
handleEOF :: a -> IOError -> IO a
handleEOF a e = if isEOFError e
then return a
- else throw $ TransportExn "TChannelTransport: Could not read" TE_UNKNOWN
+ else throw $ TransportExn "Handle: Could not read" TE_UNKNOWN
diff --git a/lib/hs/src/Thrift/Transport/Header.hs b/lib/hs/src/Thrift/Transport/Header.hs
new file mode 100644
index 0000000..2dacad2
--- /dev/null
+++ b/lib/hs/src/Thrift/Transport/Header.hs
@@ -0,0 +1,354 @@
+--
+-- Licensed to the Apache Software Foundation (ASF) under one
+-- or more contributor license agreements. See the NOTICE file
+-- distributed with this work for additional information
+-- regarding copyright ownership. The ASF licenses this file
+-- to you under the Apache License, Version 2.0 (the
+-- "License"); you may not use this file except in compliance
+-- with the License. You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing,
+-- software distributed under the License is distributed on an
+-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+-- KIND, either express or implied. See the License for the
+-- specific language governing permissions and limitations
+-- under the License.
+--
+
+module Thrift.Transport.Header
+ ( module Thrift.Transport
+ , HeaderTransport(..)
+ , openHeaderTransport
+ , ProtocolType(..)
+ , TransformType(..)
+ , ClientType(..)
+ , tResetProtocol
+ , tSetProtocol
+ ) where
+
+import Thrift.Transport
+import Thrift.Protocol.Compact
+import Control.Applicative
+import Control.Exception ( throw )
+import Control.Monad
+import Data.Bits
+import Data.IORef
+import Data.Int
+import Data.Monoid
+import Data.Word
+
+import qualified Data.Attoparsec.ByteString as P
+import qualified Data.Binary as Binary
+import qualified Data.ByteString as BS
+import qualified Data.ByteString.Char8 as C
+import qualified Data.ByteString.Lazy as LBS
+import qualified Data.ByteString.Lazy.Builder as B
+import qualified Data.Map as Map
+
+data ProtocolType = TBinary | TCompact | TJSON deriving (Enum, Eq)
+data ClientType = HeaderClient | Framed | Unframed deriving (Enum, Eq)
+
+infoIdKeyValue = 1
+
+type Headers = Map.Map String String
+
+data TransformType = ZlibTransform deriving (Enum, Eq)
+
+fromTransportType :: TransformType -> Int16
+fromTransportType ZlibTransform = 1
+
+toTransportType :: Int16 -> TransformType
+toTransportType 1 = ZlibTransform
+toTransportType _ = throw $ TransportExn "HeaderTransport: Unknown transform ID" TE_UNKNOWN
+
+data HeaderTransport i o = (Transport i, Transport o) => HeaderTransport
+ { readBuffer :: IORef LBS.ByteString
+ , writeBuffer :: IORef B.Builder
+ , inTrans :: i
+ , outTrans :: o
+ , clientType :: IORef ClientType
+ , protocolType :: IORef ProtocolType
+ , headers :: IORef [(String, String)]
+ , writeHeaders :: Headers
+ , transforms :: IORef [TransformType]
+ , writeTransforms :: [TransformType]
+ }
+
+openHeaderTransport :: (Transport i, Transport o) => i -> o -> IO (HeaderTransport i o)
+openHeaderTransport i o = do
+ pid <- newIORef TCompact
+ rBuf <- newIORef LBS.empty
+ wBuf <- newIORef mempty
+ cType <- newIORef HeaderClient
+ h <- newIORef []
+ trans <- newIORef []
+ return HeaderTransport
+ { readBuffer = rBuf
+ , writeBuffer = wBuf
+ , inTrans = i
+ , outTrans = o
+ , clientType = cType
+ , protocolType = pid
+ , headers = h
+ , writeHeaders = Map.empty
+ , transforms = trans
+ , writeTransforms = []
+ }
+
+isFramed t = (/= Unframed) <$> readIORef (clientType t)
+
+readFrame :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool
+readFrame t = do
+ let input = inTrans t
+ let rBuf = readBuffer t
+ let cType = clientType t
+ lsz <- tRead input 4
+ let sz = LBS.toStrict lsz
+ case P.parseOnly P.endOfInput sz of
+ Right _ -> do return False
+ Left _ -> do
+ case parseBinaryMagic sz of
+ Right _ -> do
+ writeIORef rBuf $ lsz
+ writeIORef cType Unframed
+ writeIORef (protocolType t) TBinary
+ return True
+ Left _ -> do
+ case parseCompactMagic sz of
+ Right _ -> do
+ writeIORef rBuf $ lsz
+ writeIORef cType Unframed
+ writeIORef (protocolType t) TCompact
+ return True
+ Left _ -> do
+ let len = Binary.decode lsz :: Int32
+ lbuf <- tReadAll input $ fromIntegral len
+ let buf = LBS.toStrict lbuf
+ case parseBinaryMagic buf of
+ Right _ -> do
+ writeIORef cType Framed
+ writeIORef (protocolType t) TBinary
+ writeIORef rBuf lbuf
+ return True
+ Left _ -> do
+ case parseCompactMagic buf of
+ Right _ -> do
+ writeIORef cType Framed
+ writeIORef (protocolType t) TCompact
+ writeIORef rBuf lbuf
+ return True
+ Left _ -> do
+ case parseHeaderMagic buf of
+ Right flags -> do
+ let (flags, seqNum, header, body) = extractHeader buf
+ writeIORef cType HeaderClient
+ handleHeader t header
+ payload <- untransform t body
+ writeIORef rBuf $ LBS.fromStrict $ payload
+ return True
+ Left _ ->
+ throw $ TransportExn "HeaderTransport: unkonwn client type" TE_UNKNOWN
+
+parseBinaryMagic = P.parseOnly $ P.word8 0x80 *> P.word8 0x01 *> P.word8 0x00 *> P.anyWord8
+parseCompactMagic = P.parseOnly $ P.word8 0x82 *> P.satisfy (\b -> b .&. 0x1f == 0x01)
+parseHeaderMagic = P.parseOnly $ P.word8 0x0f *> P.word8 0xff *> (P.count 2 P.anyWord8)
+
+parseI32 :: P.Parser Int32
+parseI32 = Binary.decode . LBS.fromStrict <$> P.take 4
+parseI16 :: P.Parser Int16
+parseI16 = Binary.decode . LBS.fromStrict <$> P.take 2
+
+extractHeader :: BS.ByteString -> (Int16, Int32, BS.ByteString, BS.ByteString)
+extractHeader bs =
+ case P.parse extractHeader_ bs of
+ P.Done remain (flags, seqNum, header) -> (flags, seqNum, header, remain)
+ _ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN
+ where
+ extractHeader_ = do
+ magic <- P.word8 0x0f *> P.word8 0xff
+ flags <- parseI16
+ seqNum <- parseI32
+ (headerSize :: Int) <- (* 4) . fromIntegral <$> parseI16
+ header <- P.take headerSize
+ return (flags, seqNum, header)
+
+handleHeader t header =
+ case P.parseOnly parseHeader header of
+ Right (pType, trans, info) -> do
+ writeIORef (protocolType t) pType
+ writeIORef (transforms t) trans
+ writeIORef (headers t) info
+ _ -> throw $ TransportExn "HeaderTransport: Invalid header" TE_UNKNOWN
+
+
+iw16 :: Int16 -> Word16
+iw16 = fromIntegral
+iw32 :: Int32 -> Word32
+iw32 = fromIntegral
+wi16 :: Word16 -> Int16
+wi16 = fromIntegral
+wi32 :: Word32 -> Int32
+wi32 = fromIntegral
+
+parseHeader :: P.Parser (ProtocolType, [TransformType], [(String, String)])
+parseHeader = do
+ protocolType <- toProtocolType <$> parseVarint wi16
+ numTrans <- fromIntegral <$> parseVarint wi16
+ trans <- replicateM numTrans parseTransform
+ info <- parseInfo
+ return (protocolType, trans, info)
+
+toProtocolType :: Int16 -> ProtocolType
+toProtocolType 0 = TBinary
+toProtocolType 1 = TJSON
+toProtocolType 2 = TCompact
+
+fromProtocolType :: ProtocolType -> Int16
+fromProtocolType TBinary = 0
+fromProtocolType TJSON = 1
+fromProtocolType TCompact = 2
+
+parseTransform :: P.Parser TransformType
+parseTransform = toTransportType <$> parseVarint wi16
+
+parseInfo :: P.Parser [(String, String)]
+parseInfo = do
+ n <- P.eitherP P.endOfInput (parseVarint wi32)
+ case n of
+ Left _ -> return []
+ Right n0 ->
+ replicateM (fromIntegral n0) $ do
+ klen <- parseVarint wi16
+ k <- P.take $ fromIntegral klen
+ vlen <- parseVarint wi16
+ v <- P.take $ fromIntegral vlen
+ return (C.unpack k, C.unpack v)
+
+parseString :: P.Parser BS.ByteString
+parseString = parseVarint wi32 >>= (P.take . fromIntegral)
+
+buildHeader :: HeaderTransport i o -> IO B.Builder
+buildHeader t = do
+ pType <- readIORef $ protocolType t
+ let pId = buildVarint $ iw16 $ fromProtocolType pType
+ let headerContent = pId <> (buildTransforms t) <> (buildInfo t)
+ let len = fromIntegral $ LBS.length $ B.toLazyByteString headerContent
+ -- TODO: length limit check
+ let padding = mconcat $ replicate (mod len 4) $ B.word8 0
+ let codedLen = B.int16BE (fromIntegral $ (quot (len - 1) 4) + 1)
+ let flags = 0
+ let seqNum = 0
+ return $ B.int16BE 0x0fff <> B.int16BE flags <> B.int32BE seqNum <> codedLen <> headerContent <> padding
+
+buildTransforms :: HeaderTransport i o -> B.Builder
+-- TODO: check length limit
+buildTransforms t =
+ let trans = writeTransforms t in
+ (buildVarint $ iw16 $ fromIntegral $ length trans) <>
+ (mconcat $ map (buildVarint . iw16 . fromTransportType) trans)
+
+buildInfo :: HeaderTransport i o -> B.Builder
+buildInfo t =
+ let h = Map.assocs $ writeHeaders t in
+ -- TODO: check length limit
+ case length h of
+ 0 -> mempty
+ len -> (buildVarint $ iw16 $ fromIntegral $ len) <> (mconcat $ map buildInfoEntry h)
+ where
+ buildInfoEntry (k, v) = buildVarStr k <> buildVarStr v
+ -- TODO: check length limit
+ buildVarStr s = (buildVarint $ iw16 $ fromIntegral $ length s) <> B.string8 s
+
+tResetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> IO Bool
+tResetProtocol t = do
+ rBuf <- readIORef $ readBuffer t
+ writeIORef (clientType t) HeaderClient
+ readFrame t
+
+tSetProtocol :: (Transport i, Transport o) => HeaderTransport i o -> ProtocolType -> IO ()
+tSetProtocol t = writeIORef (protocolType t)
+
+transform :: HeaderTransport i o -> LBS.ByteString -> LBS.ByteString
+transform t bs =
+ foldr applyTransform bs $ writeTransforms t
+ where
+ -- applyTransform bs ZlibTransform =
+ -- throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform " TE_UNKNOWN
+ applyTransform bs _ =
+ throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN
+
+untransform :: HeaderTransport i o -> BS.ByteString -> IO BS.ByteString
+untransform t bs = do
+ trans <- readIORef $ transforms t
+ return $ foldl unapplyTransform bs trans
+ where
+ -- unapplyTransform bs ZlibTransform =
+ -- throw $ TransportExn "HeaderTransport: not implemented: ZlibTransform " TE_UNKNOWN
+ unapplyTransform bs _ =
+ throw $ TransportExn "HeaderTransport: Unknown transform" TE_UNKNOWN
+
+instance (Transport i, Transport o) => Transport (HeaderTransport i o) where
+ tIsOpen t = do
+ tIsOpen (inTrans t)
+ tIsOpen (outTrans t)
+
+ tClose t = do
+ tClose(outTrans t)
+ tClose(inTrans t)
+
+ tRead t len = do
+ rBuf <- readIORef $ readBuffer t
+ if not $ LBS.null rBuf
+ then do
+ let (consumed, remain) = LBS.splitAt (fromIntegral len) rBuf
+ writeIORef (readBuffer t) remain
+ return consumed
+ else do
+ framed <- isFramed t
+ if not framed
+ then tRead (inTrans t) len
+ else do
+ ok <- readFrame t
+ if ok
+ then tRead t len
+ else return LBS.empty
+
+ tPeek t = do
+ rBuf <- readIORef (readBuffer t)
+ if not $ LBS.null rBuf
+ then return $ Just $ LBS.head rBuf
+ else do
+ framed <- isFramed t
+ if not framed
+ then tPeek (inTrans t)
+ else do
+ ok <- readFrame t
+ if ok
+ then tPeek t
+ else return Nothing
+
+ tWrite t buf = do
+ let wBuf = writeBuffer t
+ framed <- isFramed t
+ if framed
+ then modifyIORef wBuf (<> B.lazyByteString buf)
+ else
+ -- TODO: what should we do when switched to unframed in the middle ?
+ tWrite(outTrans t) buf
+
+ tFlush t = do
+ cType <- readIORef $ clientType t
+ case cType of
+ Unframed -> tFlush $ outTrans t
+ Framed -> flushBuffer t id mempty
+ HeaderClient -> buildHeader t >>= flushBuffer t (transform t)
+ where
+ flushBuffer t f header = do
+ wBuf <- readIORef $ writeBuffer t
+ writeIORef (writeBuffer t) mempty
+ let payload = B.toLazyByteString (header <> wBuf)
+ tWrite (outTrans t) $ Binary.encode (fromIntegral $ LBS.length payload :: Int32)
+ tWrite (outTrans t) $ f payload
+ tFlush (outTrans t)
diff --git a/lib/hs/thrift.cabal b/lib/hs/thrift.cabal
index fb33d9a..4e9cb18 100644
--- a/lib/hs/thrift.cabal
+++ b/lib/hs/thrift.cabal
@@ -49,6 +49,7 @@
Thrift,
Thrift.Arbitraries
Thrift.Protocol,
+ Thrift.Protocol.Header,
Thrift.Protocol.Binary,
Thrift.Protocol.Compact,
Thrift.Protocol.JSON,
@@ -57,6 +58,7 @@
Thrift.Transport.Empty,
Thrift.Transport.Framed,
Thrift.Transport.Handle,
+ Thrift.Transport.Header,
Thrift.Transport.HttpClient,
Thrift.Transport.IOBuffer,
Thrift.Transport.Memory,