refactor result classes and code which stores/loads results from storage
diff --git a/wally/storage.py b/wally/storage.py
index d33f8e5..e4e010c 100644
--- a/wally/storage.py
+++ b/wally/storage.py
@@ -7,6 +7,7 @@
import array
import shutil
import sqlite3
+import threading
from typing import Any, TypeVar, Type, IO, Tuple, cast, List, Dict, Iterable, Iterator
import yaml
@@ -16,7 +17,7 @@
from yaml import Loader, Dumper # type: ignore
-from .result_classes import Storable, IStorable
+from .result_classes import IStorable
class ISimpleStorage(metaclass=abc.ABCMeta):
@@ -59,7 +60,7 @@
class ISerializer(metaclass=abc.ABCMeta):
"""Interface for serialization class"""
@abc.abstractmethod
- def pack(self, value: Storable) -> bytes:
+ def pack(self, value: IStorable) -> bytes:
pass
@abc.abstractmethod
@@ -76,6 +77,7 @@
contains_sql = "SELECT 1 FROM wally_storage WHERE key=?"
rm_sql = "DELETE FROM wally_storage WHERE key LIKE '{}%'"
list2_sql = "SELECT key, length(data), type FROM wally_storage"
+ SQLITE3_THREADSAFE = 1
def __init__(self, db_path: str = None, existing: bool = False,
prefix: str = None, db: sqlite3.Connection = None) -> None:
@@ -89,8 +91,11 @@
raise IOError("No storage found at {!r}".format(db_path))
os.makedirs(os.path.dirname(db_path), exist_ok=True)
+ if sqlite3.threadsafety != self.SQLITE3_THREADSAFE:
+ raise RuntimeError("Sqlite3 compiled without threadsafe support, can't use DB storage on it")
+
try:
- self.db = sqlite3.connect(db_path)
+ self.db = sqlite3.connect(db_path, check_same_thread=False)
except sqlite3.OperationalError as exc:
raise IOError("Can't open database at {!r}".format(db_path)) from exc
@@ -224,7 +229,15 @@
pass
def list(self, path: str) -> Iterator[Tuple[bool, str]]:
- for fobj in os.scandir(self.j(path)):
+ path = self.j(path)
+
+ if not os.path.exists(path):
+ return
+
+ if not os.path.isdir(path):
+ raise OSError("{!r} is not a directory".format(path))
+
+ for fobj in os.scandir(path):
if fobj.path not in self.ignored:
if fobj.is_dir():
yield False, fobj.name
@@ -234,7 +247,7 @@
class YAMLSerializer(ISerializer):
"""Serialize data to yaml"""
- def pack(self, value: Storable) -> bytes:
+ def pack(self, value: IStorable) -> bytes:
try:
return yaml.dump(value, Dumper=Dumper, encoding="utf8")
except Exception as exc:
@@ -246,7 +259,7 @@
class SAFEYAMLSerializer(ISerializer):
"""Serialize data to yaml"""
- def pack(self, value: Storable) -> bytes:
+ def pack(self, value: IStorable) -> bytes:
try:
return yaml.safe_dump(value, encoding="utf8")
except Exception as exc:
@@ -274,7 +287,7 @@
fpath = "/".join(path)
return self.__class__(self.fs.sub_storage(fpath), self.db.sub_storage(fpath), self.serializer)
- def put(self, value: Storable, *path: str) -> None:
+ def put(self, value: IStorable, *path: str) -> None:
dct_value = value.raw() if isinstance(value, IStorable) else value
serialized = self.serializer.pack(dct_value)
fpath = "/".join(path)
@@ -313,7 +326,7 @@
def append_raw(self, value: bytes, *path: str) -> None:
with self.fs.get_fd("/".join(path), "rb+") as fd:
- fd.seek(offset=0, whence=os.SEEK_END)
+ fd.seek(0, os.SEEK_END)
fd.write(value)
def get_fd(self, path: str, mode: str = "r") -> IO: