blob: af309c3d971ca46f8a6f22337f85c1a5b6b6ee86 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from cStringIO import StringIO
import logging
import socket
import struct
from thrift.transport import TTransport
from thrift.transport.TTransport import TTransportException
from tornado import gen
from tornado import iostream
from tornado import netutil
class TTornadoStreamTransport(TTransport.TTransportBase):
"""a framed, buffered transport over a Tornado stream"""
def __init__(self, host, port, stream=None):
self.host = host
self.port = port
self.is_queuing_reads = False
self.read_queue = []
self.__wbuf = StringIO()
# servers provide a ready-to-go stream
self.stream = stream
if self.stream is not None:
self._set_close_callback()
# not the same number of parameters as TTransportBase.open
def open(self, callback):
logging.debug('socket connecting')
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, 0)
self.stream = iostream.IOStream(sock)
def on_close_in_connect(*_):
message = 'could not connect to {}:{}'.format(self.host, self.port)
raise TTransportException(
type=TTransportException.NOT_OPEN,
message=message)
self.stream.set_close_callback(on_close_in_connect)
def finish(*_):
self._set_close_callback()
callback()
self.stream.connect((self.host, self.port), callback=finish)
def _set_close_callback(self):
def on_close():
raise TTransportException(
type=TTransportException.END_OF_FILE,
message='socket closed')
self.stream.set_close_callback(self.close)
def close(self):
# don't raise if we intend to close
self.stream.set_close_callback(None)
self.stream.close()
def read(self, _):
# The generated code for Tornado shouldn't do individual reads -- only
# frames at a time
assert "you're doing it wrong" is True
@gen.engine
def readFrame(self, callback):
self.read_queue.append(callback)
logging.debug('read queue: %s', self.read_queue)
if self.is_queuing_reads:
# If a read is already in flight, then the while loop below should
# pull it from self.read_queue
return
self.is_queuing_reads = True
while self.read_queue:
next_callback = self.read_queue.pop()
result = yield gen.Task(self._readFrameFromStream)
next_callback(result)
self.is_queuing_reads = False
@gen.engine
def _readFrameFromStream(self, callback):
logging.debug('_readFrameFromStream')
frame_header = yield gen.Task(self.stream.read_bytes, 4)
frame_length, = struct.unpack('!i', frame_header)
logging.debug('received frame header, frame length = %i', frame_length)
frame = yield gen.Task(self.stream.read_bytes, frame_length)
logging.debug('received frame payload')
callback(frame)
def write(self, buf):
self.__wbuf.write(buf)
def flush(self, callback=None):
wout = self.__wbuf.getvalue()
wsz = len(wout)
# reset wbuf before write/flush to preserve state on underlying failure
self.__wbuf = StringIO()
# N.B.: Doing this string concatenation is WAY cheaper than making
# two separate calls to the underlying socket object. Socket writes in
# Python turn out to be REALLY expensive, but it seems to do a pretty
# good job of managing string buffer operations without excessive copies
buf = struct.pack("!i", wsz) + wout
logging.debug('writing frame length = %i', wsz)
self.stream.write(buf, callback)
class TTornadoServer(netutil.TCPServer):
def __init__(self, processor, iprot_factory, oprot_factory=None,
*args, **kwargs):
super(TTornadoServer, self).__init__(*args, **kwargs)
self._processor = processor
self._iprot_factory = iprot_factory
self._oprot_factory = (oprot_factory if oprot_factory is not None
else iprot_factory)
def handle_stream(self, stream, address):
try:
host, port = address
trans = TTornadoStreamTransport(host=host, port=port, stream=stream)
oprot = self._oprot_factory.getProtocol(trans)
def next_pass():
if not trans.stream.closed():
self._processor.process(trans, self._iprot_factory, oprot,
callback=next_pass)
next_pass()
except Exception:
logging.exception('thrift exception in handle_stream')
trans.close()