blob: 3bc52fcc2c1989ae3b7a092f7482b2f159dc6c2e [file] [log] [blame]
koder aka kdanilov3d2bc4f2016-11-12 18:31:18 +02001import os
2import time
3import json
4import socket
5import logging
6import subprocess
7from typing import Callable
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +03008
koder aka kdanilov3d2bc4f2016-11-12 18:31:18 +02009import agent
10
11from .node_interfaces import IRPCNode, NodeInfo, ISSHHost, RPCBeforeConnCallback
12from .ssh_utils import parse_ssh_uri, ssh_connect
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +030013
14
koder aka kdanilov3d2bc4f2016-11-12 18:31:18 +020015logger = logging.getLogger("wally")
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +030016
koder aka kdanilov22d134e2016-11-08 11:33:19 +020017
koder aka kdanilov3d2bc4f2016-11-12 18:31:18 +020018class SSHHost(ISSHHost):
19 def __init__(self, ssh_conn, node_name: str, ip: str) -> None:
20 self.conn = ssh_conn
21 self.node_name = node_name
22 self.ip = ip
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +030023
koder aka kdanilov3d2bc4f2016-11-12 18:31:18 +020024 def get_ip(self) -> str:
25 return self.ip
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +030026
koder aka kdanilov22d134e2016-11-08 11:33:19 +020027 def __str__(self) -> str:
koder aka kdanilov3d2bc4f2016-11-12 18:31:18 +020028 return self.node_name
29
30 def put_to_file(self, path: str, content: bytes) -> None:
31 with self.conn.open_sftp() as sftp:
32 with sftp.open(path, "wb") as fd:
33 fd.write(content)
34
35 def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
36 transport = self.conn.get_transport()
37 session = transport.open_session()
38
39 try:
40 session.set_combine_stderr(True)
41
42 stime = time.time()
43
44 if not nolog:
45 logger.debug("SSH:{0} Exec {1!r}".format(self, cmd))
46
47 session.exec_command(cmd)
48 session.settimeout(1)
49 session.shutdown_write()
50 output = ""
51
52 while True:
53 try:
54 ndata = session.recv(1024)
55 output += ndata
56 if "" == ndata:
57 break
58 except socket.timeout:
59 pass
60
61 if time.time() - stime > timeout:
62 raise OSError(output + "\nExecution timeout")
63
64 code = session.recv_exit_status()
65 finally:
66 found = False
67
68 if found:
69 session.close()
70
71 if code != 0:
72 templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
73 raise OSError(templ.format(self, cmd, code, output))
74
75 return output
76
77
78class LocalHost(ISSHHost):
79 def __str__(self):
80 return "<Local>"
81
82 def get_ip(self) -> str:
83 return 'localhost'
84
85 def put_to_file(self, path: str, content: bytes) -> None:
86 dir_name = os.path.dirname(path)
87 os.makedirs(dir_name, exist_ok=True)
88
89 with open(path, "wb") as fd:
90 fd.write(content)
91
92 def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
93 proc = subprocess.Popen(cmd, shell=True,
94 stdin=subprocess.PIPE,
95 stdout=subprocess.PIPE,
96 stderr=subprocess.STDOUT)
97
98 stdout_data, _ = proc.communicate()
99 if proc.returncode != 0:
100 templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
101 raise OSError(templ.format(self, cmd, proc.returncode, stdout_data))
102
103 return stdout_data
104
105
106def connect(conn_url: str, conn_timeout: int = 60) -> ISSHHost:
107 if conn_url == 'local':
108 return LocalHost()
109 else:
110 return SSHHost(*ssh_connect(parse_ssh_uri(conn_url), conn_timeout))
111
112
113class RPCNode(IRPCNode):
114 """Node object"""
115
116 def __init__(self, conn: agent.Client, info: NodeInfo) -> None:
117 self.info = info
118 self.conn = conn
119
120 # if self.ssh_conn_url is not None:
121 # self.ssh_cred = parse_ssh_uri(self.ssh_conn_url)
122 # self.node_id = "{0.host}:{0.port}".format(self.ssh_cred)
123 # else:
124 # self.ssh_cred = None
125 # self.node_id = None
126
127 def __str__(self) -> str:
128 return "<Node: url={!r} roles={!r} hops=/>".format(self.info.ssh_conn_url, ",".join(self.info.roles))
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +0300129
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200130 def __repr__(self) -> str:
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +0300131 return str(self)
132
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +0300133 def get_file_content(self, path: str) -> str:
134 raise NotImplementedError()
135
136 def forward_port(self, ip: str, remote_port: int, local_port: int = None) -> int:
137 raise NotImplementedError()
138
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +0300139
koder aka kdanilov3d2bc4f2016-11-12 18:31:18 +0200140def setup_rpc(node: ISSHHost, rpc_server_code: bytes, port: int = 0,
141 rpc_conn_callback: RPCBeforeConnCallback = None) -> IRPCNode:
142 code_file = node.run("mktemp").strip()
143 log_file = node.run("mktemp").strip()
144 node.put_to_file(code_file, rpc_server_code)
145 cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
146 "--show-settings --stdout-file={out_file}"
147 params_js = node.run(cmd.format(code_file=code_file,
148 listen_addr=node.get_ip(),
149 out_file=log_file,
150 port=port)).strip()
151 params = json.loads(params_js)
152 params['log_file'] = log_file
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +0300153
koder aka kdanilov3d2bc4f2016-11-12 18:31:18 +0200154 if rpc_conn_callback:
155 ip, port = rpc_conn_callback(node, port)
156 else:
157 ip = node.get_ip()
158 port = int(params['addr'].split(":")[1])
koder aka kdanilov3b4da8b2016-10-17 00:17:53 +0300159
koder aka kdanilov3d2bc4f2016-11-12 18:31:18 +0200160 rpc_conn = agent.connect((ip, port))
161 node.info.params.update(params)
162 return RPCNode(rpc_conn, node.info)
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200163
koder aka kdanilov22d134e2016-11-08 11:33:19 +0200164
koder aka kdanilov3d2bc4f2016-11-12 18:31:18 +0200165
166 # class RemoteNode(node_interfaces.IRPCNode):
167# def __init__(self, node_info: node_interfaces.NodeInfo, rpc_conn: agent.RPCClient):
168# self.info = node_info
169# self.rpc = rpc_conn
170#
171 # def get_interface(self, ip: str) -> str:
172 # """Get node external interface for given IP"""
173 # data = self.run("ip a", nolog=True)
174 # curr_iface = None
175 #
176 # for line in data.split("\n"):
177 # match1 = re.match(r"\d+:\s+(?P<name>.*?):\s\<", line)
178 # if match1 is not None:
179 # curr_iface = match1.group('name')
180 #
181 # match2 = re.match(r"\s+inet\s+(?P<ip>[0-9.]+)/", line)
182 # if match2 is not None:
183 # if match2.group('ip') == ip:
184 # assert curr_iface is not None
185 # return curr_iface
186 #
187 # raise KeyError("Can't found interface for ip {0}".format(ip))
188 #
189 # def get_user(self) -> str:
190 # """"get ssh connection username"""
191 # if self.ssh_conn_url == 'local':
192 # return getpass.getuser()
193 # return self.ssh_cred.user
194 #
195 #
196 # def run(self, cmd: str, stdin_data: str = None, timeout: int = 60, nolog: bool = False) -> Tuple[int, str]:
197 # """Run command on node. Will use rpc connection, if available"""
198 #
199 # if self.rpc_conn is None:
200 # return run_over_ssh(self.ssh_conn, cmd,
201 # stdin_data=stdin_data, timeout=timeout,
202 # nolog=nolog, node=self)
203 # assert not stdin_data
204 # proc_id = self.rpc_conn.cli.spawn(cmd)
205 # exit_code = None
206 # output = ""
207 #
208 # while exit_code is None:
209 # exit_code, stdout_data, stderr_data = self.rpc_conn.cli.get_updates(proc_id)
210 # output += stdout_data + stderr_data
211 #
212 # return exit_code, output
213
214