diff --git a/shepherd/config.py b/shepherd/config.py index 6c7a8bc..241d8ed 100644 --- a/shepherd/config.py +++ b/shepherd/config.py @@ -1,6 +1,10 @@ import re import toml +from shepherd.freezedry import freezedryable, rehydrate + + + class InvalidConfigError(Exception): pass @@ -31,28 +35,37 @@ class InvalidConfigError(Exception): # config def required interface: # Validate values. + + +# The Table and Array terms used here are directly from the TOML convention, but they essentially +# map directly to Dictionaries (Tables), and Lists (Arrays) + + + class _ConfigDefinition(): - def __init__(self, default=None, optional=False): + def __init__(self, default=None, optional=False, helptext=""): self.default = default self.optional = optional + self.helptext = helptext - def validate(self, value): # pylint: disable=W0613 + def validate(self, value): # pylint: disable=W0613 raise TypeError("_ConfigDefinition.validate() is an abstract method") + - +@freezedryable class BoolDef(_ConfigDefinition): - def __init__(self, default=None, optional=False): # pylint: disable=W0235 - super().__init__(default, optional) + def __init__(self, default=None, optional=False, helptext=""): # pylint: disable=W0235 + super().__init__(default, optional, helptext) def validate(self, value): if not isinstance(value, bool): raise InvalidConfigError("Config value must be a boolean") - +@freezedryable class IntDef(_ConfigDefinition): def __init__(self, default=None, minval=None, maxval=None, - optional=False): - super().__init__(default, optional) + optional=False, helptext=""): + super().__init__(default, optional, helptext) self.minval = minval self.maxval = maxval @@ -66,11 +79,11 @@ class IntDef(_ConfigDefinition): raise InvalidConfigError("Config value must be <= " + str(self.maxval)) - +@freezedryable class StringDef(_ConfigDefinition): def __init__(self, default=None, minlength=None, maxlength=None, - optional=False): - super().__init__(default, optional) + optional=False, helptext=""): + super().__init__(default, optional, helptext) self.minlength = minlength self.maxlength = maxlength @@ -84,10 +97,10 @@ class StringDef(_ConfigDefinition): raise InvalidConfigError("Config string length must be <= " + str(self.maxlength)) - +@freezedryable class TableDef(_ConfigDefinition): - def __init__(self, default=None, optional=False): - super().__init__(default, optional) + def __init__(self, default=None, optional=False, helptext=""): + super().__init__(default, optional, helptext) self.def_table = {} def add_def(self, name, newdef): @@ -99,7 +112,7 @@ class TableDef(_ConfigDefinition): self.def_table[name] = newdef return newdef - def validate(self, value_table): # pylint: disable=W0221 + def validate(self, value_table): # pylint: disable=W0221 def_set = set(self.def_table.keys()) value_set = set(value_table.keys()) @@ -133,23 +146,23 @@ class _ArrayDefMixin(): e.args = ("Array index: " + str(index),) + e.args raise - +@freezedryable class BoolArrayDef(_ArrayDefMixin, BoolDef): pass - +@freezedryable class IntArrayDef(_ArrayDefMixin, IntDef): pass - +@freezedryable class StringArrayDef(_ArrayDefMixin, StringDef): pass - +@freezedryable class TableArrayDef(_ArrayDefMixin, TableDef): pass - +@freezedryable class ConfDefinition(TableDef): pass @@ -172,7 +185,8 @@ class ConfigManager(): raise TypeError("Supplied config definition must be an instance " "of ConfDefinition") if table_name not in self.root_config: - raise InvalidConfigError("Config must contain table: " + table_name) + raise InvalidConfigError( + "Config must contain table: " + table_name) try: conf_def.validate(self.root_config[table_name]) except InvalidConfigError as e: @@ -186,10 +200,12 @@ class ConfigManager(): config_values[name] = self.get_config(name, conf_def) return config_values - def get_module_configs(self, modules): + def get_plugin_configs(self, plugin_classes): config_values = {} - for name, module in modules.items(): - config_values[name] = self.get_config(name, module.conf_def) + for plugin_name, plugin_class in plugin_classes.items(): + conf_def = ConfDefinition() + plugin_class.define_config(conf_def) + config_values[plugin_name] = self.get_config(plugin_name, conf_def) return config_values def dump_toml(self): @@ -221,3 +237,4 @@ def update_toml_message(filepath, message): def gen_comment(string): return '\n# shepherd_message: ' + '\n# '.join(string.replace('#', '').splitlines()) + '\n' +