blob: 0e03f08341d556f6aa3431cea142b48ced42aa74 [file] [log] [blame]
David Reissea2cba82009-03-30 21:35:00 +00001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
Bryan Duxbury69720412012-01-03 17:32:30 +000019
jfarrell8b3ca022014-02-21 12:11:14 -050020import struct
Bryan Duxbury69720412012-01-03 17:32:30 +000021from cStringIO import StringIO
22
Kevin Clarke43f7e02009-03-03 22:03:57 +000023from zope.interface import implements, Interface, Attribute
jfarrell8b3ca022014-02-21 12:11:14 -050024from twisted.internet.protocol import ServerFactory, ClientFactory, \
Kevin Clarke43f7e02009-03-03 22:03:57 +000025 connectionDone
Kevin Clarke8d3c472009-03-03 22:13:46 +000026from twisted.internet import defer
jfarrell8b3ca022014-02-21 12:11:14 -050027from twisted.internet.threads import deferToThread
Kevin Clarke43f7e02009-03-03 22:03:57 +000028from twisted.protocols import basic
Bryan Duxburycb6d9702010-04-29 18:14:54 +000029from twisted.web import server, resource, http
David Reissea2cba82009-03-30 21:35:00 +000030
Kevin Clarke43f7e02009-03-03 22:03:57 +000031from thrift.transport import TTransport
Kevin Clarke43f7e02009-03-03 22:03:57 +000032
33
34class TMessageSenderTransport(TTransport.TTransportBase):
35
36 def __init__(self):
37 self.__wbuf = StringIO()
38
39 def write(self, buf):
40 self.__wbuf.write(buf)
41
42 def flush(self):
43 msg = self.__wbuf.getvalue()
44 self.__wbuf = StringIO()
Jens Geyer83457722014-07-26 18:02:12 +020045 return defer.maybeDeferred(self.sendMessage, msg)
Kevin Clarke43f7e02009-03-03 22:03:57 +000046
47 def sendMessage(self, message):
48 raise NotImplementedError
49
50
51class TCallbackTransport(TMessageSenderTransport):
52
53 def __init__(self, func):
54 TMessageSenderTransport.__init__(self)
55 self.func = func
56
57 def sendMessage(self, message):
Jens Geyer83457722014-07-26 18:02:12 +020058 return self.func(message)
Kevin Clarke43f7e02009-03-03 22:03:57 +000059
60
61class ThriftClientProtocol(basic.Int32StringReceiver):
62
Esteve Fernandezd3571c22009-07-17 18:35:52 +000063 MAX_LENGTH = 2 ** 31 - 1
Esteve Fernandezc5a7c152009-07-17 18:18:19 +000064
Kevin Clarke43f7e02009-03-03 22:03:57 +000065 def __init__(self, client_class, iprot_factory, oprot_factory=None):
66 self._client_class = client_class
67 self._iprot_factory = iprot_factory
68 if oprot_factory is None:
69 self._oprot_factory = iprot_factory
70 else:
71 self._oprot_factory = oprot_factory
72
73 self.recv_map = {}
Kevin Clarke8d3c472009-03-03 22:13:46 +000074 self.started = defer.Deferred()
Kevin Clarke43f7e02009-03-03 22:03:57 +000075
76 def dispatch(self, msg):
77 self.sendString(msg)
78
79 def connectionMade(self):
80 tmo = TCallbackTransport(self.dispatch)
81 self.client = self._client_class(tmo, self._oprot_factory)
Kevin Clarke8d3c472009-03-03 22:13:46 +000082 self.started.callback(self.client)
Kevin Clarke43f7e02009-03-03 22:03:57 +000083
84 def connectionLost(self, reason=connectionDone):
Bryan Duxbury69720412012-01-03 17:32:30 +000085 for k, v in self.client._reqs.iteritems():
Kevin Clarke43f7e02009-03-03 22:03:57 +000086 tex = TTransport.TTransportException(
87 type=TTransport.TTransportException.END_OF_FILE,
88 message='Connection closed')
89 v.errback(tex)
90
91 def stringReceived(self, frame):
92 tr = TTransport.TMemoryBuffer(frame)
93 iprot = self._iprot_factory.getProtocol(tr)
94 (fname, mtype, rseqid) = iprot.readMessageBegin()
95
96 try:
97 method = self.recv_map[fname]
98 except KeyError:
99 method = getattr(self.client, 'recv_' + fname)
100 self.recv_map[fname] = method
101
102 method(iprot, mtype, rseqid)
103
104
jfarrell8b3ca022014-02-21 12:11:14 -0500105class ThriftSASLClientProtocol(ThriftClientProtocol):
106
107 START = 1
108 OK = 2
109 BAD = 3
110 ERROR = 4
111 COMPLETE = 5
112
113 MAX_LENGTH = 2 ** 31 - 1
114
115 def __init__(self, client_class, iprot_factory, oprot_factory=None,
116 host=None, service=None, mechanism='GSSAPI', **sasl_kwargs):
117 """
118 host: the name of the server, from a SASL perspective
119 service: the name of the server's service, from a SASL perspective
120 mechanism: the name of the preferred mechanism to use
121
122 All other kwargs will be passed to the puresasl.client.SASLClient
123 constructor.
124 """
125
126 from puresasl.client import SASLClient
127 self.SASLCLient = SASLClient
128
129 ThriftClientProtocol.__init__(self, client_class, iprot_factory, oprot_factory)
130
131 self._sasl_negotiation_deferred = None
132 self._sasl_negotiation_status = None
133 self.client = None
134
135 if host is not None:
136 self.createSASLClient(host, service, mechanism, **sasl_kwargs)
137
138 def createSASLClient(self, host, service, mechanism, **kwargs):
139 self.sasl = self.SASLClient(host, service, mechanism, **kwargs)
140
141 def dispatch(self, msg):
142 encoded = self.sasl.wrap(msg)
143 len_and_encoded = ''.join((struct.pack('!i', len(encoded)), encoded))
144 ThriftClientProtocol.dispatch(self, len_and_encoded)
145
146 @defer.inlineCallbacks
147 def connectionMade(self):
148 self._sendSASLMessage(self.START, self.sasl.mechanism)
149 initial_message = yield deferToThread(self.sasl.process)
150 self._sendSASLMessage(self.OK, initial_message)
151
152 while True:
153 status, challenge = yield self._receiveSASLMessage()
154 if status == self.OK:
155 response = yield deferToThread(self.sasl.process, challenge)
156 self._sendSASLMessage(self.OK, response)
157 elif status == self.COMPLETE:
158 if not self.sasl.complete:
159 msg = "The server erroneously indicated that SASL " \
160 "negotiation was complete"
161 raise TTransport.TTransportException(msg, message=msg)
162 else:
163 break
164 else:
165 msg = "Bad SASL negotiation status: %d (%s)" % (status, challenge)
166 raise TTransport.TTransportException(msg, message=msg)
167
168 self._sasl_negotiation_deferred = None
169 ThriftClientProtocol.connectionMade(self)
170
171 def _sendSASLMessage(self, status, body):
172 if body is None:
173 body = ""
174 header = struct.pack(">BI", status, len(body))
175 self.transport.write(header + body)
176
177 def _receiveSASLMessage(self):
178 self._sasl_negotiation_deferred = defer.Deferred()
179 self._sasl_negotiation_status = None
180 return self._sasl_negotiation_deferred
181
182 def connectionLost(self, reason=connectionDone):
183 if self.client:
184 ThriftClientProtocol.connectionLost(self, reason)
185
186 def dataReceived(self, data):
187 if self._sasl_negotiation_deferred:
188 # we got a sasl challenge in the format (status, length, challenge)
189 # save the status, let IntNStringReceiver piece the challenge data together
190 self._sasl_negotiation_status, = struct.unpack("B", data[0])
191 ThriftClientProtocol.dataReceived(self, data[1:])
192 else:
193 # normal frame, let IntNStringReceiver piece it together
194 ThriftClientProtocol.dataReceived(self, data)
195
196 def stringReceived(self, frame):
197 if self._sasl_negotiation_deferred:
198 # the frame is just a SASL challenge
199 response = (self._sasl_negotiation_status, frame)
200 self._sasl_negotiation_deferred.callback(response)
201 else:
202 # there's a second 4 byte length prefix inside the frame
203 decoded_frame = self.sasl.unwrap(frame[4:])
204 ThriftClientProtocol.stringReceived(self, decoded_frame)
205
206
Kevin Clarke43f7e02009-03-03 22:03:57 +0000207class ThriftServerProtocol(basic.Int32StringReceiver):
208
Esteve Fernandezd3571c22009-07-17 18:35:52 +0000209 MAX_LENGTH = 2 ** 31 - 1
Esteve Fernandezc5a7c152009-07-17 18:18:19 +0000210
Kevin Clarke43f7e02009-03-03 22:03:57 +0000211 def dispatch(self, msg):
212 self.sendString(msg)
213
214 def processError(self, error):
215 self.transport.loseConnection()
216
217 def processOk(self, _, tmo):
218 msg = tmo.getvalue()
219
220 if len(msg) > 0:
221 self.dispatch(msg)
222
223 def stringReceived(self, frame):
224 tmi = TTransport.TMemoryBuffer(frame)
225 tmo = TTransport.TMemoryBuffer()
226
227 iprot = self.factory.iprot_factory.getProtocol(tmi)
228 oprot = self.factory.oprot_factory.getProtocol(tmo)
229
230 d = self.factory.processor.process(iprot, oprot)
231 d.addCallbacks(self.processOk, self.processError,
232 callbackArgs=(tmo,))
233
234
235class IThriftServerFactory(Interface):
236
237 processor = Attribute("Thrift processor")
238
239 iprot_factory = Attribute("Input protocol factory")
240
241 oprot_factory = Attribute("Output protocol factory")
242
243
244class IThriftClientFactory(Interface):
245
246 client_class = Attribute("Thrift client class")
247
248 iprot_factory = Attribute("Input protocol factory")
249
250 oprot_factory = Attribute("Output protocol factory")
251
252
253class ThriftServerFactory(ServerFactory):
254
255 implements(IThriftServerFactory)
256
257 protocol = ThriftServerProtocol
258
259 def __init__(self, processor, iprot_factory, oprot_factory=None):
260 self.processor = processor
261 self.iprot_factory = iprot_factory
262 if oprot_factory is None:
263 self.oprot_factory = iprot_factory
264 else:
265 self.oprot_factory = oprot_factory
266
267
268class ThriftClientFactory(ClientFactory):
269
270 implements(IThriftClientFactory)
271
272 protocol = ThriftClientProtocol
273
274 def __init__(self, client_class, iprot_factory, oprot_factory=None):
275 self.client_class = client_class
276 self.iprot_factory = iprot_factory
277 if oprot_factory is None:
278 self.oprot_factory = iprot_factory
279 else:
280 self.oprot_factory = oprot_factory
281
282 def buildProtocol(self, addr):
283 p = self.protocol(self.client_class, self.iprot_factory,
284 self.oprot_factory)
285 p.factory = self
286 return p
Bryan Duxburycb6d9702010-04-29 18:14:54 +0000287
288
289class ThriftResource(resource.Resource):
290
291 allowedMethods = ('POST',)
292
293 def __init__(self, processor, inputProtocolFactory,
294 outputProtocolFactory=None):
295 resource.Resource.__init__(self)
296 self.inputProtocolFactory = inputProtocolFactory
297 if outputProtocolFactory is None:
298 self.outputProtocolFactory = inputProtocolFactory
299 else:
300 self.outputProtocolFactory = outputProtocolFactory
301 self.processor = processor
302
303 def getChild(self, path, request):
304 return self
305
306 def _cbProcess(self, _, request, tmo):
307 msg = tmo.getvalue()
308 request.setResponseCode(http.OK)
309 request.setHeader("content-type", "application/x-thrift")
310 request.write(msg)
311 request.finish()
312
313 def render_POST(self, request):
314 request.content.seek(0, 0)
315 data = request.content.read()
316 tmi = TTransport.TMemoryBuffer(data)
317 tmo = TTransport.TMemoryBuffer()
318
319 iprot = self.inputProtocolFactory.getProtocol(tmi)
320 oprot = self.outputProtocolFactory.getProtocol(tmo)
321
322 d = self.processor.process(iprot, oprot)
323 d.addCallback(self._cbProcess, request, tmo)
324 return server.NOT_DONE_YET