resume working
diff --git a/wally/storage.py b/wally/storage.py
index 6879dcf..2c0a26b 100644
--- a/wally/storage.py
+++ b/wally/storage.py
@@ -6,8 +6,8 @@
import abc
import array
import shutil
-from typing import Any, Iterator, TypeVar, Type, IO, Tuple, cast, List, Dict, Union, Iterable
-
+import sqlite3
+from typing import Any, TypeVar, Type, IO, Tuple, cast, List, Dict, Iterable
import yaml
try:
@@ -16,16 +16,7 @@
from yaml import Loader, Dumper # type: ignore
-class IStorable(metaclass=abc.ABCMeta):
- """Interface for type, which can be stored"""
-
-basic_types = {list, dict, tuple, set, type(None), int, str, bytes, bool, float}
-for btype in basic_types:
- # pylint: disable=E1101
- IStorable.register(btype) # type: ignore
-
-
-ObjClass = TypeVar('ObjClass')
+from .result_classes import Storable, IStorable
class ISimpleStorage(metaclass=abc.ABCMeta):
@@ -33,15 +24,19 @@
and can operate only on bytes"""
@abc.abstractmethod
- def __setitem__(self, path: str, value: bytes) -> None:
+ def put(self, value: bytes, path: str) -> None:
pass
@abc.abstractmethod
- def __getitem__(self, path: str) -> bytes:
+ def get(self, path: str) -> bytes:
pass
@abc.abstractmethod
- def __delitem__(self, path: str) -> None:
+ def rm(self, path: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ def sync(self) -> None:
pass
@abc.abstractmethod
@@ -49,75 +44,153 @@
pass
@abc.abstractmethod
- def list(self, path: str) -> Iterator[Tuple[bool, str]]:
- pass
-
- @abc.abstractmethod
- def get_stream(self, path: str, mode: str = "rb+") -> IO:
+ def get_fd(self, path: str, mode: str = "rb+") -> IO:
pass
@abc.abstractmethod
def sub_storage(self, path: str) -> 'ISimpleStorage':
pass
- @abc.abstractmethod
- def clear(self, path: str) -> None:
- pass
-
class ISerializer(metaclass=abc.ABCMeta):
"""Interface for serialization class"""
@abc.abstractmethod
- def pack(self, value: IStorable) -> bytes:
+ def pack(self, value: Storable) -> bytes:
pass
@abc.abstractmethod
- def unpack(self, data: bytes) -> IStorable:
+ def unpack(self, data: bytes) -> Any:
pass
+class DBStorage(ISimpleStorage):
+
+ create_tb_sql = "CREATE TABLE IF NOT EXISTS wally_storage (key text, data blob, type text)"
+ insert_sql = "INSERT INTO wally_storage VALUES (?, ?, ?)"
+ update_sql = "UPDATE wally_storage SET data=?, type=? WHERE key=?"
+ select_sql = "SELECT data, type FROM wally_storage WHERE key=?"
+ contains_sql = "SELECT 1 FROM wally_storage WHERE key=?"
+ rm_sql = "DELETE FROM wally_storage WHERE key LIKE '{}%'"
+ list2_sql = "SELECT key, length(data), type FROM wally_storage"
+
+ def __init__(self, db_path: str = None, existing: bool = False,
+ prefix: str = None, db: sqlite3.Connection = None) -> None:
+
+ assert not prefix or "'" not in prefix, "Broken sql prefix {!r}".format(prefix)
+
+ if db_path:
+ self.existing = existing
+ if existing:
+ if not os.path.isfile(db_path):
+ raise IOError("No storage found at {!r}".format(db_path))
+
+ os.makedirs(os.path.dirname(db_path), exist_ok=True)
+ try:
+ self.db = sqlite3.connect(db_path)
+ except sqlite3.OperationalError as exc:
+ raise IOError("Can't open database at {!r}".format(db_path)) from exc
+
+ self.db.execute(self.create_tb_sql)
+ else:
+ if db is None:
+ raise ValueError("Either db or db_path parameter must be passed")
+ self.db = db
+
+ if prefix is None:
+ self.prefix = ""
+ elif not prefix.endswith('/'):
+ self.prefix = prefix + '/'
+ else:
+ self.prefix = prefix
+
+ def put(self, value: bytes, path: str) -> None:
+ c = self.db.cursor()
+ fpath = self.prefix + path
+ c.execute(self.contains_sql, (fpath,))
+ if len(c.fetchall()) == 0:
+ c.execute(self.insert_sql, (fpath, value, 'yaml'))
+ else:
+ c.execute(self.update_sql, (value, 'yaml', fpath))
+
+ def get(self, path: str) -> bytes:
+ c = self.db.cursor()
+ c.execute(self.select_sql, (self.prefix + path,))
+ res = cast(List[Tuple[bytes, str]], c.fetchall()) # type: List[Tuple[bytes, str]]
+ if not res:
+ raise KeyError(path)
+ assert len(res) == 1
+ val, tp = res[0]
+ assert tp == 'yaml'
+ return val
+
+ def rm(self, path: str) -> None:
+ c = self.db.cursor()
+ path = self.prefix + path
+ assert "'" not in path, "Broken sql path {!r}".format(path)
+ c.execute(self.rm_sql.format(path))
+
+ def __contains__(self, path: str) -> bool:
+ c = self.db.cursor()
+ path = self.prefix + path
+ c.execute(self.contains_sql, (self.prefix + path,))
+ return len(c.fetchall()) != 0
+
+ def print_tree(self):
+ c = self.db.cursor()
+ c.execute(self.list2_sql)
+ data = list(c.fetchall())
+ data.sort()
+ print("------------------ DB ---------------------")
+ for key, data_ln, type in data:
+ print(key, data_ln, type)
+ print("------------------ END --------------------")
+
+ def get_fd(self, path: str, mode: str = "rb+") -> IO[bytes]:
+ raise NotImplementedError("SQLITE3 doesn't provide fd-like interface")
+
+ def sub_storage(self, path: str) -> 'DBStorage':
+ return self.__class__(prefix=self.prefix + path, db=self.db)
+
+ def sync(self):
+ self.db.commit()
+
+
+DB_REL_PATH = "__db__.db"
+
+
class FSStorage(ISimpleStorage):
"""Store all data in files on FS"""
def __init__(self, root_path: str, existing: bool) -> None:
self.root_path = root_path
self.existing = existing
- if existing:
- if not os.path.isdir(self.root_path):
- raise IOError("No storage found at {!r}".format(root_path))
def j(self, path: str) -> str:
return os.path.join(self.root_path, path)
- def __setitem__(self, path: str, value: bytes) -> None:
+ def put(self, value: bytes, path: str) -> None:
jpath = self.j(path)
os.makedirs(os.path.dirname(jpath), exist_ok=True)
with open(jpath, "wb") as fd:
fd.write(value)
- def __delitem__(self, path: str) -> None:
+ def get(self, path: str) -> bytes:
try:
- os.unlink(path)
- except FileNotFoundError:
- pass
+ with open(self.j(path), "rb") as fd:
+ return fd.read()
+ except FileNotFoundError as exc:
+ raise KeyError(path) from exc
- def __getitem__(self, path: str) -> bytes:
- with open(self.j(path), "rb") as fd:
- return fd.read()
+ def rm(self, path: str) -> None:
+ if os.path.isdir(path):
+ shutil.rmtree(path, ignore_errors=True)
+ elif os.path.exists(path):
+ os.unlink(path)
def __contains__(self, path: str) -> bool:
return os.path.exists(self.j(path))
- def list(self, path: str = "") -> Iterator[Tuple[bool, str]]:
- jpath = self.j(path)
- if not os.path.exists(jpath):
- return
-
- for entry in os.scandir(jpath):
- if not entry.name in ('..', '.'):
- yield entry.is_file(), entry.name
-
- def get_stream(self, path: str, mode: str = "rb+") -> IO[bytes]:
+ def get_fd(self, path: str, mode: str = "rb+") -> IO[bytes]:
jpath = self.j(path)
if "cb" == mode:
@@ -140,77 +213,89 @@
def sub_storage(self, path: str) -> 'FSStorage':
return self.__class__(self.j(path), self.existing)
- def clear(self, path: str) -> None:
- if os.path.exists(path):
- shutil.rmtree(self.j(path))
+ def sync(self):
+ pass
class YAMLSerializer(ISerializer):
"""Serialize data to yaml"""
- def pack(self, value: Any) -> bytes:
- if type(value) not in basic_types:
- # for name, val in value.__dict__.items():
- # if type(val) not in basic_types:
- # raise ValueError(("Can't pack {!r}. Attribute {} has value {!r} (type: {}), but only" +
- # " basic types accepted as attributes").format(value, name, val, type(val)))
- value = value.__dict__
- return yaml.dump(value, Dumper=Dumper, encoding="utf8")
+ def pack(self, value: Storable) -> bytes:
+ try:
+ return yaml.dump(value, Dumper=Dumper, encoding="utf8")
+ except Exception as exc:
+ raise ValueError("Can't pickle object {!r} to yaml".format(type(value))) from exc
- def unpack(self, data: bytes) -> IStorable:
+ def unpack(self, data: bytes) -> Any:
return yaml.load(data, Loader=Loader)
+class SAFEYAMLSerializer(ISerializer):
+ """Serialize data to yaml"""
+ def pack(self, value: Storable) -> bytes:
+ try:
+ return yaml.safe_dump(value, encoding="utf8")
+ except Exception as exc:
+ raise ValueError("Can't pickle object {!r} to yaml".format(type(value))) from exc
+
+ def unpack(self, data: bytes) -> Any:
+ return yaml.safe_load(data)
+
+
+ObjClass = TypeVar('ObjClass', bound=IStorable)
+
+
class Storage:
"""interface for storage"""
- def __init__(self, storage: ISimpleStorage, serializer: ISerializer) -> None:
- self.storage = storage
+ def __init__(self, fs_storage: ISimpleStorage, db_storage: ISimpleStorage, serializer: ISerializer) -> None:
+ self.fs = fs_storage
+ self.db = db_storage
self.serializer = serializer
def sub_storage(self, *path: str) -> 'Storage':
- return self.__class__(self.storage.sub_storage("/".join(path)), self.serializer)
+ fpath = "/".join(path)
+ return self.__class__(self.fs.sub_storage(fpath), self.db.sub_storage(fpath), self.serializer)
- def __setitem__(self, path: Union[str, Iterable[str]], value: Any) -> None:
- if not isinstance(path, str):
- path = "/".join(path)
+ def put(self, value: Storable, *path: str) -> None:
+ dct_value = value.raw() if isinstance(value, IStorable) else value
+ serialized = self.serializer.pack(dct_value)
+ fpath = "/".join(path)
+ self.db.put(serialized, fpath)
+ self.fs.put(serialized, fpath)
- self.storage[path] = self.serializer.pack(cast(IStorable, value))
+ def put_list(self, value: Iterable[IStorable], *path: str) -> None:
+ serialized = self.serializer.pack([obj.raw() for obj in value])
+ fpath = "/".join(path)
+ self.db.put(serialized, fpath)
+ self.fs.put(serialized, fpath)
- def __getitem__(self, path: Union[str, Iterable[str]]) -> IStorable:
- if not isinstance(path, str):
- path = "/".join(path)
+ def get(self, *path: str) -> Any:
+ return self.serializer.unpack(self.db.get("/".join(path)))
- return self.serializer.unpack(self.storage[path])
+ def rm(self, *path: str) -> None:
+ fpath = "/".join(path)
+ self.fs.rm(fpath)
+ self.db.rm(fpath)
- def __delitem__(self, path: Union[str, Iterable[str]]) -> None:
- if not isinstance(path, str):
- path = "/".join(path)
- del self.storage[path]
+ def __contains__(self, path: str) -> bool:
+ return path in self.fs or path in self.db
- def __contains__(self, path: Union[str, Iterable[str]]) -> bool:
- if not isinstance(path, str):
- path = "/".join(path)
- return path in self.storage
-
- def store_raw(self, val: bytes, *path: str) -> None:
- self.storage["/".join(path)] = val
-
- def clear(self, *path: str) -> None:
- self.storage.clear("/".join(path))
+ def put_raw(self, val: bytes, *path: str) -> None:
+ self.fs.put(val, "/".join(path))
def get_raw(self, *path: str) -> bytes:
- return self.storage["/".join(path)]
+ return self.fs.get("/".join(path))
- def list(self, *path: str) -> Iterator[Tuple[bool, str]]:
- return self.storage.list("/".join(path))
+ def get_fd(self, path: str, mode: str = "r") -> IO:
+ return self.fs.get_fd(path, mode)
- def set_array(self, value: array.array, *path: str) -> None:
- with self.get_stream("/".join(path), "wb") as fd:
+ def put_array(self, value: array.array, *path: str) -> None:
+ with self.get_fd("/".join(path), "wb") as fd:
value.tofile(fd) # type: ignore
def get_array(self, typecode: str, *path: str) -> array.array:
res = array.array(typecode)
path_s = "/".join(path)
- with self.get_stream(path_s, "rb") as fd:
+ with self.get_fd(path_s, "rb") as fd:
fd.seek(0, os.SEEK_END)
size = fd.tell()
fd.seek(0, os.SEEK_SET)
@@ -220,55 +305,33 @@
return res
def append(self, value: array.array, *path: str) -> None:
- with self.get_stream("/".join(path), "cb") as fd:
+ with self.get_fd("/".join(path), "cb") as fd:
fd.seek(0, os.SEEK_END)
value.tofile(fd) # type: ignore
- def construct(self, path: str, raw_val: Dict, obj_class: Type[ObjClass]) -> ObjClass:
- "Internal function, used to construct user type from raw unpacked value"
- if obj_class in (int, str, dict, list, None):
- raise ValueError("Can't load into build-in value - {!r} into type {}")
-
- if not isinstance(raw_val, dict):
- raise ValueError("Can't load path {!r} into python type. Raw value not dict".format(path))
-
- if not all(isinstance(key, str) for key in raw_val.keys()):
- raise ValueError("Can't load path {!r} into python type.".format(path) +
- "Raw not all keys in raw value is strings")
-
- obj = obj_class.__new__(obj_class) # type: ObjClass
- obj.__dict__.update(raw_val)
- return obj
-
def load_list(self, obj_class: Type[ObjClass], *path: str) -> List[ObjClass]:
path_s = "/".join(path)
- raw_val = self[path_s]
+ raw_val = cast(List[Dict[str, Any]], self.get(path_s))
assert isinstance(raw_val, list)
- return [self.construct(path_s, val, obj_class) for val in cast(list, raw_val)]
+ return [obj_class.fromraw(val) for val in raw_val]
def load(self, obj_class: Type[ObjClass], *path: str) -> ObjClass:
path_s = "/".join(path)
- return self.construct(path_s, cast(Dict, self[path_s]), obj_class)
+ return obj_class.fromraw(self.get(path_s))
- def get_stream(self, path: str, mode: str = "r") -> IO:
- return self.storage.get_stream(path, mode)
-
- def get(self, path: Union[str, Iterable[str]], default: Any = None) -> Any:
- if not isinstance(path, str):
- path = "/".join(path)
-
- try:
- return self[path]
- except Exception:
- return default
+ def sync(self) -> None:
+ self.db.sync()
+ self.fs.sync()
def __enter__(self) -> 'Storage':
return self
def __exit__(self, x: Any, y: Any, z: Any) -> None:
- return
+ self.sync()
def make_storage(url: str, existing: bool = False) -> Storage:
- return Storage(FSStorage(url, existing), YAMLSerializer())
+ return Storage(FSStorage(url, existing),
+ DBStorage(os.path.join(url, DB_REL_PATH)),
+ SAFEYAMLSerializer())