Chris Piro | 20c81ad | 2013-03-07 11:32:48 -0500 | [diff] [blame^] | 1 | # |
| 2 | # Licensed to the Apache Software Foundation (ASF) under one |
| 3 | # or more contributor license agreements. See the NOTICE file |
| 4 | # distributed with this work for additional information |
| 5 | # regarding copyright ownership. The ASF licenses this file |
| 6 | # to you under the Apache License, Version 2.0 (the |
| 7 | # "License"); you may not use this file except in compliance |
| 8 | # with the License. You may obtain a copy of the License at |
| 9 | # |
| 10 | # http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | # |
| 12 | # Unless required by applicable law or agreed to in writing, |
| 13 | # software distributed under the License is distributed on an |
| 14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 15 | # KIND, either express or implied. See the License for the |
| 16 | # specific language governing permissions and limitations |
| 17 | # under the License. |
| 18 | # |
| 19 | |
| 20 | from cStringIO import StringIO |
| 21 | import logging |
| 22 | import socket |
| 23 | import struct |
| 24 | |
| 25 | from thrift.transport import TTransport |
| 26 | from thrift.transport.TTransport import TTransportException |
| 27 | |
| 28 | from tornado import gen |
| 29 | from tornado import iostream |
| 30 | from tornado import netutil |
| 31 | |
| 32 | |
| 33 | class TTornadoStreamTransport(TTransport.TTransportBase): |
| 34 | """a framed, buffered transport over a Tornado stream""" |
| 35 | def __init__(self, host, port, stream=None): |
| 36 | self.host = host |
| 37 | self.port = port |
| 38 | self.is_queuing_reads = False |
| 39 | self.read_queue = [] |
| 40 | self.__wbuf = StringIO() |
| 41 | |
| 42 | # servers provide a ready-to-go stream |
| 43 | self.stream = stream |
| 44 | if self.stream is not None: |
| 45 | self._set_close_callback() |
| 46 | |
| 47 | # not the same number of parameters as TTransportBase.open |
| 48 | def open(self, callback): |
| 49 | logging.debug('socket connecting') |
| 50 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0) |
| 51 | self.stream = iostream.IOStream(sock) |
| 52 | |
| 53 | def on_close_in_connect(*_): |
| 54 | message = 'could not connect to {}:{}'.format(self.host, self.port) |
| 55 | raise TTransportException( |
| 56 | type=TTransportException.NOT_OPEN, |
| 57 | message=message) |
| 58 | self.stream.set_close_callback(on_close_in_connect) |
| 59 | |
| 60 | def finish(*_): |
| 61 | self._set_close_callback() |
| 62 | callback() |
| 63 | |
| 64 | self.stream.connect((self.host, self.port), callback=finish) |
| 65 | |
| 66 | def _set_close_callback(self): |
| 67 | def on_close(): |
| 68 | raise TTransportException( |
| 69 | type=TTransportException.END_OF_FILE, |
| 70 | message='socket closed') |
| 71 | self.stream.set_close_callback(self.close) |
| 72 | |
| 73 | def close(self): |
| 74 | # don't raise if we intend to close |
| 75 | self.stream.set_close_callback(None) |
| 76 | self.stream.close() |
| 77 | |
| 78 | def read(self, _): |
| 79 | # The generated code for Tornado shouldn't do individual reads -- only |
| 80 | # frames at a time |
| 81 | assert "you're doing it wrong" is True |
| 82 | |
| 83 | @gen.engine |
| 84 | def readFrame(self, callback): |
| 85 | self.read_queue.append(callback) |
| 86 | logging.debug('read queue: %s', self.read_queue) |
| 87 | |
| 88 | if self.is_queuing_reads: |
| 89 | # If a read is already in flight, then the while loop below should |
| 90 | # pull it from self.read_queue |
| 91 | return |
| 92 | |
| 93 | self.is_queuing_reads = True |
| 94 | while self.read_queue: |
| 95 | next_callback = self.read_queue.pop() |
| 96 | result = yield gen.Task(self._readFrameFromStream) |
| 97 | next_callback(result) |
| 98 | self.is_queuing_reads = False |
| 99 | |
| 100 | @gen.engine |
| 101 | def _readFrameFromStream(self, callback): |
| 102 | logging.debug('_readFrameFromStream') |
| 103 | frame_header = yield gen.Task(self.stream.read_bytes, 4) |
| 104 | frame_length, = struct.unpack('!i', frame_header) |
| 105 | logging.debug('received frame header, frame length = %i', frame_length) |
| 106 | frame = yield gen.Task(self.stream.read_bytes, frame_length) |
| 107 | logging.debug('received frame payload') |
| 108 | callback(frame) |
| 109 | |
| 110 | def write(self, buf): |
| 111 | self.__wbuf.write(buf) |
| 112 | |
| 113 | def flush(self, callback=None): |
| 114 | wout = self.__wbuf.getvalue() |
| 115 | wsz = len(wout) |
| 116 | # reset wbuf before write/flush to preserve state on underlying failure |
| 117 | self.__wbuf = StringIO() |
| 118 | # N.B.: Doing this string concatenation is WAY cheaper than making |
| 119 | # two separate calls to the underlying socket object. Socket writes in |
| 120 | # Python turn out to be REALLY expensive, but it seems to do a pretty |
| 121 | # good job of managing string buffer operations without excessive copies |
| 122 | buf = struct.pack("!i", wsz) + wout |
| 123 | |
| 124 | logging.debug('writing frame length = %i', wsz) |
| 125 | self.stream.write(buf, callback) |
| 126 | |
| 127 | |
| 128 | class TTornadoServer(netutil.TCPServer): |
| 129 | def __init__(self, processor, iprot_factory, oprot_factory=None, |
| 130 | *args, **kwargs): |
| 131 | super(TTornadoServer, self).__init__(*args, **kwargs) |
| 132 | |
| 133 | self._processor = processor |
| 134 | self._iprot_factory = iprot_factory |
| 135 | self._oprot_factory = (oprot_factory if oprot_factory is not None |
| 136 | else iprot_factory) |
| 137 | |
| 138 | def handle_stream(self, stream, address): |
| 139 | try: |
| 140 | host, port = address |
| 141 | trans = TTornadoStreamTransport(host=host, port=port, stream=stream) |
| 142 | oprot = self._oprot_factory.getProtocol(trans) |
| 143 | |
| 144 | def next_pass(): |
| 145 | if not trans.stream.closed(): |
| 146 | self._processor.process(trans, self._iprot_factory, oprot, |
| 147 | callback=next_pass) |
| 148 | |
| 149 | next_pass() |
| 150 | |
| 151 | except Exception: |
| 152 | logging.exception('thrift exception in handle_stream') |
| 153 | trans.close() |