diff --git a/shepherd/config.py b/shepherd/config.py index e8d152c..ac732fe 100644 --- a/shepherd/config.py +++ b/shepherd/config.py @@ -176,14 +176,15 @@ class ConfDefinition(TableDef): class ConfigManager(): def __init__(self): self.root_config = {} + self.confdefs = {} def load(self, source): """ - Load a config source into the ConfigManager. + Load a config source into the ConfigManager, replacing any existing config. Args: source: Either a dict config to load directly, a filepath to a TOML file, - or a open file. + or an open file. """ if isinstance(source, dict): # load from dict self.root_config = source @@ -193,54 +194,117 @@ class ConfigManager(): else: # load from file self.root_config = toml.load(source) - def get_config(self, table_name, conf_def): + def load_overlay(self, source): + """ + Load a config source into the ConfigManager, merging it over the top of any existing + config. Dicts will be recursively processed with keys being merged and existing values + being replaced by the new source. This includes lists, which will be treated as any other + value and completely replaced. + + Args: + source: Either a dict config to load directly, a filepath to a TOML file, + or an open file. + """ + + if isinstance(source, dict): # load from dict + new_source = source + elif isinstance(source, str): # load from pathname + with open(source, 'r') as conf_file: + new_source = toml.load(conf_file) + else: # load from file + new_source = toml.load(source) + + self._overlay(new_source, self.root_config) + + def add_confdef(self, name, confdef): + """ + Stores a ConfigDefinition for future use when validating the corresponding config table + + Args: + name (str) : Then name to store the config definition under. + confdef (ConfigDefinition): The populated ConfigDefinition to store. + """ + self.confdefs[name]=confdef + + def add_confdefs(self, confdefs): + """ + Stores multiple ConfigDefinitions at once for future use when validating the corresponding config tables + + Args: + confdefs : A dict of populated ConfigDefinitions to store, using their keys as names. + """ + self.confdefs.update(confdefs) + + def get_missing_confdefs(self): + """ + Returns a list of config table names that do not have a corresponding ConfigDefinition + stored in the ConfigManager. + """ + return list(self.root_config.keys() - self.confdefs.keys()) + + + def _overlay(self, src, dest): + for key in src: + # If the key is also in the dest and both are dicts, merge them. + if key in dest and isinstance(src[key], dict) and isinstance(dest[key], dict): + self._overlay(src[key], dest[key]) + else: + # Otherwise it's either an existing value to be replaced or needs to be added. + dest[key] = src[key] + + def validate_and_get_config(self, config_name, conf_def=None): """ Get a config dict called ``table_name`` and validate - it against ``conf_def`` before returning it. + it against the corresponding config definition stored in the ConfigManager. + If ``conf_def`` is supplied, it gets used instead. Returns a validated + config dictionary. Note that as part of validation, optional keys that are missing will be filled in with their default values (see ``TableDef``). Args: table_name: (str) Name of the config dict to find. - conf_def: (ConfDefinition) Config definition to validate against. + conf_def: (ConfDefinition) Optional config definition to validate against. """ if not isinstance(conf_def, ConfDefinition): - raise TypeError("Supplied config definition must be an instance " - "of ConfDefinition") - if table_name not in self.root_config: + conf_def = self.confdefs[config_name] + + if config_name not in self.root_config: raise InvalidConfigError( - "Config must contain table: " + table_name) + "Config must contain table: " + config_name) try: - conf_def.validate(self.root_config[table_name]) + conf_def.validate(self.root_config[config_name]) except InvalidConfigError as e: - e.args = ("Module: " + table_name,) + e.args + e.args = ("Module: " + config_name,) + e.args raise - return self.root_config[table_name] + return self.root_config[config_name] - def get_configs(self, conf_defs): + def validate_and_get_configs(self, config_names): """ - Get multiple configs at once, validating each one. + Get multiple configs from the root table at once, validating each one with the + corresponding confdef stored in the ConfigManager. Args: - conf_defs: (dict) A dictionary of ConfigDefinitions. The keys are used - as the name to find each config dict, which is then validated against - the corresponding conf def. + conf_defs: A list of config names to get. If dictionary is supplied, uses the values + as ConfigDefinitions rather than looking up a stored one in the ConfigManager. + Returns: - A dict of config dicts, with keys matching those passed in ``conf_defs``. + A dict of config dicts, with keys matching those passed in ``config_names``. """ config_values = {} - for name, conf_def in conf_defs.items(): - config_values[name] = self.get_config(name, conf_def) + if isinstance(config_names, dict): + for name, conf_def in config_names.items(): + config_values[name] = self.validate_and_get_config(name, conf_def) + else: + for name in config_names: + config_values[name] = self.validate_and_get_config(name) return config_values - def get_plugin_configs(self, plugin_classes): - config_values = {} - for plugin_name, plugin_class in plugin_classes.items(): - conf_def = ConfDefinition() - plugin_class.define_config(conf_def) - config_values[plugin_name] = self.get_config(plugin_name, conf_def) - return config_values + def get_config_names(self): + """ + Returns a list of names of top level config tables + """ + return list(self.root_config.keys()) def dump_toml(self): return toml.dumps(self.root_config)