fixeg code
diff --git a/wally/node.py b/wally/node.py
index fae7879..54a6291 100644
--- a/wally/node.py
+++ b/wally/node.py
@@ -4,7 +4,7 @@
import socket
import logging
import subprocess
-from typing import Union, cast, Any
+from typing import Union, cast, Any, Optional, Tuple, Dict, List
import agent
@@ -26,11 +26,16 @@
def __str__(self) -> str:
return self.info.node_id()
- def put_to_file(self, path: str, content: bytes) -> None:
+ def put_to_file(self, path: Optional[str], content: bytes) -> str:
+ if path is None:
+ path = self.run("mktemp").strip()
+
with self.conn.open_sftp() as sftp:
with sftp.open(path, "wb") as fd:
fd.write(content)
+ return path
+
def disconnect(self):
self.conn.close()
@@ -53,10 +58,10 @@
while True:
try:
- ndata = session.recv(1024)
- output += ndata
- if "" == ndata:
+ ndata = session.recv(1024).decode("utf-8")
+ if not ndata:
break
+ output += ndata
except socket.timeout:
pass
@@ -97,7 +102,9 @@
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
- stdout_data, _ = proc.communicate()
+ stdout_data_b, _ = proc.communicate()
+ stdout_data = stdout_data_b.decode("utf8")
+
if proc.returncode != 0:
templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
raise OSError(templ.format(self, cmd, proc.returncode, stdout_data))
@@ -108,6 +115,25 @@
pass
+def get_rpc_server_code() -> Tuple[bytes, Dict[str, bytes]]:
+ # setup rpc data
+ if agent.__file__.endswith(".pyc"):
+ path = agent.__file__[:-1]
+ else:
+ path = agent.__file__
+
+ master_code = open(path, "rb").read()
+
+ plugins = {} # type: Dict[str, bytes]
+ cli_path = os.path.join(os.path.dirname(path), "cli_plugin.py")
+ plugins["cli"] = open(cli_path, "rb").read()
+
+ fs_path = os.path.join(os.path.dirname(path), "fs_plugin.py")
+ plugins["fs"] = open(fs_path, "rb").read()
+
+ return master_code, plugins
+
+
def connect(info: Union[str, NodeInfo], conn_timeout: int = 60) -> ISSHHost:
if info == 'local':
return LocalHost()
@@ -119,12 +145,12 @@
class RPCNode(IRPCNode):
"""Node object"""
- def __init__(self, conn: agent.Client, info: NodeInfo) -> None:
+ def __init__(self, conn: agent.SimpleRPCClient, info: NodeInfo) -> None:
self.info = info
self.conn = conn
def __str__(self) -> str:
- return "<Node: url={!s} roles={!r} hops=/>".format(self.info.ssh_creds, ",".join(self.info.roles))
+ return "Node(url={!r}, roles={!r})".format(self.info.ssh_creds, ",".join(self.info.roles))
def __repr__(self) -> str:
return str(self)
@@ -138,7 +164,7 @@
def copy_file(self, local_path: str, remote_path: str = None) -> str:
raise NotImplementedError()
- def put_to_file(self, path: str, content: bytes) -> None:
+ def put_to_file(self, path: Optional[str], content: bytes) -> str:
raise NotImplementedError()
def get_interface(self, ip: str) -> str:
@@ -148,27 +174,35 @@
raise NotImplementedError()
def disconnect(self) -> str:
- raise NotImplementedError()
+ self.conn.disconnect()
+ self.conn = None
-def setup_rpc(node: ISSHHost, rpc_server_code: bytes, port: int = 0) -> IRPCNode:
- code_file = node.run("mktemp").strip()
+def setup_rpc(node: ISSHHost, rpc_server_code: bytes, plugins: Dict[str, bytes] = None, port: int = 0) -> IRPCNode:
log_file = node.run("mktemp").strip()
- node.put_to_file(code_file, rpc_server_code)
- cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
- "--show-settings --stdout-file={out_file}"
-
+ code_file = node.put_to_file(None, rpc_server_code)
ip = node.info.ssh_creds.addr.host
- params_js = node.run(cmd.format(code_file=code_file,
- listen_addr=ip,
- out_file=log_file,
- port=port)).strip()
+ cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
+ "--show-settings --stdout-file={out_file}"
+ cmd = cmd.format(code_file=code_file, listen_ip=ip, out_file=log_file, port=port)
+ params_js = node.run(cmd).strip()
params = json.loads(params_js)
params['log_file'] = log_file
+ node.info.params.update(params)
+
port = int(params['addr'].split(":")[1])
rpc_conn = agent.connect((ip, port))
- node.info.params.update(params)
+
+ if plugins is not None:
+ try:
+ for name, code in plugins.items():
+ rpc_conn.server.load_module(name, None, code)
+ except Exception:
+ rpc_conn.server.stop()
+ rpc_conn.disconnect()
+ raise
+
return RPCNode(rpc_conn, node.info)