import re import toml from copy import deepcopy from .specification import ConfigSpecification, InvalidConfigError class ConfigManager(): def __init__(self): self.root_config = {} self.confspecs = {} self.frozen_config = {} @staticmethod def _load_source(source): """ Accept a filepath or opened file representing a TOML file, or a direct dict, and return a plain parsed dict. """ if isinstance(source, dict): # load from dict return source elif isinstance(source, str): # load from pathname with open(source, 'r') as conf_file: return toml.load(conf_file) else: # load from file return toml.load(source) def load(self, source): """ Load a config source into the ConfigManager, replacing any existing config. Args: source: Either a dict config to load directly, a filepath to a TOML file, or an open file. """ self.root_config = self._load_source(source) self._overlay(self.frozen_config, self.root_config) def load_overlay(self, source): """ Load a config source into the ConfigManager, merging it over the top of any existing config. Dicts will be recursively processed with keys being merged and existing values being replaced by the new source. This includes lists, which will be treated as any other value and completely replaced. Args: source: Either the root dict of a data structure to load directly, a filepath to a TOML file, or an open TOML file. """ self._overlay(self._load_source(source), self.root_config) self._overlay(self.frozen_config, self.root_config) def freeze_value(self, bundle_name, *field_names): """ Freeze the given config field so that subsequent calls to ``load`` and ``load_overlay`` cannot change it. Can only be used for dict values or dict values nested in parent dicts. Args: bundle_name: The name of the bundle to look for the field in. *field_names: a series of strings that locate the config field, either a single key or series of nested keys. """ # Bundle names are really no different from any other nested dict names = (bundle_name,) + field_names target_field = self.root_config frozen_value = self.frozen_config # Cycle through nested names, creating frozen_config nested dicts as necessary for name in names[:-1]: target_field = target_field[name] if name not in frozen_value: frozen_value[name] = {} frozen_value = frozen_value[name] frozen_value[names[-1]] = target_field[names[-1]] def add_confspec(self, bundle_name, confspec): """ Stores a ConfigSpecification for future use when validating the corresponding config bundle Args: bundle_name (str) : The name to store the config specification under. confspec (ConfigSpecification): The populated ConfigSpecification to store. """ self.confspecs[bundle_name] = confspec def add_confspecs(self, confspecs): """ Stores multiple ConfigSpecifications at once for future use when validating the corresponding config bundles Args: confspecs : A dict of populated ConfigSpecifications to store, using their keys as names. """ self.confspecs.update(confspecs) def list_missing_confspecs(self): """ Returns a list of config bundle names that do not have a corresponding ConfigSpecification stored in the ConfigManager. """ return list(self.root_config.keys() - self.confspecs.keys()) def _overlay(self, src, dest): for key in src: # If the key is also in the dest and both are dicts, merge them. if key in dest and isinstance(src[key], dict) and isinstance(dest[key], dict): self._overlay(src[key], dest[key]) else: # Otherwise it's either an existing value to be replaced or needs to be added. dest[key] = src[key] def get_config_bundle(self, bundle_name, conf_spec=None): """ Get a config bundle called ``bundle_name`` and validate it against the corresponding config specification stored in the ConfigManager. If ``conf_spec`` is supplied, it gets used instead. Returns a copy of the validated config bundle dict. Note that as part of validation, optional keys that are missing will be filled in with their default values (see ``DictSpec``). This function will copy the config bundle *after* validation, and so config loaded in the ConfManager will be modified, but future ConfigManager manipulations won't change the returned config bundle. Args: config_name: (str) Name of the config dict to find. conf_spec: (ConfigSpecification) Optional config specification to validate against. """ if not isinstance(conf_spec, ConfigSpecification): conf_spec = self.confspecs[bundle_name] if bundle_name not in self.root_config: raise InvalidConfigError( "Config must contain dict: " + bundle_name) try: conf_spec.validate(self.root_config[bundle_name]) except InvalidConfigError as e: e.args = ("Bundle: " + bundle_name,) + e.args raise return deepcopy(self.root_config[bundle_name]) def get_config_bundles(self, bundle_names): """ Get multiple config bundles from the root dict at once, validating each one with the corresponding ConfigSpecification stored in the ConfigManager. See ``get_config_bundle`` Args: bundle_names: A list of config bundle names to get. If dictionary is supplied, uses the values as ConfigSpecifications rather than looking up a stored one in the ConfigManager. Returns: A dict of config dicts, with keys matching those passed in ``bundle_names``. """ config_values = {} if isinstance(bundle_names, dict): for name, conf_spec in bundle_names.items(): config_values[name] = self.get_config_bundle(name, conf_spec) else: for name in bundle_names: config_values[name] = self.get_config_bundle(name) return config_values def get_bundle_names(self): """ Returns a list of names of top level config bundles """ return list(self.root_config.keys()) def dump_toml(self): return toml.dumps(self.root_config) def dump_to_file(self, filepath, message=None): with open(filepath, 'w+') as f: content = self.dump_toml() if message is not None: content = content.rstrip() + gen_comment(message) f.write(content) def strip_toml_message(string): print("stripping...") return re.sub("(?m)^#\\ shepherd_message:[^\\n]*$\\n?(?:^#[^\\n]+$\\n?)*", '', string) def update_toml_message(filepath, message): with open(filepath, 'r+') as f: content = f.read() content = strip_toml_message(content).rstrip() content += gen_comment(message) f.seek(0) f.write(content) f.truncate() def gen_comment(string): return '\n# config-spec_message: ' + '\n# '.join(string.replace('#', '').splitlines()) + '\n'