pre-release bug fixes
diff --git a/ssh_utils.py b/ssh_utils.py
index e546b72..c4a18e8 100644
--- a/ssh_utils.py
+++ b/ssh_utils.py
@@ -1,10 +1,12 @@
import re
import time
+import socket
import logging
import os.path
import getpass
+import threading
-import socket
+
import paramiko
@@ -215,6 +217,7 @@
if rrm is not None:
res.__dict__.update(rrm.groupdict())
return res
+
raise ValueError("Can't parse {0!r} as ssh uri value".format(uri))
@@ -224,42 +227,63 @@
return ssh_connect(creds)
-# def get_ssh_runner(uris,
-# conn_func,
-# latest_start_time=None,
-# keep_temp_files=False):
-# logger.debug("Connecting to servers")
+all_sessions_lock = threading.Lock()
+all_sessions = []
-# with ThreadPoolExecutor(max_workers=16) as executor:
-# connections = list(executor.map(connect, uris))
-# result_queue = Queue.Queue()
-# barrier = get_barrier(len(uris), threaded=True)
+def run_over_ssh(conn, cmd, stdin_data=None, timeout=60, nolog=False):
+ "should be replaces by normal implementation, with select"
+ transport = conn.get_transport()
+ session = transport.open_session()
-# def closure(obj):
-# ths = []
-# obj.set_result_cb(result_queue.put)
+ with all_sessions_lock:
+ all_sessions.append(session)
-# params = (obj, barrier, latest_start_time)
+ try:
+ session.set_combine_stderr(True)
-# logger.debug("Start tests")
-# for conn in connections:
-# th = threading.Thread(None, conn_func, None,
-# params + (conn,))
-# th.daemon = True
-# th.start()
-# ths.append(th)
+ stime = time.time()
-# for th in ths:
-# th.join()
+ if not nolog:
+ logger.debug("SSH: Exec {1!r}".format(conn, cmd))
-# test_result = []
-# while not result_queue.empty():
-# test_result.append(result_queue.get())
+ session.exec_command(cmd)
-# logger.debug("Done. Closing connection")
-# for conn in connections:
-# conn.close()
+ if stdin_data is not None:
+ session.sendall(stdin_data)
-# return test_result
-# return closure
+ session.settimeout(1)
+ session.shutdown_write()
+ output = ""
+
+ while True:
+ try:
+ ndata = session.recv(1024)
+ output += ndata
+ if "" == ndata:
+ break
+ except socket.timeout:
+ pass
+
+ if time.time() - stime > timeout:
+ raise OSError(output + "\nExecution timeout")
+
+ code = session.recv_exit_status()
+ finally:
+ session.close()
+
+ if code != 0:
+ templ = "Cmd {0!r} failed with code {1}. Output: {2}"
+ raise OSError(templ.format(cmd, code, output))
+
+ return output
+
+
+def close_all_sessions():
+ with all_sessions_lock:
+ for session in all_sessions:
+ try:
+ session.sendall('\x03')
+ session.close()
+ except:
+ pass