blob: 3ce3eb2202ba6a2645694eed1b712e6437282bce [file] [log] [blame]
David Reissea2cba82009-03-30 21:35:00 +00001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
Bryan Duxbury69720412012-01-03 17:32:30 +000019
20from cStringIO import StringIO
21
Kevin Clarke43f7e02009-03-03 22:03:57 +000022from zope.interface import implements, Interface, Attribute
Kevin Clarke8d3c472009-03-03 22:13:46 +000023from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
Kevin Clarke43f7e02009-03-03 22:03:57 +000024 connectionDone
Kevin Clarke8d3c472009-03-03 22:13:46 +000025from twisted.internet import defer
Kevin Clarke43f7e02009-03-03 22:03:57 +000026from twisted.protocols import basic
27from twisted.python import log
Bryan Duxburycb6d9702010-04-29 18:14:54 +000028from twisted.web import server, resource, http
David Reissea2cba82009-03-30 21:35:00 +000029
Kevin Clarke43f7e02009-03-03 22:03:57 +000030from thrift.transport import TTransport
Kevin Clarke43f7e02009-03-03 22:03:57 +000031
32
33class TMessageSenderTransport(TTransport.TTransportBase):
34
35 def __init__(self):
36 self.__wbuf = StringIO()
37
38 def write(self, buf):
39 self.__wbuf.write(buf)
40
41 def flush(self):
42 msg = self.__wbuf.getvalue()
43 self.__wbuf = StringIO()
44 self.sendMessage(msg)
45
46 def sendMessage(self, message):
47 raise NotImplementedError
48
49
50class TCallbackTransport(TMessageSenderTransport):
51
52 def __init__(self, func):
53 TMessageSenderTransport.__init__(self)
54 self.func = func
55
56 def sendMessage(self, message):
57 self.func(message)
58
59
60class ThriftClientProtocol(basic.Int32StringReceiver):
61
Esteve Fernandezd3571c22009-07-17 18:35:52 +000062 MAX_LENGTH = 2 ** 31 - 1
Esteve Fernandezc5a7c152009-07-17 18:18:19 +000063
Kevin Clarke43f7e02009-03-03 22:03:57 +000064 def __init__(self, client_class, iprot_factory, oprot_factory=None):
65 self._client_class = client_class
66 self._iprot_factory = iprot_factory
67 if oprot_factory is None:
68 self._oprot_factory = iprot_factory
69 else:
70 self._oprot_factory = oprot_factory
71
72 self.recv_map = {}
Kevin Clarke8d3c472009-03-03 22:13:46 +000073 self.started = defer.Deferred()
Kevin Clarke43f7e02009-03-03 22:03:57 +000074
75 def dispatch(self, msg):
76 self.sendString(msg)
77
78 def connectionMade(self):
79 tmo = TCallbackTransport(self.dispatch)
80 self.client = self._client_class(tmo, self._oprot_factory)
Kevin Clarke8d3c472009-03-03 22:13:46 +000081 self.started.callback(self.client)
Kevin Clarke43f7e02009-03-03 22:03:57 +000082
83 def connectionLost(self, reason=connectionDone):
Bryan Duxbury69720412012-01-03 17:32:30 +000084 for k, v in self.client._reqs.iteritems():
Kevin Clarke43f7e02009-03-03 22:03:57 +000085 tex = TTransport.TTransportException(
86 type=TTransport.TTransportException.END_OF_FILE,
87 message='Connection closed')
88 v.errback(tex)
89
90 def stringReceived(self, frame):
91 tr = TTransport.TMemoryBuffer(frame)
92 iprot = self._iprot_factory.getProtocol(tr)
93 (fname, mtype, rseqid) = iprot.readMessageBegin()
94
95 try:
96 method = self.recv_map[fname]
97 except KeyError:
98 method = getattr(self.client, 'recv_' + fname)
99 self.recv_map[fname] = method
100
101 method(iprot, mtype, rseqid)
102
103
104class ThriftServerProtocol(basic.Int32StringReceiver):
105
Esteve Fernandezd3571c22009-07-17 18:35:52 +0000106 MAX_LENGTH = 2 ** 31 - 1
Esteve Fernandezc5a7c152009-07-17 18:18:19 +0000107
Kevin Clarke43f7e02009-03-03 22:03:57 +0000108 def dispatch(self, msg):
109 self.sendString(msg)
110
111 def processError(self, error):
112 self.transport.loseConnection()
113
114 def processOk(self, _, tmo):
115 msg = tmo.getvalue()
116
117 if len(msg) > 0:
118 self.dispatch(msg)
119
120 def stringReceived(self, frame):
121 tmi = TTransport.TMemoryBuffer(frame)
122 tmo = TTransport.TMemoryBuffer()
123
124 iprot = self.factory.iprot_factory.getProtocol(tmi)
125 oprot = self.factory.oprot_factory.getProtocol(tmo)
126
127 d = self.factory.processor.process(iprot, oprot)
128 d.addCallbacks(self.processOk, self.processError,
129 callbackArgs=(tmo,))
130
131
132class IThriftServerFactory(Interface):
133
134 processor = Attribute("Thrift processor")
135
136 iprot_factory = Attribute("Input protocol factory")
137
138 oprot_factory = Attribute("Output protocol factory")
139
140
141class IThriftClientFactory(Interface):
142
143 client_class = Attribute("Thrift client class")
144
145 iprot_factory = Attribute("Input protocol factory")
146
147 oprot_factory = Attribute("Output protocol factory")
148
149
150class ThriftServerFactory(ServerFactory):
151
152 implements(IThriftServerFactory)
153
154 protocol = ThriftServerProtocol
155
156 def __init__(self, processor, iprot_factory, oprot_factory=None):
157 self.processor = processor
158 self.iprot_factory = iprot_factory
159 if oprot_factory is None:
160 self.oprot_factory = iprot_factory
161 else:
162 self.oprot_factory = oprot_factory
163
164
165class ThriftClientFactory(ClientFactory):
166
167 implements(IThriftClientFactory)
168
169 protocol = ThriftClientProtocol
170
171 def __init__(self, client_class, iprot_factory, oprot_factory=None):
172 self.client_class = client_class
173 self.iprot_factory = iprot_factory
174 if oprot_factory is None:
175 self.oprot_factory = iprot_factory
176 else:
177 self.oprot_factory = oprot_factory
178
179 def buildProtocol(self, addr):
180 p = self.protocol(self.client_class, self.iprot_factory,
181 self.oprot_factory)
182 p.factory = self
183 return p
Bryan Duxburycb6d9702010-04-29 18:14:54 +0000184
185
186class ThriftResource(resource.Resource):
187
188 allowedMethods = ('POST',)
189
190 def __init__(self, processor, inputProtocolFactory,
191 outputProtocolFactory=None):
192 resource.Resource.__init__(self)
193 self.inputProtocolFactory = inputProtocolFactory
194 if outputProtocolFactory is None:
195 self.outputProtocolFactory = inputProtocolFactory
196 else:
197 self.outputProtocolFactory = outputProtocolFactory
198 self.processor = processor
199
200 def getChild(self, path, request):
201 return self
202
203 def _cbProcess(self, _, request, tmo):
204 msg = tmo.getvalue()
205 request.setResponseCode(http.OK)
206 request.setHeader("content-type", "application/x-thrift")
207 request.write(msg)
208 request.finish()
209
210 def render_POST(self, request):
211 request.content.seek(0, 0)
212 data = request.content.read()
213 tmi = TTransport.TMemoryBuffer(data)
214 tmo = TTransport.TMemoryBuffer()
215
216 iprot = self.inputProtocolFactory.getProtocol(tmi)
217 oprot = self.outputProtocolFactory.getProtocol(tmo)
218
219 d = self.processor.process(iprot, oprot)
220 d.addCallback(self._cbProcess, request, tmo)
221 return server.NOT_DONE_YET