blob: a27f0adade2b732517d20ab1e28a2223f96224ed [file] [log] [blame]
David Reissea2cba82009-03-30 21:35:00 +00001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
Bryan Duxbury69720412012-01-03 17:32:30 +000019
Nobuaki Sukegawa760511f2015-11-06 21:24:16 +090020from io import BytesIO
jfarrell8b3ca022014-02-21 12:11:14 -050021import struct
Bryan Duxbury69720412012-01-03 17:32:30 +000022
ciarancourtneydb3a92e2016-12-20 11:12:15 +000023from zope.interface import implementer, Interface, Attribute
jfarrell8b3ca022014-02-21 12:11:14 -050024from twisted.internet.protocol import ServerFactory, ClientFactory, \
Kevin Clarke43f7e02009-03-03 22:03:57 +000025 connectionDone
Kevin Clarke8d3c472009-03-03 22:13:46 +000026from twisted.internet import defer
jfarrell8b3ca022014-02-21 12:11:14 -050027from twisted.internet.threads import deferToThread
Kevin Clarke43f7e02009-03-03 22:03:57 +000028from twisted.protocols import basic
Bryan Duxburycb6d9702010-04-29 18:14:54 +000029from twisted.web import server, resource, http
David Reissea2cba82009-03-30 21:35:00 +000030
Kevin Clarke43f7e02009-03-03 22:03:57 +000031from thrift.transport import TTransport
Kevin Clarke43f7e02009-03-03 22:03:57 +000032
33
34class TMessageSenderTransport(TTransport.TTransportBase):
35
36 def __init__(self):
Nobuaki Sukegawa760511f2015-11-06 21:24:16 +090037 self.__wbuf = BytesIO()
Kevin Clarke43f7e02009-03-03 22:03:57 +000038
39 def write(self, buf):
40 self.__wbuf.write(buf)
41
42 def flush(self):
43 msg = self.__wbuf.getvalue()
Nobuaki Sukegawa760511f2015-11-06 21:24:16 +090044 self.__wbuf = BytesIO()
James Broadhead51ba56c2014-08-10 22:03:38 +010045 return self.sendMessage(msg)
Kevin Clarke43f7e02009-03-03 22:03:57 +000046
47 def sendMessage(self, message):
48 raise NotImplementedError
49
50
51class TCallbackTransport(TMessageSenderTransport):
52
53 def __init__(self, func):
54 TMessageSenderTransport.__init__(self)
55 self.func = func
56
57 def sendMessage(self, message):
James Broadhead51ba56c2014-08-10 22:03:38 +010058 return self.func(message)
Kevin Clarke43f7e02009-03-03 22:03:57 +000059
60
61class ThriftClientProtocol(basic.Int32StringReceiver):
62
Esteve Fernandezd3571c22009-07-17 18:35:52 +000063 MAX_LENGTH = 2 ** 31 - 1
Esteve Fernandezc5a7c152009-07-17 18:18:19 +000064
Kevin Clarke43f7e02009-03-03 22:03:57 +000065 def __init__(self, client_class, iprot_factory, oprot_factory=None):
66 self._client_class = client_class
67 self._iprot_factory = iprot_factory
68 if oprot_factory is None:
69 self._oprot_factory = iprot_factory
70 else:
71 self._oprot_factory = oprot_factory
72
73 self.recv_map = {}
Kevin Clarke8d3c472009-03-03 22:13:46 +000074 self.started = defer.Deferred()
Kevin Clarke43f7e02009-03-03 22:03:57 +000075
76 def dispatch(self, msg):
77 self.sendString(msg)
78
79 def connectionMade(self):
80 tmo = TCallbackTransport(self.dispatch)
81 self.client = self._client_class(tmo, self._oprot_factory)
Kevin Clarke8d3c472009-03-03 22:13:46 +000082 self.started.callback(self.client)
Kevin Clarke43f7e02009-03-03 22:03:57 +000083
84 def connectionLost(self, reason=connectionDone):
Robert Thilledd02f552015-01-08 14:48:53 -080085 # the called errbacks can add items to our client's _reqs,
86 # so we need to use a tmp, and iterate until no more requests
87 # are added during errbacks
88 if self.client:
Kevin Clarke43f7e02009-03-03 22:03:57 +000089 tex = TTransport.TTransportException(
90 type=TTransport.TTransportException.END_OF_FILE,
Robert Thilledd02f552015-01-08 14:48:53 -080091 message='Connection closed (%s)' % reason)
92 while self.client._reqs:
93 _, v = self.client._reqs.popitem()
94 v.errback(tex)
95 del self.client._reqs
96 self.client = None
Kevin Clarke43f7e02009-03-03 22:03:57 +000097
98 def stringReceived(self, frame):
99 tr = TTransport.TMemoryBuffer(frame)
100 iprot = self._iprot_factory.getProtocol(tr)
101 (fname, mtype, rseqid) = iprot.readMessageBegin()
102
103 try:
104 method = self.recv_map[fname]
105 except KeyError:
106 method = getattr(self.client, 'recv_' + fname)
107 self.recv_map[fname] = method
108
109 method(iprot, mtype, rseqid)
110
111
jfarrell8b3ca022014-02-21 12:11:14 -0500112class ThriftSASLClientProtocol(ThriftClientProtocol):
113
114 START = 1
115 OK = 2
116 BAD = 3
117 ERROR = 4
118 COMPLETE = 5
119
120 MAX_LENGTH = 2 ** 31 - 1
121
122 def __init__(self, client_class, iprot_factory, oprot_factory=None,
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900123 host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
jfarrell8b3ca022014-02-21 12:11:14 -0500124 """
125 host: the name of the server, from a SASL perspective
126 service: the name of the server's service, from a SASL perspective
127 mechanism: the name of the preferred mechanism to use
128
129 All other kwargs will be passed to the puresasl.client.SASLClient
130 constructor.
131 """
132
133 from puresasl.client import SASLClient
134 self.SASLCLient = SASLClient
135
136 ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
137
138 self._sasl_negotiation_deferred = None
139 self._sasl_negotiation_status = None
140 self.client = None
141
142 if host is not None:
143 self.createSASLClient(host, service, mechanism, **sasl_kwargs)
144
145 def createSASLClient(self, host, service, mechanism, **kwargs):
146 self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
147
148 def dispatch(self, msg):
149 encoded = self.sasl.wrap(msg)
150 len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
151 ThriftClientProtocol.dispatch(self, len_and_encoded)
152
153 @defer.inlineCallbacks
154 def connectionMade(self):
155 self._sendSASLMessage(self.START, self.sasl.mechanism)
156 initial_message = yield deferToThread(self.sasl.process)
157 self._sendSASLMessage(self.OK, initial_message)
158
159 while True:
160 status, challenge = yield self._receiveSASLMessage()
161 if status == self.OK:
162 response = yield deferToThread(self.sasl.process, challenge)
163 self._sendSASLMessage(self.OK, response)
164 elif status == self.COMPLETE:
165 if not self.sasl.complete:
166 msg = "The server erroneously indicated that SASL " \
167 "negotiation was complete"
168 raise TTransport.TTransportException(msg, message=msg)
169 else:
170 break
171 else:
172 msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
173 raise TTransport.TTransportException(msg, message=msg)
174
175 self._sasl_negotiation_deferred = None
176 ThriftClientProtocol.connectionMade(self)
177
178 def _sendSASLMessage(self, status, body):
179 if body is None:
180 body = ""
181 header = struct.pack(">BI", status, len(body))
182 self.transport.write(header + body)
183
184 def _receiveSASLMessage(self):
185 self._sasl_negotiation_deferred = defer.Deferred()
186 self._sasl_negotiation_status = None
187 return self._sasl_negotiation_deferred
188
189 def connectionLost(self, reason=connectionDone):
190 if self.client:
191 ThriftClientProtocol.connectionLost(self, reason)
192
193 def dataReceived(self, data):
194 if self._sasl_negotiation_deferred:
195 # we got a sasl challenge in the format (status, length, challenge)
196 # save the status, let IntNStringReceiver piece the challenge data together
197 self._sasl_negotiation_status, = struct.unpack("B", data[0])
198 ThriftClientProtocol.dataReceived(self, data[1:])
199 else:
200 # normal frame, let IntNStringReceiver piece it together
201 ThriftClientProtocol.dataReceived(self, data)
202
203 def stringReceived(self, frame):
204 if self._sasl_negotiation_deferred:
205 # the frame is just a SASL challenge
206 response = (self._sasl_negotiation_status, frame)
207 self._sasl_negotiation_deferred.callback(response)
208 else:
209 # there's a second 4 byte length prefix inside the frame
210 decoded_frame = self.sasl.unwrap(frame[4:])
211 ThriftClientProtocol.stringReceived(self, decoded_frame)
212
213
Kevin Clarke43f7e02009-03-03 22:03:57 +0000214class ThriftServerProtocol(basic.Int32StringReceiver):
215
Esteve Fernandezd3571c22009-07-17 18:35:52 +0000216 MAX_LENGTH = 2 ** 31 - 1
Esteve Fernandezc5a7c152009-07-17 18:18:19 +0000217
Kevin Clarke43f7e02009-03-03 22:03:57 +0000218 def dispatch(self, msg):
219 self.sendString(msg)
220
221 def processError(self, error):
222 self.transport.loseConnection()
223
224 def processOk(self, _, tmo):
225 msg = tmo.getvalue()
226
227 if len(msg) > 0:
228 self.dispatch(msg)
229
230 def stringReceived(self, frame):
231 tmi = TTransport.TMemoryBuffer(frame)
232 tmo = TTransport.TMemoryBuffer()
233
234 iprot = self.factory.iprot_factory.getProtocol(tmi)
235 oprot = self.factory.oprot_factory.getProtocol(tmo)
236
237 d = self.factory.processor.process(iprot, oprot)
238 d.addCallbacks(self.processOk, self.processError,
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900239 callbackArgs=(tmo,))
Kevin Clarke43f7e02009-03-03 22:03:57 +0000240
241
242class IThriftServerFactory(Interface):
243
244 processor = Attribute("Thrift processor")
245
246 iprot_factory = Attribute("Input protocol factory")
247
248 oprot_factory = Attribute("Output protocol factory")
249
250
251class IThriftClientFactory(Interface):
252
253 client_class = Attribute("Thrift client class")
254
255 iprot_factory = Attribute("Input protocol factory")
256
257 oprot_factory = Attribute("Output protocol factory")
258
259
ciarancourtneydb3a92e2016-12-20 11:12:15 +0000260@implementer(IThriftServerFactory)
Kevin Clarke43f7e02009-03-03 22:03:57 +0000261class ThriftServerFactory(ServerFactory):
262
Kevin Clarke43f7e02009-03-03 22:03:57 +0000263 protocol = ThriftServerProtocol
264
265 def __init__(self, processor, iprot_factory, oprot_factory=None):
266 self.processor = processor
267 self.iprot_factory = iprot_factory
268 if oprot_factory is None:
269 self.oprot_factory = iprot_factory
270 else:
271 self.oprot_factory = oprot_factory
272
273
ciarancourtneydb3a92e2016-12-20 11:12:15 +0000274@implementer(IThriftClientFactory)
Kevin Clarke43f7e02009-03-03 22:03:57 +0000275class ThriftClientFactory(ClientFactory):
276
Kevin Clarke43f7e02009-03-03 22:03:57 +0000277 protocol = ThriftClientProtocol
278
279 def __init__(self, client_class, iprot_factory, oprot_factory=None):
280 self.client_class = client_class
281 self.iprot_factory = iprot_factory
282 if oprot_factory is None:
283 self.oprot_factory = iprot_factory
284 else:
285 self.oprot_factory = oprot_factory
286
287 def buildProtocol(self, addr):
288 p = self.protocol(self.client_class, self.iprot_factory,
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900289 self.oprot_factory)
Kevin Clarke43f7e02009-03-03 22:03:57 +0000290 p.factory = self
291 return p
Bryan Duxburycb6d9702010-04-29 18:14:54 +0000292
293
294class ThriftResource(resource.Resource):
295
296 allowedMethods = ('POST',)
297
298 def __init__(self, processor, inputProtocolFactory,
Nobuaki Sukegawa10308cb2016-02-03 01:57:03 +0900299 outputProtocolFactory=None):
Bryan Duxburycb6d9702010-04-29 18:14:54 +0000300 resource.Resource.__init__(self)
301 self.inputProtocolFactory = inputProtocolFactory
302 if outputProtocolFactory is None:
303 self.outputProtocolFactory = inputProtocolFactory
304 else:
305 self.outputProtocolFactory = outputProtocolFactory
306 self.processor = processor
307
308 def getChild(self, path, request):
309 return self
310
311 def _cbProcess(self, _, request, tmo):
312 msg = tmo.getvalue()
313 request.setResponseCode(http.OK)
314 request.setHeader("content-type", "application/x-thrift")
315 request.write(msg)
316 request.finish()
317
318 def render_POST(self, request):
319 request.content.seek(0, 0)
320 data = request.content.read()
321 tmi = TTransport.TMemoryBuffer(data)
322 tmo = TTransport.TMemoryBuffer()
323
324 iprot = self.inputProtocolFactory.getProtocol(tmi)
325 oprot = self.outputProtocolFactory.getProtocol(tmo)
326
327 d = self.processor.process(iprot, oprot)
328 d.addCallback(self._cbProcess, request, tmo)
329 return server.NOT_DONE_YET