blob: 43cc9bde6f8339790668278ad399b433e7611bc3 [file] [log] [blame]
import os
import contextlib
from unittest.mock import patch
from typing import Iterator
from wally import ssh_utils, ssh, node, node_interfaces
creds = "root@osd-0"
def test_ssh_url_parser():
default_user = "default_user"
creds = [
("test", ssh_utils.ConnCreds("test", default_user, port=22)),
("test:13", ssh_utils.ConnCreds("test", default_user, port=13)),
("test::xxx.key", ssh_utils.ConnCreds("test", default_user, port=22, key_file="xxx.key")),
("test:123:xxx.key", ssh_utils.ConnCreds("test", default_user, port=123, key_file="xxx.key")),
("user@test", ssh_utils.ConnCreds("test", "user", port=22)),
("user@test:13", ssh_utils.ConnCreds("test", "user", port=13)),
("user@test::xxx.key", ssh_utils.ConnCreds("test", "user", port=22, key_file="xxx.key")),
("user@test:123:xxx.key", ssh_utils.ConnCreds("test", "user", port=123, key_file="xxx.key")),
("user:passwd@test", ssh_utils.ConnCreds("test", "user", port=22, passwd="passwd")),
("user:passwd:@test", ssh_utils.ConnCreds("test", "user", port=22, passwd="passwd:")),
("user:passwd:@test:123", ssh_utils.ConnCreds("test", "user", port=123, passwd="passwd:"))
]
for uri, expected in creds:
with patch('getpass.getuser', lambda : default_user):
parsed = ssh_utils.parse_ssh_uri(uri)
assert parsed.user == expected.user, uri
assert parsed.addr.port == expected.addr.port, uri
assert parsed.addr.host == expected.addr.host, uri
assert parsed.key_file == expected.key_file, uri
assert parsed.passwd == expected.passwd, uri
CONNECT_URI = "localhost"
@contextlib.contextmanager
def conn_ctx(uri, *args):
creds = ssh_utils.parse_ssh_uri(CONNECT_URI)
node_info = node_interfaces.NodeInfo(creds, set())
conn = node.connect(node_info, *args)
try:
yield conn
finally:
conn.disconnect()
def test_ssh_connect():
with conn_ctx(CONNECT_URI) as conn:
assert set(conn.run("ls -1 /").split()) == set(fname for fname in os.listdir("/") if not fname.startswith('.'))
def test_ssh_complex():
pass
def test_file_copy():
data1 = b"-" * 1024
data2 = b"+" * 1024
with conn_ctx(CONNECT_URI) as conn:
path = conn.put_to_file(None, data1)
assert data1 == open(path, 'rb').read()
assert path == conn.put_to_file(path, data2)
assert data2 == open(path, 'rb').read()
assert len(data2) > 10
assert path == conn.put_to_file(path, data2[10:])
assert data2[10:] == open(path, 'rb').read()