blob: cabe345ed78c892bf6000eafb9eefc7e9e018937 [file] [log] [blame]
David Reissea2cba82009-03-30 21:35:00 +00001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
Bryan Duxbury69720412012-01-03 17:32:30 +000019
Nobuaki Sukegawa760511f2015-11-06 21:24:16 +090020from io import BytesIO
jfarrell8b3ca022014-02-21 12:11:14 -050021import struct
Bryan Duxbury69720412012-01-03 17:32:30 +000022
Kevin Clarke43f7e02009-03-03 22:03:57 +000023from zope.interface import implements, Interface, Attribute
jfarrell8b3ca022014-02-21 12:11:14 -050024from twisted.internet.protocol import ServerFactory, ClientFactory, \
Kevin Clarke43f7e02009-03-03 22:03:57 +000025 connectionDone
Kevin Clarke8d3c472009-03-03 22:13:46 +000026from twisted.internet import defer
jfarrell8b3ca022014-02-21 12:11:14 -050027from twisted.internet.threads import deferToThread
Kevin Clarke43f7e02009-03-03 22:03:57 +000028from twisted.protocols import basic
Bryan Duxburycb6d9702010-04-29 18:14:54 +000029from twisted.web import server, resource, http
David Reissea2cba82009-03-30 21:35:00 +000030
Kevin Clarke43f7e02009-03-03 22:03:57 +000031from thrift.transport import TTransport
Nobuaki Sukegawa760511f2015-11-06 21:24:16 +090032import six
Kevin Clarke43f7e02009-03-03 22:03:57 +000033
34
35class TMessageSenderTransport(TTransport.TTransportBase):
36
37 def __init__(self):
Nobuaki Sukegawa760511f2015-11-06 21:24:16 +090038 self.__wbuf = BytesIO()
Kevin Clarke43f7e02009-03-03 22:03:57 +000039
40 def write(self, buf):
41 self.__wbuf.write(buf)
42
43 def flush(self):
44 msg = self.__wbuf.getvalue()
Nobuaki Sukegawa760511f2015-11-06 21:24:16 +090045 self.__wbuf = BytesIO()
James Broadhead51ba56c2014-08-10 22:03:38 +010046 return self.sendMessage(msg)
Kevin Clarke43f7e02009-03-03 22:03:57 +000047
48 def sendMessage(self, message):
49 raise NotImplementedError
50
51
52class TCallbackTransport(TMessageSenderTransport):
53
54 def __init__(self, func):
55 TMessageSenderTransport.__init__(self)
56 self.func = func
57
58 def sendMessage(self, message):
James Broadhead51ba56c2014-08-10 22:03:38 +010059 return self.func(message)
Kevin Clarke43f7e02009-03-03 22:03:57 +000060
61
62class ThriftClientProtocol(basic.Int32StringReceiver):
63
Esteve Fernandezd3571c22009-07-17 18:35:52 +000064 MAX_LENGTH = 2 ** 31 - 1
Esteve Fernandezc5a7c152009-07-17 18:18:19 +000065
Kevin Clarke43f7e02009-03-03 22:03:57 +000066 def __init__(self, client_class, iprot_factory, oprot_factory=None):
67 self._client_class = client_class
68 self._iprot_factory = iprot_factory
69 if oprot_factory is None:
70 self._oprot_factory = iprot_factory
71 else:
72 self._oprot_factory = oprot_factory
73
74 self.recv_map = {}
Kevin Clarke8d3c472009-03-03 22:13:46 +000075 self.started = defer.Deferred()
Kevin Clarke43f7e02009-03-03 22:03:57 +000076
77 def dispatch(self, msg):
78 self.sendString(msg)
79
80 def connectionMade(self):
81 tmo = TCallbackTransport(self.dispatch)
82 self.client = self._client_class(tmo, self._oprot_factory)
Kevin Clarke8d3c472009-03-03 22:13:46 +000083 self.started.callback(self.client)
Kevin Clarke43f7e02009-03-03 22:03:57 +000084
85 def connectionLost(self, reason=connectionDone):
Nobuaki Sukegawa760511f2015-11-06 21:24:16 +090086 for k, v in six.iteritems(self.client._reqs):
Kevin Clarke43f7e02009-03-03 22:03:57 +000087 tex = TTransport.TTransportException(
88 type=TTransport.TTransportException.END_OF_FILE,
89 message='Connection closed')
90 v.errback(tex)
91
92 def stringReceived(self, frame):
93 tr = TTransport.TMemoryBuffer(frame)
94 iprot = self._iprot_factory.getProtocol(tr)
95 (fname, mtype, rseqid) = iprot.readMessageBegin()
96
97 try:
98 method = self.recv_map[fname]
99 except KeyError:
100 method = getattr(self.client, 'recv_' + fname)
101 self.recv_map[fname] = method
102
103 method(iprot, mtype, rseqid)
104
105
jfarrell8b3ca022014-02-21 12:11:14 -0500106class ThriftSASLClientProtocol(ThriftClientProtocol):
107
108 START = 1
109 OK = 2
110 BAD = 3
111 ERROR = 4
112 COMPLETE = 5
113
114 MAX_LENGTH = 2 ** 31 - 1
115
116 def __init__(self, client_class, iprot_factory, oprot_factory=None,
117 host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
118 """
119 host: the name of the server, from a SASL perspective
120 service: the name of the server's service, from a SASL perspective
121 mechanism: the name of the preferred mechanism to use
122
123 All other kwargs will be passed to the puresasl.client.SASLClient
124 constructor.
125 """
126
127 from puresasl.client import SASLClient
128 self.SASLCLient = SASLClient
129
130 ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
131
132 self._sasl_negotiation_deferred = None
133 self._sasl_negotiation_status = None
134 self.client = None
135
136 if host is not None:
137 self.createSASLClient(host, service, mechanism, **sasl_kwargs)
138
139 def createSASLClient(self, host, service, mechanism, **kwargs):
140 self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
141
142 def dispatch(self, msg):
143 encoded = self.sasl.wrap(msg)
144 len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
145 ThriftClientProtocol.dispatch(self, len_and_encoded)
146
147 @defer.inlineCallbacks
148 def connectionMade(self):
149 self._sendSASLMessage(self.START, self.sasl.mechanism)
150 initial_message = yield deferToThread(self.sasl.process)
151 self._sendSASLMessage(self.OK, initial_message)
152
153 while True:
154 status, challenge = yield self._receiveSASLMessage()
155 if status == self.OK:
156 response = yield deferToThread(self.sasl.process, challenge)
157 self._sendSASLMessage(self.OK, response)
158 elif status == self.COMPLETE:
159 if not self.sasl.complete:
160 msg = "The server erroneously indicated that SASL " \
161 "negotiation was complete"
162 raise TTransport.TTransportException(msg, message=msg)
163 else:
164 break
165 else:
166 msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
167 raise TTransport.TTransportException(msg, message=msg)
168
169 self._sasl_negotiation_deferred = None
170 ThriftClientProtocol.connectionMade(self)
171
172 def _sendSASLMessage(self, status, body):
173 if body is None:
174 body = ""
175 header = struct.pack(">BI", status, len(body))
176 self.transport.write(header + body)
177
178 def _receiveSASLMessage(self):
179 self._sasl_negotiation_deferred = defer.Deferred()
180 self._sasl_negotiation_status = None
181 return self._sasl_negotiation_deferred
182
183 def connectionLost(self, reason=connectionDone):
184 if self.client:
185 ThriftClientProtocol.connectionLost(self, reason)
186
187 def dataReceived(self, data):
188 if self._sasl_negotiation_deferred:
189 # we got a sasl challenge in the format (status, length, challenge)
190 # save the status, let IntNStringReceiver piece the challenge data together
191 self._sasl_negotiation_status, = struct.unpack("B", data[0])
192 ThriftClientProtocol.dataReceived(self, data[1:])
193 else:
194 # normal frame, let IntNStringReceiver piece it together
195 ThriftClientProtocol.dataReceived(self, data)
196
197 def stringReceived(self, frame):
198 if self._sasl_negotiation_deferred:
199 # the frame is just a SASL challenge
200 response = (self._sasl_negotiation_status, frame)
201 self._sasl_negotiation_deferred.callback(response)
202 else:
203 # there's a second 4 byte length prefix inside the frame
204 decoded_frame = self.sasl.unwrap(frame[4:])
205 ThriftClientProtocol.stringReceived(self, decoded_frame)
206
207
Kevin Clarke43f7e02009-03-03 22:03:57 +0000208class ThriftServerProtocol(basic.Int32StringReceiver):
209
Esteve Fernandezd3571c22009-07-17 18:35:52 +0000210 MAX_LENGTH = 2 ** 31 - 1
Esteve Fernandezc5a7c152009-07-17 18:18:19 +0000211
Kevin Clarke43f7e02009-03-03 22:03:57 +0000212 def dispatch(self, msg):
213 self.sendString(msg)
214
215 def processError(self, error):
216 self.transport.loseConnection()
217
218 def processOk(self, _, tmo):
219 msg = tmo.getvalue()
220
221 if len(msg) > 0:
222 self.dispatch(msg)
223
224 def stringReceived(self, frame):
225 tmi = TTransport.TMemoryBuffer(frame)
226 tmo = TTransport.TMemoryBuffer()
227
228 iprot = self.factory.iprot_factory.getProtocol(tmi)
229 oprot = self.factory.oprot_factory.getProtocol(tmo)
230
231 d = self.factory.processor.process(iprot, oprot)
232 d.addCallbacks(self.processOk, self.processError,
233 callbackArgs=(tmo,))
234
235
236class IThriftServerFactory(Interface):
237
238 processor = Attribute("Thrift processor")
239
240 iprot_factory = Attribute("Input protocol factory")
241
242 oprot_factory = Attribute("Output protocol factory")
243
244
245class IThriftClientFactory(Interface):
246
247 client_class = Attribute("Thrift client class")
248
249 iprot_factory = Attribute("Input protocol factory")
250
251 oprot_factory = Attribute("Output protocol factory")
252
253
254class ThriftServerFactory(ServerFactory):
255
256 implements(IThriftServerFactory)
257
258 protocol = ThriftServerProtocol
259
260 def __init__(self, processor, iprot_factory, oprot_factory=None):
261 self.processor = processor
262 self.iprot_factory = iprot_factory
263 if oprot_factory is None:
264 self.oprot_factory = iprot_factory
265 else:
266 self.oprot_factory = oprot_factory
267
268
269class ThriftClientFactory(ClientFactory):
270
271 implements(IThriftClientFactory)
272
273 protocol = ThriftClientProtocol
274
275 def __init__(self, client_class, iprot_factory, oprot_factory=None):
276 self.client_class = client_class
277 self.iprot_factory = iprot_factory
278 if oprot_factory is None:
279 self.oprot_factory = iprot_factory
280 else:
281 self.oprot_factory = oprot_factory
282
283 def buildProtocol(self, addr):
284 p = self.protocol(self.client_class, self.iprot_factory,
285 self.oprot_factory)
286 p.factory = self
287 return p
Bryan Duxburycb6d9702010-04-29 18:14:54 +0000288
289
290class ThriftResource(resource.Resource):
291
292 allowedMethods = ('POST',)
293
294 def __init__(self, processor, inputProtocolFactory,
295 outputProtocolFactory=None):
296 resource.Resource.__init__(self)
297 self.inputProtocolFactory = inputProtocolFactory
298 if outputProtocolFactory is None:
299 self.outputProtocolFactory = inputProtocolFactory
300 else:
301 self.outputProtocolFactory = outputProtocolFactory
302 self.processor = processor
303
304 def getChild(self, path, request):
305 return self
306
307 def _cbProcess(self, _, request, tmo):
308 msg = tmo.getvalue()
309 request.setResponseCode(http.OK)
310 request.setHeader("content-type", "application/x-thrift")
311 request.write(msg)
312 request.finish()
313
314 def render_POST(self, request):
315 request.content.seek(0, 0)
316 data = request.content.read()
317 tmi = TTransport.TMemoryBuffer(data)
318 tmo = TTransport.TMemoryBuffer()
319
320 iprot = self.inputProtocolFactory.getProtocol(tmi)
321 oprot = self.outputProtocolFactory.getProtocol(tmo)
322
323 d = self.processor.process(iprot, oprot)
324 d.addCallback(self._cbProcess, request, tmo)
325 return server.NOT_DONE_YET