blob: a1ba1502d71c2593f6faf88e4a68134c5781aad8 [file] [log] [blame]
David Reissea2cba82009-03-30 21:35:00 +00001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
Kevin Clarke43f7e02009-03-03 22:03:57 +000019from zope.interface import implements, Interface, Attribute
Kevin Clarke8d3c472009-03-03 22:13:46 +000020from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
Kevin Clarke43f7e02009-03-03 22:03:57 +000021 connectionDone
Kevin Clarke8d3c472009-03-03 22:13:46 +000022from twisted.internet import defer
Kevin Clarke43f7e02009-03-03 22:03:57 +000023from twisted.protocols import basic
24from twisted.python import log
David Reissea2cba82009-03-30 21:35:00 +000025
26
Kevin Clarke43f7e02009-03-03 22:03:57 +000027from thrift.transport import TTransport
28from cStringIO import StringIO
29
30
31class TMessageSenderTransport(TTransport.TTransportBase):
32
33 def __init__(self):
34 self.__wbuf = StringIO()
35
36 def write(self, buf):
37 self.__wbuf.write(buf)
38
39 def flush(self):
40 msg = self.__wbuf.getvalue()
41 self.__wbuf = StringIO()
42 self.sendMessage(msg)
43
44 def sendMessage(self, message):
45 raise NotImplementedError
46
47
48class TCallbackTransport(TMessageSenderTransport):
49
50 def __init__(self, func):
51 TMessageSenderTransport.__init__(self)
52 self.func = func
53
54 def sendMessage(self, message):
55 self.func(message)
56
57
58class ThriftClientProtocol(basic.Int32StringReceiver):
59
Esteve Fernandezd3571c22009-07-17 18:35:52 +000060 MAX_LENGTH = 2 ** 31 - 1
Esteve Fernandezc5a7c152009-07-17 18:18:19 +000061
Kevin Clarke43f7e02009-03-03 22:03:57 +000062 def __init__(self, client_class, iprot_factory, oprot_factory=None):
63 self._client_class = client_class
64 self._iprot_factory = iprot_factory
65 if oprot_factory is None:
66 self._oprot_factory = iprot_factory
67 else:
68 self._oprot_factory = oprot_factory
69
70 self.recv_map = {}
Kevin Clarke8d3c472009-03-03 22:13:46 +000071 self.started = defer.Deferred()
Kevin Clarke43f7e02009-03-03 22:03:57 +000072
73 def dispatch(self, msg):
74 self.sendString(msg)
75
76 def connectionMade(self):
77 tmo = TCallbackTransport(self.dispatch)
78 self.client = self._client_class(tmo, self._oprot_factory)
Kevin Clarke8d3c472009-03-03 22:13:46 +000079 self.started.callback(self.client)
Kevin Clarke43f7e02009-03-03 22:03:57 +000080
81 def connectionLost(self, reason=connectionDone):
82 for k,v in self.client._reqs.iteritems():
83 tex = TTransport.TTransportException(
84 type=TTransport.TTransportException.END_OF_FILE,
85 message='Connection closed')
86 v.errback(tex)
87
88 def stringReceived(self, frame):
89 tr = TTransport.TMemoryBuffer(frame)
90 iprot = self._iprot_factory.getProtocol(tr)
91 (fname, mtype, rseqid) = iprot.readMessageBegin()
92
93 try:
94 method = self.recv_map[fname]
95 except KeyError:
96 method = getattr(self.client, 'recv_' + fname)
97 self.recv_map[fname] = method
98
99 method(iprot, mtype, rseqid)
100
101
102class ThriftServerProtocol(basic.Int32StringReceiver):
103
Esteve Fernandezd3571c22009-07-17 18:35:52 +0000104 MAX_LENGTH = 2 ** 31 - 1
Esteve Fernandezc5a7c152009-07-17 18:18:19 +0000105
Kevin Clarke43f7e02009-03-03 22:03:57 +0000106 def dispatch(self, msg):
107 self.sendString(msg)
108
109 def processError(self, error):
110 self.transport.loseConnection()
111
112 def processOk(self, _, tmo):
113 msg = tmo.getvalue()
114
115 if len(msg) > 0:
116 self.dispatch(msg)
117
118 def stringReceived(self, frame):
119 tmi = TTransport.TMemoryBuffer(frame)
120 tmo = TTransport.TMemoryBuffer()
121
122 iprot = self.factory.iprot_factory.getProtocol(tmi)
123 oprot = self.factory.oprot_factory.getProtocol(tmo)
124
125 d = self.factory.processor.process(iprot, oprot)
126 d.addCallbacks(self.processOk, self.processError,
127 callbackArgs=(tmo,))
128
129
130class IThriftServerFactory(Interface):
131
132 processor = Attribute("Thrift processor")
133
134 iprot_factory = Attribute("Input protocol factory")
135
136 oprot_factory = Attribute("Output protocol factory")
137
138
139class IThriftClientFactory(Interface):
140
141 client_class = Attribute("Thrift client class")
142
143 iprot_factory = Attribute("Input protocol factory")
144
145 oprot_factory = Attribute("Output protocol factory")
146
147
148class ThriftServerFactory(ServerFactory):
149
150 implements(IThriftServerFactory)
151
152 protocol = ThriftServerProtocol
153
154 def __init__(self, processor, iprot_factory, oprot_factory=None):
155 self.processor = processor
156 self.iprot_factory = iprot_factory
157 if oprot_factory is None:
158 self.oprot_factory = iprot_factory
159 else:
160 self.oprot_factory = oprot_factory
161
162
163class ThriftClientFactory(ClientFactory):
164
165 implements(IThriftClientFactory)
166
167 protocol = ThriftClientProtocol
168
169 def __init__(self, client_class, iprot_factory, oprot_factory=None):
170 self.client_class = client_class
171 self.iprot_factory = iprot_factory
172 if oprot_factory is None:
173 self.oprot_factory = iprot_factory
174 else:
175 self.oprot_factory = oprot_factory
176
177 def buildProtocol(self, addr):
178 p = self.protocol(self.client_class, self.iprot_factory,
179 self.oprot_factory)
180 p.factory = self
181 return p