blob: b27a7499539e2fbcbbbe6be90387cad2aae409b2 [file] [log] [blame]
Neil Williams66a44c52018-08-13 16:12:24 -07001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20from thrift.protocol.TBinaryProtocol import TBinaryProtocolAccelerated
21from thrift.protocol.TCompactProtocol import TCompactProtocolAccelerated
22from thrift.protocol.TProtocol import TProtocolBase, TProtocolException
23from thrift.Thrift import TApplicationException, TMessageType
24from thrift.transport.THeaderTransport import THeaderTransport, THeaderSubprotocolID, THeaderClientType
25
26
27PROTOCOLS_BY_ID = {
28 THeaderSubprotocolID.BINARY: TBinaryProtocolAccelerated,
29 THeaderSubprotocolID.COMPACT: TCompactProtocolAccelerated,
30}
31
32
33class THeaderProtocol(TProtocolBase):
34 """A framed protocol with headers and payload transforms.
35
36 THeaderProtocol frames other Thrift protocols and adds support for optional
37 out-of-band headers. The currently supported subprotocols are
38 TBinaryProtocol and TCompactProtocol.
39
40 It's also possible to apply transforms to the encoded message payload. The
41 only transform currently supported is to gzip.
42
43 When used in a server, THeaderProtocol can accept messages from
44 non-THeaderProtocol clients if allowed (see `allowed_client_types`). This
45 includes framed and unframed transports and both TBinaryProtocol and
46 TCompactProtocol. The server will respond in the appropriate dialect for
47 the connected client. HTTP clients are not currently supported.
48
49 THeaderProtocol does not currently support THTTPServer, TNonblockingServer,
50 or TProcessPoolServer.
51
52 See doc/specs/HeaderFormat.md for details of the wire format.
53
54 """
55
56 def __init__(self, transport, allowed_client_types):
57 # much of the actual work for THeaderProtocol happens down in
58 # THeaderTransport since we need to do low-level shenanigans to detect
59 # if the client is sending us headers or one of the headerless formats
60 # we support. this wraps the real transport with the one that does all
61 # the magic.
62 if not isinstance(transport, THeaderTransport):
63 transport = THeaderTransport(transport, allowed_client_types)
64 super(THeaderProtocol, self).__init__(transport)
65 self._set_protocol()
66
67 def get_headers(self):
68 return self.trans.get_headers()
69
70 def set_header(self, key, value):
71 self.trans.set_header(key, value)
72
73 def clear_headers(self):
74 self.trans.clear_headers()
75
76 def add_transform(self, transform_id):
77 self.trans.add_transform(transform_id)
78
79 def writeMessageBegin(self, name, ttype, seqid):
80 self.trans.sequence_id = seqid
81 return self._protocol.writeMessageBegin(name, ttype, seqid)
82
83 def writeMessageEnd(self):
84 return self._protocol.writeMessageEnd()
85
86 def writeStructBegin(self, name):
87 return self._protocol.writeStructBegin(name)
88
89 def writeStructEnd(self):
90 return self._protocol.writeStructEnd()
91
92 def writeFieldBegin(self, name, ttype, fid):
93 return self._protocol.writeFieldBegin(name, ttype, fid)
94
95 def writeFieldEnd(self):
96 return self._protocol.writeFieldEnd()
97
98 def writeFieldStop(self):
99 return self._protocol.writeFieldStop()
100
101 def writeMapBegin(self, ktype, vtype, size):
102 return self._protocol.writeMapBegin(ktype, vtype, size)
103
104 def writeMapEnd(self):
105 return self._protocol.writeMapEnd()
106
107 def writeListBegin(self, etype, size):
108 return self._protocol.writeListBegin(etype, size)
109
110 def writeListEnd(self):
111 return self._protocol.writeListEnd()
112
113 def writeSetBegin(self, etype, size):
114 return self._protocol.writeSetBegin(etype, size)
115
116 def writeSetEnd(self):
117 return self._protocol.writeSetEnd()
118
119 def writeBool(self, bool_val):
120 return self._protocol.writeBool(bool_val)
121
122 def writeByte(self, byte):
123 return self._protocol.writeByte(byte)
124
125 def writeI16(self, i16):
126 return self._protocol.writeI16(i16)
127
128 def writeI32(self, i32):
129 return self._protocol.writeI32(i32)
130
131 def writeI64(self, i64):
132 return self._protocol.writeI64(i64)
133
134 def writeDouble(self, dub):
135 return self._protocol.writeDouble(dub)
136
137 def writeBinary(self, str_val):
138 return self._protocol.writeBinary(str_val)
139
140 def _set_protocol(self):
141 try:
142 protocol_cls = PROTOCOLS_BY_ID[self.trans.protocol_id]
143 except KeyError:
144 raise TApplicationException(
145 TProtocolException.INVALID_PROTOCOL,
146 "Unknown protocol requested.",
147 )
148
149 self._protocol = protocol_cls(self.trans)
150 self._fast_encode = self._protocol._fast_encode
151 self._fast_decode = self._protocol._fast_decode
152
153 def readMessageBegin(self):
154 try:
155 self.trans.readFrame(0)
156 self._set_protocol()
157 except TApplicationException as exc:
158 self._protocol.writeMessageBegin(b"", TMessageType.EXCEPTION, 0)
159 exc.write(self._protocol)
160 self._protocol.writeMessageEnd()
161 self.trans.flush()
162
163 return self._protocol.readMessageBegin()
164
165 def readMessageEnd(self):
166 return self._protocol.readMessageEnd()
167
168 def readStructBegin(self):
169 return self._protocol.readStructBegin()
170
171 def readStructEnd(self):
172 return self._protocol.readStructEnd()
173
174 def readFieldBegin(self):
175 return self._protocol.readFieldBegin()
176
177 def readFieldEnd(self):
178 return self._protocol.readFieldEnd()
179
180 def readMapBegin(self):
181 return self._protocol.readMapBegin()
182
183 def readMapEnd(self):
184 return self._protocol.readMapEnd()
185
186 def readListBegin(self):
187 return self._protocol.readListBegin()
188
189 def readListEnd(self):
190 return self._protocol.readListEnd()
191
192 def readSetBegin(self):
193 return self._protocol.readSetBegin()
194
195 def readSetEnd(self):
196 return self._protocol.readSetEnd()
197
198 def readBool(self):
199 return self._protocol.readBool()
200
201 def readByte(self):
202 return self._protocol.readByte()
203
204 def readI16(self):
205 return self._protocol.readI16()
206
207 def readI32(self):
208 return self._protocol.readI32()
209
210 def readI64(self):
211 return self._protocol.readI64()
212
213 def readDouble(self):
214 return self._protocol.readDouble()
215
216 def readBinary(self):
217 return self._protocol.readBinary()
218
219
220class THeaderProtocolFactory(object):
221 def __init__(self, allowed_client_types=(THeaderClientType.HEADERS,)):
222 self.allowed_client_types = allowed_client_types
223
224 def getProtocol(self, trans):
225 return THeaderProtocol(trans, self.allowed_client_types)