fixes, add complete test suite
diff --git a/ssh_utils.py b/ssh_utils.py
index a7dda3f..7c859cf 100644
--- a/ssh_utils.py
+++ b/ssh_utils.py
@@ -1,20 +1,66 @@
import re
+import time
import Queue
import logging
import os.path
-import traceback
+import getpass
import threading
+import socket
+import paramiko
from concurrent.futures import ThreadPoolExecutor
-import itest
-from utils import ssh_connect
-from utils import get_barrier, log_error, wait_on_barrier
+from utils import get_barrier
logger = logging.getLogger("io-perf-tool")
conn_uri_attrs = ("user", "passwd", "host", "port", "path")
+def ssh_connect(creds, retry_count=60, timeout=1):
+ ssh = paramiko.SSHClient()
+ ssh.load_host_keys('/dev/null')
+ ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ ssh.known_hosts = None
+ for i in range(retry_count):
+ try:
+ if creds.user is None:
+ user = getpass.getuser()
+ else:
+ user = creds.user
+
+ if creds.passwd is not None:
+ ssh.connect(creds.host,
+ username=user,
+ password=creds.passwd,
+ port=creds.port,
+ allow_agent=False,
+ look_for_keys=False)
+ return ssh
+
+ if creds.key_file is not None:
+ ssh.connect(creds.host,
+ username=user,
+ key_filename=creds.key_file,
+ look_for_keys=False,
+ port=creds.port)
+ return ssh
+
+ key_file = os.path.expanduser('~/.ssh/id_rsa')
+ ssh.connect(creds.host,
+ username=user,
+ key_filename=key_file,
+ look_for_keys=False,
+ port=creds.port)
+ return ssh
+ # raise ValueError("Wrong credentials {0}".format(creds.__dict__))
+ except paramiko.PasswordRequiredException:
+ raise
+ except socket.error:
+ if i == retry_count - 1:
+ raise
+ time.sleep(timeout)
+
+
def normalize_dirpath(dirpath):
while dirpath.endswith("/"):
dirpath = dirpath[:-1]
@@ -178,21 +224,8 @@
return ssh_connect(creds)
-def conn_func(obj, barrier, latest_start_time, conn):
- try:
- test_iter = itest.run_test_iter(obj, conn)
- next(test_iter)
-
- wait_on_barrier(barrier, latest_start_time)
-
- with log_error("!Run test"):
- return next(test_iter)
- except:
- print traceback.format_exc()
- raise
-
-
def get_ssh_runner(uris,
+ conn_func,
latest_start_time=None,
keep_temp_files=False):
logger.debug("Connecting to servers")