blob: 619ea7002575acabd8c9957f2e190377b05281db [file] [log] [blame]
Kevin Clarke43f7e02009-03-03 22:03:57 +00001from zope.interface import implements, Interface, Attribute
2from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory,
3 connectionDone
4from twisted.protocols import basic
5from twisted.python import log
6from thrift.transport import TTransport
7from cStringIO import StringIO
8
9
10class TMessageSenderTransport(TTransport.TTransportBase):
11
12 def __init__(self):
13 self.__wbuf = StringIO()
14
15 def write(self, buf):
16 self.__wbuf.write(buf)
17
18 def flush(self):
19 msg = self.__wbuf.getvalue()
20 self.__wbuf = StringIO()
21 self.sendMessage(msg)
22
23 def sendMessage(self, message):
24 raise NotImplementedError
25
26
27class TCallbackTransport(TMessageSenderTransport):
28
29 def __init__(self, func):
30 TMessageSenderTransport.__init__(self)
31 self.func = func
32
33 def sendMessage(self, message):
34 self.func(message)
35
36
37class ThriftClientProtocol(basic.Int32StringReceiver):
38
39 def __init__(self, client_class, iprot_factory, oprot_factory=None):
40 self._client_class = client_class
41 self._iprot_factory = iprot_factory
42 if oprot_factory is None:
43 self._oprot_factory = iprot_factory
44 else:
45 self._oprot_factory = oprot_factory
46
47 self.recv_map = {}
48
49 def dispatch(self, msg):
50 self.sendString(msg)
51
52 def connectionMade(self):
53 tmo = TCallbackTransport(self.dispatch)
54 self.client = self._client_class(tmo, self._oprot_factory)
55
56 def connectionLost(self, reason=connectionDone):
57 for k,v in self.client._reqs.iteritems():
58 tex = TTransport.TTransportException(
59 type=TTransport.TTransportException.END_OF_FILE,
60 message='Connection closed')
61 v.errback(tex)
62
63 def stringReceived(self, frame):
64 tr = TTransport.TMemoryBuffer(frame)
65 iprot = self._iprot_factory.getProtocol(tr)
66 (fname, mtype, rseqid) = iprot.readMessageBegin()
67
68 try:
69 method = self.recv_map[fname]
70 except KeyError:
71 method = getattr(self.client, 'recv_' + fname)
72 self.recv_map[fname] = method
73
74 method(iprot, mtype, rseqid)
75
76
77class ThriftServerProtocol(basic.Int32StringReceiver):
78
79 def dispatch(self, msg):
80 self.sendString(msg)
81
82 def processError(self, error):
83 self.transport.loseConnection()
84
85 def processOk(self, _, tmo):
86 msg = tmo.getvalue()
87
88 if len(msg) > 0:
89 self.dispatch(msg)
90
91 def stringReceived(self, frame):
92 tmi = TTransport.TMemoryBuffer(frame)
93 tmo = TTransport.TMemoryBuffer()
94
95 iprot = self.factory.iprot_factory.getProtocol(tmi)
96 oprot = self.factory.oprot_factory.getProtocol(tmo)
97
98 d = self.factory.processor.process(iprot, oprot)
99 d.addCallbacks(self.processOk, self.processError,
100 callbackArgs=(tmo,))
101
102
103class IThriftServerFactory(Interface):
104
105 processor = Attribute("Thrift processor")
106
107 iprot_factory = Attribute("Input protocol factory")
108
109 oprot_factory = Attribute("Output protocol factory")
110
111
112class IThriftClientFactory(Interface):
113
114 client_class = Attribute("Thrift client class")
115
116 iprot_factory = Attribute("Input protocol factory")
117
118 oprot_factory = Attribute("Output protocol factory")
119
120
121class ThriftServerFactory(ServerFactory):
122
123 implements(IThriftServerFactory)
124
125 protocol = ThriftServerProtocol
126
127 def __init__(self, processor, iprot_factory, oprot_factory=None):
128 self.processor = processor
129 self.iprot_factory = iprot_factory
130 if oprot_factory is None:
131 self.oprot_factory = iprot_factory
132 else:
133 self.oprot_factory = oprot_factory
134
135
136class ThriftClientFactory(ClientFactory):
137
138 implements(IThriftClientFactory)
139
140 protocol = ThriftClientProtocol
141
142 def __init__(self, client_class, iprot_factory, oprot_factory=None):
143 self.client_class = client_class
144 self.iprot_factory = iprot_factory
145 if oprot_factory is None:
146 self.oprot_factory = iprot_factory
147 else:
148 self.oprot_factory = oprot_factory
149
150 def buildProtocol(self, addr):
151 p = self.protocol(self.client_class, self.iprot_factory,
152 self.oprot_factory)
153 p.factory = self
154 return p