import threading import secrets from types import SimpleNamespace from pathlib import Path from urllib.parse import urlparse, urlunparse, urljoin from hashlib import blake2b import time import logging import toml import requests from configspec import * import statesman # Namespace of types intended for server-side use. def get_export(): from . import plugin export = SimpleNamespace() export.InterfaceCall = plugin.InterfaceCall return export log = logging.getLogger("shepherd.agent.control") _control_update_required = threading.Condition() def _update_required_callback(): with _control_update_required: _control_update_required.notify() def control_confspec(): """ Returns the control config specification """ confspec = ConfigSpecification() confspec.add_spec("server", StringSpec()) confspec.add_spec("intro_key", StringSpec()) return confspec class CoreUpdateState(): def __init__(self, cmd_reader, cmd_result_writer): """ Control update handler for the `/update` core endpoint. """ self.topic_bundle = statesman.TopicBundle({ 'status': statesman.StateWriter(), 'config-spec': statesman.StateWriter(), 'device-config': statesman.StateWriter(), 'applied-config': statesman.StateWriter(), 'control-commands': cmd_reader, 'command-results': cmd_result_writer}) self.topic_bundle.set_update_required_callback(_update_required_callback) def set_static_state(self, local_config, applied_config, confspec): # These should all effectively be static self.topic_bundle['device-config'].set_state(local_config) self.topic_bundle['applied-config'].set_state(applied_config) self.topic_bundle['config-spec'].set_state(confspec) def set_status(self, status_dict): self.topic_bundle['status'].set_state(status_dict) class CommandRunner(): def __init__(self, interface_functions): self.cmd_reader = statesman.SequenceReader( new_message_callback=self.on_new_command_message) self.cmd_result_writer = statesman.SequenceWriter() self._functions = interface_functions self.current_commands = {} def on_new_command_message(self, message): # This should be a single list, where the first value is the command ID and the second # value is a plugin.FunctionCall commandID = message[0] command_call = message[1] command_thread = threading.Thread(target=self._process_command, args=(commandID, command_call)) command_thread.start() def _process_command(self, commandID, command_call): if commandID in self.current_commands: raise ValueError(F"Already running a command with ID {commandID}") self.current_commands[commandID] = threading.current_thread() try: command_call.resolve(self._functions) result = command_call.call() self.cmd_result_writer.add_message([commandID, result]) finally: self.current_commands.pop(commandID) class PluginUpdateState(): def __init__(self): self.topic_bundle = statesman.TopicBundle() # config-spec should be static, but isn't known yet when this is created self.topic_bundle.add('status', statesman.StateWriter()) self.topic_bundle.add('config-spec', statesman.StateWriter()) self.topic_bundle.add('command-spec', statesman.StateWriter()) # Why is config split out into plugins? Just like the device config and applied config, # it's only loaded once at the start. Is this purely because it's easy to get at from the # PluginInterface where this object is created? self.topic_bundle.set_update_required_callback(_update_required_callback) def set_status(self, status_dict): self.topic_bundle['status'].set_state(status_dict) def set_confspec(self, config_spec): self.topic_bundle['config-spec'].set_state(config_spec) def set_commandspec(self, command_spec): self.topic_bundle['command-spec'].set_state(command_spec) def clean_https_url(dirty_url): """ Take a url with or without the leading "https://" scheme, and convert it to one that does. Change HTTP to HTTPS if present. """ # Some weirdness with URL parsing means that by default most urls (like www.google.com) # get treated as relative # https://stackoverflow.com/questions/53816559/python-3-netloc-value-in-urllib-parse-is-empty-if-url-doesnt-have if "//" not in dirty_url: dirty_url = "//"+dirty_url return urlunparse(urlparse(dirty_url)._replace(scheme="https")) def load_device_identity(root_dir): """ Attempt to load the device identity from the shepherd.identity file. Will throw exceptions if this fails. Returns a tuple of (device_secret, device_id) """ identity_filepath = Path(root_dir, 'shepherd.identity') if not identity_filepath.exists(): log.warning(F"{identity_filepath} file does not exist") raise FileNotFoundError() with identity_filepath.open() as identity_file: identity_dict = toml.load(identity_file) dev_secret = identity_dict["device_secret"] dev_secret_bytes = bytes.fromhex(dev_secret) if len(dev_secret_bytes) != 16: log.error(F"Device secret loaded from file {identity_filepath} does not contain the " "required 16 bytes") raise ValueError() secret_hash = blake2b(dev_secret_bytes, digest_size=16).hexdigest() dev_id = secret_hash[:8] log.info(F"Loaded device identity. ID: {dev_id}") return (dev_secret, dev_id) def generate_device_identity(root_dir): """ Generate a new device identity and save it to the shepherd.identity file. Returns a tuple of (device_secret, device_id). """ dev_secret = secrets.token_hex(16) identity_dict = {} identity_dict['device_secret'] = dev_secret identity_filepath = Path(root_dir, 'shepherd.identity') with identity_filepath.open('w+') as identity_file: toml.dump(identity_dict, identity_file) dev_secret_bytes = bytes.fromhex(dev_secret) secret_hash = blake2b(dev_secret_bytes, digest_size=16).hexdigest() dev_id = secret_hash[:8] log.info(F"Generated new device identity. ID: {dev_id}") return (dev_secret, dev_id) _update_thread_init_done = threading.Event() _stop_event = threading.Event() def stop(): _stop_event.set() _update_required_callback() log.info("Control thread stop requested.") def start_control(config, root_dir, core_update_state, plugin_update_states): """ Start the Control update thread and initialise the Shepherd Control systems. """ _stop_event.clear() _update_thread_init_done.clear() control_thread = threading.Thread(target=_control_update_loop, args=( config, root_dir, core_update_state, plugin_update_states)) control_thread.start() # Wait for init so our log makes sense _update_thread_init_done.wait() return control_thread def _control_update_loop(config, root_dir, core_update_state, plugin_update_states): control_api_url = urljoin(clean_https_url(config["server"]), "/agent") log.info(F"Control server API endpoint is {control_api_url}") intro_key = config["intro_key"] log.info(F"Using intro key: {intro_key}") try: device_secret, device_id = load_device_identity(root_dir) except Exception: log.warning("Could not load device identity from shepherd.identity file") device_secret, device_id = generate_device_identity(root_dir) _update_thread_init_done.set() update_rate_limiter = SmoothTokenBucketLimit(10, 10*60, 3, time.monotonic()) session = requests.Session() # r=session.post('https://api.shepherd.test/agent/update') while True: # Spin here until something needs updating with _control_update_required: new_endpoint_updates = {} # a dict of url:topic_bundle pairs while True: if core_update_state.topic_bundle.is_update_required(): new_endpoint_updates['/update'] = core_update_state.topic_bundle for plugin_name, state in plugin_update_states.items(): if state.topic_bundle.is_update_required(): new_endpoint_updates[f"/pluginupdate/{plugin_name}"] = state.topic_bundle if (len(new_endpoint_updates) > 0) or _stop_event.is_set(): break _control_update_required.wait() for endpoint, topic_bundle in new_endpoint_updates.items(): try: r = session.post(control_api_url+endpoint, json=topic_bundle.get_payload(), auth=(device_secret, intro_key)) if r.status_code == requests.codes['conflict']: # Server replies with this when trying to add our device ID and failing # due to it already existing (device secret hash is a mismatch). We need to # regenerate our ID log.info(F"Control server has indicated that device ID {device_id} already" " exists. Generating new one...") device_secret, device_id = generate_device_identity(root_dir) elif r.status_code == requests.codes['ok']: topic_bundle.process_message(r.json()) except requests.exceptions.RequestException: log.exception("Failed to make Shepherd Control request") if _stop_event.is_set(): # Breaking here is a clean way of killing any delay and allowing a final update before # the thread ends. log.warning("Control thread stopping...") _stop_event.clear() break delay = update_rate_limiter.new_event(time.monotonic()) _stop_event.wait(delay) _update_thread_init_done.clear() def get_cached_config(config_dir): return {} def clear_cached_config(config_dir): pass class SmoothTokenBucketLimit(): """ Event rate limiter implementing a modified Token Bucket algorithm. Delay returned ramps up as the bucket empties. """ def __init__(self, allowed_events, period, allowed_burst, initial_time): self.allowed_events = allowed_events self.period = period self.allowed_burst = allowed_burst self.last_token_timestamp = initial_time self.tokens = allowed_events self._is_saturated = False def new_event(self, time_now): """ Register a new event for the rate limiter. Return a required delay to ignore future events for in seconds. Conceptually, the "token" we're grabbing here is for the _next_ event. """ if self.tokens < self.allowed_events: time_since_last_token = time_now - self.last_token_timestamp tokens_added = int(time_since_last_token/(self.period/self.allowed_events)) self.tokens = self.tokens + tokens_added self.last_token_timestamp = self.last_token_timestamp + \ (self.period/self.allowed_events)*tokens_added if self.tokens >= self.allowed_events: self.last_token_timestamp = time_now if self.tokens > 0: self.tokens = self.tokens - 1 # Add a delay that ramps from 0 when the bucket is allowed_burst from full to p/x when # it is empty ramp_token_count = self.allowed_events-self.allowed_burst if self.tokens > ramp_token_count: delay = 0 else: delay = ((self.period/self.allowed_events)/ramp_token_count) * \ (ramp_token_count-self.tokens) self._is_saturated = False else: delay = (self.period/self.allowed_events) - (time_now-self.last_token_timestamp) self.last_token_timestamp = time_now+delay # This delay makes is set to account for the next token that would otherwise be added, # without relying on the returned delay _actually_ occurring exactly. self._is_saturated = True return delay def is_saturated(self): """ Returns true if the rate limiter delay is at it's maximum value """ return self._is_saturated