blob: 816827cc8a907efc9bb51f642d400c948c7f9414 [file] [log] [blame]
David Reiss74421272008-11-07 23:09:31 +00001"""Implementation of non-blocking server.
2
3The main idea of the server is reciving and sending requests
4only from main thread.
5
6It also makes thread pool server in tasks terms, not connections.
7"""
8import threading
9import socket
10import Queue
11import select
12import struct
13import logging
14
15from thrift.transport import TTransport
16from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
17
18__all__ = ['TNonblockingServer']
19
20class Worker(threading.Thread):
21 """Worker is a small helper to process incoming connection."""
22 def __init__(self, queue):
23 threading.Thread.__init__(self)
24 self.queue = queue
25
26 def run(self):
27 """Process queries from task queue, stop if processor is None."""
28 while True:
29 try:
30 processor, iprot, oprot, otrans, callback = self.queue.get()
31 if processor is None:
32 break
33 processor.process(iprot, oprot)
34 callback(True, otrans.getvalue())
35 except Exception:
36 logging.exception("Exception while processing request")
37 callback(False, '')
38
39WAIT_LEN = 0
40WAIT_MESSAGE = 1
41WAIT_PROCESS = 2
42SEND_ANSWER = 3
43CLOSED = 4
44
45def locked(func):
46 "Decorator which locks self.lock."
47 def nested(self, *args, **kwargs):
48 self.lock.acquire()
49 try:
50 return func(self, *args, **kwargs)
51 finally:
52 self.lock.release()
53 return nested
54
55def socket_exception(func):
56 "Decorator close object on socket.error."
57 def read(self, *args, **kwargs):
58 try:
59 return func(self, *args, **kwargs)
60 except socket.error:
61 self.close()
62 return read
63
64class Connection:
65 """Basic class is represented connection.
66
67 It can be in state:
68 WAIT_LEN --- connection is reading request len.
69 WAIT_MESSAGE --- connection is reading request.
70 WAIT_PROCESS --- connection has just read whole request and
71 waits for call ready routine.
72 SEND_ANSWER --- connection is sending answer string (including length
73 of answer).
74 CLOSED --- socket was closed and connection should be deleted.
75 """
76 def __init__(self, new_socket, wake_up):
77 self.socket = new_socket
78 self.socket.setblocking(False)
79 self.status = WAIT_LEN
80 self.len = 0
81 self.message = ''
82 self.lock = threading.Lock()
83 self.wake_up = wake_up
84
85 def _read_len(self):
86 """Reads length of request.
87
88 It's really paranoic routine and it may be replaced by
89 self.socket.recv(4)."""
90 read = self.socket.recv(4 - len(self.message))
91 if len(read) == 0:
92 # if we read 0 bytes and self.message is empty, it means client close
93 # connection
94 if len(self.message) != 0:
95 logging.error("can't read frame size from socket")
96 self.close()
97 return
98 self.message += read
99 if len(self.message) == 4:
100 self.len, = struct.unpack('!i', self.message)
101 if self.len < 0:
102 logging.error("negative frame size, it seems client"\
103 " doesn't use FramedTransport")
104 self.close()
105 elif self.len == 0:
106 logging.error("empty frame, it's really strange")
107 self.close()
108 else:
109 self.message = ''
110 self.status = WAIT_MESSAGE
111
112 @socket_exception
113 def read(self):
114 """Reads data from stream and switch state."""
115 assert self.status in (WAIT_LEN, WAIT_MESSAGE)
116 if self.status == WAIT_LEN:
117 self._read_len()
118 # go back to the main loop here for simplicity instead of
119 # falling through, even though there is a good chance that
120 # the message is already available
121 elif self.status == WAIT_MESSAGE:
122 read = self.socket.recv(self.len - len(self.message))
123 if len(read) == 0:
124 logging.error("can't read frame from socket (get %d of %d bytes)" %
125 (len(self.message), self.len))
126 self.close()
127 return
128 self.message += read
129 if len(self.message) == self.len:
130 self.status = WAIT_PROCESS
131
132 @socket_exception
133 def write(self):
134 """Writes data from socket and switch state."""
135 assert self.status == SEND_ANSWER
136 sent = self.socket.send(self.message)
137 if sent == len(self.message):
138 self.status = WAIT_LEN
139 self.message = ''
140 self.len = 0
141 else:
142 self.message = self.message[sent:]
143
144 @locked
145 def ready(self, all_ok, message):
146 """Callback function for switching state and waking up main thread.
147
148 This function is the only function witch can be called asynchronous.
149
150 The ready can switch Connection to three states:
151 WAIT_LEN if request was async.
152 SEND_ANSWER if request was processed in normal way.
153 CLOSED if request throws unexpected exception.
154
155 The one wakes up main thread.
156 """
157 assert self.status == WAIT_PROCESS
158 if not all_ok:
159 self.close()
160 self.wake_up()
161 return
162 self.len = ''
163 self.message = struct.pack('!i', len(message)) + message
164 if len(message) == 0:
David Reissc51986f2009-03-24 20:01:25 +0000165 # it was a oneway request, do not write answer
David Reiss74421272008-11-07 23:09:31 +0000166 self.status = WAIT_LEN
167 else:
168 self.status = SEND_ANSWER
169 self.wake_up()
170
171 @locked
172 def is_writeable(self):
173 "Returns True if connection should be added to write list of select."
174 return self.status == SEND_ANSWER
175
176 # it's not necessary, but...
177 @locked
178 def is_readable(self):
179 "Returns True if connection should be added to read list of select."
180 return self.status in (WAIT_LEN, WAIT_MESSAGE)
181
182 @locked
183 def is_closed(self):
184 "Returns True if connection is closed."
185 return self.status == CLOSED
186
187 def fileno(self):
188 "Returns the file descriptor of the associated socket."
189 return self.socket.fileno()
190
191 def close(self):
192 "Closes connection"
193 self.status = CLOSED
194 self.socket.close()
195
196class TNonblockingServer:
197 """Non-blocking server."""
198 def __init__(self, processor, lsocket, inputProtocolFactory=None,
199 outputProtocolFactory=None, threads=10):
200 self.processor = processor
201 self.socket = lsocket
202 self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
203 self.out_protocol = outputProtocolFactory or self.in_protocol
204 self.threads = int(threads)
205 self.clients = {}
206 self.tasks = Queue.Queue()
207 self._read, self._write = socket.socketpair()
208 self.prepared = False
209
210 def setNumThreads(self, num):
211 """Set the number of worker threads that should be created."""
212 # implement ThreadPool interface
213 assert not self.prepared, "You can't change number of threads for working server"
214 self.threads = num
215
216 def prepare(self):
217 """Prepares server for serve requests."""
218 self.socket.listen()
219 for _ in xrange(self.threads):
220 thread = Worker(self.tasks)
221 thread.setDaemon(True)
222 thread.start()
223 self.prepared = True
224
225 def wake_up(self):
226 """Wake up main thread.
227
228 The server usualy waits in select call in we should terminate one.
229 The simplest way is using socketpair.
230
231 Select always wait to read from the first socket of socketpair.
232
233 In this case, we can just write anything to the second socket from
234 socketpair."""
235 self._write.send('1')
236
237 def _select(self):
238 """Does select on open connections."""
239 readable = [self.socket.handle.fileno(), self._read.fileno()]
240 writable = []
241 for i, connection in self.clients.items():
242 if connection.is_readable():
243 readable.append(connection.fileno())
244 if connection.is_writeable():
245 writable.append(connection.fileno())
246 if connection.is_closed():
247 del self.clients[i]
248 return select.select(readable, writable, readable)
249
250 def handle(self):
251 """Handle requests.
252
253 WARNING! You must call prepare BEFORE calling handle.
254 """
255 assert self.prepared, "You have to call prepare before handle"
256 rset, wset, xset = self._select()
257 for readable in rset:
258 if readable == self._read.fileno():
259 # don't care i just need to clean readable flag
260 self._read.recv(1024)
261 elif readable == self.socket.handle.fileno():
262 client = self.socket.accept().handle
263 self.clients[client.fileno()] = Connection(client, self.wake_up)
264 else:
265 connection = self.clients[readable]
266 connection.read()
267 if connection.status == WAIT_PROCESS:
268 itransport = TTransport.TMemoryBuffer(connection.message)
269 otransport = TTransport.TMemoryBuffer()
270 iprot = self.in_protocol.getProtocol(itransport)
271 oprot = self.out_protocol.getProtocol(otransport)
272 self.tasks.put([self.processor, iprot, oprot,
273 otransport, connection.ready])
274 for writeable in wset:
275 self.clients[writeable].write()
276 for oob in xset:
277 self.clients[oob].close()
278 del self.clients[oob]
279
280 def close(self):
281 """Closes the server."""
282 for _ in xrange(self.threads):
283 self.tasks.put([None, None, None, None, None])
284 self.socket.close()
285 self.prepared = False
286
287 def serve(self):
288 """Serve forever."""
289 self.prepare()
290 while True:
291 self.handle()