blob: 32ec58a500b47b36e50760280c7c4c998c3eec39 [file] [log] [blame]
import os
import zlib
import time
import json
import socket
import logging
import tempfile
import subprocess
from typing import Union, cast, Optional, Tuple, Dict
from agent import agent
import paramiko
from .node_interfaces import IRPCNode, NodeInfo, ISSHHost
from .ssh import connect as ssh_connect
logger = logging.getLogger("wally")
class SSHHost(ISSHHost):
def __init__(self, conn: paramiko.SSHClient, info: NodeInfo) -> None:
self.conn = conn
self.info = info
def __str__(self) -> str:
return self.node_id
@property
def node_id(self) -> str:
return self.info.node_id
def put_to_file(self, path: Optional[str], content: bytes) -> str:
if path is None:
path = self.run("mktemp", nolog=True).strip()
logger.debug("PUT %s bytes to %s", len(content), path)
with self.conn.open_sftp() as sftp:
with sftp.open(path, "wb") as fd:
fd.write(content)
return path
def disconnect(self):
self.conn.close()
def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
if not nolog:
logger.debug("SSH:{0} Exec {1!r}".format(self, cmd))
transport = self.conn.get_transport()
session = transport.open_session()
try:
session.set_combine_stderr(True)
stime = time.time()
session.exec_command(cmd)
session.settimeout(1)
session.shutdown_write()
output = ""
while True:
try:
ndata = session.recv(1024).decode("utf-8")
if not ndata:
break
output += ndata
except socket.timeout:
pass
if time.time() - stime > timeout:
raise OSError(output + "\nExecution timeout")
code = session.recv_exit_status()
finally:
found = False
if found:
session.close()
if code != 0:
templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
raise OSError(templ.format(self, cmd, code, output))
return output
class LocalHost(ISSHHost):
def __str__(self):
return "<Local>"
def get_ip(self) -> str:
return 'localhost'
def put_to_file(self, path: Optional[str], content: bytes) -> str:
if path is None:
fd, path = tempfile.mkstemp(text=False)
os.close(fd)
else:
dir_name = os.path.dirname(path)
os.makedirs(dir_name, exist_ok=True)
with open(path, "wb") as fd2:
fd2.write(content)
return path
def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
proc = subprocess.Popen(cmd, shell=True,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
stdout_data_b, _ = proc.communicate()
stdout_data = stdout_data_b.decode("utf8")
if proc.returncode != 0:
templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
raise OSError(templ.format(self, cmd, proc.returncode, stdout_data))
return stdout_data
def disconnect(self):
pass
def get_rpc_server_code() -> Tuple[bytes, Dict[str, bytes]]:
# setup rpc data
if agent.__file__.endswith(".pyc"):
path = agent.__file__[:-1]
else:
path = agent.__file__
master_code = open(path, "rb").read()
plugins = {} # type: Dict[str, bytes]
cli_path = os.path.join(os.path.dirname(path), "cli_plugin.py")
plugins["cli"] = open(cli_path, "rb").read()
fs_path = os.path.join(os.path.dirname(path), "fs_plugin.py")
plugins["fs"] = open(fs_path, "rb").read()
return master_code, plugins
def connect(info: Union[str, NodeInfo], conn_timeout: int = 60) -> ISSHHost:
if info == 'local':
return LocalHost()
else:
info_c = cast(NodeInfo, info)
return SSHHost(ssh_connect(info_c.ssh_creds, conn_timeout), info_c)
class RPCNode(IRPCNode):
"""Node object"""
def __init__(self, conn: agent.SimpleRPCClient, info: NodeInfo) -> None:
self.info = info
self.conn = conn
def __str__(self) -> str:
return "Node({!r})".format(self.info)
def __repr__(self) -> str:
return str(self)
@property
def node_id(self) -> str:
return self.info.node_id
def get_file_content(self, path: str, expanduser: bool = False, compress: bool = True) -> bytes:
logger.debug("GET %s from %s", path, self.info)
if expanduser:
path = self.conn.fs.expanduser(path)
res = self.conn.fs.get_file(path, compress)
logger.debug("Download %s bytes from remote file %s from %s", len(res), path, self.info)
if compress:
res = zlib.decompress(res)
return res
def run(self, cmd: str, timeout: int = 60, nolog: bool = False, check_timeout: float = 0.01) -> str:
if not nolog:
logger.debug("Node %s - run %s", self.node_id, cmd)
cmd_b = cmd.encode("utf8")
proc_id = self.conn.cli.spawn(cmd_b, timeout=timeout, merge_out=True)
out = ""
while True:
code, outb, _ = self.conn.cli.get_updates(proc_id)
out += outb.decode("utf8")
if code is not None:
break
time.sleep(check_timeout)
if code != 0:
templ = "Node {} - cmd {!r} failed with code {}. Output: {!r}."
raise OSError(templ.format(self.node_id, cmd, code, out))
return out
def copy_file(self, local_path: str, remote_path: str = None,
expanduser: bool = False,
compress: bool = False) -> str:
if expanduser:
remote_path = self.conn.fs.expanduser(remote_path)
data = open(local_path, 'rb').read() # type: bytes
return self.put_to_file(remote_path, data, compress=compress)
def put_to_file(self, path: Optional[str], content: bytes, expanduser: bool = False, compress: bool = False) -> str:
if expanduser:
path = self.conn.fs.expanduser(path)
if compress:
content = zlib.compress(content)
return self.conn.fs.store_file(path, content, compress)
def stat_file(self, path: str, expanduser: bool = False) -> Dict[str, int]:
if expanduser:
path = self.conn.fs.expanduser(path)
return self.conn.fs.file_stat(path)
def __exit__(self, x, y, z) -> bool:
self.disconnect(stop=True)
return False
def upload_plugin(self, name: str, code: bytes, version: str = None) -> None:
self.conn.server.load_module(name, version, code)
def disconnect(self, stop: bool = False) -> None:
if stop:
logger.debug("Stopping RPC server on %s", self.info)
self.conn.server.stop()
logger.debug("Disconnecting from %s", self.info)
self.conn.disconnect()
self.conn = None
def get_node_python_27(node: ISSHHost) -> Optional[str]:
python_cmd = None # type: Optional[str]
try:
python_cmd = node.run('which python2.7').strip()
except Exception as exc:
pass
if python_cmd is None:
try:
if '2.7' in node.run('python --version'):
python_cmd = node.run('which python').strip()
except Exception as exc:
pass
return python_cmd
def setup_rpc(node: ISSHHost,
rpc_server_code: bytes,
plugins: Dict[str, bytes] = None,
port: int = 0,
log_level: str = None) -> IRPCNode:
logger.debug("Setting up RPC connection to {}".format(node.info))
python_cmd = get_node_python_27(node)
if python_cmd:
logger.debug("python2.7 on node {} path is {}".format(node.info, python_cmd))
else:
logger.error(("Can't find python2.7 on node {}. " +
"Install python2.7 and rerun test").format(node.info))
raise ValueError("Python not found")
code_file = node.put_to_file(None, rpc_server_code)
ip = node.info.ssh_creds.addr.host
log_file = None # type: Optional[str]
if log_level:
log_file = node.run("mktemp", nolog=True).strip()
cmd = "{} {} --log-level={} server --listen-addr={}:{} --daemon --show-settings"
cmd = cmd.format(python_cmd, code_file, log_level, ip, port) + " --stdout-file={}".format(log_file)
logger.info("Agent logs for node {} stored remotely in file {}, log level is {}".format(
node.node_id, log_file, log_level))
else:
cmd = "{} {} --log-level=CRITICAL server --listen-addr={}:{} --daemon --show-settings"
cmd = cmd.format(python_cmd, code_file, ip, port)
params_js = node.run(cmd).strip()
params = json.loads(params_js)
node.info.params.update(params)
port = int(params['addr'].split(":")[1])
rpc_conn = agent.connect((ip, port))
rpc_node = RPCNode(rpc_conn, node.info)
rpc_node.rpc_log_file = log_file
if plugins is not None:
try:
for name, code in plugins.items():
rpc_node.upload_plugin(name, code)
except Exception:
rpc_node.disconnect(True)
raise
return rpc_node
# class RemoteNode(node_interfaces.IRPCNode):
# def __init__(self, node_info: node_interfaces.NodeInfo, rpc_conn: agent.RPCClient):
# self.info = node_info
# self.rpc = rpc_conn
#
# def get_interface(self, ip: str) -> str:
# """Get node external interface for given IP"""
# data = self.run("ip a", nolog=True)
# curr_iface = None
#
# for line in data.split("\n"):
# match1 = re.match(r"\d+:\s+(?P<name>.*?):\s\<", line)
# if match1 is not None:
# curr_iface = match1.group('name')
#
# match2 = re.match(r"\s+inet\s+(?P<ip>[0-9.]+)/", line)
# if match2 is not None:
# if match2.group('ip') == ip:
# assert curr_iface is not None
# return curr_iface
#
# raise KeyError("Can't found interface for ip {0}".format(ip))
#
# def get_user(self) -> str:
# """"get ssh connection username"""
# if self.ssh_conn_url == 'local':
# return getpass.getuser()
# return self.ssh_cred.user
#
#
# def run(self, cmd: str, stdin_data: str = None, timeout: int = 60, nolog: bool = False) -> Tuple[int, str]:
# """Run command on node. Will use rpc connection, if available"""
#
# if self.rpc_conn is None:
# return run_over_ssh(self.ssh_conn, cmd,
# stdin_data=stdin_data, timeout=timeout,
# nolog=nolog, node=self)
# assert not stdin_data
# proc_id = self.rpc_conn.cli.spawn(cmd)
# exit_code = None
# output = ""
#
# while exit_code is None:
# exit_code, stdout_data, stderr_data = self.rpc_conn.cli.get_updates(proc_id)
# output += stdout_data + stderr_data
#
# return exit_code, output