import threading import secrets from pathlib import Path from urllib.parse import urlparse, urlunparse, urljoin from hashlib import blake2b import time import logging import toml import requests from configspec import * import statesman log = logging.getLogger("shepherd.agent.control") _control_update_required = threading.Condition() def _update_required_callback(): with _control_update_required: _control_update_required.notify() def control_confspec(): """ Returns the control config specification """ confspec = ConfigSpecification(optional=True) confspec.add_spec("server", StringSpec()) confspec.add_spec("intro_key", StringSpec()) return confspec class CoreUpdateState(): def __init__(self, local_config, applied_config): self.topic_bundle = statesman.TopicBundle() self.topic_bundle.add_writer('status', statesman.StateWriter()) self.topic_bundle.add_writer('config-spec', statesman.StateWriter()) self.topic_bundle.add_writer('device-config', statesman.StateWriter()) self.topic_bundle.add_writer('applied-config', statesman.StateWriter()) self.topic_bundle.set_update_required_callback(_update_required_callback) # These should all effectively be static self.topic_bundle['device-config'].set_state(local_config) self.topic_bundle['applied-config'].set_state(applied_config) def set_status(self, status_dict): self.topic_bundle['status'].set_state(status_dict) class PluginUpdateState(): def __init__(self): self.topic_bundle = statesman.TopicBundle() # config-spec should be static, but isn't known yet when this is created self.topic_bundle.add_writer('status', statesman.StateWriter()) self.topic_bundle.add_writer('config-spec', statesman.StateWriter()) self.topic_bundle.set_update_required_callback(_update_required_callback) def set_status(self, status_dict): self.topic_bundle['status'].set_state(status_dict) def set_confspec(self, config_spec): self.topic_bundle['config-spec'].set_state(config_spec) def clean_https_url(dirty_url): """ Take a url with or without the leading "https://" scheme, and convert it to one that does. Change HTTP to HTTPS if present. """ # Some weirdness with URL parsing means that by default most urls (like www.google.com) # get treated as relative # https://stackoverflow.com/questions/53816559/python-3-netloc-value-in-urllib-parse-is-empty-if-url-doesnt-have if "//" not in dirty_url: dirty_url = "//"+dirty_url return urlunparse(urlparse(dirty_url)._replace(scheme="https")) def load_device_identity(root_dir): """ Attempt to load the device identity from the shepherd.identity file. Will throw exceptions if this fails. Returns a tuple of (device_secret, device_id) """ identity_filepath = Path(root_dir, 'shepherd.identity') if not identity_filepath.exists(): log.warning(F"{identity_filepath} file does not exist") raise FileNotFoundError() with identity_filepath.open() as identity_file: identity_dict = toml.load(identity_file) dev_secret = identity_dict["device_secret"] dev_secret_bytes = bytes.fromhex(dev_secret) if len(dev_secret_bytes) != 16: log.error(F"Device secret loaded from file {identity_filepath} does not contain the " "required 16 bytes") raise ValueError() secret_hash = blake2b(dev_secret_bytes, digest_size=16).hexdigest() dev_id = secret_hash[:8] log.info(F"Loaded device identity. ID: {dev_id}") return (dev_secret, dev_id) def generate_device_identity(root_dir): """ Generate a new device identity and save it to the shepherd.identity file. Returns a tuple of (device_secret, device_id). """ dev_secret = secrets.token_hex(16) identity_dict = {} identity_dict['device_secret'] = dev_secret identity_filepath = Path(root_dir, 'shepherd.identity') with identity_filepath.open('w+') as identity_file: toml.dump(identity_dict, identity_file) dev_secret_bytes = bytes.fromhex(dev_secret) secret_hash = blake2b(dev_secret_bytes, digest_size=16).hexdigest() dev_id = secret_hash[:8] log.info(F"Generated new device identity. ID: {dev_id}") return (dev_secret, dev_id) _update_thread_init_done = threading.Event() _stop_event = threading.Event() def stop(): _stop_event.set() _update_required_callback() log.info("Control thread stop requested.") def init_control(config, root_dir, core_update_state, plugin_update_states): """ Start the Control update thread and initialise the Shepherd Control systems. """ _stop_event.clear() _update_thread_init_done.clear() control_thread = threading.Thread(target=_control_update_loop, args=( config, root_dir, core_update_state, plugin_update_states)) control_thread.start() # Wait for init so our log makes sense _update_thread_init_done.wait() return control_thread def _control_update_loop(config, root_dir, core_update_state, plugin_update_states): control_api_url = urljoin(clean_https_url(config["server"]), "/agent") log.info(F"Control server API endpoint is {control_api_url}") intro_key = config["intro_key"] log.info(F"Using intro key: {intro_key}") try: device_secret, device_id = load_device_identity(root_dir) except Exception: log.warning("Could not load device identity from shepherd.identity file") device_secret, device_id = generate_device_identity(root_dir) _update_thread_init_done.set() update_rate_limiter = SmoothTokenBucketLimit(10, 10*60, 3, time.monotonic()) session = requests.Session() # r=session.post('https://api.shepherd.test/agent/update') while True: # Spin here until something needs updating with _control_update_required: new_endpoint_updates = {} # a dict of url:topic_bundle pairs while True: if core_update_state.topic_bundle.is_update_required(): new_endpoint_updates['/update'] = core_update_state.topic_bundle for plugin_name, state in plugin_update_states.items(): if state.topic_bundle.is_update_required(): new_endpoint_updates[f"/pluginupdate/{plugin_name}"] = state.topic_bundle if (len(new_endpoint_updates) > 0) or _stop_event.is_set(): break _control_update_required.wait() for endpoint, topic_bundle in new_endpoint_updates.items(): try: r = session.post(control_api_url+endpoint, json=topic_bundle.get_payload(), auth=(device_secret, intro_key)) if r.status_code == requests.codes['conflict']: # Server replies with this when trying to add our device ID and failing # due to it already existing (device secret hash is a mismatch). We need to # regenerate our ID log.info(F"Control server has indicated that device ID {device_id} already" " exists. Generating new one...") device_secret, device_id = generate_device_identity(root_dir) elif r.status_code == requests.codes['ok']: topic_bundle.process_message(r.json()) except requests.exceptions.RequestException: log.exception("Failed to make Shepherd Control request") if _stop_event.is_set(): # Breaking here is a clean way of killing any delay and allowing a final update before # the thread ends. log.warning("Control thread stopping...") break delay = update_rate_limiter.new_event(time.monotonic()) _stop_event.wait(delay) _update_thread_init_done.clear() def get_cached_config(config_dir): return {} def clear_cached_config(config_dir): pass class SmoothTokenBucketLimit(): """ Event rate limiter implementing a modified Token Bucket algorithm. Delay returned ramps up as the bucket empties. """ def __init__(self, allowed_events, period, allowed_burst, initial_time): self.allowed_events = allowed_events self.period = period self.allowed_burst = allowed_burst self.last_token_timestamp = initial_time self.tokens = allowed_events self._is_saturated = False def new_event(self, time_now): """ Register a new event for the rate limiter. Return a required delay to ignore future events for in seconds. Conceptually, the "token" we're grabbing here is for the _next_ event. """ if self.tokens < self.allowed_events: time_since_last_token = time_now - self.last_token_timestamp tokens_added = int(time_since_last_token/(self.period/self.allowed_events)) self.tokens = self.tokens + tokens_added self.last_token_timestamp = self.last_token_timestamp + \ (self.period/self.allowed_events)*tokens_added if self.tokens >= self.allowed_events: self.last_token_timestamp = time_now if self.tokens > 0: self.tokens = self.tokens - 1 # Add a delay that ramps from 0 when the bucket is allowed_burst from full to p/x when # it is empty ramp_token_count = self.allowed_events-self.allowed_burst if self.tokens > ramp_token_count: delay = 0 else: delay = ((self.period/self.allowed_events)/ramp_token_count) * \ (ramp_token_count-self.tokens) self._is_saturated = False else: delay = (self.period/self.allowed_events) - (time_now-self.last_token_timestamp) self.last_token_timestamp = time_now+delay # This delay makes is set to account for the next token that would otherwise be added, # without relying on the returned delay _actually_ occurring exactly. self._is_saturated = True return delay def is_saturated(self): """ Returns true if the rate limiter delay is at it's maximum value """ return self._is_saturated