You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
shepherd-agent/shepherd/config.py

372 lines
13 KiB

"""
Configuration management module. Enables configuration to be validated against
requirement definitions before being loaded and used.
Compatible with both raw config data structures and TOML files, config data must
start with a root dict containing named "config bundles". These are intended to
refer to different modular parts of the application needing configuration, and the
config data structure must contain at least one.
Each config bundle itself needs to have a dict at the root, and so in practice a minimal
TOML config file would look like::
[myapp]
config_thingy_a = "foooooo!"
important_number = 8237
This would resolve to a config bundle named "myapp" that results in the dict::
{"config_thingy_a": "foooooo!", "important_number": 8237}
Root items that are not dicts are not supported, for instance both the following TOML files would fail::
[[myapp]]
important_number = 8237
[[myapp]]
another_important_number = 2963
(root object in bundle is a list)
::
root_thingy = 46
(root object in config is a single value)
"""
import re
import toml
from abc import ABC, abstractmethod
from .freezedry import freezedryable, rehydrate
class InvalidConfigError(Exception):
pass
# The Table and Array terms from the TOML convention essentially
# map directly to Dictionaries (Tables), and Lists (Arrays)
class _ConfigDefinition(ABC):
def __init__(self, default=None, optional=False, helptext=""):
self.default = default
self.optional = optional
self.helptext = helptext
@abstractmethod
def validate(self, value):
pass
@freezedryable
class BoolDef(_ConfigDefinition):
def __init__(self, default=None, optional=False, helptext=""):
super().__init__(default, optional, helptext)
def validate(self, value):
if not isinstance(value, bool):
raise InvalidConfigError("Config value must be a boolean")
@freezedryable
class IntDef(_ConfigDefinition):
def __init__(self, default=None, minval=None, maxval=None,
optional=False, helptext=""):
super().__init__(default, optional, helptext)
self.minval = minval
self.maxval = maxval
def validate(self, value):
if not isinstance(value, int):
raise InvalidConfigError("Config value must be an integer")
if self.minval is not None and value < self.minval:
raise InvalidConfigError("Config value must be >= " +
str(self.minval))
if self.maxval is not None and value > self.maxval:
raise InvalidConfigError("Config value must be <= " +
str(self.maxval))
@freezedryable
class StringDef(_ConfigDefinition):
def __init__(self, default="", minlength=None, maxlength=None,
optional=False, helptext=""):
super().__init__(default, optional, helptext)
self.minlength = minlength
self.maxlength = maxlength
def validate(self, value):
if not isinstance(value, str):
raise InvalidConfigError(F"Config value must be a string and is {value}")
if self.minlength is not None and len(value) < self.minlength:
raise InvalidConfigError("Config string length must be >= " +
str(self.minlength))
if self.maxlength is not None and len(value) > self.maxlength:
raise InvalidConfigError("Config string length must be <= " +
str(self.maxlength))
@freezedryable
class DictDef(_ConfigDefinition):
def __init__(self, default=None, optional=False, helptext=""):
super().__init__(default, optional, helptext)
self.def_dict = {}
def add_def(self, name, newdef):
if not isinstance(newdef, _ConfigDefinition):
raise TypeError("Config definiton must be an instance of a "
"ConfigDefinition subclass")
if not isinstance(name, str):
raise TypeError("Config definition name must be a string")
self.def_dict[name] = newdef
return newdef
def validate(self, value_dict): # pylint: disable=W0221
def_set = set(self.def_dict.keys())
value_set = set(value_dict.keys())
for missing_key in def_set - value_set:
if not self.def_dict[missing_key].optional:
raise InvalidConfigError("Dict must contain key: " +
missing_key)
else:
value_dict[missing_key] = self.def_dict[missing_key].default
for extra_key in value_set - def_set:
raise InvalidConfigError("Dict contains unknown key: " +
extra_key)
for key, value in value_dict.items():
try:
self.def_dict[key].validate(value)
except InvalidConfigError as e:
e.args = ("Key: " + key,) + e.args
raise
def get_template(self, include_optional=False):
"""
Return a config dict with the minimum structure required for this ConfigDefinition.
Default values will be included, though not all required fields will necessarily have
defaults that successfully validate.
Args:
include_optional: If set true, will include *all* config fields, not just the
required ones
Returns:
Dict containing the structure that should be passed back in (with values) to comply
with this ConfigDefinition.
"""
template = {}
for key, confdef in self.def_dict.items():
if confdef.optional and (not include_optional):
continue
if hasattr(confdef,"get_template"):
template[key]=confdef.get_template(include_optional)
else:
template[key]=confdef.default
return template
class _ListDefMixin():
def validate(self, value_list):
if not isinstance(value_list, list):
raise InvalidConfigError("Config item must be a list")
for index, value in enumerate(value_list):
try:
super().validate(value)
except InvalidConfigError as e:
e.args = ("List index: " + str(index),) + e.args
raise
def get_template(self, include_optional=False):
if hasattr(super(),"get_template"):
return [super().get_template(include_optional)]
else:
return [self.default]
@freezedryable
class BoolListDef(_ListDefMixin, BoolDef):
pass
@freezedryable
class IntListDef(_ListDefMixin, IntDef):
pass
@freezedryable
class StringListDef(_ListDefMixin, StringDef):
pass
@freezedryable
class DictListDef(_ListDefMixin, DictDef):
pass
@freezedryable
class ConfDefinition(DictDef):
pass
class ConfigManager():
def __init__(self):
self.root_config = {}
self.confdefs = {}
def load(self, source):
"""
Load a config source into the ConfigManager, replacing any existing config.
Args:
source: Either a dict config to load directly, a filepath to a TOML file,
or an open file.
"""
if isinstance(source, dict): # load from dict
self.root_config = source
elif isinstance(source, str): # load from pathname
with open(source, 'r') as conf_file:
self.root_config = toml.load(conf_file)
else: # load from file
self.root_config = toml.load(source)
def load_overlay(self, source):
"""
Load a config source into the ConfigManager, merging it over the top of any existing
config. Dicts will be recursively processed with keys being merged and existing values
being replaced by the new source. This includes lists, which will be treated as any other
value and completely replaced.
Args:
source: Either the root dict of a data structure to load directly, a filepath to a TOML file,
or an open TOML file.
"""
if isinstance(source, dict): # load from dict
new_source = source
elif isinstance(source, str): # load from pathname
with open(source, 'r') as conf_file:
new_source = toml.load(conf_file)
else: # load from file
new_source = toml.load(source)
self._overlay(new_source, self.root_config)
def add_confdef(self, bundle_name, confdef):
"""
Stores a ConfigDefinition for future use when validating the corresponding config bundle
Args:
bundle_name (str) : The name to store the config definition under.
confdef (ConfigDefinition): The populated ConfigDefinition to store.
"""
self.confdefs[bundle_name]=confdef
def add_confdefs(self, confdefs):
"""
Stores multiple ConfigDefinitions at once for future use when validating the corresponding config bundles
Args:
confdefs : A dict of populated ConfigDefinitions to store, using their keys as names.
"""
self.confdefs.update(confdefs)
def list_missing_confdefs(self):
"""
Returns a list of config bundle names that do not have a corresponding ConfigDefinition
stored in the ConfigManager.
"""
return list(self.root_config.keys() - self.confdefs.keys())
def _overlay(self, src, dest):
for key in src:
# If the key is also in the dest and both are dicts, merge them.
if key in dest and isinstance(src[key], dict) and isinstance(dest[key], dict):
self._overlay(src[key], dest[key])
else:
# Otherwise it's either an existing value to be replaced or needs to be added.
dest[key] = src[key]
def get_config_bundle(self, bundle_name, conf_def=None):
"""
Get a config bundle called ``bundle_name`` and validate
it against the corresponding config definition stored in the ConfigManager.
If ``conf_def`` is supplied, it gets used instead. Returns a validated
config bundle dict.
Note that as part of validation, optional keys that are missing will be
filled in with their default values (see ``DictDef``).
Args:
config_name: (str) Name of the config dict to find.
conf_def: (ConfDefinition) Optional config definition to validate against.
"""
if not isinstance(conf_def, ConfDefinition):
conf_def = self.confdefs[bundle_name]
if bundle_name not in self.root_config:
raise InvalidConfigError(
"Config must contain dict: " + bundle_name)
try:
conf_def.validate(self.root_config[bundle_name])
except InvalidConfigError as e:
e.args = ("Module: " + bundle_name,) + e.args
raise
return self.root_config[bundle_name]
def get_config_bundles(self, bundle_names):
"""
Get multiple config bundles from the root dict at once, validating each one with the
corresponding confdef stored in the ConfigManager.
Args:
bundle_names: A list of config bundle names to get. If dictionary is supplied, uses the values
as ConfigDefinitions rather than looking up a stored one in the ConfigManager.
Returns:
A dict of config dicts, with keys matching those passed in ``bundle_names``.
"""
config_values = {}
if isinstance(bundle_names, dict):
for name, conf_def in bundle_names.items():
config_values[name] = self.get_config_bundle(name, conf_def)
else:
for name in bundle_names:
config_values[name] = self.get_config_bundle(name)
return config_values
def get_bundle_names(self):
"""
Returns a list of names of top level config bundles
"""
return list(self.root_config.keys())
def dump_toml(self):
return toml.dumps(self.root_config)
def dump_to_file(self, filepath, message=None):
with open(filepath, 'w+') as f:
content = self.dump_toml()
if message is not None:
content = content.rstrip() + gen_comment(message)
f.write(content)
def strip_toml_message(string):
print("stripping...")
return re.sub("(?m)^#\\ shepherd_message:[^\\n]*$\\n?(?:^#[^\\n]+$\\n?)*",
'', string)
def update_toml_message(filepath, message):
with open(filepath, 'r+') as f:
content = f.read()
content = strip_toml_message(content).rstrip()
content += gen_comment(message)
f.seek(0)
f.write(content)
f.truncate()
def gen_comment(string):
return '\n# shepherd_message: ' + '\n# '.join(string.replace('#', '').splitlines()) + '\n'