From 7250882a65ac312f2c1dc281bc5a6fb9ac8532bd Mon Sep 17 00:00:00 2001 From: jelleas Date: Fri, 17 Aug 2018 00:21:50 -0400 Subject: [PATCH] Loader --- lib50/config.py | 152 ++++++++++++++++++++++++++++++++++++++++++ tests/config_tests.py | 105 ++++++++++++++++++++++++----- 2 files changed, 240 insertions(+), 17 deletions(-) diff --git a/lib50/config.py b/lib50/config.py index ab8406e..54c4aee 100644 --- a/lib50/config.py +++ b/lib50/config.py @@ -1,5 +1,6 @@ import enum import yaml +import collections from . import errors from . import _ @@ -9,6 +10,157 @@ from yaml import SafeLoader +class TaggedValue: + def __init__(self, value, tag, *tags): + for t in tags: + setattr(self, t[1:], False) + setattr(self, tag[1:], True) + self.tag = tag + self.tags = set(tags) + self.value = value + + def __repr__(self): + return f"TaggedValue(tag={self.tag}, tags={self.tags})" + + +class Loader: + def __init__(self, tool, *global_tags, default=None): + self._global_tags = self._ensure_exclamation(global_tags) + self._global_default = default if not default or default.startswith("!") else "!" + default + self._scopes = collections.defaultdict(list) + self.tool = tool + + def scope(self, key, *tags, default=None): + """Only apply tags and default for top-level key, effectively scoping the tags.""" + scope = self._scopes[key] + tags = self._ensure_exclamation(tags) + default = default if not default or default.startswith("!") else "!" + default + + if scope: + scope[0] = scope[0] + tags + scope[1] = default if default else scope[1] + else: + scope.append(tags) + scope.append(default) + + def load(self, content): + """Parse yaml content.""" + # Try parsing the YAML with global tags + try: + config = yaml.load(content, Loader=self._loader(self._global_tags)) + except yaml.YAMLError: + raise errors.InvalidConfigError(_("Config is not valid yaml.")) + + # Try extracting just the tool portion + try: + config = config[self.tool] + except (TypeError, KeyError): + return None + + # If no scopes, just apply global default + if not isinstance(config, dict): + config = self._apply_default(config, self._global_default) + else: + # Figure out what scopes exist + scoped_keys = set(key for key in self._scopes) + + # For every scope + for key in config: + # If scope has custom tags, apply + if key in scoped_keys: + # local tags, and local default + tags, default = self._scopes[key] + + # Inherit global default if no local default + if not default: + default = self._global_default + + config[key] = self._apply_default(config[key], default) + self._apply_scope(config[key], tags) + # Otherwise just apply global default + else: + config[key] = self._apply_default(config[key], self._global_default) + + self._validate(config) + + return config + + def _loader(self, tags): + """Create a yaml Loader.""" + class ConfigLoader(SafeLoader): + pass + ConfigLoader.add_multi_constructor("", lambda loader, prefix, node: TaggedValue(node.value, node.tag, *tags)) + return ConfigLoader + + def _validate(self, config): + """Check whether every TaggedValue has a valid tag, otherwise raise InvalidConfigError""" + if isinstance(config, dict): + # Recursively validate each item in the config + for val in config.values(): + self._validate(val) + + elif isinstance(config, list): + # Recursively validate each item in the config + for item in config: + self._validate(item) + + elif isinstance(config, TaggedValue): + tagged_value = config + + # if tagged_value is invalid, error + if tagged_value.tag not in tagged_value.tags: + raise errors.InvalidConfigError(_("{} is not a valid tag for {}".format(tagged_value.tag, self.tool))) + + def _apply_default(self, config, default): + """ + Apply default value to every str in config. + Also ensure every TaggedValue has default in .tags + """ + # No default, nothing to be done here + if not default: + return config + + # If the entire config is just a string, return default TaggedValue + if isinstance(config, str): + return TaggedValue(config, default, default, *self._global_tags) + + if isinstance(config, dict): + # Recursively apply defaults for each item in the config + for key, val in config.items(): + config[key] = self._apply_default(val, default) + + elif isinstance(config, list): + # Recursively apply defaults for each item in the config + for i, val in enumerate(config): + config[i] = self._apply_default(val, default) + + elif isinstance(config, TaggedValue): + # Make sure each TaggedValue knows about the default tag + config.tags.add(default) + + return config + + def _apply_scope(self, config, tags): + """Add locally scoped tags to config""" + if isinstance(config, dict): + # Recursively _apply_scope for each item in the config + for val in config.values(): + self._apply_scope(val, tags) + + elif isinstance(config, list): + # Recursively _apply_scope for each item in the config + for item in config: + self._apply_scope(item, tags) + + elif isinstance(config, TaggedValue): + # add all local tags + config.tags |= set(tags) + + @staticmethod + def _ensure_exclamation(tags): + return [tag if tag.startswith("!") else "!" + tag for tag in tags] + + class InvalidTag: """Class representing unrecognized tags""" def __init__(self, loader, prefix, node): diff --git a/tests/config_tests.py b/tests/config_tests.py index 575127d..bac6673 100644 --- a/tests/config_tests.py +++ b/tests/config_tests.py @@ -1,20 +1,22 @@ import unittest +import sys +import lib50.errors import lib50.config -class TestLoad(unittest.TestCase): +class TestLoader(unittest.TestCase): def test_no_tool(self): content = "" - config = lib50.config.load(content, "check50") + config = lib50.config.Loader("check50").load(content) self.assertEqual(config, None) def test_falsy_tool(self): content = "check50: false" - config = lib50.config.load(content, "check50") + config = lib50.config.Loader("check50").load(content) self.assertFalse(config) def test_truthy_tool(self): content = "check50: true" - config = lib50.config.load(content, "check50") + config = lib50.config.Loader("check50").load(content) self.assertTrue(config) def test_no_files(self): @@ -22,32 +24,101 @@ def test_no_files(self): "check50:\n" \ " dependencies:\n" \ " - foo" - config = lib50.config.load(content, "check50") + config = lib50.config.Loader("check50").load(content) self.assertEqual(config, {"dependencies" : ["foo"]}) - def test_include_file(self): + def test_global_tag(self): + content = \ + "check50:\n" \ + " foo:\n" \ + " - !include baz\n" \ + " bar:\n" \ + " - !include qux" + config = lib50.config.Loader("check50", "include").load(content) + self.assertTrue(config["foo"][0].include) + self.assertEqual(config["foo"][0].value, "baz") + self.assertTrue(config["bar"][0].include) + self.assertEqual(config["bar"][0].value, "qux") + + def test_local_tag(self): content = \ "check50:\n" \ " files:\n" \ " - !include foo" - config = lib50.config.load(content, "check50") - self.assertTrue(config["files"][0].type == lib50.config.PatternType.Included) + loader = lib50.config.Loader("check50") + loader.scope("files", "include") + config = loader.load(content) + self.assertTrue(config["files"][0].include) + self.assertEqual(config["files"][0].value, "foo") - def test_exclude_file(self): + content = \ + "check50:\n" \ + " bar:\n" \ + " - !include foo" + loader = lib50.config.Loader("check50") + loader.scope("files", "include", default=False) + with self.assertRaises(lib50.errors.InvalidConfigError): + config = loader.load(content) + + def test_no_default(self): content = \ "check50:\n" \ " files:\n" \ - " - !exclude foo" - config = lib50.config.load(content, "check50") - self.assertTrue(config["files"][0].type == lib50.config.PatternType.Excluded) + " - !INVALID foo" + loader = lib50.config.Loader("check50") + loader.scope("files", "include", default=False) + with self.assertRaises(lib50.errors.InvalidConfigError): + config = loader.load(content) - def test_require_file(self): + def test_local_default(self): content = \ "check50:\n" \ " files:\n" \ - " - !require foo" - config = lib50.config.load(content, "check50") - self.assertTrue(config["files"][0].type == lib50.config.PatternType.Required) + " - foo" + loader = lib50.config.Loader("check50") + loader.scope("files", default="bar") + config = loader.load(content) + self.assertTrue(config["files"][0].bar) + self.assertEqual(config["files"][0].value, "foo") + + def test_global_default(self): + content = \ + "check50:\n" \ + " files:\n" \ + " - foo" + config = lib50.config.Loader("check50", default="bar").load(content) + self.assertTrue(config["files"][0].bar) + self.assertEqual(config["files"][0].value, "foo") + + def test_multiple_defaults(self): + content = \ + "check50:\n" \ + " foo:\n" \ + " - baz\n" \ + " bar:\n" \ + " - qux" + loader = lib50.config.Loader("check50", default="include") + loader.scope("bar", default="exclude") + config = loader.load(content) + self.assertTrue(config["foo"][0].include) + self.assertEqual(config["foo"][0].value, "baz") + self.assertTrue(config["bar"][0].exclude) + self.assertEqual(config["bar"][0].value, "qux") + + def test_same_tag_default(self): + content = \ + "check50:\n" \ + " foo:\n" \ + " - !include bar\n" \ + " - baz" + config = lib50.config.Loader("check50", "include", default="include").load(content) + self.assertTrue(config["foo"][0].include) + self.assertEqual(config["foo"][0].value, "bar") + self.assertTrue(config["foo"][1].include) + self.assertEqual(config["foo"][1].value, "baz") + + if __name__ == '__main__': - unittest.main() + suite = unittest.TestLoader().loadTestsFromModule(sys.modules[__name__]) + unittest.TextTestRunner(verbosity=2).run(suite)