fixeg code
diff --git a/tests/test_rpc.py b/tests/test_rpc.py
new file mode 100644
index 0000000..fd35555
--- /dev/null
+++ b/tests/test_rpc.py
@@ -0,0 +1,38 @@
+import contextlib
+
+from wally import ssh_utils, node, node_interfaces
+
+
+CONNECT_URI = "localhost"
+
+
+@contextlib.contextmanager
+def rpc_conn_ctx(uri):
+ creds = ssh_utils.parse_ssh_uri(uri)
+ rpc_code, modules = node.get_rpc_server_code()
+
+ ssh_conn = node.connect(node_interfaces.NodeInfo(creds, set()))
+ try:
+ rpc_conn = node.setup_rpc(ssh_conn, rpc_code, plugins=modules)
+ try:
+ yield rpc_conn
+ finally:
+ rpc_conn.conn.server.stop()
+ rpc_conn.disconnect()
+ finally:
+ ssh_conn.disconnect()
+
+
+def test_rpc_simple():
+ with rpc_conn_ctx(CONNECT_URI) as conn:
+ names = conn.conn.server.rpc_info()
+ assert 'server.list_modules' in names
+ assert 'server.load_module' in names
+ assert 'server.rpc_info' in names
+ assert 'server.stop' in names
+
+
+def test_rpc_plugins():
+ with rpc_conn_ctx(CONNECT_URI) as conn:
+ print(conn.conn.server.rpc_info())
+ assert conn.conn.fs.file_exists("/")
diff --git a/tests/test_ssh.py b/tests/test_ssh.py
index efc5f09..43cc9bd 100644
--- a/tests/test_ssh.py
+++ b/tests/test_ssh.py
@@ -1,33 +1,77 @@
-import getpass
+import os
+import contextlib
+from unittest.mock import patch
+from typing import Iterator
-from oktest import ok
-from wally import ssh_utils, ssh
+from wally import ssh_utils, ssh, node, node_interfaces
creds = "root@osd-0"
def test_ssh_url_parser():
- curr_user = getpass.getuser()
- creds = {
- "test": ssh_utils.ConnCreds("test", curr_user, port=23),
- "test:13": ssh_utils.ConnCreds("test", curr_user, port=13),
- "test::xxx.key": ssh_utils.ConnCreds("test", curr_user, port=23, key_file="xxx.key"),
- "test:123:xxx.key": ssh_utils.ConnCreds("test", curr_user, port=123, key_file="xxx.key"),
- "user@test": ssh_utils.ConnCreds("test", "user", port=23),
- "user@test:13": ssh_utils.ConnCreds("test", "user", port=13),
- "user@test::xxx.key": ssh_utils.ConnCreds("test", "user", port=23, key_file="xxx.key"),
- "user@test:123:xxx.key": ssh_utils.ConnCreds("test", "user", port=123, key_file="xxx.key"),
- "user:passwd:@test": ssh_utils.ConnCreds("test", curr_user, port=23, passwd="passwd:"),
- "user:passwd:@test:123": ssh_utils.ConnCreds("test", curr_user, port=123, passwd="passwd:"),
- }
+ default_user = "default_user"
- for uri, expected in creds.items():
- parsed = ssh_utils.parse_ssh_uri(uri)
- ok(parsed.user) == expected.user
- ok(parsed.addr.port) == expected.addr.port
- ok(parsed.addr.host) == expected.addr.host
- ok(parsed.key_file) == expected.key_file
- ok(parsed.passwd) == expected.passwd
+ creds = [
+ ("test", ssh_utils.ConnCreds("test", default_user, port=22)),
+ ("test:13", ssh_utils.ConnCreds("test", default_user, port=13)),
+ ("test::xxx.key", ssh_utils.ConnCreds("test", default_user, port=22, key_file="xxx.key")),
+ ("test:123:xxx.key", ssh_utils.ConnCreds("test", default_user, port=123, key_file="xxx.key")),
+ ("user@test", ssh_utils.ConnCreds("test", "user", port=22)),
+ ("user@test:13", ssh_utils.ConnCreds("test", "user", port=13)),
+ ("user@test::xxx.key", ssh_utils.ConnCreds("test", "user", port=22, key_file="xxx.key")),
+ ("user@test:123:xxx.key", ssh_utils.ConnCreds("test", "user", port=123, key_file="xxx.key")),
+ ("user:passwd@test", ssh_utils.ConnCreds("test", "user", port=22, passwd="passwd")),
+ ("user:passwd:@test", ssh_utils.ConnCreds("test", "user", port=22, passwd="passwd:")),
+ ("user:passwd:@test:123", ssh_utils.ConnCreds("test", "user", port=123, passwd="passwd:"))
+ ]
+ for uri, expected in creds:
+ with patch('getpass.getuser', lambda : default_user):
+ parsed = ssh_utils.parse_ssh_uri(uri)
+
+ assert parsed.user == expected.user, uri
+ assert parsed.addr.port == expected.addr.port, uri
+ assert parsed.addr.host == expected.addr.host, uri
+ assert parsed.key_file == expected.key_file, uri
+ assert parsed.passwd == expected.passwd, uri
+
+
+CONNECT_URI = "localhost"
+
+
+@contextlib.contextmanager
+def conn_ctx(uri, *args):
+ creds = ssh_utils.parse_ssh_uri(CONNECT_URI)
+ node_info = node_interfaces.NodeInfo(creds, set())
+ conn = node.connect(node_info, *args)
+ try:
+ yield conn
+ finally:
+ conn.disconnect()
+
+
+def test_ssh_connect():
+ with conn_ctx(CONNECT_URI) as conn:
+ assert set(conn.run("ls -1 /").split()) == set(fname for fname in os.listdir("/") if not fname.startswith('.'))
+
+
+def test_ssh_complex():
+ pass
+
+
+def test_file_copy():
+ data1 = b"-" * 1024
+ data2 = b"+" * 1024
+
+ with conn_ctx(CONNECT_URI) as conn:
+ path = conn.put_to_file(None, data1)
+ assert data1 == open(path, 'rb').read()
+
+ assert path == conn.put_to_file(path, data2)
+ assert data2 == open(path, 'rb').read()
+
+ assert len(data2) > 10
+ assert path == conn.put_to_file(path, data2[10:])
+ assert data2[10:] == open(path, 'rb').read()