You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
307 lines
12 KiB
307 lines
12 KiB
import re
|
|
from copy import deepcopy
|
|
import toml
|
|
from .specification import ConfigSpecification, InvalidConfigError
|
|
|
|
|
|
class ConfigManager():
|
|
def __init__(self):
|
|
self.config_source = {} # source values
|
|
self.root_config = {} # validated config
|
|
self.confspecs = {}
|
|
self.frozen_config = {} # validated config to reapply each load
|
|
self._saved_state = {}
|
|
|
|
def save_fallback(self):
|
|
"""
|
|
Save the current state of the ConfigManager for future restoration with ``fallback()``.
|
|
Includes the loaded config source, the validated config, any added config specifications,
|
|
and any frozen values.
|
|
"""
|
|
self._saved_state["config_source"] = deepcopy(self.config_source)
|
|
self._saved_state["root_config"] = deepcopy(self.root_config)
|
|
self._saved_state["confspecs"] = deepcopy(self.confspecs)
|
|
self._saved_state["frozen_config"] = deepcopy(self.frozen_config)
|
|
|
|
def fallback(self):
|
|
"""
|
|
Restore the state of the ConfigManager to what is was when ``save_fallback()`` was last
|
|
called. Includes the loaded config source, the validated config, any added config
|
|
specifications, and any frozen values.
|
|
"""
|
|
if not all(k in self._saved_state for k in ("config_source", "root_config",
|
|
"confspecs", "frozen_config")):
|
|
raise Exception("Can't fallback ConfigManager without calling save_fallback() first!")
|
|
self.config_source = self._saved_state["config_source"]
|
|
self.root_config = self._saved_state["root_config"]
|
|
self.confspecs = self._saved_state["confspecs"]
|
|
self.frozen_config = self._saved_state["frozen_config"]
|
|
|
|
@staticmethod
|
|
def _get_source(source):
|
|
"""
|
|
Accept a filepath or opened file representing a TOML file, or a direct dict,
|
|
and return a plain parsed dict.
|
|
"""
|
|
if isinstance(source, dict): # load from dict
|
|
return source
|
|
|
|
if isinstance(source, str): # load from pathname
|
|
with open(source, 'r') as conf_file:
|
|
return toml.load(conf_file)
|
|
|
|
else: # load from file
|
|
return toml.load(source)
|
|
|
|
def load(self, source):
|
|
"""
|
|
Load a config source into the ConfigManager, replacing any existing config.
|
|
|
|
Args:
|
|
source: Either a dict config to load directly, a filepath to a TOML file,
|
|
or an open file.
|
|
"""
|
|
self.config_source = self._get_source(source)
|
|
self._overlay(self.frozen_config, self.config_source)
|
|
# New source, so wipe validated config
|
|
self.root_config = {}
|
|
|
|
def load_overlay(self, source):
|
|
"""
|
|
Load a config source into the ConfigManager, merging it over the top of any existing
|
|
config. Dicts will be recursively processed with keys being merged and existing values
|
|
being replaced by the new source. This includes lists, which will be treated as any other
|
|
value and completely replaced.
|
|
|
|
Args:
|
|
source: Either the root dict of a data structure to load directly, a filepath to a TOML
|
|
file, or an open TOML file.
|
|
"""
|
|
self._overlay(self._get_source(source), self.config_source)
|
|
self._overlay(self.frozen_config, self.config_source)
|
|
# New source, so wipe validated config
|
|
self.root_config = {}
|
|
|
|
def freeze_value(self, bundle_name, *field_names):
|
|
"""
|
|
Freeze the given config field so that subsequent calls to ``load`` and ``load_overlay``
|
|
cannot change it. Can only be used for dict values or dict values nested in parent dicts.
|
|
|
|
Args:
|
|
bundle_name: The name of the bundle to look for the field in.
|
|
*field_names: a series of strings that locate the config field, either a single
|
|
key or series of nested keys.
|
|
"""
|
|
|
|
# Bundle names are really no different from any other nested dict
|
|
names = (bundle_name,) + field_names
|
|
|
|
target_field = self.root_config
|
|
frozen_value = self.frozen_config
|
|
|
|
# Cycle through nested names, creating frozen_config nested dicts as necessary
|
|
for name in names[:-1]:
|
|
target_field = target_field[name]
|
|
if name not in frozen_value:
|
|
frozen_value[name] = {}
|
|
frozen_value = frozen_value[name]
|
|
|
|
frozen_value[names[-1]] = target_field[names[-1]]
|
|
|
|
def add_confspec(self, bundle_name, confspec):
|
|
"""
|
|
Stores a ConfigSpecification for future use when validating the corresponding config bundle
|
|
|
|
Args:
|
|
bundle_name (str) : The name to store the config specification under.
|
|
confspec (ConfigSpecification): The populated ConfigSpecification to store.
|
|
"""
|
|
self.confspecs[bundle_name] = confspec
|
|
|
|
def add_confspecs(self, confspecs):
|
|
"""
|
|
Stores multiple ConfigSpecifications at once for future use when validating the
|
|
corresponding config bundles
|
|
|
|
Args:
|
|
confspecs : A dict of populated ConfigSpecifications to store, using their keys as
|
|
names.
|
|
"""
|
|
self.confspecs.update(confspecs)
|
|
|
|
def list_missing_confspecs(self):
|
|
"""
|
|
Returns a list of config bundle names that do not have a corresponding ConfigSpecification
|
|
stored in the ConfigManager.
|
|
"""
|
|
return list(self.config_source.keys() - self.confspecs.keys())
|
|
|
|
def _overlay(self, src, dest):
|
|
for key in src:
|
|
# If the key is also in the dest and both are dicts, merge them.
|
|
if key in dest and isinstance(src[key], dict) and isinstance(dest[key], dict):
|
|
self._overlay(src[key], dest[key])
|
|
else:
|
|
# Otherwise it's either an existing value to be replaced or needs to be added.
|
|
dest[key] = src[key]
|
|
|
|
def validate_bundle(self, bundle_name, conf_spec=None):
|
|
"""
|
|
Validate the config bundle called ``bundle_name`` against the corresponding specification
|
|
stored in the ConfigManager. If ``conf_spec`` is supplied, it gets used instead.
|
|
|
|
Stores the resulting validated config bundle for later retrieval with
|
|
``get_config_bundle()``.
|
|
|
|
Note that as part of validation, optional keys that are missing will be filled in with
|
|
their default values (see ``DictSpec``).
|
|
|
|
Args:
|
|
bundle_name: (str) Name of the config dict to find.
|
|
conf_spec: (ConfigSpecification) Optional config specification to validate against.
|
|
|
|
Returns:
|
|
dict: The validated config bundle.
|
|
|
|
Raises:
|
|
InvalidConfigError: If the canfig source fails validation, or a matching config
|
|
specification can't be found.
|
|
"""
|
|
if not isinstance(conf_spec, ConfigSpecification):
|
|
if bundle_name not in self.confspecs:
|
|
raise InvalidConfigError(
|
|
"No ConfigSpecification supplied for bundle: " + bundle_name)
|
|
conf_spec = self.confspecs[bundle_name]
|
|
|
|
if bundle_name not in self.config_source:
|
|
raise InvalidConfigError("Config source must contain dict: " + bundle_name)
|
|
|
|
bundle_source = deepcopy(self.config_source[bundle_name])
|
|
try:
|
|
conf_spec.validate(bundle_source)
|
|
except InvalidConfigError as e:
|
|
e.args = ("Bundle: " + bundle_name,) + e.args
|
|
raise
|
|
|
|
self.root_config[bundle_name] = bundle_source
|
|
return self.root_config[bundle_name]
|
|
|
|
def validate_bundles(self, bundle_names=None):
|
|
"""
|
|
Validate multiple config bundles at once, validating each one with the corresponding
|
|
ConfigSpecification stored in the ConfigManager. See ``validate_bundle()``.
|
|
|
|
Args:
|
|
bundle_names: A list of config bundle names to get. If dictionary is supplied, uses
|
|
the values as ConfigSpecifications rather than looking up the ones stored in the
|
|
ConfigManager. If None, will use all bundle names in the config source.
|
|
|
|
Returns:
|
|
dict: A dict of the config bundles, with keys matching those passed in
|
|
``bundle_names``.
|
|
"""
|
|
if bundle_names is None:
|
|
bundle_names = self.get_bundle_names()
|
|
|
|
config_values = {}
|
|
if isinstance(bundle_names, dict):
|
|
for name, conf_spec in bundle_names.items():
|
|
config_values[name] = self.validate_bundle(name, conf_spec)
|
|
elif isinstance(bundle_names, str):
|
|
# Protection against a single name being passed anyway and otherwise
|
|
# being parsed as letters
|
|
config_values[bundle_names] = self.validate_bundle(bundle_names)
|
|
else:
|
|
for name in bundle_names:
|
|
config_values[name] = self.validate_bundle(name)
|
|
|
|
return config_values
|
|
|
|
def get_config_bundle(self, bundle_name):
|
|
"""
|
|
Get a validated config bundle called ``bundle_name``. If not yet validated, will validate
|
|
the config source against the corresponding config specification stored in the
|
|
ConfigManager (see ``validate_bundle()``).
|
|
|
|
Args:
|
|
bundle_name: (str) Name of the config bundle to find.
|
|
|
|
Returns:
|
|
dict: The validated config bundle.
|
|
"""
|
|
|
|
if bundle_name not in self.root_config:
|
|
return self.validate_bundle(bundle_name)
|
|
|
|
return self.root_config[bundle_name]
|
|
|
|
def get_config_bundles(self, bundle_names=None):
|
|
"""
|
|
Get multiple config bundles at once. If not yet validated, each will validate
|
|
their config source against the corresponding config specification stored in the
|
|
ConfigManager (see ``validate_bundle()``).
|
|
|
|
Args:
|
|
bundle_names: A list of config bundle names to get. If None, will use all bundle
|
|
names in the config source.
|
|
|
|
Returns:
|
|
dict: A dict of the validated config bundles, with keys matching those passed in
|
|
``bundle_names``.
|
|
"""
|
|
if bundle_names is None:
|
|
bundle_names = self.get_bundle_names()
|
|
|
|
config_values = {}
|
|
if isinstance(bundle_names, str):
|
|
# Protection against a single name being passed anyway and otherwise
|
|
# being parsed as letters
|
|
config_values[bundle_names] = self.get_config_bundle(bundle_names)
|
|
else:
|
|
for name in bundle_names:
|
|
config_values[name] = self.get_config_bundle(name)
|
|
|
|
return config_values
|
|
|
|
def get_bundle_names(self):
|
|
"""
|
|
Returns a list of config bundle names contained in the source.
|
|
|
|
Note that this may include bundles that have not been verified yet. Calling
|
|
`validate_bundles()` or `get_config_bundles()` first will make sure all config source
|
|
bundles are verified.
|
|
"""
|
|
return list(self.config_source.keys())
|
|
|
|
def dump_toml(self):
|
|
if self.root_config.keys() != self.config_source.keys():
|
|
raise Exception("Can't dump an unvalidated config table!")
|
|
return toml.dumps(self.root_config)
|
|
|
|
def dump_to_file(self, filepath, message=None):
|
|
with open(filepath, 'w+') as f:
|
|
content = self.dump_toml()
|
|
if message is not None:
|
|
content = content.rstrip() + gen_comment(message)
|
|
f.write(content)
|
|
|
|
|
|
def strip_toml_message(string):
|
|
print("stripping...")
|
|
return re.sub("(?m)^#\\ shepherd_message:[^\\n]*$\\n?(?:^#[^\\n]+$\\n?)*",
|
|
'', string)
|
|
|
|
|
|
def update_toml_message(filepath, message):
|
|
with open(filepath, 'r+') as f:
|
|
content = f.read()
|
|
content = strip_toml_message(content).rstrip()
|
|
content += gen_comment(message)
|
|
f.seek(0)
|
|
f.write(content)
|
|
f.truncate()
|
|
|
|
|
|
def gen_comment(string):
|
|
return '\n# config-spec_message: ' + '\n# '.join(string.replace('#', '').splitlines()) + '\n'
|