blob: ef3e0f21ce7a0caadc4ed9bb3c9d88ea1332a240 [file] [log] [blame]
Chris Piro20c81ad2013-03-07 11:32:48 -05001#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
Roger Meierd52edba2014-08-07 17:03:47 +020020from __future__ import absolute_import
Chris Piro20c81ad2013-03-07 11:32:48 -050021import socket
22import struct
23
Konrad Grochowski3a724e32014-08-12 11:48:29 -040024import logging
25logger = logging.getLogger(__name__)
26
Roger Meierd52edba2014-08-07 17:03:47 +020027from thrift.transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer
Chris Piro20c81ad2013-03-07 11:32:48 -050028
Roger Meierd52edba2014-08-07 17:03:47 +020029from io import BytesIO
30from collections import deque
31from contextlib import contextmanager
32from tornado import gen, iostream, ioloop, tcpserver, concurrent
33
34__all__ = ['TTornadoServer', 'TTornadoStreamTransport']
Chris Piro20c81ad2013-03-07 11:32:48 -050035
36
Roger Meierd52edba2014-08-07 17:03:47 +020037class _Lock(object):
38 def __init__(self):
39 self._waiters = deque()
40
41 def acquired(self):
42 return len(self._waiters) > 0
43
44 @gen.coroutine
45 def acquire(self):
46 blocker = self._waiters[-1] if self.acquired() else None
47 future = concurrent.Future()
48 self._waiters.append(future)
49 if blocker:
50 yield blocker
51
52 raise gen.Return(self._lock_context())
53
54 def release(self):
55 assert self.acquired(), 'Lock not aquired'
56 future = self._waiters.popleft()
57 future.set_result(None)
58
59 @contextmanager
60 def _lock_context(self):
61 try:
62 yield
63 finally:
64 self.release()
65
66
67class TTornadoStreamTransport(TTransportBase):
Chris Piro20c81ad2013-03-07 11:32:48 -050068 """a framed, buffered transport over a Tornado stream"""
Roger Meierd52edba2014-08-07 17:03:47 +020069 def __init__(self, host, port, stream=None, io_loop=None):
Chris Piro20c81ad2013-03-07 11:32:48 -050070 self.host = host
71 self.port = port
Roger Meierd52edba2014-08-07 17:03:47 +020072 self.io_loop = io_loop or ioloop.IOLoop.current()
73 self.__wbuf = BytesIO()
74 self._read_lock = _Lock()
Chris Piro20c81ad2013-03-07 11:32:48 -050075
76 # servers provide a ready-to-go stream
77 self.stream = stream
Chris Piro20c81ad2013-03-07 11:32:48 -050078
Roger Meierd52edba2014-08-07 17:03:47 +020079 def with_timeout(self, timeout, future):
80 return gen.with_timeout(timeout, future, self.io_loop)
81
82 @gen.coroutine
83 def open(self, timeout=None):
Konrad Grochowski3a724e32014-08-12 11:48:29 -040084 logger.debug('socket connecting')
Chris Piro20c81ad2013-03-07 11:32:48 -050085 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
86 self.stream = iostream.IOStream(sock)
87
Roger Meierd52edba2014-08-07 17:03:47 +020088 try:
89 connect = self.stream.connect((self.host, self.port))
90 if timeout is not None:
91 yield self.with_timeout(timeout, connect)
92 else:
93 yield connect
94 except (socket.error, IOError, ioloop.TimeoutError) as e:
95 message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e)
Chris Piro20c81ad2013-03-07 11:32:48 -050096 raise TTransportException(
97 type=TTransportException.NOT_OPEN,
98 message=message)
Chris Piro20c81ad2013-03-07 11:32:48 -050099
Roger Meierd52edba2014-08-07 17:03:47 +0200100 raise gen.Return(self)
Chris Piro20c81ad2013-03-07 11:32:48 -0500101
Roger Meierd52edba2014-08-07 17:03:47 +0200102 def set_close_callback(self, callback):
103 """
104 Should be called only after open() returns
105 """
106 self.stream.set_close_callback(callback)
Chris Piro20c81ad2013-03-07 11:32:48 -0500107
108 def close(self):
109 # don't raise if we intend to close
110 self.stream.set_close_callback(None)
111 self.stream.close()
112
113 def read(self, _):
114 # The generated code for Tornado shouldn't do individual reads -- only
115 # frames at a time
Roger Meierd52edba2014-08-07 17:03:47 +0200116 assert False, "you're doing it wrong"
Chris Piro20c81ad2013-03-07 11:32:48 -0500117
Roger Meierd52edba2014-08-07 17:03:47 +0200118 @contextmanager
119 def io_exception_context(self):
120 try:
121 yield
122 except (socket.error, IOError) as e:
123 raise TTransportException(
124 type=TTransportException.END_OF_FILE,
125 message=str(e))
126 except iostream.StreamBufferFullError as e:
127 raise TTransportException(
128 type=TTransportException.UNKNOWN,
129 message=str(e))
Chris Piro20c81ad2013-03-07 11:32:48 -0500130
Roger Meierd52edba2014-08-07 17:03:47 +0200131 @gen.coroutine
132 def readFrame(self):
133 # IOStream processes reads one at a time
134 with (yield self._read_lock.acquire()):
135 with self.io_exception_context():
136 frame_header = yield self.stream.read_bytes(4)
137 if len(frame_header) == 0:
138 raise iostream.StreamClosedError('Read zero bytes from stream')
139 frame_length, = struct.unpack('!i', frame_header)
Roger Meierd52edba2014-08-07 17:03:47 +0200140 frame = yield self.stream.read_bytes(frame_length)
Roger Meierd52edba2014-08-07 17:03:47 +0200141 raise gen.Return(frame)
Chris Piro20c81ad2013-03-07 11:32:48 -0500142
143 def write(self, buf):
144 self.__wbuf.write(buf)
145
Roger Meierd52edba2014-08-07 17:03:47 +0200146 def flush(self):
147 frame = self.__wbuf.getvalue()
Chris Piro20c81ad2013-03-07 11:32:48 -0500148 # reset wbuf before write/flush to preserve state on underlying failure
Roger Meierd52edba2014-08-07 17:03:47 +0200149 frame_length = struct.pack('!i', len(frame))
150 self.__wbuf = BytesIO()
151 with self.io_exception_context():
152 return self.stream.write(frame_length + frame)
Chris Piro20c81ad2013-03-07 11:32:48 -0500153
154
henrique3e25e5e2013-11-08 19:06:21 +0100155class TTornadoServer(tcpserver.TCPServer):
Chris Piro20c81ad2013-03-07 11:32:48 -0500156 def __init__(self, processor, iprot_factory, oprot_factory=None,
157 *args, **kwargs):
158 super(TTornadoServer, self).__init__(*args, **kwargs)
159
160 self._processor = processor
161 self._iprot_factory = iprot_factory
162 self._oprot_factory = (oprot_factory if oprot_factory is not None
163 else iprot_factory)
164
Roger Meierd52edba2014-08-07 17:03:47 +0200165 @gen.coroutine
Chris Piro20c81ad2013-03-07 11:32:48 -0500166 def handle_stream(self, stream, address):
Roger Meierd52edba2014-08-07 17:03:47 +0200167 host, port = address
168 trans = TTornadoStreamTransport(host=host, port=port, stream=stream,
169 io_loop=self.io_loop)
170 oprot = self._oprot_factory.getProtocol(trans)
171
Chris Piro20c81ad2013-03-07 11:32:48 -0500172 try:
Roger Meierd52edba2014-08-07 17:03:47 +0200173 while not trans.stream.closed():
174 frame = yield trans.readFrame()
175 tr = TMemoryBuffer(frame)
176 iprot = self._iprot_factory.getProtocol(tr)
177 yield self._processor.process(iprot, oprot)
Chris Piro20c81ad2013-03-07 11:32:48 -0500178 except Exception:
Konrad Grochowski3a724e32014-08-12 11:48:29 -0400179 logger.exception('thrift exception in handle_stream')
Chris Piro20c81ad2013-03-07 11:32:48 -0500180 trans.close()
Roger Meierd52edba2014-08-07 17:03:47 +0200181
Konrad Grochowski3a724e32014-08-12 11:48:29 -0400182 logger.info('client disconnected %s:%d', host, port)