Reimplement Parameters
Parameters is not a dict, so let's not pretend to be one. Instead, be
a very specific type doing exactly what it should and exposing only as
much functionality as needed.
This is a complete rewrite and pulls in the entire mergers logic.
Also switches from nose to unittest.
Signed-off-by: martin f. krafft <madduck@madduck.net>
diff --git a/reclass/datatypes/parameters.py b/reclass/datatypes/parameters.py
index 43dcb02..ade290d 100644
--- a/reclass/datatypes/parameters.py
+++ b/reclass/datatypes/parameters.py
@@ -8,13 +8,110 @@
#
from reclass.mergers.dict import DictRecursivePolicyUpdate
-class Parameters(dict):
+class Parameters(object):
+ '''
+ A class to hold nested dictionaries with the following speciality:
- def __init__(self, *args, **kwargs):
- super(Parameters, self).__init__(*args, **kwargs)
+ "merging" a dictionary (the "new" dictionary) into the current Parameters
+ causes a recursive walk of the new dict, during which
- def merge(self, other, merger=DictRecursivePolicyUpdate()):
- self.update(merger.merge(self, other))
+ - scalars (incl. tuples) are replaced with the value from the new
+ dictionary;
+ - lists are extended, not replaced;
+ - dictionaries are updated (using dict.update), not replaced;
+
+ To support this speciality, this class only exposes very limited
+ functionality and does not try to be a really mapping object.
+ '''
+ DEFAULT_PATH_DELIMITER = ':' # useful default for YAML
+
+ def __init__(self, mapping=None, delimiter=None):
+ if delimiter is None:
+ delimiter = Parameters.DEFAULT_PATH_DELIMITER
+ self._delimiter = delimiter
+ self._base = {}
+ if mapping is not None:
+ # we initialise by merging, otherwise the list of references might
+ # not be updated
+ self.merge(mapping)
+
+ delimiter = property(lambda self: self._delimiter)
+
+ def __len__(self):
+ return len(self._base)
def __repr__(self):
- return '<Parameters {0}>'.format(super(Parameters, self).__repr__())
+ return '%s(%r, %r)' % (self.__class__.__name__, self._base,
+ self.delimiter)
+
+ def __eq__(self, other):
+ return self._base == other._base \
+ and self._delimiter == other._delimiter
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def as_dict(self):
+ return self._base.copy()
+
+ def _update_scalar(self, cur, new, delim, parent):
+ if delim is None:
+ return new
+
+ else:
+ return new #TODO
+
+ def _extend_list(self, cur, new, delim, parent):
+ if isinstance(cur, list):
+ ret = cur
+ else:
+ ret = [cur]
+ for i in new:
+ ret.append(self._merge_recurse(None, i, delim, parent))
+ return ret
+
+ def _merge_dict(self, cur, new, delim, parent):
+ if isinstance(cur, dict):
+ ret = cur
+ else:
+ # nothing sensible to do
+ raise TypeError('Cannot merge dict into {0} '
+ 'objects'.format(type(cur)))
+
+ if delim is None:
+ # a delimiter of None indicates that there is no value
+ # processing to be done, and since there is no current
+ # value, we do not need to walk the new dictionary:
+ ret.update(new)
+ return ret
+
+ for key, newvalue in new.iteritems():
+ ret[key] = self._merge_recurse(ret.get(key), newvalue, delim,
+ (ret, key))
+ return ret
+
+ def _merge_recurse(self, cur, new, delim, parent=None):
+ if isinstance(new, dict):
+ if cur is None:
+ cur = {}
+ return self._merge_dict(cur, new, delim, parent)
+
+ elif isinstance(new, list):
+ if cur is None:
+ cur = []
+ return self._extend_list(cur, new, delim, parent)
+
+ else:
+ return self._update_scalar(cur, new, delim, parent)
+
+ def merge(self, other):
+ if isinstance(other, dict):
+ self._base = self._merge_recurse(self._base, other, None)
+
+ elif isinstance(other, self.__class__):
+ self._base = self._merge_recurse(self._base, other._base,
+ other.delimiter)
+
+ else:
+ raise TypeError('Cannot merge %s objects into %s' % (type(other),
+ self.__class__.__name__))
diff --git a/reclass/datatypes/tests/test_parameters.py b/reclass/datatypes/tests/test_parameters.py
index 504b2e7..dde5573 100644
--- a/reclass/datatypes/tests/test_parameters.py
+++ b/reclass/datatypes/tests/test_parameters.py
@@ -7,16 +7,186 @@
# Released under the terms of the Artistic Licence 2.0
#
from reclass.datatypes import Parameters
+import unittest
+try:
+ import unittest.mock as mock
+except ImportError:
+ import mock
-class TestParameters:
+SIMPLE = {'one': 1, 'two': 2, 'three': 3}
- def test_constructor0(self):
- c = Parameters()
- assert len(c) == 0
+class TestParameters(unittest.TestCase):
- def test_constructor1(self):
- DATA = {'blue':'white', 'black':'yellow'}
- c = Parameters(DATA)
- assert len(c) == len(DATA)
- for i in c.iterkeys():
- assert DATA[i] == c[i]
+ def _construct_mocked_params(self, iterable=None, delimiter=None):
+ p = Parameters(iterable, delimiter)
+ self._base = base = p._base
+ p._base = mock.MagicMock(spec_set=dict, wraps=base)
+ p._base.__repr__ = mock.MagicMock(autospec=dict.__repr__,
+ return_value=repr(base))
+ return p, p._base
+
+ def test_len_empty(self):
+ p, b = self._construct_mocked_params()
+ l = 0
+ b.__len__.return_value = l
+ self.assertEqual(len(p), l)
+ b.__len__.assert_called_with()
+
+ def test_constructor(self):
+ p, b = self._construct_mocked_params(SIMPLE)
+ l = len(SIMPLE)
+ b.__len__.return_value = l
+ self.assertEqual(len(p), l)
+ b.__len__.assert_called_with()
+
+ def test_repr_empty(self):
+ p, b = self._construct_mocked_params()
+ b.__repr__.return_value = repr({})
+ self.assertEqual('%r' % p, '%s(%r, %r)' % (p.__class__.__name__, {},
+ Parameters.DEFAULT_PATH_DELIMITER))
+ b.__repr__.assert_called_once_with()
+
+ def test_repr(self):
+ p, b = self._construct_mocked_params(SIMPLE)
+ b.__repr__.return_value = repr(SIMPLE)
+ self.assertEqual('%r' % p, '%s(%r, %r)' % (p.__class__.__name__, SIMPLE,
+ Parameters.DEFAULT_PATH_DELIMITER))
+ b.__repr__.assert_called_once_with()
+
+ def test_repr_delimiter(self):
+ delim = '%'
+ p, b = self._construct_mocked_params(SIMPLE, delim)
+ b.__repr__.return_value = repr(SIMPLE)
+ self.assertEqual('%r' % p, '%s(%r, %r)' % (p.__class__.__name__, SIMPLE, delim))
+ b.__repr__.assert_called_once_with()
+
+ def test_equal_empty(self):
+ p1, b1 = self._construct_mocked_params()
+ p2, b2 = self._construct_mocked_params()
+ b1.__eq__.return_value = True
+ self.assertEqual(p1, p2)
+ b1.__eq__.assert_called_once_with(b2)
+
+ def test_equal_default_delimiter(self):
+ p1, b1 = self._construct_mocked_params(SIMPLE)
+ p2, b2 = self._construct_mocked_params(SIMPLE,
+ Parameters.DEFAULT_PATH_DELIMITER)
+ b1.__eq__.return_value = True
+ self.assertEqual(p1, p2)
+ b1.__eq__.assert_called_once_with(b2)
+
+ def test_equal_contents(self):
+ p1, b1 = self._construct_mocked_params(SIMPLE)
+ p2, b2 = self._construct_mocked_params(SIMPLE)
+ b1.__eq__.return_value = True
+ self.assertEqual(p1, p2)
+ b1.__eq__.assert_called_once_with(b2)
+
+ def test_unequal_content(self):
+ p1, b1 = self._construct_mocked_params()
+ p2, b2 = self._construct_mocked_params(SIMPLE)
+ b1.__eq__.return_value = False
+ self.assertNotEqual(p1, p2)
+ b1.__eq__.assert_called_once_with(b2)
+
+ def test_unequal_delimiter(self):
+ p1, b1 = self._construct_mocked_params(delimiter=':')
+ p2, b2 = self._construct_mocked_params(delimiter='%')
+ b1.__eq__.return_value = False
+ self.assertNotEqual(p1, p2)
+ b1.__eq__.assert_called_once_with(b2)
+
+ def test_construct_wrong_type(self):
+ with self.assertRaises(TypeError):
+ self._construct_mocked_params('wrong type')
+
+ def test_merge_wrong_type(self):
+ p, b = self._construct_mocked_params()
+ with self.assertRaises(TypeError):
+ p.merge('wrong type')
+
+ def test_get_dict(self):
+ p, b = self._construct_mocked_params(SIMPLE)
+ self.assertDictEqual(p.as_dict(), SIMPLE)
+
+ def test_merge_scalars(self):
+ p1, b1 = self._construct_mocked_params(SIMPLE)
+ mergee = {'five':5,'four':4,'None':None,'tuple':(1,2,3)}
+ p2, b2 = self._construct_mocked_params(mergee)
+ p1.merge(p2)
+ for key, value in mergee.iteritems():
+ # check that each key, value in mergee resulted in a get call and
+ # a __setitem__ call against b1 (the merge target)
+ self.assertIn(mock.call(key), b1.get.call_args_list)
+ self.assertIn(mock.call(key, value), b1.__setitem__.call_args_list)
+
+
+class TestParametersNoMock(unittest.TestCase):
+
+ def test_merge_scalars(self):
+ p = Parameters(SIMPLE)
+ mergee = {'five':5,'four':4,'None':None,'tuple':(1,2,3)}
+ p.merge(mergee)
+ goal = SIMPLE.copy()
+ goal.update(mergee)
+ self.assertDictEqual(p.as_dict(), goal)
+
+ def test_merge_scalars_overwrite(self):
+ p = Parameters(SIMPLE)
+ mergee = {'two':5,'four':4,'three':None,'one':(1,2,3)}
+ p.merge(mergee)
+ goal = SIMPLE.copy()
+ goal.update(mergee)
+ self.assertDictEqual(p.as_dict(), goal)
+
+ def test_merge_lists(self):
+ l1 = [1,2,3]
+ l2 = [2,3,4]
+ p1 = Parameters(dict(list=l1[:]))
+ p2 = Parameters(dict(list=l2))
+ p1.merge(p2)
+ self.assertListEqual(p1.as_dict()['list'], l1+l2)
+
+ def test_merge_list_into_scalar(self):
+ l = ['foo', 1, 2]
+ p1 = Parameters(dict(key=l[0]))
+ p1.merge(Parameters(dict(key=l[1:])))
+ self.assertListEqual(p1.as_dict()['key'], l)
+
+ def test_merge_scalar_over_list(self):
+ l = ['foo', 1, 2]
+ p1 = Parameters(dict(key=l[:2]))
+ p1.merge(Parameters(dict(key=l[2])))
+ self.assertEqual(p1.as_dict()['key'], l[2])
+
+ def test_merge_dicts(self):
+ mergee = {'five':5,'four':4,'None':None,'tuple':(1,2,3)}
+ p = Parameters(dict(dict=SIMPLE))
+ p.merge(Parameters(dict(dict=mergee)))
+ goal = SIMPLE.copy()
+ goal.update(mergee)
+ self.assertDictEqual(p.as_dict(), dict(dict=goal))
+
+ def test_merge_dicts_overwrite(self):
+ mergee = {'two':5,'four':4,'three':None,'one':(1,2,3)}
+ p = Parameters(dict(dict=SIMPLE))
+ p.merge(Parameters(dict(dict=mergee)))
+ goal = SIMPLE.copy()
+ goal.update(mergee)
+ self.assertDictEqual(p.as_dict(), dict(dict=goal))
+
+ def test_merge_dict_into_scalar(self):
+ p = Parameters(dict(base='foo'))
+ with self.assertRaises(TypeError):
+ p.merge(Parameters(dict(base=SIMPLE)))
+
+ def test_merge_scalar_over_dict(self):
+ p = Parameters(dict(base=SIMPLE))
+ mergee = {'base':'foo'}
+ p.merge(Parameters(mergee))
+ self.assertDictEqual(p.as_dict(), mergee)
+
+ test_cur = test_merge_scalar_over_dict
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/reclass/storage/__init__.py b/reclass/storage/__init__.py
index 7cfb60c..be9440a 100644
--- a/reclass/storage/__init__.py
+++ b/reclass/storage/__init__.py
@@ -31,7 +31,7 @@
},
'classes': entity.classes.as_list(),
'applications': entity.applications.as_list(),
- 'parameters': entity.parameters
+ 'parameters': entity.parameters.as_dict()
}
def _list_inventory(self):