refactoring is on the way
diff --git a/wally/node.py b/wally/node.py
index 3bc52fc..2b58571 100644
--- a/wally/node.py
+++ b/wally/node.py
@@ -4,34 +4,36 @@
import socket
import logging
import subprocess
-from typing import Callable
+from typing import Union, cast, Any
+
import agent
+import paramiko
-from .node_interfaces import IRPCNode, NodeInfo, ISSHHost, RPCBeforeConnCallback
-from .ssh_utils import parse_ssh_uri, ssh_connect
+
+from .node_interfaces import IRPCNode, NodeInfo, ISSHHost
+from .ssh import connect as ssh_connect
logger = logging.getLogger("wally")
class SSHHost(ISSHHost):
- def __init__(self, ssh_conn, node_name: str, ip: str) -> None:
- self.conn = ssh_conn
- self.node_name = node_name
- self.ip = ip
-
- def get_ip(self) -> str:
- return self.ip
+ def __init__(self, conn: paramiko.SSHClient, info: NodeInfo) -> None:
+ self.conn = conn
+ self.info = info
def __str__(self) -> str:
- return self.node_name
+ return self.info.node_id()
def put_to_file(self, path: str, content: bytes) -> None:
with self.conn.open_sftp() as sftp:
with sftp.open(path, "wb") as fd:
fd.write(content)
+ def disconnect(self):
+ self.conn.close()
+
def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
transport = self.conn.get_transport()
session = transport.open_session()
@@ -102,12 +104,16 @@
return stdout_data
+ def disconnect(self):
+ pass
-def connect(conn_url: str, conn_timeout: int = 60) -> ISSHHost:
- if conn_url == 'local':
+
+def connect(info: Union[str, NodeInfo], conn_timeout: int = 60) -> ISSHHost:
+ if info == 'local':
return LocalHost()
else:
- return SSHHost(*ssh_connect(parse_ssh_uri(conn_url), conn_timeout))
+ info_c = cast(NodeInfo, info)
+ return SSHHost(ssh_connect(info_c.ssh_creds, conn_timeout), info_c)
class RPCNode(IRPCNode):
@@ -117,46 +123,50 @@
self.info = info
self.conn = conn
- # if self.ssh_conn_url is not None:
- # self.ssh_cred = parse_ssh_uri(self.ssh_conn_url)
- # self.node_id = "{0.host}:{0.port}".format(self.ssh_cred)
- # else:
- # self.ssh_cred = None
- # self.node_id = None
-
def __str__(self) -> str:
- return "<Node: url={!r} roles={!r} hops=/>".format(self.info.ssh_conn_url, ",".join(self.info.roles))
+ return "<Node: url={!s} roles={!r} hops=/>".format(self.info.ssh_creds, ",".join(self.info.roles))
def __repr__(self) -> str:
return str(self)
- def get_file_content(self, path: str) -> str:
+ def get_file_content(self, path: str) -> bytes:
raise NotImplementedError()
- def forward_port(self, ip: str, remote_port: int, local_port: int = None) -> int:
- raise NotImplementedError()
+ def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
+ raise NotImplemented()
+
+ def copy_file(self, local_path: str, remote_path: str = None) -> str:
+ raise NotImplemented()
+
+ def put_to_file(self, path: str, content: bytes) -> None:
+ raise NotImplemented()
+
+ def get_interface(self, ip: str) -> str:
+ raise NotImplemented()
+
+ def stat_file(self, path: str) -> Any:
+ raise NotImplemented()
+
+ def disconnect(self) -> str:
+ raise NotImplemented()
-def setup_rpc(node: ISSHHost, rpc_server_code: bytes, port: int = 0,
- rpc_conn_callback: RPCBeforeConnCallback = None) -> IRPCNode:
+def setup_rpc(node: ISSHHost, rpc_server_code: bytes, port: int = 0) -> IRPCNode:
code_file = node.run("mktemp").strip()
log_file = node.run("mktemp").strip()
node.put_to_file(code_file, rpc_server_code)
cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
"--show-settings --stdout-file={out_file}"
+
+ ip = node.info.ssh_creds.addr.host
+
params_js = node.run(cmd.format(code_file=code_file,
- listen_addr=node.get_ip(),
+ listen_addr=ip,
out_file=log_file,
port=port)).strip()
params = json.loads(params_js)
params['log_file'] = log_file
-
- if rpc_conn_callback:
- ip, port = rpc_conn_callback(node, port)
- else:
- ip = node.get_ip()
- port = int(params['addr'].split(":")[1])
-
+ port = int(params['addr'].split(":")[1])
rpc_conn = agent.connect((ip, port))
node.info.params.update(params)
return RPCNode(rpc_conn, node.info)