blob: 3e8bbab0d02999d93ff422508b98063fe9cdbcdf [file] [log] [blame]
"""
This module contains interfaces for storage classes
"""
import os
import re
import abc
import array
import shutil
import sqlite3
import logging
from typing import Any, TypeVar, Type, IO, Tuple, cast, List, Dict, Iterable, Iterator
import yaml
try:
from yaml import CLoader as Loader, CDumper as Dumper # type: ignore
except ImportError:
from yaml import Loader, Dumper # type: ignore
from .common_types import IStorable
logger = logging.getLogger("wally")
class ISimpleStorage(metaclass=abc.ABCMeta):
"""interface for low-level storage, which doesn't support serialization
and can operate only on bytes"""
@abc.abstractmethod
def put(self, value: bytes, path: str) -> None:
pass
@abc.abstractmethod
def get(self, path: str) -> bytes:
pass
@abc.abstractmethod
def rm(self, path: str) -> None:
pass
@abc.abstractmethod
def sync(self) -> None:
pass
@abc.abstractmethod
def __contains__(self, path: str) -> bool:
pass
@abc.abstractmethod
def get_fd(self, path: str, mode: str = "rb+") -> IO:
pass
@abc.abstractmethod
def sub_storage(self, path: str) -> 'ISimpleStorage':
pass
@abc.abstractmethod
def list(self, path: str) -> Iterator[Tuple[bool, str]]:
pass
class ISerializer(metaclass=abc.ABCMeta):
"""Interface for serialization class"""
@abc.abstractmethod
def pack(self, value: IStorable) -> bytes:
pass
@abc.abstractmethod
def unpack(self, data: bytes) -> Any:
pass
class DBStorage(ISimpleStorage):
create_tb_sql = "CREATE TABLE IF NOT EXISTS wally_storage (key text, data blob, type text)"
insert_sql = "INSERT INTO wally_storage VALUES (?, ?, ?)"
update_sql = "UPDATE wally_storage SET data=?, type=? WHERE key=?"
select_sql = "SELECT data, type FROM wally_storage WHERE key=?"
contains_sql = "SELECT 1 FROM wally_storage WHERE key=?"
rm_sql = "DELETE FROM wally_storage WHERE key LIKE '{}%'"
list2_sql = "SELECT key, length(data), type FROM wally_storage"
SQLITE3_THREADSAFE = 1
def __init__(self, db_path: str = None, existing: bool = False,
prefix: str = None, db: sqlite3.Connection = None) -> None:
assert not prefix or "'" not in prefix, "Broken sql prefix {!r}".format(prefix)
if db_path:
self.existing = existing
if existing:
if not os.path.isfile(db_path):
raise IOError("No storage found at {!r}".format(db_path))
os.makedirs(os.path.dirname(db_path), exist_ok=True)
if sqlite3.threadsafety != self.SQLITE3_THREADSAFE:
raise RuntimeError("Sqlite3 compiled without threadsafe support, can't use DB storage on it")
try:
self.db = sqlite3.connect(db_path, check_same_thread=False)
except sqlite3.OperationalError as exc:
raise IOError("Can't open database at {!r}".format(db_path)) from exc
self.db.execute(self.create_tb_sql)
else:
if db is None:
raise ValueError("Either db or db_path parameter must be passed")
self.db = db
if prefix is None:
self.prefix = ""
elif not prefix.endswith('/'):
self.prefix = prefix + '/'
else:
self.prefix = prefix
def put(self, value: bytes, path: str) -> None:
c = self.db.cursor()
fpath = self.prefix + path
c.execute(self.contains_sql, (fpath,))
if len(c.fetchall()) == 0:
c.execute(self.insert_sql, (fpath, value, 'yaml'))
else:
c.execute(self.update_sql, (value, 'yaml', fpath))
def get(self, path: str) -> bytes:
c = self.db.cursor()
c.execute(self.select_sql, (self.prefix + path,))
res = cast(List[Tuple[bytes, str]], c.fetchall()) # type: List[Tuple[bytes, str]]
if not res:
raise KeyError(path)
assert len(res) == 1
val, tp = res[0]
assert tp == 'yaml'
return val
def rm(self, path: str) -> None:
c = self.db.cursor()
path = self.prefix + path
assert "'" not in path, "Broken sql path {!r}".format(path)
c.execute(self.rm_sql.format(path))
def __contains__(self, path: str) -> bool:
c = self.db.cursor()
path = self.prefix + path
c.execute(self.contains_sql, (self.prefix + path,))
return len(c.fetchall()) != 0
def print_tree(self):
c = self.db.cursor()
c.execute(self.list2_sql)
data = list(c.fetchall())
data.sort()
print("------------------ DB ---------------------")
for key, data_ln, type in data:
print(key, data_ln, type)
print("------------------ END --------------------")
def sub_storage(self, path: str) -> 'DBStorage':
return self.__class__(prefix=self.prefix + path, db=self.db)
def sync(self):
self.db.commit()
def get_fd(self, path: str, mode: str = "rb+") -> IO[bytes]:
raise NotImplementedError("SQLITE3 doesn't provide fd-like interface")
def list(self, path: str) -> Iterator[Tuple[bool, str]]:
raise NotImplementedError("SQLITE3 doesn't provide list method")
DB_REL_PATH = "__db__.db"
class FSStorage(ISimpleStorage):
"""Store all data in files on FS"""
def __init__(self, root_path: str, existing: bool) -> None:
self.root_path = root_path
self.existing = existing
self.ignored = {self.j(DB_REL_PATH), '.', '..'}
def j(self, path: str) -> str:
return os.path.join(self.root_path, path)
def put(self, value: bytes, path: str) -> None:
jpath = self.j(path)
os.makedirs(os.path.dirname(jpath), exist_ok=True)
with open(jpath, "wb") as fd:
fd.write(value)
def get(self, path: str) -> bytes:
try:
with open(self.j(path), "rb") as fd:
return fd.read()
except FileNotFoundError as exc:
raise KeyError(path) from exc
def rm(self, path: str) -> None:
if os.path.isdir(path):
shutil.rmtree(path, ignore_errors=True)
elif os.path.exists(path):
os.unlink(path)
def __contains__(self, path: str) -> bool:
return os.path.exists(self.j(path))
def get_fd(self, path: str, mode: str = "rb+") -> IO[bytes]:
jpath = self.j(path)
if "cb" == mode:
create_on_fail = True
mode = "rb+"
os.makedirs(os.path.dirname(jpath), exist_ok=True)
else:
create_on_fail = False
try:
fd = open(jpath, mode)
except IOError:
if not create_on_fail:
raise
fd = open(jpath, "wb")
return cast(IO[bytes], fd)
def sub_storage(self, path: str) -> 'FSStorage':
return self.__class__(self.j(path), self.existing)
def sync(self):
pass
def list(self, path: str) -> Iterator[Tuple[bool, str]]:
path = self.j(path)
if not os.path.exists(path):
return
if not os.path.isdir(path):
raise OSError("{!r} is not a directory".format(path))
for fobj in os.scandir(path):
if fobj.path not in self.ignored:
if fobj.is_dir():
yield False, fobj.name
else:
yield True, fobj.name
class YAMLSerializer(ISerializer):
"""Serialize data to yaml"""
def pack(self, value: IStorable) -> bytes:
try:
return yaml.dump(value, Dumper=Dumper, encoding="utf8")
except Exception as exc:
raise ValueError("Can't pickle object {!r} to yaml".format(type(value))) from exc
def unpack(self, data: bytes) -> Any:
return yaml.load(data, Loader=Loader)
class SAFEYAMLSerializer(ISerializer):
"""Serialize data to yaml"""
def pack(self, value: IStorable) -> bytes:
try:
return yaml.safe_dump(value, encoding="utf8")
except Exception as exc:
raise ValueError("Can't pickle object {!r} to yaml".format(type(value))) from exc
def unpack(self, data: bytes) -> Any:
return yaml.safe_load(data)
ObjClass = TypeVar('ObjClass', bound=IStorable)
class _Raise:
pass
class Storage:
"""interface for storage"""
typechar_pad_size = 16
typepad = bytes(0 for i in range(typechar_pad_size - 1))
def __init__(self, fs_storage: ISimpleStorage, db_storage: ISimpleStorage, serializer: ISerializer) -> None:
self.fs = fs_storage
self.db = db_storage
self.serializer = serializer
def sub_storage(self, *path: str) -> 'Storage':
fpath = "/".join(path)
return self.__class__(self.fs.sub_storage(fpath), self.db.sub_storage(fpath), self.serializer)
def put(self, value: Any, *path: str) -> None:
dct_value = cast(IStorable, value).raw() if isinstance(value, IStorable) else value
serialized = self.serializer.pack(dct_value) # type: ignore
fpath = "/".join(path)
self.db.put(serialized, fpath)
self.fs.put(serialized, fpath)
def put_list(self, value: Iterable[IStorable], *path: str) -> None:
serialized = self.serializer.pack([obj.raw() for obj in value]) # type: ignore
fpath = "/".join(path)
self.db.put(serialized, fpath)
self.fs.put(serialized, fpath)
def get(self, path: str, default: Any = _Raise) -> Any:
try:
vl = self.db.get(path)
except:
if default is _Raise:
raise
return default
return self.serializer.unpack(vl)
def rm(self, *path: str) -> None:
fpath = "/".join(path)
self.fs.rm(fpath)
self.db.rm(fpath)
def __contains__(self, path: str) -> bool:
return path in self.fs or path in self.db
def put_raw(self, val: bytes, *path: str) -> str:
fpath = "/".join(path)
self.fs.put(val, fpath)
# TODO: dirty hack
return self.resolve_raw(fpath)
def resolve_raw(self, fpath) -> str:
return cast(FSStorage, self.fs).j(fpath)
def get_raw(self, *path: str) -> bytes:
return self.fs.get("/".join(path))
def append_raw(self, value: bytes, *path: str) -> None:
with self.fs.get_fd("/".join(path), "rb+") as fd:
fd.seek(0, os.SEEK_END)
fd.write(value)
def get_fd(self, path: str, mode: str = "r") -> IO:
return self.fs.get_fd(path, mode)
def put_array(self, value: array.array, *path: str) -> None:
typechar = value.typecode.encode('ascii')
assert len(typechar) == 1
with self.get_fd("/".join(path), "wb") as fd:
fd.write(typechar + self.typepad)
value.tofile(fd) # type: ignore
def get_array(self, *path: str) -> array.array:
path_s = "/".join(path)
with self.get_fd(path_s, "rb") as fd:
fd.seek(0, os.SEEK_END)
size = fd.tell() - self.typechar_pad_size
fd.seek(0, os.SEEK_SET)
typecode = chr(fd.read(self.typechar_pad_size)[0])
res = array.array(typecode)
assert size % res.itemsize == 0, "Storage object at path {} contains no array of {} or corrupted."\
.format(path_s, typecode)
res.fromfile(fd, size // res.itemsize) # type: ignore
return res
def append(self, value: array.array, *path: str) -> None:
typechar = value.typecode.encode('ascii')
assert len(typechar) == 1
expected_typeheader = typechar + self.typepad
with self.get_fd("/".join(path), "cb") as fd:
fd.seek(0, os.SEEK_END)
if fd.tell() != 0:
fd.seek(0, os.SEEK_SET)
real_typecode = fd.read(self.typechar_pad_size)
if real_typecode[0] != expected_typeheader[0]:
logger.error("Try to append array with typechar %r to array with typechar %r at path %r",
value.typecode, typechar, "/".join(path))
raise StopIteration()
fd.seek(0, os.SEEK_END)
else:
fd.write(expected_typeheader)
value.tofile(fd) # type: ignore
def load_list(self, obj_class: Type[ObjClass], *path: str) -> List[ObjClass]:
path_s = "/".join(path)
raw_val = cast(List[Dict[str, Any]], self.get(path_s))
assert isinstance(raw_val, list)
return [cast(ObjClass, obj_class.fromraw(val)) for val in raw_val]
def load(self, obj_class: Type[ObjClass], *path: str) -> ObjClass:
path_s = "/".join(path)
return cast(ObjClass, obj_class.fromraw(self.get(path_s)))
def sync(self) -> None:
self.db.sync()
self.fs.sync()
def __enter__(self) -> 'Storage':
return self
def __exit__(self, x: Any, y: Any, z: Any) -> None:
self.sync()
def list(self, *path: str) -> Iterator[Tuple[bool, str]]:
return self.fs.list("/".join(path))
def _iter_paths(self,
root: str,
path_parts: List[str],
groups: Dict[str, str]) -> Iterator[Tuple[bool, str, Dict[str, str]]]:
curr = path_parts[0]
rest = path_parts[1:]
for is_file, name in self.list(root):
if rest and is_file:
continue
rr = re.match(pattern=curr + "$", string=name)
if rr:
if root:
path = root + "/" + name
else:
path = name
new_groups = rr.groupdict().copy()
new_groups.update(groups)
if rest:
yield from self._iter_paths(path, rest, new_groups)
else:
yield is_file, path, new_groups
def make_storage(url: str, existing: bool = False) -> Storage:
return Storage(FSStorage(url, existing),
DBStorage(os.path.join(url, DB_REL_PATH)),
SAFEYAMLSerializer())