blob: deec708a1fd0ff971338e996e22988a91c77db85 [file] [log] [blame]
David Reissea2cba82009-03-30 21:35:00 +00001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
David Reiss74421272008-11-07 23:09:31 +000019"""Implementation of non-blocking server.
20
21The main idea of the server is reciving and sending requests
22only from main thread.
23
24It also makes thread pool server in tasks terms, not connections.
25"""
26import threading
27import socket
28import Queue
29import select
30import struct
31import logging
32
33from thrift.transport import TTransport
34from thrift.protocol.TBinaryProtocol import TBinaryProtocolFactory
35
36__all__ = ['TNonblockingServer']
37
38class Worker(threading.Thread):
39 """Worker is a small helper to process incoming connection."""
40 def __init__(self, queue):
41 threading.Thread.__init__(self)
42 self.queue = queue
43
44 def run(self):
45 """Process queries from task queue, stop if processor is None."""
46 while True:
47 try:
48 processor, iprot, oprot, otrans, callback = self.queue.get()
49 if processor is None:
50 break
51 processor.process(iprot, oprot)
52 callback(True, otrans.getvalue())
53 except Exception:
54 logging.exception("Exception while processing request")
55 callback(False, '')
56
57WAIT_LEN = 0
58WAIT_MESSAGE = 1
59WAIT_PROCESS = 2
60SEND_ANSWER = 3
61CLOSED = 4
62
63def locked(func):
64 "Decorator which locks self.lock."
65 def nested(self, *args, **kwargs):
66 self.lock.acquire()
67 try:
68 return func(self, *args, **kwargs)
69 finally:
70 self.lock.release()
71 return nested
72
73def socket_exception(func):
74 "Decorator close object on socket.error."
75 def read(self, *args, **kwargs):
76 try:
77 return func(self, *args, **kwargs)
78 except socket.error:
79 self.close()
80 return read
81
82class Connection:
83 """Basic class is represented connection.
84
85 It can be in state:
86 WAIT_LEN --- connection is reading request len.
87 WAIT_MESSAGE --- connection is reading request.
88 WAIT_PROCESS --- connection has just read whole request and
89 waits for call ready routine.
90 SEND_ANSWER --- connection is sending answer string (including length
91 of answer).
92 CLOSED --- socket was closed and connection should be deleted.
93 """
94 def __init__(self, new_socket, wake_up):
95 self.socket = new_socket
96 self.socket.setblocking(False)
97 self.status = WAIT_LEN
98 self.len = 0
99 self.message = ''
100 self.lock = threading.Lock()
101 self.wake_up = wake_up
102
103 def _read_len(self):
104 """Reads length of request.
105
106 It's really paranoic routine and it may be replaced by
107 self.socket.recv(4)."""
108 read = self.socket.recv(4 - len(self.message))
109 if len(read) == 0:
110 # if we read 0 bytes and self.message is empty, it means client close
111 # connection
112 if len(self.message) != 0:
113 logging.error("can't read frame size from socket")
114 self.close()
115 return
116 self.message += read
117 if len(self.message) == 4:
118 self.len, = struct.unpack('!i', self.message)
119 if self.len < 0:
120 logging.error("negative frame size, it seems client"\
121 " doesn't use FramedTransport")
122 self.close()
123 elif self.len == 0:
124 logging.error("empty frame, it's really strange")
125 self.close()
126 else:
127 self.message = ''
128 self.status = WAIT_MESSAGE
129
130 @socket_exception
131 def read(self):
132 """Reads data from stream and switch state."""
133 assert self.status in (WAIT_LEN, WAIT_MESSAGE)
134 if self.status == WAIT_LEN:
135 self._read_len()
136 # go back to the main loop here for simplicity instead of
137 # falling through, even though there is a good chance that
138 # the message is already available
139 elif self.status == WAIT_MESSAGE:
140 read = self.socket.recv(self.len - len(self.message))
141 if len(read) == 0:
142 logging.error("can't read frame from socket (get %d of %d bytes)" %
143 (len(self.message), self.len))
144 self.close()
145 return
146 self.message += read
147 if len(self.message) == self.len:
148 self.status = WAIT_PROCESS
149
150 @socket_exception
151 def write(self):
152 """Writes data from socket and switch state."""
153 assert self.status == SEND_ANSWER
154 sent = self.socket.send(self.message)
155 if sent == len(self.message):
156 self.status = WAIT_LEN
157 self.message = ''
158 self.len = 0
159 else:
160 self.message = self.message[sent:]
161
162 @locked
163 def ready(self, all_ok, message):
164 """Callback function for switching state and waking up main thread.
165
166 This function is the only function witch can be called asynchronous.
167
168 The ready can switch Connection to three states:
David Reiss6ce401d2009-03-24 20:01:58 +0000169 WAIT_LEN if request was oneway.
David Reiss74421272008-11-07 23:09:31 +0000170 SEND_ANSWER if request was processed in normal way.
171 CLOSED if request throws unexpected exception.
172
173 The one wakes up main thread.
174 """
175 assert self.status == WAIT_PROCESS
176 if not all_ok:
177 self.close()
178 self.wake_up()
179 return
180 self.len = ''
181 self.message = struct.pack('!i', len(message)) + message
182 if len(message) == 0:
David Reissc51986f2009-03-24 20:01:25 +0000183 # it was a oneway request, do not write answer
David Reiss74421272008-11-07 23:09:31 +0000184 self.status = WAIT_LEN
185 else:
186 self.status = SEND_ANSWER
187 self.wake_up()
188
189 @locked
190 def is_writeable(self):
191 "Returns True if connection should be added to write list of select."
192 return self.status == SEND_ANSWER
193
194 # it's not necessary, but...
195 @locked
196 def is_readable(self):
197 "Returns True if connection should be added to read list of select."
198 return self.status in (WAIT_LEN, WAIT_MESSAGE)
199
200 @locked
201 def is_closed(self):
202 "Returns True if connection is closed."
203 return self.status == CLOSED
204
205 def fileno(self):
206 "Returns the file descriptor of the associated socket."
207 return self.socket.fileno()
208
209 def close(self):
210 "Closes connection"
211 self.status = CLOSED
212 self.socket.close()
213
214class TNonblockingServer:
215 """Non-blocking server."""
216 def __init__(self, processor, lsocket, inputProtocolFactory=None,
217 outputProtocolFactory=None, threads=10):
218 self.processor = processor
219 self.socket = lsocket
220 self.in_protocol = inputProtocolFactory or TBinaryProtocolFactory()
221 self.out_protocol = outputProtocolFactory or self.in_protocol
222 self.threads = int(threads)
223 self.clients = {}
224 self.tasks = Queue.Queue()
225 self._read, self._write = socket.socketpair()
226 self.prepared = False
227
228 def setNumThreads(self, num):
229 """Set the number of worker threads that should be created."""
230 # implement ThreadPool interface
231 assert not self.prepared, "You can't change number of threads for working server"
232 self.threads = num
233
234 def prepare(self):
235 """Prepares server for serve requests."""
236 self.socket.listen()
237 for _ in xrange(self.threads):
238 thread = Worker(self.tasks)
239 thread.setDaemon(True)
240 thread.start()
241 self.prepared = True
242
243 def wake_up(self):
244 """Wake up main thread.
245
246 The server usualy waits in select call in we should terminate one.
247 The simplest way is using socketpair.
248
249 Select always wait to read from the first socket of socketpair.
250
251 In this case, we can just write anything to the second socket from
252 socketpair."""
253 self._write.send('1')
254
255 def _select(self):
256 """Does select on open connections."""
257 readable = [self.socket.handle.fileno(), self._read.fileno()]
258 writable = []
259 for i, connection in self.clients.items():
260 if connection.is_readable():
261 readable.append(connection.fileno())
262 if connection.is_writeable():
263 writable.append(connection.fileno())
264 if connection.is_closed():
265 del self.clients[i]
266 return select.select(readable, writable, readable)
267
268 def handle(self):
269 """Handle requests.
270
271 WARNING! You must call prepare BEFORE calling handle.
272 """
273 assert self.prepared, "You have to call prepare before handle"
274 rset, wset, xset = self._select()
275 for readable in rset:
276 if readable == self._read.fileno():
277 # don't care i just need to clean readable flag
278 self._read.recv(1024)
279 elif readable == self.socket.handle.fileno():
280 client = self.socket.accept().handle
281 self.clients[client.fileno()] = Connection(client, self.wake_up)
282 else:
283 connection = self.clients[readable]
284 connection.read()
285 if connection.status == WAIT_PROCESS:
286 itransport = TTransport.TMemoryBuffer(connection.message)
287 otransport = TTransport.TMemoryBuffer()
288 iprot = self.in_protocol.getProtocol(itransport)
289 oprot = self.out_protocol.getProtocol(otransport)
290 self.tasks.put([self.processor, iprot, oprot,
291 otransport, connection.ready])
292 for writeable in wset:
293 self.clients[writeable].write()
294 for oob in xset:
295 self.clients[oob].close()
296 del self.clients[oob]
297
298 def close(self):
299 """Closes the server."""
300 for _ in xrange(self.threads):
301 self.tasks.put([None, None, None, None, None])
302 self.socket.close()
303 self.prepared = False
304
305 def serve(self):
306 """Serve forever."""
307 self.prepare()
308 while True:
309 self.handle()