blob: 444b768217fb17d46bf861df4ed3a921467768fa [file] [log] [blame]
Kevin Clarke43f7e02009-03-03 22:03:57 +00001from zope.interface import implements, Interface, Attribute
Kevin Clarke8d3c472009-03-03 22:13:46 +00002from twisted.internet.protocol import Protocol, ServerFactory, ClientFactory, \
Kevin Clarke43f7e02009-03-03 22:03:57 +00003 connectionDone
Kevin Clarke8d3c472009-03-03 22:13:46 +00004from twisted.internet import defer
Kevin Clarke43f7e02009-03-03 22:03:57 +00005from twisted.protocols import basic
6from twisted.python import log
7from thrift.transport import TTransport
8from cStringIO import StringIO
9
10
11class TMessageSenderTransport(TTransport.TTransportBase):
12
13 def __init__(self):
14 self.__wbuf = StringIO()
15
16 def write(self, buf):
17 self.__wbuf.write(buf)
18
19 def flush(self):
20 msg = self.__wbuf.getvalue()
21 self.__wbuf = StringIO()
22 self.sendMessage(msg)
23
24 def sendMessage(self, message):
25 raise NotImplementedError
26
27
28class TCallbackTransport(TMessageSenderTransport):
29
30 def __init__(self, func):
31 TMessageSenderTransport.__init__(self)
32 self.func = func
33
34 def sendMessage(self, message):
35 self.func(message)
36
37
38class ThriftClientProtocol(basic.Int32StringReceiver):
39
40 def __init__(self, client_class, iprot_factory, oprot_factory=None):
41 self._client_class = client_class
42 self._iprot_factory = iprot_factory
43 if oprot_factory is None:
44 self._oprot_factory = iprot_factory
45 else:
46 self._oprot_factory = oprot_factory
47
48 self.recv_map = {}
Kevin Clarke8d3c472009-03-03 22:13:46 +000049 self.started = defer.Deferred()
Kevin Clarke43f7e02009-03-03 22:03:57 +000050
51 def dispatch(self, msg):
52 self.sendString(msg)
53
54 def connectionMade(self):
55 tmo = TCallbackTransport(self.dispatch)
56 self.client = self._client_class(tmo, self._oprot_factory)
Kevin Clarke8d3c472009-03-03 22:13:46 +000057 self.started.callback(self.client)
Kevin Clarke43f7e02009-03-03 22:03:57 +000058
59 def connectionLost(self, reason=connectionDone):
60 for k,v in self.client._reqs.iteritems():
61 tex = TTransport.TTransportException(
62 type=TTransport.TTransportException.END_OF_FILE,
63 message='Connection closed')
64 v.errback(tex)
65
66 def stringReceived(self, frame):
67 tr = TTransport.TMemoryBuffer(frame)
68 iprot = self._iprot_factory.getProtocol(tr)
69 (fname, mtype, rseqid) = iprot.readMessageBegin()
70
71 try:
72 method = self.recv_map[fname]
73 except KeyError:
74 method = getattr(self.client, 'recv_' + fname)
75 self.recv_map[fname] = method
76
77 method(iprot, mtype, rseqid)
78
79
80class ThriftServerProtocol(basic.Int32StringReceiver):
81
82 def dispatch(self, msg):
83 self.sendString(msg)
84
85 def processError(self, error):
86 self.transport.loseConnection()
87
88 def processOk(self, _, tmo):
89 msg = tmo.getvalue()
90
91 if len(msg) > 0:
92 self.dispatch(msg)
93
94 def stringReceived(self, frame):
95 tmi = TTransport.TMemoryBuffer(frame)
96 tmo = TTransport.TMemoryBuffer()
97
98 iprot = self.factory.iprot_factory.getProtocol(tmi)
99 oprot = self.factory.oprot_factory.getProtocol(tmo)
100
101 d = self.factory.processor.process(iprot, oprot)
102 d.addCallbacks(self.processOk, self.processError,
103 callbackArgs=(tmo,))
104
105
106class IThriftServerFactory(Interface):
107
108 processor = Attribute("Thrift processor")
109
110 iprot_factory = Attribute("Input protocol factory")
111
112 oprot_factory = Attribute("Output protocol factory")
113
114
115class IThriftClientFactory(Interface):
116
117 client_class = Attribute("Thrift client class")
118
119 iprot_factory = Attribute("Input protocol factory")
120
121 oprot_factory = Attribute("Output protocol factory")
122
123
124class ThriftServerFactory(ServerFactory):
125
126 implements(IThriftServerFactory)
127
128 protocol = ThriftServerProtocol
129
130 def __init__(self, processor, iprot_factory, oprot_factory=None):
131 self.processor = processor
132 self.iprot_factory = iprot_factory
133 if oprot_factory is None:
134 self.oprot_factory = iprot_factory
135 else:
136 self.oprot_factory = oprot_factory
137
138
139class ThriftClientFactory(ClientFactory):
140
141 implements(IThriftClientFactory)
142
143 protocol = ThriftClientProtocol
144
145 def __init__(self, client_class, iprot_factory, oprot_factory=None):
146 self.client_class = client_class
147 self.iprot_factory = iprot_factory
148 if oprot_factory is None:
149 self.oprot_factory = iprot_factory
150 else:
151 self.oprot_factory = oprot_factory
152
153 def buildProtocol(self, addr):
154 p = self.protocol(self.client_class, self.iprot_factory,
155 self.oprot_factory)
156 p.factory = self
157 return p