moving code to cephlib
diff --git a/wally/storage.py b/wally/storage.py
index c8edf5d..ab52e12 100644
--- a/wally/storage.py
+++ b/wally/storage.py
@@ -18,7 +18,6 @@
import numpy
from .common_types import IStorable
-from .utils import shape2str, str2shape
logger = logging.getLogger("wally")
@@ -61,6 +60,43 @@
pass
+class ITSStorage(metaclass=abc.ABCMeta):
+ """interface for low-level storage, which doesn't support serialization
+ and can operate only on bytes"""
+
+ @abc.abstractmethod
+ def put(self, value: bytes, path: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ def get(self, path: str) -> bytes:
+ pass
+
+ @abc.abstractmethod
+ def rm(self, path: str) -> None:
+ pass
+
+ @abc.abstractmethod
+ def sync(self) -> None:
+ pass
+
+ @abc.abstractmethod
+ def __contains__(self, path: str) -> bool:
+ pass
+
+ @abc.abstractmethod
+ def get_fd(self, path: str, mode: str = "rb+") -> IO:
+ pass
+
+ @abc.abstractmethod
+ def sub_storage(self, path: str) -> 'ISimpleStorage':
+ pass
+
+ @abc.abstractmethod
+ def list(self, path: str) -> Iterator[Tuple[bool, str]]:
+ pass
+
+
class ISerializer(metaclass=abc.ABCMeta):
"""Interface for serialization class"""
@abc.abstractmethod
@@ -72,115 +108,13 @@
pass
-class DBStorage(ISimpleStorage):
-
- create_tb_sql = "CREATE TABLE IF NOT EXISTS wally_storage (key text, data blob, type text)"
- insert_sql = "INSERT INTO wally_storage VALUES (?, ?, ?)"
- update_sql = "UPDATE wally_storage SET data=?, type=? WHERE key=?"
- select_sql = "SELECT data, type FROM wally_storage WHERE key=?"
- contains_sql = "SELECT 1 FROM wally_storage WHERE key=?"
- rm_sql = "DELETE FROM wally_storage WHERE key LIKE '{}%'"
- list2_sql = "SELECT key, length(data), type FROM wally_storage"
- SQLITE3_THREADSAFE = 1
-
- def __init__(self, db_path: str = None, existing: bool = False,
- prefix: str = None, db: sqlite3.Connection = None) -> None:
-
- assert not prefix or "'" not in prefix, "Broken sql prefix {!r}".format(prefix)
-
- if db_path:
- self.existing = existing
- if existing:
- if not os.path.isfile(db_path):
- raise IOError("No storage found at {!r}".format(db_path))
-
- os.makedirs(os.path.dirname(db_path), exist_ok=True)
- if sqlite3.threadsafety != self.SQLITE3_THREADSAFE:
- raise RuntimeError("Sqlite3 compiled without threadsafe support, can't use DB storage on it")
-
- try:
- self.db = sqlite3.connect(db_path, check_same_thread=False)
- except sqlite3.OperationalError as exc:
- raise IOError("Can't open database at {!r}".format(db_path)) from exc
-
- self.db.execute(self.create_tb_sql)
- else:
- if db is None:
- raise ValueError("Either db or db_path parameter must be passed")
- self.db = db
-
- if prefix is None:
- self.prefix = ""
- elif not prefix.endswith('/'):
- self.prefix = prefix + '/'
- else:
- self.prefix = prefix
-
- def put(self, value: bytes, path: str) -> None:
- c = self.db.cursor()
- fpath = self.prefix + path
- c.execute(self.contains_sql, (fpath,))
- if len(c.fetchall()) == 0:
- c.execute(self.insert_sql, (fpath, value, 'yaml'))
- else:
- c.execute(self.update_sql, (value, 'yaml', fpath))
-
- def get(self, path: str) -> bytes:
- c = self.db.cursor()
- c.execute(self.select_sql, (self.prefix + path,))
- res = cast(List[Tuple[bytes, str]], c.fetchall()) # type: List[Tuple[bytes, str]]
- if not res:
- raise KeyError(path)
- assert len(res) == 1
- val, tp = res[0]
- assert tp == 'yaml'
- return val
-
- def rm(self, path: str) -> None:
- c = self.db.cursor()
- path = self.prefix + path
- assert "'" not in path, "Broken sql path {!r}".format(path)
- c.execute(self.rm_sql.format(path))
-
- def __contains__(self, path: str) -> bool:
- c = self.db.cursor()
- path = self.prefix + path
- c.execute(self.contains_sql, (self.prefix + path,))
- return len(c.fetchall()) != 0
-
- def print_tree(self):
- c = self.db.cursor()
- c.execute(self.list2_sql)
- data = list(c.fetchall())
- data.sort()
- print("------------------ DB ---------------------")
- for key, data_ln, type in data:
- print(key, data_ln, type)
- print("------------------ END --------------------")
-
- def sub_storage(self, path: str) -> 'DBStorage':
- return self.__class__(prefix=self.prefix + path, db=self.db)
-
- def sync(self):
- self.db.commit()
-
- def get_fd(self, path: str, mode: str = "rb+") -> IO[bytes]:
- raise NotImplementedError("SQLITE3 doesn't provide fd-like interface")
-
- def list(self, path: str) -> Iterator[Tuple[bool, str]]:
- raise NotImplementedError("SQLITE3 doesn't provide list method")
-
-
-DB_REL_PATH = "__db__.db"
-
-
class FSStorage(ISimpleStorage):
"""Store all data in files on FS"""
def __init__(self, root_path: str, existing: bool) -> None:
self.root_path = root_path
self.existing = existing
- self.ignored = {self.j(DB_REL_PATH), '.', '..'}
+ self.ignored = {'.', '..'}
def j(self, path: str) -> str:
return os.path.join(self.root_path, path)
@@ -288,37 +222,31 @@
pass
-csv_file_encoding = 'ascii'
-
-
class Storage:
"""interface for storage"""
- def __init__(self, fs_storage: ISimpleStorage, db_storage: ISimpleStorage, serializer: ISerializer) -> None:
- self.fs = fs_storage
- self.db = db_storage
+ def __init__(self, sstorage: ISimpleStorage, serializer: ISerializer) -> None:
+ self.sstorage = sstorage
self.serializer = serializer
def sub_storage(self, *path: str) -> 'Storage':
fpath = "/".join(path)
- return self.__class__(self.fs.sub_storage(fpath), self.db.sub_storage(fpath), self.serializer)
+ return self.__class__(self.sstorage.sub_storage(fpath), self.serializer)
def put(self, value: Any, *path: str) -> None:
dct_value = cast(IStorable, value).raw() if isinstance(value, IStorable) else value
serialized = self.serializer.pack(dct_value) # type: ignore
fpath = "/".join(path)
- self.db.put(serialized, fpath)
- self.fs.put(serialized, fpath)
+ self.sstorage.put(serialized, fpath)
def put_list(self, value: Iterable[IStorable], *path: str) -> None:
serialized = self.serializer.pack([obj.raw() for obj in value]) # type: ignore
fpath = "/".join(path)
- self.db.put(serialized, fpath)
- self.fs.put(serialized, fpath)
+ self.sstorage.put(serialized, fpath)
def get(self, path: str, default: Any = _Raise) -> Any:
try:
- vl = self.db.get(path)
+ vl = self.sstorage.get(path)
except:
if default is _Raise:
raise
@@ -328,97 +256,30 @@
def rm(self, *path: str) -> None:
fpath = "/".join(path)
- self.fs.rm(fpath)
- self.db.rm(fpath)
+ self.sstorage.rm(fpath)
def __contains__(self, path: str) -> bool:
- return path in self.fs or path in self.db
+ return path in self.sstorage
def put_raw(self, val: bytes, *path: str) -> str:
fpath = "/".join(path)
- self.fs.put(val, fpath)
+ self.sstorage.put(val, fpath)
# TODO: dirty hack
return self.resolve_raw(fpath)
def resolve_raw(self, fpath) -> str:
- return cast(FSStorage, self.fs).j(fpath)
+ return cast(FSStorage, self.sstorage).j(fpath)
def get_raw(self, *path: str) -> bytes:
- return self.fs.get("/".join(path))
+ return self.sstorage.get("/".join(path))
def append_raw(self, value: bytes, *path: str) -> None:
- with self.fs.get_fd("/".join(path), "rb+") as fd:
+ with self.sstorage.get_fd("/".join(path), "rb+") as fd:
fd.seek(0, os.SEEK_END)
fd.write(value)
def get_fd(self, path: str, mode: str = "r") -> IO:
- return self.fs.get_fd(path, mode)
-
- def put_array(self, header: List[str], value: numpy.array, *path: str) -> None:
- for val in header:
- assert isinstance(val, str) and ',' not in val, \
- "Can't convert {!r} to array header, as it's values contains comma".format(header)
-
- fpath = "/".join(path)
- with self.get_fd(fpath, "wb") as fd:
- self.do_append(fd, header, value, fpath)
-
- def get_array(self, *path: str) -> Tuple[List[str], numpy.array]:
- path_s = "/".join(path)
- with self.get_fd(path_s, "rb") as fd:
- header = fd.readline().decode(csv_file_encoding).rstrip().split(",")
- type_code, second_axis = header[-2:]
- res = numpy.genfromtxt(fd, dtype=type_code, delimiter=',')
-
- if '0' == second_axis:
- res.shape = (len(res),)
-
- return header[:-2], res
-
- def append(self, header: List[str], value: numpy.array, *path: str) -> None:
- for val in header:
- assert isinstance(val, str) and ',' not in val, \
- "Can't convert {!r} to array header, as it's values contains comma".format(header)
-
- fpath = "/".join(path)
- with self.get_fd(fpath, "cb") as fd:
- self.do_append(fd, header, value, fpath, maybe_append=True)
-
- def do_append(self, fd, header: List[str], value: numpy.array, path: str, fmt="%lu",
- maybe_append: bool = False) -> None:
-
- if len(value.shape) == 1:
- second_axis = 0
- else:
- second_axis = value.shape[1]
- header += [value.dtype.name, str(second_axis)]
-
- write_header = False
-
- if maybe_append:
- fd.seek(0, os.SEEK_END)
- if fd.tell() != 0:
- fd.seek(0, os.SEEK_SET)
- # check header match
- curr_header = fd.readline().decode(csv_file_encoding).rstrip().split(",")
- assert header == curr_header, \
- "Path {!r}. Expected header ({!r}) and current header ({!r}) don't match"\
- .format(path, header, curr_header)
- fd.seek(0, os.SEEK_END)
- else:
- write_header = True
- else:
- write_header = True
-
- if write_header:
- fd.write((",".join(header) + "\n").encode(csv_file_encoding))
-
- if len(value.shape) == 1:
- # make array vertical to simplify reading
- vw = value.view().reshape((value.shape[0], 1))
- else:
- vw = value
- numpy.savetxt(fd, vw, delimiter=',', newline="\n", fmt=fmt)
+ return self.sstorage.get_fd(path, mode)
def load_list(self, obj_class: Type[ObjClass], *path: str) -> List[ObjClass]:
path_s = "/".join(path)
@@ -431,8 +292,7 @@
return cast(ObjClass, obj_class.fromraw(self.get(path_s)))
def sync(self) -> None:
- self.db.sync()
- self.fs.sync()
+ self.sstorage.sync()
def __enter__(self) -> 'Storage':
return self
@@ -441,7 +301,7 @@
self.sync()
def list(self, *path: str) -> Iterator[Tuple[bool, str]]:
- return self.fs.list("/".join(path))
+ return self.sstorage.list("/".join(path))
def _iter_paths(self,
root: str,
@@ -472,7 +332,5 @@
def make_storage(url: str, existing: bool = False) -> Storage:
- return Storage(FSStorage(url, existing),
- DBStorage(os.path.join(url, DB_REL_PATH)),
- SAFEYAMLSerializer())
+ return Storage(FSStorage(url, existing), SAFEYAMLSerializer())