Noam Zilberstein | af5d64a | 2014-07-31 15:44:13 -0700 | [diff] [blame] | 1 | -- |
| 2 | -- Licensed to the Apache Software Foundation (ASF) under one |
| 3 | -- or more contributor license agreements. See the NOTICE file |
| 4 | -- distributed with this work for additional information |
| 5 | -- regarding copyright ownership. The ASF licenses this file |
| 6 | -- to you under the Apache License, Version 2.0 (the |
| 7 | -- "License"); you may not use this file except in compliance |
| 8 | -- with the License. You may obtain a copy of the License at |
| 9 | -- |
| 10 | -- http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | -- |
| 12 | -- Unless required by applicable law or agreed to in writing, |
| 13 | -- software distributed under the License is distributed on an |
| 14 | -- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | -- KIND, either express or implied. See the License for the |
| 16 | -- specific language governing permissions and limitations |
| 17 | -- under the License. |
| 18 | -- |
| 19 | |
| 20 | {-# LANGUAGE CPP #-} |
| 21 | {-# LANGUAGE ExistentialQuantification #-} |
| 22 | {-# LANGUAGE OverloadedStrings #-} |
| 23 | {-# LANGUAGE ScopedTypeVariables #-} |
| 24 | |
| 25 | module Thrift.Protocol.Compact |
| 26 | ( module Thrift.Protocol |
| 27 | , CompactProtocol(..) |
| 28 | ) where |
| 29 | |
| 30 | import Control.Applicative |
| 31 | import Control.Exception ( throw ) |
| 32 | import Control.Monad |
| 33 | import Data.Attoparsec.ByteString as P |
| 34 | import Data.Attoparsec.ByteString.Lazy as LP |
| 35 | import Data.Bits |
| 36 | import Data.ByteString.Lazy.Builder as B |
| 37 | import Data.Int |
| 38 | import Data.List as List |
| 39 | import Data.Monoid |
| 40 | import Data.Word |
| 41 | import Data.Text.Lazy.Encoding ( decodeUtf8, encodeUtf8 ) |
| 42 | |
| 43 | import Thrift.Protocol hiding (versionMask) |
| 44 | import Thrift.Transport |
| 45 | import Thrift.Types |
| 46 | |
| 47 | import qualified Data.ByteString as BS |
| 48 | import qualified Data.ByteString.Lazy as LBS |
| 49 | import qualified Data.HashMap.Strict as Map |
| 50 | import qualified Data.Text.Lazy as LT |
| 51 | |
| 52 | -- | the Compact Protocol implements the standard Thrift 'TCompactProcotol' |
| 53 | -- which is similar to the 'TBinaryProtocol', but takes less space on the wire. |
| 54 | -- Integral types are encoded using as varints. |
| 55 | data CompactProtocol a = CompactProtocol a |
| 56 | -- ^ Constuct a 'CompactProtocol' with a 'Transport' |
| 57 | |
| 58 | protocolID, version, typeMask :: Int8 |
| 59 | protocolID = 0x82 -- 1000 0010 |
| 60 | version = 0x01 |
| 61 | versionMask = 0x1f -- 0001 1111 |
| 62 | typeMask = 0xe0 -- 1110 0000 |
| 63 | typeShiftAmount :: Int |
| 64 | typeShiftAmount = 5 |
| 65 | |
| 66 | |
| 67 | instance Protocol CompactProtocol where |
| 68 | getTransport (CompactProtocol t) = t |
| 69 | |
| 70 | writeMessageBegin p (n, t, s) = tWrite (getTransport p) $ toLazyByteString $ |
| 71 | B.int8 protocolID <> |
| 72 | B.int8 ((version .&. versionMask) .|. |
| 73 | (((fromIntegral $ fromEnum t) `shiftL` |
| 74 | typeShiftAmount) .&. typeMask)) <> |
| 75 | buildVarint (i32ToZigZag s) <> |
| 76 | buildCompactValue (TString $ encodeUtf8 n) |
| 77 | |
| 78 | readMessageBegin p = runParser p $ do |
| 79 | pid <- fromIntegral <$> P.anyWord8 |
| 80 | when (pid /= protocolID) $ error "Bad Protocol ID" |
| 81 | w <- fromIntegral <$> P.anyWord8 |
| 82 | let ver = w .&. versionMask |
| 83 | when (ver /= version) $ error "Bad Protocol version" |
| 84 | let typ = (w `shiftR` typeShiftAmount) .&. 0x03 |
| 85 | seqId <- parseVarint zigZagToI32 |
| 86 | TString name <- parseCompactValue T_STRING |
| 87 | return (decodeUtf8 name, toEnum $ fromIntegral $ typ, seqId) |
| 88 | |
| 89 | serializeVal _ = toLazyByteString . buildCompactValue |
| 90 | deserializeVal _ ty bs = |
| 91 | case LP.eitherResult $ LP.parse (parseCompactValue ty) bs of |
| 92 | Left s -> error s |
| 93 | Right val -> val |
| 94 | |
| 95 | readVal p ty = runParser p $ parseCompactValue ty |
| 96 | |
| 97 | |
| 98 | -- | Writing Functions |
| 99 | buildCompactValue :: ThriftVal -> Builder |
| 100 | buildCompactValue (TStruct fields) = buildCompactStruct fields |
| 101 | buildCompactValue (TMap kt vt entries) = |
| 102 | let len = fromIntegral $ length entries :: Word32 in |
| 103 | if len == 0 |
| 104 | then B.word8 0x00 |
| 105 | else buildVarint len <> |
| 106 | B.word8 (fromTType kt `shiftL` 4 .|. fromTType vt) <> |
| 107 | buildCompactMap entries |
| 108 | buildCompactValue (TList ty entries) = |
| 109 | let len = length entries in |
| 110 | (if len < 15 |
| 111 | then B.word8 $ (fromIntegral len `shiftL` 4) .|. fromTType ty |
| 112 | else B.word8 (0xF0 .|. fromTType ty) <> |
| 113 | buildVarint (fromIntegral len :: Word32)) <> |
| 114 | buildCompactList entries |
| 115 | buildCompactValue (TSet ty entries) = buildCompactValue (TList ty entries) |
| 116 | buildCompactValue (TBool b) = |
| 117 | B.word8 $ toEnum $ if b then 1 else 0 |
| 118 | buildCompactValue (TByte b) = int8 b |
| 119 | buildCompactValue (TI16 i) = buildVarint $ i16ToZigZag i |
| 120 | buildCompactValue (TI32 i) = buildVarint $ i32ToZigZag i |
| 121 | buildCompactValue (TI64 i) = buildVarint $ i64ToZigZag i |
| 122 | buildCompactValue (TDouble d) = doubleBE d |
| 123 | buildCompactValue (TString s) = buildVarint len <> lazyByteString s |
| 124 | where |
| 125 | len = fromIntegral (LBS.length s) :: Word32 |
| 126 | |
| 127 | buildCompactStruct :: Map.HashMap Int16 (LT.Text, ThriftVal) -> Builder |
| 128 | buildCompactStruct = flip (loop 0) mempty . Map.toList |
| 129 | where |
| 130 | loop _ [] acc = acc <> B.word8 (fromTType T_STOP) |
| 131 | loop lastId ((fid, (_,val)) : fields) acc = loop fid fields $ acc <> |
| 132 | (if fid > lastId && fid - lastId <= 15 |
| 133 | then B.word8 $ fromIntegral ((fid - lastId) `shiftL` 4) .|. typeOf val |
| 134 | else B.word8 (typeOf val) <> buildVarint (i16ToZigZag fid)) <> |
| 135 | (if typeOf val > 0x02 -- Not a T_BOOL |
| 136 | then buildCompactValue val |
| 137 | else mempty) -- T_BOOLs are encoded in the type |
| 138 | buildCompactMap :: [(ThriftVal, ThriftVal)] -> Builder |
| 139 | buildCompactMap = foldl combine mempty |
| 140 | where |
| 141 | combine s (key, val) = buildCompactValue key <> buildCompactValue val <> s |
| 142 | |
| 143 | buildCompactList :: [ThriftVal] -> Builder |
| 144 | buildCompactList = foldr (mappend . buildCompactValue) mempty |
| 145 | |
| 146 | -- | Reading Functions |
| 147 | parseCompactValue :: ThriftType -> Parser ThriftVal |
| 148 | parseCompactValue (T_STRUCT _) = TStruct <$> parseCompactStruct |
| 149 | parseCompactValue (T_MAP kt' vt') = do |
| 150 | n <- parseVarint id |
| 151 | if n == 0 |
| 152 | then return $ TMap kt' vt' [] |
| 153 | else do |
| 154 | w <- P.anyWord8 |
| 155 | let kt = typeFrom $ w `shiftR` 4 |
| 156 | vt = typeFrom $ w .&. 0x0F |
| 157 | TMap kt vt <$> parseCompactMap kt vt n |
| 158 | parseCompactValue (T_LIST ty) = TList ty <$> parseCompactList |
| 159 | parseCompactValue (T_SET ty) = TSet ty <$> parseCompactList |
| 160 | parseCompactValue T_BOOL = TBool . (/=0) <$> P.anyWord8 |
| 161 | parseCompactValue T_BYTE = TByte . fromIntegral <$> P.anyWord8 |
| 162 | parseCompactValue T_I16 = TI16 <$> parseVarint zigZagToI16 |
| 163 | parseCompactValue T_I32 = TI32 <$> parseVarint zigZagToI32 |
| 164 | parseCompactValue T_I64 = TI64 <$> parseVarint zigZagToI64 |
| 165 | parseCompactValue T_DOUBLE = TDouble . bsToDouble <$> P.take 8 |
| 166 | parseCompactValue T_STRING = do |
| 167 | len :: Word32 <- parseVarint id |
| 168 | TString . LBS.fromStrict <$> P.take (fromIntegral len) |
| 169 | parseCompactValue ty = error $ "Cannot read value of type " ++ show ty |
| 170 | |
| 171 | parseCompactStruct :: Parser (Map.HashMap Int16 (LT.Text, ThriftVal)) |
| 172 | parseCompactStruct = Map.fromList <$> parseFields 0 |
| 173 | where |
| 174 | parseFields :: Int16 -> Parser [(Int16, (LT.Text, ThriftVal))] |
| 175 | parseFields lastId = do |
| 176 | w <- P.anyWord8 |
| 177 | if w == 0x00 |
| 178 | then return [] |
| 179 | else do |
| 180 | let ty = typeFrom (w .&. 0x0F) |
| 181 | modifier = (w .&. 0xF0) `shiftR` 4 |
| 182 | fid <- if modifier /= 0 |
| 183 | then return (lastId + fromIntegral modifier) |
| 184 | else parseVarint zigZagToI16 |
| 185 | val <- if ty == T_BOOL |
| 186 | then return (TBool $ (w .&. 0x0F) == 0x01) |
| 187 | else parseCompactValue ty |
| 188 | ((fid, (LT.empty, val)) : ) <$> parseFields fid |
| 189 | |
| 190 | parseCompactMap :: ThriftType -> ThriftType -> Int32 -> |
| 191 | Parser [(ThriftVal, ThriftVal)] |
| 192 | parseCompactMap kt vt n | n <= 0 = return [] |
| 193 | | otherwise = do |
| 194 | k <- parseCompactValue kt |
| 195 | v <- parseCompactValue vt |
| 196 | ((k,v) :) <$> parseCompactMap kt vt (n-1) |
| 197 | |
| 198 | parseCompactList :: Parser [ThriftVal] |
| 199 | parseCompactList = do |
| 200 | w <- P.anyWord8 |
| 201 | let ty = typeFrom $ w .&. 0x0F |
| 202 | lsize = w `shiftR` 4 |
| 203 | size <- if lsize == 0xF |
| 204 | then parseVarint id |
| 205 | else return $ fromIntegral lsize |
| 206 | loop ty size |
| 207 | where |
| 208 | loop :: ThriftType -> Int32 -> Parser [ThriftVal] |
| 209 | loop ty n | n <= 0 = return [] |
| 210 | | otherwise = liftM2 (:) (parseCompactValue ty) |
| 211 | (loop ty (n-1)) |
| 212 | |
| 213 | -- Signed numbers must be converted to "Zig Zag" format before they can be |
| 214 | -- serialized in the Varint format |
| 215 | i16ToZigZag :: Int16 -> Word16 |
| 216 | i16ToZigZag n = fromIntegral $ (n `shiftL` 1) `xor` (n `shiftR` 15) |
| 217 | |
| 218 | zigZagToI16 :: Word16 -> Int16 |
| 219 | zigZagToI16 n = fromIntegral $ (n `shiftR` 1) `xor` negate (n .&. 0x1) |
| 220 | |
| 221 | i32ToZigZag :: Int32 -> Word32 |
| 222 | i32ToZigZag n = fromIntegral $ (n `shiftL` 1) `xor` (n `shiftR` 31) |
| 223 | |
| 224 | zigZagToI32 :: Word32 -> Int32 |
| 225 | zigZagToI32 n = fromIntegral $ (n `shiftR` 1) `xor` negate (n .&. 0x1) |
| 226 | |
| 227 | i64ToZigZag :: Int64 -> Word64 |
| 228 | i64ToZigZag n = fromIntegral $ (n `shiftL` 1) `xor` (n `shiftR` 63) |
| 229 | |
| 230 | zigZagToI64 :: Word64 -> Int64 |
| 231 | zigZagToI64 n = fromIntegral $ (n `shiftR` 1) `xor` negate (n .&. 0x1) |
| 232 | |
| 233 | buildVarint :: (Bits a, Integral a) => a -> Builder |
| 234 | buildVarint n | n .&. complement 0x7F == 0 = B.word8 $ fromIntegral n |
| 235 | | otherwise = B.word8 (0x80 .|. (fromIntegral n .&. 0x7F)) <> |
| 236 | buildVarint (n `shiftR` 7) |
| 237 | |
| 238 | parseVarint :: (Bits a, Integral a, Ord a) => (a -> b) -> Parser b |
| 239 | parseVarint fromZigZag = do |
| 240 | bytestemp <- BS.unpack <$> P.takeTill (not . flip testBit 7) |
| 241 | lsb <- P.anyWord8 |
| 242 | let bytes = lsb : List.reverse bytestemp |
| 243 | return $ fromZigZag $ List.foldl' combine 0x00 bytes |
| 244 | where combine a b = (a `shiftL` 7) .|. (fromIntegral b .&. 0x7f) |
| 245 | |
| 246 | -- | Compute the Compact Type |
| 247 | fromTType :: ThriftType -> Word8 |
| 248 | fromTType ty = case ty of |
| 249 | T_STOP -> 0x00 |
| 250 | T_BOOL -> 0x01 |
| 251 | T_BYTE -> 0x03 |
| 252 | T_I16 -> 0x04 |
| 253 | T_I32 -> 0x05 |
| 254 | T_I64 -> 0x06 |
| 255 | T_DOUBLE -> 0x07 |
| 256 | T_STRING -> 0x08 |
| 257 | T_LIST{} -> 0x09 |
| 258 | T_SET{} -> 0x0A |
| 259 | T_MAP{} -> 0x0B |
| 260 | T_STRUCT{} -> 0x0C |
| 261 | T_VOID -> error "No Compact type for T_VOID" |
| 262 | |
| 263 | typeOf :: ThriftVal -> Word8 |
| 264 | typeOf v = case v of |
| 265 | TBool True -> 0x01 |
| 266 | TBool False -> 0x02 |
| 267 | TByte _ -> 0x03 |
| 268 | TI16 _ -> 0x04 |
| 269 | TI32 _ -> 0x05 |
| 270 | TI64 _ -> 0x06 |
| 271 | TDouble _ -> 0x07 |
| 272 | TString _ -> 0x08 |
| 273 | TList{} -> 0x09 |
| 274 | TSet{} -> 0x0A |
| 275 | TMap{} -> 0x0B |
| 276 | TStruct{} -> 0x0C |
| 277 | |
| 278 | typeFrom :: Word8 -> ThriftType |
| 279 | typeFrom w = case w of |
| 280 | 0x01 -> T_BOOL |
| 281 | 0x02 -> T_BOOL |
| 282 | 0x03 -> T_BYTE |
| 283 | 0x04 -> T_I16 |
| 284 | 0x05 -> T_I32 |
| 285 | 0x06 -> T_I64 |
| 286 | 0x07 -> T_DOUBLE |
| 287 | 0x08 -> T_STRING |
| 288 | 0x09 -> T_LIST T_VOID |
| 289 | 0x0A -> T_SET T_VOID |
| 290 | 0x0B -> T_MAP T_VOID T_VOID |
| 291 | 0x0C -> T_STRUCT Map.empty |
| 292 | n -> error $ "typeFrom: " ++ show n ++ " is not a compact type" |