You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

306 lines
12 KiB

import re
from copy import deepcopy
import toml
from .specification import ConfigSpecification, InvalidConfigError
class ConfigManager():
def __init__(self):
self.config_source = {} # source values
self.root_config = {} # validated config
self.confspecs = {}
self.frozen_config = {} # validated config to reapply each load
self._saved_state = {}
def save_fallback(self):
"""
Save the current state of the ConfigManager for future restoration with ``fallback()``.
Includes the loaded config source, the validated config, any added config specifications,
and any frozen values.
"""
self._saved_state["config_source"] = deepcopy(self.config_source)
self._saved_state["root_config"] = deepcopy(self.root_config)
self._saved_state["confspecs"] = deepcopy(self.confspecs)
self._saved_state["frozen_config"] = deepcopy(self.frozen_config)
def fallback(self):
"""
Restore the state of the ConfigManager to what is was when ``save_fallback()`` was last
called. Includes the loaded config source, the validated config, any added config
specifications, and any frozen values.
"""
if not all(k in self._saved_state for k in ("config_source", "root_config",
"confspecs", "frozen_config")):
raise Exception("Can't fallback ConfigManager without calling save_fallback() first!")
self.config_source = self._saved_state["config_source"]
self.root_config = self._saved_state["root_config"]
self.confspecs = self._saved_state["confspecs"]
self.frozen_config = self._saved_state["frozen_config"]
@staticmethod
def _get_source(source):
"""
Accept a filepath or opened file representing a TOML file, or a direct dict,
and return a plain parsed dict.
"""
if isinstance(source, dict): # load from dict
return source
if isinstance(source, str): # load from pathname
with open(source, 'r') as conf_file:
return toml.load(conf_file)
else: # load from file
return toml.load(source)
def load(self, source):
"""
Load a config source into the ConfigManager, replacing any existing config.
Args:
source: Either a dict config to load directly, a filepath to a TOML file,
or an open file.
"""
self.config_source = self._get_source(source)
self._overlay(self.frozen_config, self.config_source)
# New source, so wipe validated config
self.root_config = {}
def load_overlay(self, source):
"""
Load a config source into the ConfigManager, merging it over the top of any existing
config. Dicts will be recursively processed with keys being merged and existing values
being replaced by the new source. This includes lists, which will be treated as any other
value and completely replaced.
Args:
source: Either the root dict of a data structure to load directly, a filepath to a TOML
file, or an open TOML file.
"""
self._overlay(self._get_source(source), self.config_source)
self._overlay(self.frozen_config, self.config_source)
# New source, so wipe validated config
self.root_config = {}
def freeze_value(self, bundle_name, *field_names):
"""
Freeze the given config field so that subsequent calls to ``load`` and ``load_overlay``
cannot change it. Can only be used for dict values or dict values nested in parent dicts.
Args:
bundle_name: The name of the bundle to look for the field in.
*field_names: a series of strings that locate the config field, either a single
key or series of nested keys.
"""
# Bundle names are really no different from any other nested dict
names = (bundle_name,) + field_names
target_field = self.root_config
frozen_value = self.frozen_config
# Cycle through nested names, creating frozen_config nested dicts as necessary
for name in names[:-1]:
target_field = target_field[name]
if name not in frozen_value:
frozen_value[name] = {}
frozen_value = frozen_value[name]
frozen_value[names[-1]] = target_field[names[-1]]
def add_confspec(self, bundle_name, confspec):
"""
Stores a ConfigSpecification for future use when validating the corresponding config bundle
Args:
bundle_name (str) : The name to store the config specification under.
confspec (ConfigSpecification): The populated ConfigSpecification to store.
"""
self.confspecs[bundle_name] = confspec
def add_confspecs(self, confspecs):
"""
Stores multiple ConfigSpecifications at once for future use when validating the
corresponding config bundles
Args:
confspecs : A dict of populated ConfigSpecifications to store, using their keys as
names.
"""
self.confspecs.update(confspecs)
def list_missing_confspecs(self):
"""
Returns a list of config bundle names that do not have a corresponding ConfigSpecification
stored in the ConfigManager.
"""
return list(self.config_source.keys() - self.confspecs.keys())
def _overlay(self, src, dest):
for key in src:
# If the key is also in the dest and both are dicts, merge them.
if key in dest and isinstance(src[key], dict) and isinstance(dest[key], dict):
self._overlay(src[key], dest[key])
else:
# Otherwise it's either an existing value to be replaced or needs to be added.
dest[key] = src[key]
def validate_bundle(self, bundle_name, conf_spec=None):
"""
Validate the config bundle called ``bundle_name`` against the corresponding specification
stored in the ConfigManager. If ``conf_spec`` is supplied, it gets used instead.
Stores the resulting validated config bundle for later retrieval with
``get_config_bundle()``.
Note that as part of validation, optional keys that are missing will be filled in with
their default values (see ``DictSpec``).
Args:
bundle_name: (str) Name of the config dict to find.
conf_spec: (ConfigSpecification) Optional config specification to validate against.
Returns:
dict: The validated config bundle.
Raises:
InvalidConfigError: If the canfig source fails validation, or a matching config
specification can't be found.
"""
if not isinstance(conf_spec, ConfigSpecification):
if bundle_name not in self.confspecs:
raise InvalidConfigError(
"No ConfigSpecification supplied for bundle: " + bundle_name)
conf_spec = self.confspecs[bundle_name]
if bundle_name not in self.config_source:
raise InvalidConfigError("Config source must contain dict: " + bundle_name)
bundle_source = deepcopy(self.config_source[bundle_name])
try:
conf_spec.validate(bundle_source)
except InvalidConfigError as e:
e.args = ("Bundle: " + bundle_name,) + e.args
raise
self.root_config[bundle_name] = bundle_source
return self.root_config[bundle_name]
def validate_bundles(self, bundle_names=None):
"""
Validate multiple config bundles at once, validating each one with the corresponding
ConfigSpecification stored in the ConfigManager. See ``validate_bundle()``.
Args:
bundle_names: A list of config bundle names to get. If dictionary is supplied, uses
the values as ConfigSpecifications rather than looking up the ones stored in the
ConfigManager. If None, will use all bundle names in the config source.
Returns:
dict: A dict of the config bundles, with keys matching those passed in
``bundle_names``.
"""
if bundle_names is None:
bundle_names = self.get_bundle_names()
config_values = {}
if isinstance(bundle_names, dict):
for name, conf_spec in bundle_names.items():
config_values[name] = self.validate_bundle(name, conf_spec)
elif isinstance(bundle_names, str):
# Protection against a single name being passed anyway and otherwise
# being parsed as letters
config_values[bundle_names] = self.validate_bundle(bundle_names)
else:
for name in bundle_names:
config_values[name] = self.validate_bundle(name)
return config_values
def get_config_bundle(self, bundle_name):
"""
Get a validated config bundle called ``bundle_name``. If not yet validated, will validate
the config source against the corresponding config specification stored in the
ConfigManager (see ``validate_bundle()``).
Args:
bundle_name: (str) Name of the config bundle to find.
Returns:
dict: The validated config bundle.
"""
if bundle_name not in self.root_config:
return self.validate_bundle(bundle_name)
return self.root_config[bundle_name]
def get_config_bundles(self, bundle_names=None):
"""
Get multiple config bundles at once. If not yet validated, each will validate
their config source against the corresponding config specification stored in the
ConfigManager (see ``validate_bundle()``).
Args:
bundle_names: A list of config bundle names to get. If None, will use all bundle
names in the config source.
Returns:
dict: A dict of the validated config bundles, with keys matching those passed in
``bundle_names``.
"""
if bundle_names is None:
bundle_names = self.get_bundle_names()
config_values = {}
if isinstance(bundle_names, str):
# Protection against a single name being passed anyway and otherwise
# being parsed as letters
config_values[bundle_names] = self.get_config_bundle(bundle_names)
else:
for name in bundle_names:
config_values[name] = self.get_config_bundle(name)
return config_values
def get_bundle_names(self):
"""
Returns a list of config bundle names contained in the source.
Note that this may include bundles that have not been verified yet. Calling
`validate_bundles()` or `get_config_bundles()` first will make sure all config source
bundles are verified.
"""
return list(self.config_source.keys())
def dump_toml(self):
if self.root_config.keys() != self.config_source.keys():
raise Exception("Can't dump an unvalidated config table!")
return toml.dumps(self.root_config)
def dump_to_file(self, filepath, message=None):
with open(filepath, 'w+') as f:
content = self.dump_toml()
if message is not None:
content = content.rstrip() + gen_comment(message)
f.write(content)
def strip_toml_message(string):
return re.sub(r"(?m)^#\\ config-spec_message:[^\\n]*$\\n?(?:^#[^\\n]+$\\n?)*",
'', string)
def update_toml_message(filepath, message):
with open(filepath, 'r+') as f:
content = f.read()
content = strip_toml_message(content).rstrip()
content += gen_comment(message)
f.seek(0)
f.write(content)
f.truncate()
def gen_comment(string):
return '\n# config-spec_message: ' + '\n# '.join(string.replace('#', '').splitlines()) + '\n'