blob: c8498c51eaa0da670c52443275f7c79449ac5663 [file] [log] [blame]
Chris Piro20c81ad2013-03-07 11:32:48 -05001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
Roger Meierd52edba2014-08-07 17:03:47 +020020from __future__ import absolute_import
Chris Piro20c81ad2013-03-07 11:32:48 -050021import logging
22import socket
23import struct
24
Roger Meierd52edba2014-08-07 17:03:47 +020025from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
Chris Piro20c81ad2013-03-07 11:32:48 -050026
Roger Meierd52edba2014-08-07 17:03:47 +020027from io import BytesIO
28from collections import deque
29from contextlib import contextmanager
30from tornado import gen, iostream, ioloop, tcpserver, concurrent
31
32__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
Chris Piro20c81ad2013-03-07 11:32:48 -050033
34
Roger Meierd52edba2014-08-07 17:03:47 +020035class _Lock(object):
36 def __init__(self):
37 self._waiters = deque()
38
39 def acquired(self):
40 return len(self._waiters) > 0
41
42 @gen.coroutine
43 def acquire(self):
44 blocker = self._waiters[-1] if self.acquired() else None
45 future = concurrent.Future()
46 self._waiters.append(future)
47 if blocker:
48 yield blocker
49
50 raise gen.Return(self._lock_context())
51
52 def release(self):
53 assert self.acquired(), 'Lock not aquired'
54 future = self._waiters.popleft()
55 future.set_result(None)
56
57 @contextmanager
58 def _lock_context(self):
59 try:
60 yield
61 finally:
62 self.release()
63
64
65class TTornadoStreamTransport(TTransportBase):
Chris Piro20c81ad2013-03-07 11:32:48 -050066 """a framed, buffered transport over a Tornado stream"""
Roger Meierd52edba2014-08-07 17:03:47 +020067 def __init__(self, host, port, stream=None, io_loop=None):
Chris Piro20c81ad2013-03-07 11:32:48 -050068 self.host = host
69 self.port = port
Roger Meierd52edba2014-08-07 17:03:47 +020070 self.io_loop = io_loop or ioloop.IOLoop.current()
71 self.__wbuf = BytesIO()
72 self._read_lock = _Lock()
Chris Piro20c81ad2013-03-07 11:32:48 -050073
74 # servers provide a ready-to-go stream
75 self.stream = stream
Chris Piro20c81ad2013-03-07 11:32:48 -050076
Roger Meierd52edba2014-08-07 17:03:47 +020077 def with_timeout(self, timeout, future):
78 return gen.with_timeout(timeout, future, self.io_loop)
79
80 @gen.coroutine
81 def open(self, timeout=None):
Chris Piro20c81ad2013-03-07 11:32:48 -050082 logging.debug('socket connecting')
83 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
84 self.stream = iostream.IOStream(sock)
85
Roger Meierd52edba2014-08-07 17:03:47 +020086 try:
87 connect = self.stream.connect((self.host, self.port))
88 if timeout is not None:
89 yield self.with_timeout(timeout, connect)
90 else:
91 yield connect
92 except (socket.error, IOError, ioloop.TimeoutError) as e:
93 message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
Chris Piro20c81ad2013-03-07 11:32:48 -050094 raise TTransportException(
95 type=TTransportException.NOT_OPEN,
96 message=message)
Chris Piro20c81ad2013-03-07 11:32:48 -050097
Roger Meierd52edba2014-08-07 17:03:47 +020098 raise gen.Return(self)
Chris Piro20c81ad2013-03-07 11:32:48 -050099
Roger Meierd52edba2014-08-07 17:03:47 +0200100 def set_close_callback(self, callback):
101 """
102 Should be called only after open() returns
103 """
104 self.stream.set_close_callback(callback)
Chris Piro20c81ad2013-03-07 11:32:48 -0500105
106 def close(self):
107 # don't raise if we intend to close
108 self.stream.set_close_callback(None)
109 self.stream.close()
110
111 def read(self, _):
112 # The generated code for Tornado shouldn't do individual reads -- only
113 # frames at a time
Roger Meierd52edba2014-08-07 17:03:47 +0200114 assert False, "you're doing it wrong"
Chris Piro20c81ad2013-03-07 11:32:48 -0500115
Roger Meierd52edba2014-08-07 17:03:47 +0200116 @contextmanager
117 def io_exception_context(self):
118 try:
119 yield
120 except (socket.error, IOError) as e:
121 raise TTransportException(
122 type=TTransportException.END_OF_FILE,
123 message=str(e))
124 except iostream.StreamBufferFullError as e:
125 raise TTransportException(
126 type=TTransportException.UNKNOWN,
127 message=str(e))
Chris Piro20c81ad2013-03-07 11:32:48 -0500128
Roger Meierd52edba2014-08-07 17:03:47 +0200129 @gen.coroutine
130 def readFrame(self):
131 # IOStream processes reads one at a time
132 with (yield self._read_lock.acquire()):
133 with self.io_exception_context():
134 frame_header = yield self.stream.read_bytes(4)
135 if len(frame_header) == 0:
136 raise iostream.StreamClosedError('Read zero bytes from stream')
137 frame_length, = struct.unpack('!i', frame_header)
138 logging.debug('received frame header, frame length = %d', frame_length)
139 frame = yield self.stream.read_bytes(frame_length)
140 logging.debug('received frame payload: %r', frame)
141 raise gen.Return(frame)
Chris Piro20c81ad2013-03-07 11:32:48 -0500142
143 def write(self, buf):
144 self.__wbuf.write(buf)
145
Roger Meierd52edba2014-08-07 17:03:47 +0200146 def flush(self):
147 frame = self.__wbuf.getvalue()
Chris Piro20c81ad2013-03-07 11:32:48 -0500148 # reset wbuf before write/flush to preserve state on underlying failure
Roger Meierd52edba2014-08-07 17:03:47 +0200149 frame_length = struct.pack('!i', len(frame))
150 self.__wbuf = BytesIO()
151 with self.io_exception_context():
152 return self.stream.write(frame_length + frame)
Chris Piro20c81ad2013-03-07 11:32:48 -0500153
154
henrique3e25e5e2013-11-08 19:06:21 +0100155class TTornadoServer(tcpserver.TCPServer):
Chris Piro20c81ad2013-03-07 11:32:48 -0500156 def __init__(self, processor, iprot_factory, oprot_factory=None,
157 *args, **kwargs):
158 super(TTornadoServer, self).__init__(*args, **kwargs)
159
160 self._processor = processor
161 self._iprot_factory = iprot_factory
162 self._oprot_factory = (oprot_factory if oprot_factory is not None
163 else iprot_factory)
164
Roger Meierd52edba2014-08-07 17:03:47 +0200165 @gen.coroutine
Chris Piro20c81ad2013-03-07 11:32:48 -0500166 def handle_stream(self, stream, address):
Roger Meierd52edba2014-08-07 17:03:47 +0200167 host, port = address
168 trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
169 io_loop=self.io_loop)
170 oprot = self._oprot_factory.getProtocol(trans)
171
Chris Piro20c81ad2013-03-07 11:32:48 -0500172 try:
Roger Meierd52edba2014-08-07 17:03:47 +0200173 while not trans.stream.closed():
174 frame = yield trans.readFrame()
175 tr = TMemoryBuffer(frame)
176 iprot = self._iprot_factory.getProtocol(tr)
177 yield self._processor.process(iprot, oprot)
Chris Piro20c81ad2013-03-07 11:32:48 -0500178 except Exception:
179 logging.exception('thrift exception in handle_stream')
180 trans.close()
Roger Meierd52edba2014-08-07 17:03:47 +0200181
182 logging.info('client disconnected %s:%d', host, port)