| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # |
| |
| import errno |
| import logging |
| import os |
| import socket |
| import sys |
| |
| from .TTransport import TTransportBase, TTransportException, TServerTransportBase |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| class TSocketBase(TTransportBase): |
| def _resolveAddr(self): |
| if self._unix_socket is not None: |
| return [(socket.AF_UNIX, socket.SOCK_STREAM, None, None, |
| self._unix_socket)] |
| else: |
| return socket.getaddrinfo(self.host, |
| self.port, |
| self._socket_family, |
| socket.SOCK_STREAM, |
| 0, |
| socket.AI_PASSIVE) |
| |
| def close(self): |
| if self.handle: |
| self.handle.close() |
| self.handle = None |
| |
| |
| class TSocket(TSocketBase): |
| """Socket implementation of TTransport base.""" |
| |
| def __init__(self, host='localhost', port=9090, unix_socket=None, |
| socket_family=socket.AF_UNSPEC, |
| socket_keepalive=False): |
| """Initialize a TSocket |
| |
| @param host(str) The host to connect to. |
| @param port(int) The (TCP) port to connect to. |
| @param unix_socket(str) The filename of a unix socket to connect to. |
| (host and port will be ignored.) |
| @param socket_family(int) The socket family to use with this socket. |
| @param socket_keepalive(bool) enable TCP keepalive, default off. |
| """ |
| self.host = host |
| self.port = port |
| self.handle = None |
| self._unix_socket = unix_socket |
| self._timeout = None |
| self._socket_family = socket_family |
| self._socket_keepalive = socket_keepalive |
| |
| def setHandle(self, h): |
| self.handle = h |
| |
| def isOpen(self): |
| if self.handle is None: |
| return False |
| |
| # this lets us cheaply see if the other end of the socket is still |
| # connected. if disconnected, we'll get EOF back (expressed as zero |
| # bytes of data) otherwise we'll get one byte or an error indicating |
| # we'd have to block for data. |
| # |
| # note that we're not doing this with socket.MSG_DONTWAIT because 1) |
| # it's linux-specific and 2) gevent-patched sockets hide EAGAIN from us |
| # when timeout is non-zero. |
| original_timeout = self.handle.gettimeout() |
| try: |
| self.handle.settimeout(0) |
| try: |
| peeked_bytes = self.handle.recv(1, socket.MSG_PEEK) |
| except (socket.error, OSError) as exc: # on modern python this is just BlockingIOError |
| if exc.errno in (errno.EWOULDBLOCK, errno.EAGAIN): |
| return True |
| return False |
| except ValueError: |
| # SSLSocket fails on recv with non-zero flags; fallback to the old behavior |
| return True |
| finally: |
| self.handle.settimeout(original_timeout) |
| |
| # the length will be zero if we got EOF (indicating connection closed) |
| return len(peeked_bytes) == 1 |
| |
| def setTimeout(self, ms): |
| if ms is None: |
| self._timeout = None |
| else: |
| self._timeout = ms / 1000.0 |
| |
| if self.handle is not None: |
| self.handle.settimeout(self._timeout) |
| |
| def _do_open(self, family, socktype): |
| return socket.socket(family, socktype) |
| |
| @property |
| def _address(self): |
| return self._unix_socket if self._unix_socket else '%s:%d' % (self.host, self.port) |
| |
| def open(self): |
| if self.handle: |
| raise TTransportException(type=TTransportException.ALREADY_OPEN, message="already open") |
| try: |
| addrs = self._resolveAddr() |
| except socket.gaierror as gai: |
| msg = 'failed to resolve sockaddr for ' + str(self._address) |
| logger.exception(msg) |
| raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=gai) |
| for family, socktype, _, _, sockaddr in addrs: |
| handle = self._do_open(family, socktype) |
| |
| # TCP keep-alive |
| if self._socket_keepalive: |
| handle.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) |
| |
| handle.settimeout(self._timeout) |
| try: |
| handle.connect(sockaddr) |
| self.handle = handle |
| return |
| except socket.error: |
| handle.close() |
| logger.info('Could not connect to %s', sockaddr, exc_info=True) |
| msg = 'Could not connect to any of %s' % list(map(lambda a: a[4], |
| addrs)) |
| logger.error(msg) |
| raise TTransportException(type=TTransportException.NOT_OPEN, message=msg) |
| |
| def read(self, sz): |
| try: |
| buff = self.handle.recv(sz) |
| except socket.timeout as e: |
| raise TTransportException(type=TTransportException.TIMED_OUT, message="read timeout", inner=e) |
| except socket.error as e: |
| if (e.args[0] == errno.ECONNRESET and |
| (sys.platform == 'darwin' or sys.platform.startswith('freebsd'))): |
| # freebsd and Mach don't follow POSIX semantic of recv |
| # and fail with ECONNRESET if peer performed shutdown. |
| # See corresponding comment and code in TSocket::read() |
| # in lib/cpp/src/transport/TSocket.cpp. |
| self.close() |
| # Trigger the check to raise the END_OF_FILE exception below. |
| buff = '' |
| else: |
| raise TTransportException(message="unexpected exception", inner=e) |
| if len(buff) == 0: |
| raise TTransportException(type=TTransportException.END_OF_FILE, |
| message='TSocket read 0 bytes') |
| return buff |
| |
| def write(self, buff): |
| if not self.handle: |
| raise TTransportException(type=TTransportException.NOT_OPEN, |
| message='Transport not open') |
| sent = 0 |
| have = len(buff) |
| while sent < have: |
| try: |
| plus = self.handle.send(buff) |
| if plus == 0: |
| raise TTransportException(type=TTransportException.END_OF_FILE, |
| message='TSocket sent 0 bytes') |
| sent += plus |
| buff = buff[plus:] |
| except socket.error as e: |
| raise TTransportException(message="unexpected exception", inner=e) |
| |
| def flush(self): |
| pass |
| |
| |
| class TServerSocket(TSocketBase, TServerTransportBase): |
| """Socket implementation of TServerTransport base.""" |
| |
| def __init__(self, host=None, port=9090, unix_socket=None, socket_family=socket.AF_UNSPEC): |
| self.host = host |
| self.port = port |
| self._unix_socket = unix_socket |
| self._socket_family = socket_family |
| self.handle = None |
| self._backlog = 128 |
| |
| def setBacklog(self, backlog=None): |
| if not self.handle: |
| self._backlog = backlog |
| else: |
| # We cann't update backlog when it is already listening, since the |
| # handle has been created. |
| logger.warning('You have to set backlog before listen.') |
| |
| def listen(self): |
| res0 = self._resolveAddr() |
| socket_family = self._socket_family == socket.AF_UNSPEC and socket.AF_INET6 or self._socket_family |
| for res in res0: |
| if res[0] is socket_family or res is res0[-1]: |
| break |
| |
| # We need remove the old unix socket if the file exists and |
| # nobody is listening on it. |
| if self._unix_socket: |
| tmp = socket.socket(res[0], res[1]) |
| try: |
| tmp.connect(res[4]) |
| except socket.error as err: |
| eno, message = err.args |
| if eno == errno.ECONNREFUSED: |
| os.unlink(res[4]) |
| |
| self.handle = s = socket.socket(res[0], res[1]) |
| if s.family is socket.AF_INET6: |
| s.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
| if hasattr(s, 'settimeout'): |
| s.settimeout(None) |
| s.bind(res[4]) |
| s.listen(self._backlog) |
| |
| def accept(self): |
| client, addr = self.handle.accept() |
| result = TSocket() |
| result.setHandle(client) |
| return result |