Merge pull request #93 from salt-formulas/andrewp-fix-class-mappings-regression

allow class mappings to match on node name or node path
diff --git a/doc/source/operations.rst b/doc/source/operations.rst
index f744148..08b34e5 100644
--- a/doc/source/operations.rst
+++ b/doc/source/operations.rst
@@ -101,6 +101,13 @@
 can be assigned to each mapping by providing a space-separated list (class
 names cannot contain spaces anyway).
 
+By default the class mappings regex match is done against the node name. This can
+be changed to do the match against the path of the node file from the classes
+directory, but dropping the .yml extension at the end of the node file. This is
+controlled with the setting class_mappings_match_path. When False (the
+default) the match is done again the node name and when true the match is done
+against the node file path.
+
 .. warning::
 
   The class mappings do not really belong in the configuration file, as they
diff --git a/reclass/core.py b/reclass/core.py
index 3e0ab34..1ce74ed 100644
--- a/reclass/core.py
+++ b/reclass/core.py
@@ -72,26 +72,30 @@
             key = '/{0}/'.format(key)
         return key, list(lexer)
 
-    def _get_class_mappings_entity(self, nodename):
+    def _get_class_mappings_entity(self, entity):
         if not self._class_mappings:
             return Entity(self._settings, name='empty (class mappings)')
         c = Classes()
+        if self._settings.class_mappings_match_path:
+            matchname = entity.pathname
+        else:
+            matchname = entity.name
         for mapping in self._class_mappings:
             matched = False
             key, klasses = Core._shlex_split(mapping)
             if key[0] == ('/'):
-                matched = Core._match_regexp(key[1:-1], nodename)
+                matched = Core._match_regexp(key[1:-1], matchname)
                 if matched:
                     for klass in klasses:
                         c.append_if_new(matched.expand(klass))
 
             else:
-                if Core._match_glob(key, nodename):
+                if Core._match_glob(key, matchname):
                     for klass in klasses:
                         c.append_if_new(klass)
 
         return Entity(self._settings, classes=c,
-                      name='class mappings for node {0}'.format(nodename))
+                      name='class mappings for node {0}'.format(entity.name))
 
     def _get_input_data_entity(self):
         if not self._input_data:
@@ -207,7 +211,7 @@
         if node_entity.environment == None:
             node_entity.environment = self._settings.default_environment
         base_entity = Entity(self._settings, name='base')
-        base_entity.merge(self._get_class_mappings_entity(node_entity.name))
+        base_entity.merge(self._get_class_mappings_entity(node_entity))
         base_entity.merge(self._get_input_data_entity())
         base_entity.merge_parameters(self._get_automatic_parameters(nodename, node_entity.environment))
         seen = {}
diff --git a/reclass/datatypes/entity.py b/reclass/datatypes/entity.py
index 2e0e1e4..88b5afe 100644
--- a/reclass/datatypes/entity.py
+++ b/reclass/datatypes/entity.py
@@ -24,9 +24,10 @@
     '''
     def __init__(self, settings, classes=None, applications=None,
                  parameters=None, exports=None, uri=None, name=None,
-                 environment=None):
+                 pathname=None, environment=None):
         self._uri = '' if uri is None else uri
         self._name = '' if name is None else name
+        self._pathname = '' if pathname is None else pathname
         self._classes = self._set_field(classes, Classes)
         self._applications = self._set_field(applications, Applications)
         pars = [None, settings, uri]
@@ -36,6 +37,7 @@
 
     name = property(lambda s: s._name)
     uri = property(lambda s: s._uri)
+    pathname = property(lambda s: s._pathname)
     classes = property(lambda s: s._classes)
     applications = property(lambda s: s._applications)
     parameters = property(lambda s: s._parameters)
@@ -101,10 +103,10 @@
         return not self.__eq__(other)
 
     def __repr__(self):
-        return "%s(%r, %r, %r, %r, uri=%r, name=%r, environment=%r)" % (
+        return "%s(%r, %r, %r, %r, uri=%r, name=%r, pathname=%r, environment=%r)" % (
                    self.__class__.__name__, self.classes, self.applications,
                    self.parameters, self.exports, self.uri, self.name,
-                   self.environment)
+                   self.pathname, self.environment)
 
     def as_dict(self):
         return {'classes': self._classes.as_list(),
diff --git a/reclass/defaults.py b/reclass/defaults.py
index f240f3f..f50a8ad 100644
--- a/reclass/defaults.py
+++ b/reclass/defaults.py
@@ -57,3 +57,5 @@
 
 AUTOMATIC_RECLASS_PARAMETERS = True
 DEFAULT_ENVIRONMENT = 'base'
+
+CLASS_MAPPINGS_MATCH_PATH = False
diff --git a/reclass/settings.py b/reclass/settings.py
index 62af976..e9e8a36 100644
--- a/reclass/settings.py
+++ b/reclass/settings.py
@@ -19,6 +19,7 @@
         'allow_dict_over_scalar': defaults.OPT_ALLOW_DICT_OVER_SCALAR,
         'allow_none_override': defaults.OPT_ALLOW_NONE_OVERRIDE,
         'automatic_parameters': defaults.AUTOMATIC_RECLASS_PARAMETERS,
+        'class_mappings_match_path': defaults.CLASS_MAPPINGS_MATCH_PATH,
         'default_environment': defaults.DEFAULT_ENVIRONMENT,
         'delimiter': defaults.PARAMETER_INTERPOLATION_DELIMITER,
         'dict_key_override_prefix':
diff --git a/reclass/storage/yaml_fs/__init__.py b/reclass/storage/yaml_fs/__init__.py
index 3577b36..ee49df3 100644
--- a/reclass/storage/yaml_fs/__init__.py
+++ b/reclass/storage/yaml_fs/__init__.py
@@ -95,18 +95,20 @@
         try:
             relpath = self._nodes[name]
             path = os.path.join(self.nodes_uri, relpath)
+            pathname = os.path.splitext(relpath)[0]
         except KeyError as e:
             raise reclass.errors.NodeNotFound(self.name, name, self.nodes_uri)
-        entity = YamlData.from_file(path).get_entity(name, settings)
+        entity = YamlData.from_file(path).get_entity(name, pathname, settings)
         return entity
 
     def get_class(self, name, environment, settings):
         vvv('GET CLASS {0}'.format(name))
         try:
             path = os.path.join(self.classes_uri, self._classes[name])
+            pathname = os.path.splitext(self._classes[name])[0]
         except KeyError as e:
             raise reclass.errors.ClassNotFound(self.name, name, self.classes_uri)
-        entity = YamlData.from_file(path).get_entity(name, settings)
+        entity = YamlData.from_file(path).get_entity(name, pathname, settings)
         return entity
 
     def enumerate_nodes(self):
diff --git a/reclass/storage/yaml_git/__init__.py b/reclass/storage/yaml_git/__init__.py
index 4c73998..d9c84cd 100644
--- a/reclass/storage/yaml_git/__init__.py
+++ b/reclass/storage/yaml_git/__init__.py
@@ -273,7 +273,9 @@
     def get_node(self, name, settings):
         file = self._nodes[name]
         blob = self._repos[self._nodes_uri.repo].get(file.id)
-        entity = YamlData.from_string(blob.data, 'git_fs://{0} {1} {2}'.format(self._nodes_uri.repo, self._nodes_uri.branch, file.path)).get_entity(name, settings)
+        uri = 'git_fs://{0} {1} {2}'.format(self._nodes_uri.repo, self._nodes_uri.branch, file.path)
+        pathname = os.path.splitext(file.path)[0]
+        entity = YamlData.from_string(blob.data, uri).get_entity(name, pathname, settings)
         return entity
 
     def get_class(self, name, environment, settings):
@@ -288,7 +290,9 @@
             raise reclass.errors.NotFoundError("File " + name + " missing from " + uri.repo + " branch " + uri.branch)
         file = self._repos[uri.repo].files[uri.branch][name]
         blob = self._repos[uri.repo].get(file.id)
-        entity = YamlData.from_string(blob.data, 'git_fs://{0} {1} {2}'.format(uri.repo, uri.branch, file.path)).get_entity(name, settings)
+        uri = 'git_fs://{0} {1} {2}'.format(uri.repo, uri.branch, file.path)
+        pathname = os.path.splitext(file.path)[0]
+        entity = YamlData.from_string(blob.data, uri).get_entity(name, pathname, settings)
         return entity
 
     def enumerate_nodes(self):
diff --git a/reclass/storage/yamldata.py b/reclass/storage/yamldata.py
index a38b589..f68d803 100644
--- a/reclass/storage/yamldata.py
+++ b/reclass/storage/yamldata.py
@@ -80,10 +80,7 @@
     def count_dots(self, value):
         return len(list(self.yield_dots(value)))
 
-    def get_entity(self, name, settings):
-        #if name is None:
-        #    name = self._uri
-
+    def get_entity(self, name, pathname, settings):
         classes = self._data.get('classes')
         if classes is None:
             classes = []
@@ -108,7 +105,7 @@
         env = self._data.get('environment', None)
 
         return datatypes.Entity(settings, classes=classes, applications=applications, parameters=parameters,
-                                exports=exports, name=name, environment=env, uri=self.uri)
+                                exports=exports, name=name, pathname=pathname, environment=env, uri=self.uri)
 
     def __str__(self):
         return '<{0} {1}, {2}>'.format(self.__class__.__name__, self._uri,
diff --git a/reclass/tests/data/04/classes/one.yml b/reclass/tests/data/04/classes/one.yml
new file mode 100644
index 0000000..37ee5e8
--- /dev/null
+++ b/reclass/tests/data/04/classes/one.yml
@@ -0,0 +1,2 @@
+parameters:
+  test1: 1
diff --git a/reclass/tests/data/04/classes/three.yml b/reclass/tests/data/04/classes/three.yml
new file mode 100644
index 0000000..f71f8ce
--- /dev/null
+++ b/reclass/tests/data/04/classes/three.yml
@@ -0,0 +1,2 @@
+parameters:
+  test3: 3
diff --git a/reclass/tests/data/04/classes/two.yml b/reclass/tests/data/04/classes/two.yml
new file mode 100644
index 0000000..80d5209
--- /dev/null
+++ b/reclass/tests/data/04/classes/two.yml
@@ -0,0 +1,2 @@
+parameters:
+  test2: 2
diff --git a/reclass/tests/data/04/nodes/alpha/node1.yml b/reclass/tests/data/04/nodes/alpha/node1.yml
new file mode 100644
index 0000000..f0f59f5
--- /dev/null
+++ b/reclass/tests/data/04/nodes/alpha/node1.yml
@@ -0,0 +1,2 @@
+classes:
+  - one
diff --git a/reclass/tests/test_core.py b/reclass/tests/test_core.py
index 4827177..c1e283d 100644
--- a/reclass/tests/test_core.py
+++ b/reclass/tests/test_core.py
@@ -23,13 +23,13 @@
 
 class TestCore(unittest.TestCase):
 
-    def _core(self, dataset, opts={}):
+    def _core(self, dataset, opts={}, class_mappings=[]):
         inventory_uri = os.path.dirname(os.path.abspath(__file__)) + '/data/' + dataset
         path_mangler = get_path_mangler('yaml_fs')
         nodes_uri, classes_uri = path_mangler(inventory_uri, 'nodes', 'classes')
         settings = Settings(opts)
         storage = get_storage('yaml_fs', nodes_uri, classes_uri, settings.compose_node_name)
-        return Core(storage, None, settings)
+        return Core(storage, class_mappings, settings)
 
     def test_type_conversion(self):
         reclass = self._core('01')
@@ -72,7 +72,7 @@
         self.assertEqual(node['parameters'], params)
 
     def test_compose_node_names(self):
-        reclass = self._core('03', {'compose_node_name': True})
+        reclass = self._core('03', opts={'compose_node_name': True})
         alpha_one_node = reclass.nodeinfo('alpha.one')
         alpha_one_res = {'a': 1, 'alpha': [1, 2], 'beta': {'a': 1, 'b': 2}, 'b': 2, '_reclass_': {'environment': 'base', 'name': {'full': 'alpha.one', 'short': 'alpha'}}}
         alpha_two_node = reclass.nodeinfo('alpha.two')
@@ -86,6 +86,18 @@
         self.assertEqual(beta_one_node['parameters'], beta_one_res)
         self.assertEqual(beta_two_node['parameters'], beta_two_res)
 
+    def test_class_mappings_match_path_false(self):
+        reclass = self._core('04', opts={'class_mappings_match_path': False}, class_mappings=['node*    two', 'alpha/node*    three'])
+        node = reclass.nodeinfo('node1')
+        params = { 'test1': 1, 'test2': 2, '_reclass_': {'environment': u'base', 'name': {'full': 'node1', 'short': 'node1'}}}
+        self.assertEqual(node['parameters'], params)
+
+    def test_class_mappings_match_path_true(self):
+        reclass = self._core('04', opts={'class_mappings_match_path': True}, class_mappings=['node*    two', 'alpha/node*    three'])
+        node = reclass.nodeinfo('node1')
+        params = { 'test1': 1, 'test3': 3, '_reclass_': {'environment': u'base', 'name': {'full': 'node1', 'short': 'node1'}}}
+        self.assertEqual(node['parameters'], params)
+
 
 if __name__ == '__main__':
     unittest.main()