Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 1 | # |
| 2 | # Licensed to the Apache Software Foundation (ASF) under one |
| 3 | # or more contributor license agreements. See the NOTICE file |
| 4 | # distributed with this work for additional information |
| 5 | # regarding copyright ownership. The ASF licenses this file |
| 6 | # to you under the Apache License, Version 2.0 (the |
| 7 | # "License"); you may not use this file except in compliance |
| 8 | # with the License. You may obtain a copy of the License at |
| 9 | # |
| 10 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | # |
| 12 | # Unless required by applicable law or agreed to in writing, |
| 13 | # software distributed under the License is distributed on an |
| 14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | # KIND, either express or implied. See the License for the |
| 16 | # specific language governing permissions and limitations |
| 17 | # under the License. |
| 18 | # |
| 19 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 20 | from __future__ import absolute_import |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 21 | import socket |
| 22 | import struct |
Konrad Grochowski | 3a724e3 | 2014-08-12 11:48:29 -0400 | [diff] [blame] | 23 | import logging |
| 24 | logger = logging.getLogger(__name__) |
| 25 | |
Nobuaki Sukegawa | 760511f | 2015-11-06 21:24:16 +0900 | [diff] [blame] | 26 | from .transport.TTransport import TTransportException, TTransportBase, TMemoryBuffer |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 27 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 28 | from io import BytesIO |
| 29 | from collections import deque |
| 30 | from contextlib import contextmanager |
| 31 | from tornado import gen, iostream, ioloop, tcpserver, concurrent |
| 32 | |
| 33 | __all__ = ['TTornadoServer', 'TTornadoStreamTransport'] |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 34 | |
| 35 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 36 | class _Lock(object): |
| 37 | def __init__(self): |
| 38 | self._waiters = deque() |
| 39 | |
| 40 | def acquired(self): |
| 41 | return len(self._waiters) > 0 |
| 42 | |
| 43 | @gen.coroutine |
| 44 | def acquire(self): |
| 45 | blocker = self._waiters[-1] if self.acquired() else None |
| 46 | future = concurrent.Future() |
| 47 | self._waiters.append(future) |
| 48 | if blocker: |
| 49 | yield blocker |
| 50 | |
| 51 | raise gen.Return(self._lock_context()) |
| 52 | |
| 53 | def release(self): |
| 54 | assert self.acquired(), 'Lock not aquired' |
| 55 | future = self._waiters.popleft() |
| 56 | future.set_result(None) |
| 57 | |
| 58 | @contextmanager |
| 59 | def _lock_context(self): |
| 60 | try: |
| 61 | yield |
| 62 | finally: |
| 63 | self.release() |
| 64 | |
| 65 | |
| 66 | class TTornadoStreamTransport(TTransportBase): |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 67 | """a framed, buffered transport over a Tornado stream""" |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 68 | def __init__(self, host, port, stream=None, io_loop=None): |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 69 | self.host = host |
| 70 | self.port = port |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 71 | self.io_loop = io_loop or ioloop.IOLoop.current() |
| 72 | self.__wbuf = BytesIO() |
| 73 | self._read_lock = _Lock() |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 74 | |
| 75 | # servers provide a ready-to-go stream |
| 76 | self.stream = stream |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 77 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 78 | def with_timeout(self, timeout, future): |
| 79 | return gen.with_timeout(timeout, future, self.io_loop) |
| 80 | |
| 81 | @gen.coroutine |
| 82 | def open(self, timeout=None): |
Konrad Grochowski | 3a724e3 | 2014-08-12 11:48:29 -0400 | [diff] [blame] | 83 | logger.debug('socket connecting') |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 84 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) |
| 85 | self.stream = iostream.IOStream(sock) |
| 86 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 87 | try: |
| 88 | connect = self.stream.connect((self.host, self.port)) |
| 89 | if timeout is not None: |
| 90 | yield self.with_timeout(timeout, connect) |
| 91 | else: |
| 92 | yield connect |
| 93 | except (socket.error, IOError, ioloop.TimeoutError) as e: |
| 94 | message = 'could not connect to {}:{} ({})'.format(self.host, self.port, e) |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 95 | raise TTransportException( |
| 96 | type=TTransportException.NOT_OPEN, |
| 97 | message=message) |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 98 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 99 | raise gen.Return(self) |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 100 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 101 | def set_close_callback(self, callback): |
| 102 | """ |
| 103 | Should be called only after open() returns |
| 104 | """ |
| 105 | self.stream.set_close_callback(callback) |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 106 | |
| 107 | def close(self): |
| 108 | # don't raise if we intend to close |
| 109 | self.stream.set_close_callback(None) |
| 110 | self.stream.close() |
| 111 | |
| 112 | def read(self, _): |
| 113 | # The generated code for Tornado shouldn't do individual reads -- only |
| 114 | # frames at a time |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 115 | assert False, "you're doing it wrong" |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 116 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 117 | @contextmanager |
| 118 | def io_exception_context(self): |
| 119 | try: |
| 120 | yield |
| 121 | except (socket.error, IOError) as e: |
| 122 | raise TTransportException( |
| 123 | type=TTransportException.END_OF_FILE, |
| 124 | message=str(e)) |
| 125 | except iostream.StreamBufferFullError as e: |
| 126 | raise TTransportException( |
| 127 | type=TTransportException.UNKNOWN, |
| 128 | message=str(e)) |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 129 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 130 | @gen.coroutine |
| 131 | def readFrame(self): |
| 132 | # IOStream processes reads one at a time |
| 133 | with (yield self._read_lock.acquire()): |
| 134 | with self.io_exception_context(): |
| 135 | frame_header = yield self.stream.read_bytes(4) |
| 136 | if len(frame_header) == 0: |
| 137 | raise iostream.StreamClosedError('Read zero bytes from stream') |
| 138 | frame_length, = struct.unpack('!i', frame_header) |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 139 | frame = yield self.stream.read_bytes(frame_length) |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 140 | raise gen.Return(frame) |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 141 | |
| 142 | def write(self, buf): |
| 143 | self.__wbuf.write(buf) |
| 144 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 145 | def flush(self): |
| 146 | frame = self.__wbuf.getvalue() |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 147 | # reset wbuf before write/flush to preserve state on underlying failure |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 148 | frame_length = struct.pack('!i', len(frame)) |
| 149 | self.__wbuf = BytesIO() |
| 150 | with self.io_exception_context(): |
| 151 | return self.stream.write(frame_length + frame) |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 152 | |
| 153 | |
henrique | 3e25e5e | 2013-11-08 19:06:21 +0100 | [diff] [blame] | 154 | class TTornadoServer(tcpserver.TCPServer): |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 155 | def __init__(self, processor, iprot_factory, oprot_factory=None, |
| 156 | *args, **kwargs): |
| 157 | super(TTornadoServer, self).__init__(*args, **kwargs) |
| 158 | |
| 159 | self._processor = processor |
| 160 | self._iprot_factory = iprot_factory |
| 161 | self._oprot_factory = (oprot_factory if oprot_factory is not None |
| 162 | else iprot_factory) |
| 163 | |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 164 | @gen.coroutine |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 165 | def handle_stream(self, stream, address): |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 166 | host, port = address |
| 167 | trans = TTornadoStreamTransport(host=host, port=port, stream=stream, |
| 168 | io_loop=self.io_loop) |
| 169 | oprot = self._oprot_factory.getProtocol(trans) |
| 170 | |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 171 | try: |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 172 | while not trans.stream.closed(): |
Jens Geyer | 145749c | 2015-10-16 19:21:22 +0200 | [diff] [blame] | 173 | try: |
| 174 | frame = yield trans.readFrame() |
| 175 | except TTransportException as e: |
| 176 | if e.type == TTransportException.END_OF_FILE: |
| 177 | break |
| 178 | else: |
| 179 | raise |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 180 | tr = TMemoryBuffer(frame) |
| 181 | iprot = self._iprot_factory.getProtocol(tr) |
| 182 | yield self._processor.process(iprot, oprot) |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 183 | except Exception: |
Konrad Grochowski | 3a724e3 | 2014-08-12 11:48:29 -0400 | [diff] [blame] | 184 | logger.exception('thrift exception in handle_stream') |
Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame] | 185 | trans.close() |
Roger Meier | d52edba | 2014-08-07 17:03:47 +0200 | [diff] [blame] | 186 | |
Konrad Grochowski | 3a724e3 | 2014-08-12 11:48:29 -0400 | [diff] [blame] | 187 | logger.info('client disconnected %s:%d', host, port) |