blob: 43cc9bde6f8339790668278ad399b433e7611bc3 [file] [log] [blame]
koder aka kdanilove7e1a4d2016-12-17 20:29:52 +02001import os
2import contextlib
3from unittest.mock import patch
4from typing import Iterator
koder aka kdanilov39e449e2016-12-17 15:15:26 +02005
koder aka kdanilov39e449e2016-12-17 15:15:26 +02006
koder aka kdanilove7e1a4d2016-12-17 20:29:52 +02007from wally import ssh_utils, ssh, node, node_interfaces
koder aka kdanilov39e449e2016-12-17 15:15:26 +02008
9
10creds = "root@osd-0"
11
12
13def test_ssh_url_parser():
koder aka kdanilove7e1a4d2016-12-17 20:29:52 +020014 default_user = "default_user"
koder aka kdanilov39e449e2016-12-17 15:15:26 +020015
koder aka kdanilove7e1a4d2016-12-17 20:29:52 +020016 creds = [
17 ("test", ssh_utils.ConnCreds("test", default_user, port=22)),
18 ("test:13", ssh_utils.ConnCreds("test", default_user, port=13)),
19 ("test::xxx.key", ssh_utils.ConnCreds("test", default_user, port=22, key_file="xxx.key")),
20 ("test:123:xxx.key", ssh_utils.ConnCreds("test", default_user, port=123, key_file="xxx.key")),
21 ("user@test", ssh_utils.ConnCreds("test", "user", port=22)),
22 ("user@test:13", ssh_utils.ConnCreds("test", "user", port=13)),
23 ("user@test::xxx.key", ssh_utils.ConnCreds("test", "user", port=22, key_file="xxx.key")),
24 ("user@test:123:xxx.key", ssh_utils.ConnCreds("test", "user", port=123, key_file="xxx.key")),
25 ("user:passwd@test", ssh_utils.ConnCreds("test", "user", port=22, passwd="passwd")),
26 ("user:passwd:@test", ssh_utils.ConnCreds("test", "user", port=22, passwd="passwd:")),
27 ("user:passwd:@test:123", ssh_utils.ConnCreds("test", "user", port=123, passwd="passwd:"))
28 ]
koder aka kdanilov39e449e2016-12-17 15:15:26 +020029
koder aka kdanilove7e1a4d2016-12-17 20:29:52 +020030 for uri, expected in creds:
31 with patch('getpass.getuser', lambda : default_user):
32 parsed = ssh_utils.parse_ssh_uri(uri)
33
34 assert parsed.user == expected.user, uri
35 assert parsed.addr.port == expected.addr.port, uri
36 assert parsed.addr.host == expected.addr.host, uri
37 assert parsed.key_file == expected.key_file, uri
38 assert parsed.passwd == expected.passwd, uri
39
40
41CONNECT_URI = "localhost"
42
43
44@contextlib.contextmanager
45def conn_ctx(uri, *args):
46 creds = ssh_utils.parse_ssh_uri(CONNECT_URI)
47 node_info = node_interfaces.NodeInfo(creds, set())
48 conn = node.connect(node_info, *args)
49 try:
50 yield conn
51 finally:
52 conn.disconnect()
53
54
55def test_ssh_connect():
56 with conn_ctx(CONNECT_URI) as conn:
57 assert set(conn.run("ls -1 /").split()) == set(fname for fname in os.listdir("/") if not fname.startswith('.'))
58
59
60def test_ssh_complex():
61 pass
62
63
64def test_file_copy():
65 data1 = b"-" * 1024
66 data2 = b"+" * 1024
67
68 with conn_ctx(CONNECT_URI) as conn:
69 path = conn.put_to_file(None, data1)
70 assert data1 == open(path, 'rb').read()
71
72 assert path == conn.put_to_file(path, data2)
73 assert data2 == open(path, 'rb').read()
74
75 assert len(data2) > 10
76 assert path == conn.put_to_file(path, data2[10:])
77 assert data2[10:] == open(path, 'rb').read()