koder aka kdanilov | e7e1a4d | 2016-12-17 20:29:52 +0200 | [diff] [blame] | 1 | import os |
| 2 | import contextlib |
| 3 | from unittest.mock import patch |
| 4 | from typing import Iterator |
koder aka kdanilov | 39e449e | 2016-12-17 15:15:26 +0200 | [diff] [blame] | 5 | |
koder aka kdanilov | 39e449e | 2016-12-17 15:15:26 +0200 | [diff] [blame] | 6 | |
koder aka kdanilov | e7e1a4d | 2016-12-17 20:29:52 +0200 | [diff] [blame] | 7 | from wally import ssh_utils, ssh, node, node_interfaces |
koder aka kdanilov | 39e449e | 2016-12-17 15:15:26 +0200 | [diff] [blame] | 8 | |
| 9 | |
| 10 | creds = "root@osd-0" |
| 11 | |
| 12 | |
| 13 | def test_ssh_url_parser(): |
koder aka kdanilov | e7e1a4d | 2016-12-17 20:29:52 +0200 | [diff] [blame] | 14 | default_user = "default_user" |
koder aka kdanilov | 39e449e | 2016-12-17 15:15:26 +0200 | [diff] [blame] | 15 | |
koder aka kdanilov | e7e1a4d | 2016-12-17 20:29:52 +0200 | [diff] [blame] | 16 | creds = [ |
| 17 | ("test", ssh_utils.ConnCreds("test", default_user, port=22)), |
| 18 | ("test:13", ssh_utils.ConnCreds("test", default_user, port=13)), |
| 19 | ("test::xxx.key", ssh_utils.ConnCreds("test", default_user, port=22, key_file="xxx.key")), |
| 20 | ("test:123:xxx.key", ssh_utils.ConnCreds("test", default_user, port=123, key_file="xxx.key")), |
| 21 | ("user@test", ssh_utils.ConnCreds("test", "user", port=22)), |
| 22 | ("user@test:13", ssh_utils.ConnCreds("test", "user", port=13)), |
| 23 | ("user@test::xxx.key", ssh_utils.ConnCreds("test", "user", port=22, key_file="xxx.key")), |
| 24 | ("user@test:123:xxx.key", ssh_utils.ConnCreds("test", "user", port=123, key_file="xxx.key")), |
| 25 | ("user:passwd@test", ssh_utils.ConnCreds("test", "user", port=22, passwd="passwd")), |
| 26 | ("user:passwd:@test", ssh_utils.ConnCreds("test", "user", port=22, passwd="passwd:")), |
| 27 | ("user:passwd:@test:123", ssh_utils.ConnCreds("test", "user", port=123, passwd="passwd:")) |
| 28 | ] |
koder aka kdanilov | 39e449e | 2016-12-17 15:15:26 +0200 | [diff] [blame] | 29 | |
koder aka kdanilov | e7e1a4d | 2016-12-17 20:29:52 +0200 | [diff] [blame] | 30 | for uri, expected in creds: |
| 31 | with patch('getpass.getuser', lambda : default_user): |
| 32 | parsed = ssh_utils.parse_ssh_uri(uri) |
| 33 | |
| 34 | assert parsed.user == expected.user, uri |
| 35 | assert parsed.addr.port == expected.addr.port, uri |
| 36 | assert parsed.addr.host == expected.addr.host, uri |
| 37 | assert parsed.key_file == expected.key_file, uri |
| 38 | assert parsed.passwd == expected.passwd, uri |
| 39 | |
| 40 | |
| 41 | CONNECT_URI = "localhost" |
| 42 | |
| 43 | |
| 44 | @contextlib.contextmanager |
| 45 | def conn_ctx(uri, *args): |
| 46 | creds = ssh_utils.parse_ssh_uri(CONNECT_URI) |
| 47 | node_info = node_interfaces.NodeInfo(creds, set()) |
| 48 | conn = node.connect(node_info, *args) |
| 49 | try: |
| 50 | yield conn |
| 51 | finally: |
| 52 | conn.disconnect() |
| 53 | |
| 54 | |
| 55 | def test_ssh_connect(): |
| 56 | with conn_ctx(CONNECT_URI) as conn: |
| 57 | assert set(conn.run("ls -1 /").split()) == set(fname for fname in os.listdir("/") if not fname.startswith('.')) |
| 58 | |
| 59 | |
| 60 | def test_ssh_complex(): |
| 61 | pass |
| 62 | |
| 63 | |
| 64 | def test_file_copy(): |
| 65 | data1 = b"-" * 1024 |
| 66 | data2 = b"+" * 1024 |
| 67 | |
| 68 | with conn_ctx(CONNECT_URI) as conn: |
| 69 | path = conn.put_to_file(None, data1) |
| 70 | assert data1 == open(path, 'rb').read() |
| 71 | |
| 72 | assert path == conn.put_to_file(path, data2) |
| 73 | assert data2 == open(path, 'rb').read() |
| 74 | |
| 75 | assert len(data2) > 10 |
| 76 | assert path == conn.put_to_file(path, data2[10:]) |
| 77 | assert data2[10:] == open(path, 'rb').read() |