blob: b5c2147b269e1e6235e7e4f96b8e932518ffd4dd [file] [log] [blame]
David Reissea2cba82009-03-30 21:35:00 +00001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
Kevin Clarke43f7e02009-03-03 22:03:57 +000019from zope.interface import implements, Interface, Attribute
Kevin Clarke8d3c472009-03-03 22:13:46 +000020from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
Kevin Clarke43f7e02009-03-03 22:03:57 +000021 connectionDone
Kevin Clarke8d3c472009-03-03 22:13:46 +000022from twisted.internet import defer
Kevin Clarke43f7e02009-03-03 22:03:57 +000023from twisted.protocols import basic
24from twisted.python import log
David Reissea2cba82009-03-30 21:35:00 +000025
26
Kevin Clarke43f7e02009-03-03 22:03:57 +000027from thrift.transport import TTransport
28from cStringIO import StringIO
29
30
31class TMessageSenderTransport(TTransport.TTransportBase):
32
33 def __init__(self):
34 self.__wbuf = StringIO()
35
36 def write(self, buf):
37 self.__wbuf.write(buf)
38
39 def flush(self):
40 msg = self.__wbuf.getvalue()
41 self.__wbuf = StringIO()
42 self.sendMessage(msg)
43
44 def sendMessage(self, message):
45 raise NotImplementedError
46
47
48class TCallbackTransport(TMessageSenderTransport):
49
50 def __init__(self, func):
51 TMessageSenderTransport.__init__(self)
52 self.func = func
53
54 def sendMessage(self, message):
55 self.func(message)
56
57
58class ThriftClientProtocol(basic.Int32StringReceiver):
59
60 def __init__(self, client_class, iprot_factory, oprot_factory=None):
61 self._client_class = client_class
62 self._iprot_factory = iprot_factory
63 if oprot_factory is None:
64 self._oprot_factory = iprot_factory
65 else:
66 self._oprot_factory = oprot_factory
67
68 self.recv_map = {}
Kevin Clarke8d3c472009-03-03 22:13:46 +000069 self.started = defer.Deferred()
Kevin Clarke43f7e02009-03-03 22:03:57 +000070
71 def dispatch(self, msg):
72 self.sendString(msg)
73
74 def connectionMade(self):
75 tmo = TCallbackTransport(self.dispatch)
76 self.client = self._client_class(tmo, self._oprot_factory)
Kevin Clarke8d3c472009-03-03 22:13:46 +000077 self.started.callback(self.client)
Kevin Clarke43f7e02009-03-03 22:03:57 +000078
79 def connectionLost(self, reason=connectionDone):
80 for k,v in self.client._reqs.iteritems():
81 tex = TTransport.TTransportException(
82 type=TTransport.TTransportException.END_OF_FILE,
83 message='Connection closed')
84 v.errback(tex)
85
86 def stringReceived(self, frame):
87 tr = TTransport.TMemoryBuffer(frame)
88 iprot = self._iprot_factory.getProtocol(tr)
89 (fname, mtype, rseqid) = iprot.readMessageBegin()
90
91 try:
92 method = self.recv_map[fname]
93 except KeyError:
94 method = getattr(self.client, 'recv_' + fname)
95 self.recv_map[fname] = method
96
97 method(iprot, mtype, rseqid)
98
99
100class ThriftServerProtocol(basic.Int32StringReceiver):
101
102 def dispatch(self, msg):
103 self.sendString(msg)
104
105 def processError(self, error):
106 self.transport.loseConnection()
107
108 def processOk(self, _, tmo):
109 msg = tmo.getvalue()
110
111 if len(msg) > 0:
112 self.dispatch(msg)
113
114 def stringReceived(self, frame):
115 tmi = TTransport.TMemoryBuffer(frame)
116 tmo = TTransport.TMemoryBuffer()
117
118 iprot = self.factory.iprot_factory.getProtocol(tmi)
119 oprot = self.factory.oprot_factory.getProtocol(tmo)
120
121 d = self.factory.processor.process(iprot, oprot)
122 d.addCallbacks(self.processOk, self.processError,
123 callbackArgs=(tmo,))
124
125
126class IThriftServerFactory(Interface):
127
128 processor = Attribute("Thrift processor")
129
130 iprot_factory = Attribute("Input protocol factory")
131
132 oprot_factory = Attribute("Output protocol factory")
133
134
135class IThriftClientFactory(Interface):
136
137 client_class = Attribute("Thrift client class")
138
139 iprot_factory = Attribute("Input protocol factory")
140
141 oprot_factory = Attribute("Output protocol factory")
142
143
144class ThriftServerFactory(ServerFactory):
145
146 implements(IThriftServerFactory)
147
148 protocol = ThriftServerProtocol
149
150 def __init__(self, processor, iprot_factory, oprot_factory=None):
151 self.processor = processor
152 self.iprot_factory = iprot_factory
153 if oprot_factory is None:
154 self.oprot_factory = iprot_factory
155 else:
156 self.oprot_factory = oprot_factory
157
158
159class ThriftClientFactory(ClientFactory):
160
161 implements(IThriftClientFactory)
162
163 protocol = ThriftClientProtocol
164
165 def __init__(self, client_class, iprot_factory, oprot_factory=None):
166 self.client_class = client_class
167 self.iprot_factory = iprot_factory
168 if oprot_factory is None:
169 self.oprot_factory = iprot_factory
170 else:
171 self.oprot_factory = oprot_factory
172
173 def buildProtocol(self, addr):
174 p = self.protocol(self.client_class, self.iprot_factory,
175 self.oprot_factory)
176 p.factory = self
177 return p