typing and refactoring on the way
diff --git a/wally/ssh_utils.py b/wally/ssh_utils.py
index ada4af6..2941e7c 100644
--- a/wally/ssh_utils.py
+++ b/wally/ssh_utils.py
@@ -1,298 +1,25 @@
 import re
+import json
 import time
-import errno
-import random
 import socket
-import shutil
 import logging
 import os.path
 import getpass
-import StringIO
-import threading
+from io import BytesIO
 import subprocess
+from typing import Union, Optional, cast, Dict, List, Tuple, Any, Callable
+from concurrent.futures import ThreadPoolExecutor
 
 import paramiko
 
+import agent
+
+from . import interfaces, utils
+
 
 logger = logging.getLogger("wally")
 
 
-class Local(object):
-    "simulate ssh connection to local"
-    @classmethod
-    def open_sftp(cls):
-        return cls()
-
-    @classmethod
-    def mkdir(cls, remotepath, mode=None):
-        os.mkdir(remotepath)
-        if mode is not None:
-            os.chmod(remotepath, mode)
-
-    @classmethod
-    def put(cls, localfile, remfile):
-        dirname = os.path.dirname(remfile)
-        if not os.path.exists(dirname):
-            os.makedirs(dirname)
-        shutil.copyfile(localfile, remfile)
-
-    @classmethod
-    def get(cls, remfile, localfile):
-        dirname = os.path.dirname(localfile)
-        if not os.path.exists(dirname):
-            os.makedirs(dirname)
-        shutil.copyfile(remfile, localfile)
-
-    @classmethod
-    def chmod(cls, path, mode):
-        os.chmod(path, mode)
-
-    @classmethod
-    def copytree(cls, src, dst):
-        shutil.copytree(src, dst)
-
-    @classmethod
-    def remove(cls, path):
-        os.unlink(path)
-
-    @classmethod
-    def close(cls):
-        pass
-
-    @classmethod
-    def open(cls, *args, **kwarhgs):
-        return open(*args, **kwarhgs)
-
-    @classmethod
-    def stat(cls, path):
-        return os.stat(path)
-
-    def __enter__(self):
-        return self
-
-    def __exit__(self, x, y, z):
-        return False
-
-
-NODE_KEYS = {}
-
-
-def exists(sftp, path):
-    "os.path.exists for paramiko's SCP object"
-    try:
-        sftp.stat(path)
-        return True
-    except IOError as e:
-        if e.errno == errno.ENOENT:
-            return False
-        raise
-
-
-def set_key_for_node(host_port, key):
-    sio = StringIO.StringIO(key)
-    NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
-    sio.close()
-
-
-def ssh_connect(creds, conn_timeout=60, reuse_conn=None):
-    if creds == 'local':
-        return Local()
-
-    tcp_timeout = 15
-    banner_timeout = 30
-
-    if reuse_conn is None:
-        ssh = paramiko.SSHClient()
-        ssh.load_host_keys('/dev/null')
-        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-        ssh.known_hosts = None
-    else:
-        ssh = reuse_conn
-
-    etime = time.time() + conn_timeout
-
-    while True:
-        try:
-            tleft = etime - time.time()
-            c_tcp_timeout = min(tcp_timeout, tleft)
-
-            if paramiko.__version_info__ >= (1, 15, 2):
-                banner_timeout = {'banner_timeout': min(banner_timeout, tleft)}
-            else:
-                banner_timeout = {}
-
-            if creds.passwd is not None:
-                ssh.connect(creds.host,
-                            timeout=c_tcp_timeout,
-                            username=creds.user,
-                            password=creds.passwd,
-                            port=creds.port,
-                            allow_agent=False,
-                            look_for_keys=False,
-                            **banner_timeout)
-            elif creds.key_file is not None:
-                ssh.connect(creds.host,
-                            username=creds.user,
-                            timeout=c_tcp_timeout,
-                            key_filename=creds.key_file,
-                            look_for_keys=False,
-                            port=creds.port,
-                            **banner_timeout)
-            elif (creds.host, creds.port) in NODE_KEYS:
-                ssh.connect(creds.host,
-                            username=creds.user,
-                            timeout=c_tcp_timeout,
-                            pkey=NODE_KEYS[(creds.host, creds.port)],
-                            look_for_keys=False,
-                            port=creds.port,
-                            **banner_timeout)
-            else:
-                key_file = os.path.expanduser('~/.ssh/id_rsa')
-                ssh.connect(creds.host,
-                            username=creds.user,
-                            timeout=c_tcp_timeout,
-                            key_filename=key_file,
-                            look_for_keys=False,
-                            port=creds.port,
-                            **banner_timeout)
-            return ssh
-        except paramiko.PasswordRequiredException:
-            raise
-        except (socket.error, paramiko.SSHException):
-            if time.time() > etime:
-                raise
-            time.sleep(1)
-
-
-def save_to_remote(sftp, path, content):
-    with sftp.open(path, "wb") as fd:
-        fd.write(content)
-
-
-def read_from_remote(sftp, path):
-    with sftp.open(path, "rb") as fd:
-        return fd.read()
-
-
-def normalize_dirpath(dirpath):
-    while dirpath.endswith("/"):
-        dirpath = dirpath[:-1]
-    return dirpath
-
-
-ALL_RWX_MODE = ((1 << 9) - 1)
-
-
-def ssh_mkdir(sftp, remotepath, mode=ALL_RWX_MODE, intermediate=False):
-    remotepath = normalize_dirpath(remotepath)
-    if intermediate:
-        try:
-            sftp.mkdir(remotepath, mode=mode)
-        except (IOError, OSError):
-            upper_dir = remotepath.rsplit("/", 1)[0]
-
-            if upper_dir == '' or upper_dir == '/':
-                raise
-
-            ssh_mkdir(sftp, upper_dir, mode=mode, intermediate=True)
-            return sftp.mkdir(remotepath, mode=mode)
-    else:
-        sftp.mkdir(remotepath, mode=mode)
-
-
-def ssh_copy_file(sftp, localfile, remfile, preserve_perm=True):
-    sftp.put(localfile, remfile)
-    if preserve_perm:
-        sftp.chmod(remfile, os.stat(localfile).st_mode & ALL_RWX_MODE)
-
-
-def put_dir_recursively(sftp, localpath, remotepath, preserve_perm=True):
-    "upload local directory to remote recursively"
-
-    # hack for localhost connection
-    if hasattr(sftp, "copytree"):
-        sftp.copytree(localpath, remotepath)
-        return
-
-    assert remotepath.startswith("/"), "%s must be absolute path" % remotepath
-
-    # normalize
-    localpath = normalize_dirpath(localpath)
-    remotepath = normalize_dirpath(remotepath)
-
-    try:
-        sftp.chdir(remotepath)
-        localsuffix = localpath.rsplit("/", 1)[1]
-        remotesuffix = remotepath.rsplit("/", 1)[1]
-        if localsuffix != remotesuffix:
-            remotepath = os.path.join(remotepath, localsuffix)
-    except IOError:
-        pass
-
-    for root, dirs, fls in os.walk(localpath):
-        prefix = os.path.commonprefix([localpath, root])
-        suffix = root.split(prefix, 1)[1]
-        if suffix.startswith("/"):
-            suffix = suffix[1:]
-
-        remroot = os.path.join(remotepath, suffix)
-
-        try:
-            sftp.chdir(remroot)
-        except IOError:
-            if preserve_perm:
-                mode = os.stat(root).st_mode & ALL_RWX_MODE
-            else:
-                mode = ALL_RWX_MODE
-            ssh_mkdir(sftp, remroot, mode=mode, intermediate=True)
-            sftp.chdir(remroot)
-
-        for f in fls:
-            remfile = os.path.join(remroot, f)
-            localfile = os.path.join(root, f)
-            ssh_copy_file(sftp, localfile, remfile, preserve_perm)
-
-
-def delete_file(conn, path):
-    sftp = conn.open_sftp()
-    sftp.remove(path)
-    sftp.close()
-
-
-def copy_paths(conn, paths):
-    sftp = conn.open_sftp()
-    try:
-        for src, dst in paths.items():
-            try:
-                if os.path.isfile(src):
-                    ssh_copy_file(sftp, src, dst)
-                elif os.path.isdir(src):
-                    put_dir_recursively(sftp, src, dst)
-                else:
-                    templ = "Can't copy {0!r} - " + \
-                            "it neither a file not a directory"
-                    raise OSError(templ.format(src))
-            except Exception as exc:
-                tmpl = "Scp {0!r} => {1!r} failed - {2!r}"
-                raise OSError(tmpl.format(src, dst, exc))
-    finally:
-        sftp.close()
-
-
-class ConnCreds(object):
-    conn_uri_attrs = ("user", "passwd", "host", "port", "path")
-
-    def __init__(self):
-        for name in self.conn_uri_attrs:
-            setattr(self, name, None)
-
-    def __str__(self):
-        return str(self.__dict__)
-
-
-uri_reg_exprs = []
-
-
 class URIsNamespace(object):
     class ReParts(object):
         user_rr = "[^:]*?"
@@ -323,11 +50,29 @@
         "^{user_rr}:{passwd_rr}@{host_rr}:{port_rr}$",
     ]
 
+    uri_reg_exprs = []  # type: List[str]
     for templ in templs:
         uri_reg_exprs.append(templ.format(**re_dct))
 
 
-def parse_ssh_uri(uri):
+class ConnCreds:
+    conn_uri_attrs = ("user", "passwd", "host", "port", "key_file")
+
+    def __init__(self) -> None:
+        self.user = None  # type: Optional[str]
+        self.passwd = None  # type: Optional[str]
+        self.host = None  # type: str
+        self.port = 22  # type: int
+        self.key_file = None  # type: Optional[str]
+
+    def __str__(self) -> str:
+        return str(self.__dict__)
+
+
+SSHCredsType = Union[str, ConnCreds]
+
+
+def parse_ssh_uri(uri: str) -> ConnCreds:
     # [ssh://]+
     # user:passwd@ip_host:port
     # user:passwd@ip_host
@@ -344,12 +89,12 @@
         uri = uri[len("ssh://"):]
 
     res = ConnCreds()
-    res.port = "22"
+    res.port = 22
     res.key_file = None
     res.passwd = None
     res.user = getpass.getuser()
 
-    for rr in uri_reg_exprs:
+    for rr in URIsNamespace.uri_reg_exprs:
         rrm = re.match(rr, uri)
         if rrm is not None:
             res.__dict__.update(rrm.groupdict())
@@ -358,18 +103,174 @@
     raise ValueError("Can't parse {0!r} as ssh uri value".format(uri))
 
 
-def reconnect(conn, uri, **params):
+class LocalHost(interfaces.IHost):
+    def __str__(self):
+        return "<Local>"
+
+    def get_ip(self) -> str:
+        return 'localhost'
+
+    def put_to_file(self, path: str, content: bytes) -> None:
+        dirname = os.path.dirname(path)
+        if not os.path.exists(dirname):
+            os.makedirs(dirname)
+        with open(path, "wb") as fd:
+            fd.write(content)
+
+    def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
+        proc = subprocess.Popen(cmd, shell=True,
+                                stdin=subprocess.PIPE,
+                                stdout=subprocess.PIPE,
+                                stderr=subprocess.STDOUT)
+
+        stdout_data, _ = proc.communicate()
+        if proc.returncode != 0:
+            templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
+            raise OSError(templ.format(self, cmd, proc.returncode, stdout_data))
+
+        return stdout_data
+
+
+class SSHHost(interfaces.IHost):
+    def __init__(self, ssh_conn, node_name: str, ip: str) -> None:
+        self.conn = ssh_conn
+        self.node_name = node_name
+        self.ip = ip
+
+    def get_ip(self) -> str:
+        return self.ip
+
+    def __str__(self) -> str:
+        return self.node_name
+
+    def put_to_file(self, path: str, content: bytes) -> None:
+        with self.conn.open_sftp() as sftp:
+            with sftp.open(path, "wb") as fd:
+                fd.write(content)
+
+    def run(self, cmd: str, timeout: int = 60, nolog: bool = False) -> str:
+        transport = self.conn.get_transport()
+        session = transport.open_session()
+
+        try:
+            session.set_combine_stderr(True)
+
+            stime = time.time()
+
+            if not nolog:
+                logger.debug("SSH:{0} Exec {1!r}".format(self, cmd))
+
+            session.exec_command(cmd)
+            session.settimeout(1)
+            session.shutdown_write()
+            output = ""
+
+            while True:
+                try:
+                    ndata = session.recv(1024)
+                    output += ndata
+                    if "" == ndata:
+                        break
+                except socket.timeout:
+                    pass
+
+                if time.time() - stime > timeout:
+                    raise OSError(output + "\nExecution timeout")
+
+            code = session.recv_exit_status()
+        finally:
+            found = False
+
+            if found:
+                session.close()
+
+        if code != 0:
+            templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
+            raise OSError(templ.format(self, cmd, code, output))
+
+        return output
+
+
+NODE_KEYS = {}  # type: Dict[Tuple[str, int], paramiko.RSAKey]
+
+
+def set_key_for_node(host_port: Tuple[str, int], key: bytes) -> None:
+    sio = BytesIO(key)
+    NODE_KEYS[host_port] = paramiko.RSAKey.from_private_key(sio)
+    sio.close()
+
+
+def ssh_connect(creds: SSHCredsType, conn_timeout: int = 60) -> interfaces.IHost:
+    if creds == 'local':
+        return LocalHost()
+
+    tcp_timeout = 15
+    default_banner_timeout = 30
+
+    ssh = paramiko.SSHClient()
+    ssh.load_host_keys('/dev/null')
+    ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+    ssh.known_hosts = None
+
+    end_time = time.time() + conn_timeout  # type: float
+
+    while True:
+        try:
+            time_left = end_time - time.time()
+            c_tcp_timeout = min(tcp_timeout, time_left)
+
+            banner_timeout_arg = {}  # type: Dict[str, int]
+            if paramiko.__version_info__ >= (1, 15, 2):
+                banner_timeout_arg['banner_timeout'] = int(min(default_banner_timeout, time_left))
+
+            creds = cast(ConnCreds, creds)
+
+            if creds.passwd is not None:
+                ssh.connect(creds.host,
+                            timeout=c_tcp_timeout,
+                            username=creds.user,
+                            password=cast(str, creds.passwd),
+                            port=creds.port,
+                            allow_agent=False,
+                            look_for_keys=False,
+                            **banner_timeout_arg)
+            elif creds.key_file is not None:
+                ssh.connect(creds.host,
+                            username=creds.user,
+                            timeout=c_tcp_timeout,
+                            key_filename=cast(str, creds.key_file),
+                            look_for_keys=False,
+                            port=creds.port,
+                            **banner_timeout_arg)
+            elif (creds.host, creds.port) in NODE_KEYS:
+                ssh.connect(creds.host,
+                            username=creds.user,
+                            timeout=c_tcp_timeout,
+                            pkey=NODE_KEYS[(creds.host, creds.port)],
+                            look_for_keys=False,
+                            port=creds.port,
+                            **banner_timeout_arg)
+            else:
+                key_file = os.path.expanduser('~/.ssh/id_rsa')
+                ssh.connect(creds.host,
+                            username=creds.user,
+                            timeout=c_tcp_timeout,
+                            key_filename=key_file,
+                            look_for_keys=False,
+                            port=creds.port,
+                            **banner_timeout_arg)
+            return SSHHost(ssh, "{0.host}:{0.port}".format(creds), creds.host)
+        except paramiko.PasswordRequiredException:
+            raise
+        except (socket.error, paramiko.SSHException):
+            if time.time() > end_time:
+                raise
+            time.sleep(1)
+
+
+def connect(uri: str, **params) -> interfaces.IHost:
     if uri == 'local':
-        return conn
-
-    creds = parse_ssh_uri(uri)
-    creds.port = int(creds.port)
-    return ssh_connect(creds, reuse_conn=conn, **params)
-
-
-def connect(uri, **params):
-    if uri == 'local':
-        res = Local()
+        res = LocalHost()
     else:
         creds = parse_ssh_uri(uri)
         creds.port = int(creds.port)
@@ -377,180 +278,58 @@
     return res
 
 
-all_sessions_lock = threading.Lock()
-all_sessions = {}
+SetupResult = Tuple[interfaces.IRPC, Dict[str, Any]]
 
 
-class BGSSHTask(object):
-    CHECK_RETRY = 5
+RPCBeforeConnCallback = Callable[[interfaces.IHost, int], None]
 
-    def __init__(self, node, use_sudo):
-        self.node = node
-        self.pid = None
-        self.use_sudo = use_sudo
 
-    def start(self, orig_cmd, **params):
-        uniq_name = 'test'
-        cmd = "screen -S {0} -d -m {1}".format(uniq_name, orig_cmd)
-        run_over_ssh(self.node.connection, cmd,
-                     timeout=10, node=self.node.get_conn_id(),
-                     **params)
-        processes = run_over_ssh(self.node.connection, "ps aux", nolog=True)
+def setup_rpc(node: interfaces.IHost,
+              rpc_server_code: bytes,
+              port: int = 0,
+              rpc_conn_callback: RPCBeforeConnCallback = None) -> SetupResult:
+    code_file = node.run("mktemp").strip()
+    log_file = node.run("mktemp").strip()
+    node.put_to_file(code_file, rpc_server_code)
+    cmd = "python {code_file} server --listen-addr={listen_ip}:{port} --daemon " + \
+          "--show-settings --stdout-file={out_file}"
+    params_js = node.run(cmd.format(code_file=code_file,
+                                    listen_addr=node.get_ip(),
+                                    out_file=log_file,
+                                    port=port)).strip()
+    params = json.loads(params_js)
+    params['log_file'] = log_file
 
-        for iter in range(self.CHECK_RETRY):
-            for proc in processes.split("\n"):
-                if orig_cmd in proc and "SCREEN" not in proc:
-                    self.pid = proc.split()[1]
-                    break
-            if self.pid is not None:
-                break
-            time.sleep(1)
+    if rpc_conn_callback:
+        ip, port = rpc_conn_callback(node, port)
+    else:
+        ip = node.get_ip()
+        port = int(params['addr'].split(":")[1])
 
-        if self.pid is None:
-            self.pid = -1
+    return agent.connect((ip, port)), params
 
-    def check_running(self):
-        assert self.pid is not None
-        if -1 == self.pid:
-            return False
+
+def wait_ssh_awailable(addrs: List[Tuple[str, int]],
+                       timeout: int = 300,
+                       tcp_timeout: float = 1.0,
+                       max_threads: int = 32) -> None:
+    addrs = addrs[:]
+    tout = utils.Timeout(timeout)
+
+    def check_sock(addr):
+        s = socket.socket()
+        s.settimeout(tcp_timeout)
         try:
-            run_over_ssh(self.node.connection,
-                         "ls /proc/{0}".format(self.pid),
-                         timeout=10, nolog=True)
+            s.connect(addr)
             return True
-        except OSError:
+        except (socket.timeout, ConnectionRefusedError):
             return False
 
-    def kill(self, soft=True, use_sudo=True):
-        assert self.pid is not None
-        if self.pid == -1:
-            return True
-        try:
-            if soft:
-                cmd = "kill {0}"
-            else:
-                cmd = "kill -9 {0}"
-
-            if self.use_sudo:
-                cmd = "sudo " + cmd
-
-            run_over_ssh(self.node.connection,
-                         cmd.format(self.pid), nolog=True)
-            return True
-        except OSError:
-            return False
-
-    def wait(self, soft_timeout, timeout):
-        end_of_wait_time = timeout + time.time()
-        soft_end_of_wait_time = soft_timeout + time.time()
-
-        # time_till_check = random.randint(5, 10)
-        time_till_check = 2
-
-        # time_till_first_check = random.randint(2, 6)
-        time_till_first_check = 2
-        time.sleep(time_till_first_check)
-        if not self.check_running():
-            return True
-
-        while self.check_running() and time.time() < soft_end_of_wait_time:
-            # time.sleep(soft_end_of_wait_time - time.time())
-            time.sleep(time_till_check)
-
-        while end_of_wait_time > time.time():
-            time.sleep(time_till_check)
-            if not self.check_running():
-                break
-        else:
-            self.kill()
-            time.sleep(1)
-            if self.check_running():
-                self.kill(soft=False)
-            return False
-        return True
+    with ThreadPoolExecutor(max_workers=max_threads) as pool:
+        while addrs:
+            check_result = pool.map(check_sock, addrs)
+            addrs = [addr for ok, addr in zip(check_result, addrs) if not ok]  # type: List[Tuple[str, int]]
+            tout.tick()
 
 
-def run_over_ssh(conn, cmd, stdin_data=None, timeout=60,
-                 nolog=False, node=None):
-    "should be replaces by normal implementation, with select"
 
-    if isinstance(conn, Local):
-        if not nolog:
-            logger.debug("SSH:local Exec {0!r}".format(cmd))
-        proc = subprocess.Popen(cmd, shell=True,
-                                stdin=subprocess.PIPE,
-                                stdout=subprocess.PIPE,
-                                stderr=subprocess.STDOUT)
-
-        stdoutdata, _ = proc.communicate(input=stdin_data)
-        if proc.returncode != 0:
-            templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
-            raise OSError(templ.format(node, cmd, proc.returncode, stdoutdata))
-
-        return stdoutdata
-
-    transport = conn.get_transport()
-    session = transport.open_session()
-
-    if node is None:
-        node = ""
-
-    with all_sessions_lock:
-        all_sessions[id(session)] = session
-
-    try:
-        session.set_combine_stderr(True)
-
-        stime = time.time()
-
-        if not nolog:
-            logger.debug("SSH:{0} Exec {1!r}".format(node, cmd))
-
-        session.exec_command(cmd)
-
-        if stdin_data is not None:
-            session.sendall(stdin_data)
-
-        session.settimeout(1)
-        session.shutdown_write()
-        output = ""
-
-        while True:
-            try:
-                ndata = session.recv(1024)
-                output += ndata
-                if "" == ndata:
-                    break
-            except socket.timeout:
-                pass
-
-            if time.time() - stime > timeout:
-                raise OSError(output + "\nExecution timeout")
-
-        code = session.recv_exit_status()
-    finally:
-        found = False
-        with all_sessions_lock:
-            if id(session) in all_sessions:
-                found = True
-                del all_sessions[id(session)]
-
-        if found:
-            session.close()
-
-    if code != 0:
-        templ = "SSH:{0} Cmd {1!r} failed with code {2}. Output: {3}"
-        raise OSError(templ.format(node, cmd, code, output))
-
-    return output
-
-
-def close_all_sessions():
-    with all_sessions_lock:
-        for session in all_sessions.values():
-            try:
-                session.sendall('\x03')
-                session.close()
-            except:
-                pass
-        all_sessions.clear()