diff --git a/shepherd/agent/control.py b/shepherd/agent/control.py index fa28e70..680b233 100644 --- a/shepherd/agent/control.py +++ b/shepherd/agent/control.py @@ -1,132 +1,291 @@ -import os -import uuid -import subprocess -import requests import threading -import json +import secrets +from pathlib import Path from urllib.parse import urlparse, urlunparse, urljoin -from collections import namedtuple +from hashlib import blake2b +import time +import logging +import toml +import requests +from configspec import * +import statesman -from . import plugin -# Check for shepherd.new file in edit conf dir. If there, -# or if no shepherd.id file can be found, generate a new one. -# For now, also attempt to delete /var/lib/zerotier-one/identity.public and identity.secret -# Once generated, if it was due to shepherd.new file, delete it. +log = logging.getLogger("shepherd.agent.control") +_control_update_required = threading.Condition() -#Start new thread, and push ID and core config to api.shepherd.distreon.net/client/update -class UpdateManager(): - def __init__(self): - pass +def _update_required_callback(): + with _control_update_required: + _control_update_required.notify() + + +def control_confspec(): + """ + Returns the control config specification + """ + confspec = ConfigSpecification(optional=True) + confspec.add_spec("server", StringSpec()) + confspec.add_spec("intro_key", StringSpec()) + + return confspec + + +class CoreUpdateState(): + def __init__(self, local_config, applied_config): + self.topic_bundle = statesman.TopicBundle() -class SequenceUpdate(): - Item = namedtuple('Item', ['sequence_number', 'content']) + self.topic_bundle.add_writer('status', statesman.StateWriter()) + self.topic_bundle.add_writer('config-spec', statesman.StateWriter()) + self.topic_bundle.add_writer('device-config', statesman.StateWriter()) + self.topic_bundle.add_writer('applied-config', statesman.StateWriter()) + + self.topic_bundle.set_update_required_callback(_update_required_callback) + + # These should all effectively be static + self.topic_bundle['device-config'].set_state(local_config) + self.topic_bundle['applied-config'].set_state(applied_config) + + def set_status(self, status_dict): + self.topic_bundle['status'].set_state(status_dict) + + +class PluginUpdateState(): def __init__(self): - self.items = [] - self._sequence_count = 0 - self._dirty = False - - def _next_sequence_number(self): - # TODO: need to establish a max sequence number, so that it can be compared to half - # that range and wrap around. - self._sequence_count +=1 - return self._sequence_count - - def mark_as_dirty(self): - self._dirty = True - - def add_item(self, item): - self.items.append(self.Item(self._next_sequence_number(), item)) - self.mark_as_dirty() - - def get_payload(): - pass - def process_ack(): - pass - -client_id = None -control_url = None -api_key = None - -def _update_job(core_config, plugin_config): - payload = {"client_id":client_id, "core_config":core_config,"plugin_config":plugin_config} - #json_string = json.dumps(payload) - try: - # Using the json arg rather than json.dumps ourselves automatically sets the Content-Type - # header to application/json, which Flask expects to work correctly - r = requests.post(control_url, json=payload, auth=(client_id, api_key)) - except requests.exceptions.ConnectionError: - raise - -def generate_new_id(root_dir): - global client_id - with open(os.path.join(root_dir, "shepherd.id"), 'w+') as f: - new_id = uuid.uuid1() - client_id = str(new_id) - f.write(client_id) - generate_new_zerotier_id() - -def get_config(config_dir): - return {} + self.topic_bundle = statesman.TopicBundle() + + # config-spec should be static, but isn't known yet when this is created + self.topic_bundle.add_writer('status', statesman.StateWriter()) + self.topic_bundle.add_writer('config-spec', statesman.StateWriter()) -def init_control(core_config, plugin_config): - global client_id - global control_url - global api_key + self.topic_bundle.set_update_required_callback(_update_required_callback) - # On init, need to be able to quickly return the cached shepherd control layer if necessary. - - # Create the /update endpoint structure + def set_status(self, status_dict): + self.topic_bundle['status'].set_state(status_dict) - root_dir = os.path.expanduser(core_config["root_dir"]) - editconf_dir = os.path.dirname(os.path.expanduser(core_config["conf_edit_path"])) + def set_confspec(self, config_spec): + self.topic_bundle['config-spec'].set_state(config_spec) - #Some weirdness with URL parsing means that by default most urls (like www.google.com) + +def clean_https_url(dirty_url): + """ + Take a url with or without the leading "https://" scheme, and convert it to one that + does. Change HTTP to HTTPS if present. + """ + # Some weirdness with URL parsing means that by default most urls (like www.google.com) # get treated as relative # https://stackoverflow.com/questions/53816559/python-3-netloc-value-in-urllib-parse-is-empty-if-url-doesnt-have - control_url = core_config["control_server"] - if "//" not in control_url: - control_url = "//"+control_url - control_url = urlunparse(urlparse(control_url)._replace(scheme="https")) - control_url = urljoin(control_url, "/client/update") - print(F"Control url is: {control_url}") - - api_key = core_config["api_key"] - - if os.path.isfile(os.path.join(editconf_dir, "shepherd.new")): - generate_new_id(root_dir) - os.remove(os.path.join(editconf_dir, "shepherd.new")) - print(F"Config hostname: {core_config['hostname']}") - if not (core_config["hostname"] == ""): - print("Attempting to change hostname") - subprocess.run(["raspi-config", "nonint", "do_hostname", core_config["hostname"]]) - elif not os.path.isfile(os.path.join(root_dir, "shepherd.id")): - generate_new_id(root_dir) - else: - with open(os.path.join(root_dir, "shepherd.id"), 'r') as id_file: - client_id = id_file.readline().strip() - - print(F"Client ID is: {client_id}") - - control_thread = threading.Thread(target=_update_job, args=(core_config,plugin_config)) + if "//" not in dirty_url: + dirty_url = "//"+dirty_url + return urlunparse(urlparse(dirty_url)._replace(scheme="https")) + + +def load_device_identity(root_dir): + """ + Attempt to load the device identity from the shepherd.identity file. Will throw exceptions if + this fails. Returns a tuple of (device_secret, device_id) + """ + identity_filepath = Path(root_dir, 'shepherd.identity') + if not identity_filepath.exists(): + log.warning(F"{identity_filepath} file does not exist") + raise FileNotFoundError() + + with identity_filepath.open() as identity_file: + identity_dict = toml.load(identity_file) + + dev_secret = identity_dict["device_secret"] + + dev_secret_bytes = bytes.fromhex(dev_secret) + + if len(dev_secret_bytes) != 16: + log.error(F"Device secret loaded from file {identity_filepath} does not contain the " + "required 16 bytes") + raise ValueError() + + secret_hash = blake2b(dev_secret_bytes, digest_size=16).hexdigest() + dev_id = secret_hash[:8] + log.info(F"Loaded device identity. ID: {dev_id}") + return (dev_secret, dev_id) + + +def generate_device_identity(root_dir): + """ + Generate a new device identity and save it to the shepherd.identity file. + Returns a tuple of (device_secret, device_id). + """ + + dev_secret = secrets.token_hex(16) + + identity_dict = {} + identity_dict['device_secret'] = dev_secret + + identity_filepath = Path(root_dir, 'shepherd.identity') + with identity_filepath.open('w+') as identity_file: + toml.dump(identity_dict, identity_file) + + dev_secret_bytes = bytes.fromhex(dev_secret) + secret_hash = blake2b(dev_secret_bytes, digest_size=16).hexdigest() + dev_id = secret_hash[:8] + + log.info(F"Generated new device identity. ID: {dev_id}") + return (dev_secret, dev_id) + + +_update_thread_init_done = threading.Event() + +_stop_event = threading.Event() + + +def stop(): + _stop_event.set() + _update_required_callback() + log.info("Control thread stop requested.") + + +def init_control(config, root_dir, core_update_state, plugin_update_states): + """ + Start the Control update thread and initialise the Shepherd Control systems. + """ + _stop_event.clear() + _update_thread_init_done.clear() + control_thread = threading.Thread(target=_control_update_loop, args=( + config, root_dir, core_update_state, plugin_update_states)) control_thread.start() - -def _post_logs_job(): - logs = shepherd.plugin.plugin_functions["scout"].get_logs() - measurements = shepherd.plugin.plugin_functions["scout"].get_measurements() + # Wait for init so our log makes sense + _update_thread_init_done.wait() + + return control_thread + +def _control_update_loop(config, root_dir, core_update_state, plugin_update_states): + control_api_url = urljoin(clean_https_url(config["server"]), "/agent") + log.info(F"Control server API endpoint is {control_api_url}") + intro_key = config["intro_key"] + log.info(F"Using intro key: {intro_key}") - payload = {"client_id":client_id, "logs":logs, "measurements":measurements} - try: - r = requests.post(control_url, json=payload, auth=(client_id, api_key)) - except requests.exceptions.ConnectionError: - pass + device_secret, device_id = load_device_identity(root_dir) + except Exception: + log.warning("Could not load device identity from shepherd.identity file") + device_secret, device_id = generate_device_identity(root_dir) + + _update_thread_init_done.set() + + update_rate_limiter = SmoothTokenBucketLimit(10, 10*60, 3, time.monotonic()) + + session = requests.Session() + # r=session.post('https://api.shepherd.test/agent/update') + while True: + # Spin here until something needs updating + with _control_update_required: + + new_endpoint_updates = {} # a dict of url:topic_bundle pairs + while True: + if core_update_state.topic_bundle.is_update_required(): + new_endpoint_updates['/update'] = core_update_state.topic_bundle + + for plugin_name, state in plugin_update_states.items(): + if state.topic_bundle.is_update_required(): + new_endpoint_updates[f"/pluginupdate/{plugin_name}"] = state.topic_bundle + + if (len(new_endpoint_updates) > 0) or _stop_event.is_set(): + break + + _control_update_required.wait() + + for endpoint, topic_bundle in new_endpoint_updates.items(): + try: + r = session.post(control_api_url+endpoint, + json=topic_bundle.get_payload(), + auth=(device_secret, intro_key)) + + if r.status_code == requests.codes['conflict']: + # Server replies with this when trying to add our device ID and failing + # due to it already existing (device secret hash is a mismatch). We need to + # regenerate our ID + log.info(F"Control server has indicated that device ID {device_id} already" + " exists. Generating new one...") + device_secret, device_id = generate_device_identity(root_dir) + elif r.status_code == requests.codes['ok']: + topic_bundle.process_message(r.json()) + except requests.exceptions.RequestException: + log.exception("Failed to make Shepherd Control request") + + if _stop_event.is_set(): + # Breaking here is a clean way of killing any delay and allowing a final update before + # the thread ends. + log.warning("Control thread stopping...") + break + + delay = update_rate_limiter.new_event(time.monotonic()) + _stop_event.wait(delay) + + _update_thread_init_done.clear() + + +def get_cached_config(config_dir): + return {} + + +def clear_cached_config(config_dir): + pass + + +class SmoothTokenBucketLimit(): + """ + Event rate limiter implementing a modified Token Bucket algorithm. Delay returned ramps + up as the bucket empties. + """ + + def __init__(self, allowed_events, period, allowed_burst, initial_time): + self.allowed_events = allowed_events + self.period = period + self.allowed_burst = allowed_burst + self.last_token_timestamp = initial_time + self.tokens = allowed_events + self._is_saturated = False + + def new_event(self, time_now): + """ + Register a new event for the rate limiter. Return a required delay to ignore future + events for in seconds. Conceptually, the "token" we're grabbing here is for the _next_ + event. + """ + if self.tokens < self.allowed_events: + time_since_last_token = time_now - self.last_token_timestamp + tokens_added = int(time_since_last_token/(self.period/self.allowed_events)) + + self.tokens = self.tokens + tokens_added + self.last_token_timestamp = self.last_token_timestamp + \ + (self.period/self.allowed_events)*tokens_added + if self.tokens >= self.allowed_events: + self.last_token_timestamp = time_now + + if self.tokens > 0: + self.tokens = self.tokens - 1 + # Add a delay that ramps from 0 when the bucket is allowed_burst from full to p/x when + # it is empty + ramp_token_count = self.allowed_events-self.allowed_burst + if self.tokens > ramp_token_count: + delay = 0 + else: + delay = ((self.period/self.allowed_events)/ramp_token_count) * \ + (ramp_token_count-self.tokens) + + self._is_saturated = False + else: + delay = (self.period/self.allowed_events) - (time_now-self.last_token_timestamp) + self.last_token_timestamp = time_now+delay + # This delay makes is set to account for the next token that would otherwise be added, + # without relying on the returned delay _actually_ occurring exactly. + self._is_saturated = True -def post_logs(): - post_logs_thread = threading.Thread(target=_post_logs_job, args=()) - post_logs_thread.start() + return delay + def is_saturated(self): + """ + Returns true if the rate limiter delay is at it's maximum value + """ + return self._is_saturated diff --git a/shepherd/agent/core.py b/shepherd/agent/core.py index b6e1184..463ce32 100644 --- a/shepherd/agent/core.py +++ b/shepherd/agent/core.py @@ -7,6 +7,7 @@ import os import sys from pathlib import Path import glob +from copy import deepcopy from datetime import datetime from types import SimpleNamespace from pprint import pprint @@ -53,8 +54,12 @@ def echo_section(title, input_text=None, on_nl=True): " Shepherd Control remote features") @click.option('-d', '--default-config-only', 'only_default_layer', is_flag=True, help="Ignore the custom config layer (still uses the Control config above that)") +@click.option('-n', '--new', 'new_run', is_flag=True, + help="Clear existing device identity and cached Shepherd Control config layer." + " Also triggered by the presence of a shepherd.new file in the" + " same directory as the custom config layer file.") @click.pass_context -def cli(ctx, default_config_path, local_operation, only_default_layer): +def cli(ctx, default_config_path, local_operation, only_default_layer, new_run): """ Core service. If default config file is not provided with '-c' option, the first filename in the current working directory beginning with "shepherd" and @@ -94,7 +99,7 @@ def cli(ctx, default_config_path, local_operation, only_default_layer): confman = ConfigManager() - compile_config(confman, default_config_path, layers_disabled) + compile_config(confman, default_config_path, layers_disabled, new_run) plugin_configs = confman.get_config_bundles() del plugin_configs["shepherd"] @@ -105,7 +110,21 @@ def cli(ctx, default_config_path, local_operation, only_default_layer): ctx.obj.core_config = core_config return - # control.init_control(core_config, plugin_configs) + applied_config = confman.get_config_bundles() + # Not part of normal ConfigManager, we just saved this here to pass it to Control + local_config = confman.saved_local_config + + core_update_state = control.CoreUpdateState(local_config, applied_config) + + plugin_update_states = {name: iface._update_state + for name, iface in plugin.plugin_interfaces.items()} + + if core_config["control"] is not None: + control.init_control(core_config["control"], core_config["root_dir"], + core_update_state, plugin_update_states) + else: + log.warning("Shepherd control config section not present. Will not attempt to connect to" + " Shepherd Control server.") scheduler.init_scheduler(core_config) plugin.init_plugins(plugin_configs, core_config) @@ -264,7 +283,20 @@ def template(ctx, plugin_name, include_all, config_path, plugin_dir): f.write(template_toml) -def compile_config(confman, default_config_path, layers_disabled): +def check_new_run_file(custom_config_path): + if not custom_config_path: + return False + + trigger_path = Path(Path(custom_config_path).parent, 'shepherd.new') + if trigger_path.exists(): + trigger_path.unlink() + log.info("'shepherd.new' file detected, removing file and clearing old state...") + return True + + return False + + +def compile_config(confman, default_config_path, layers_disabled, new_run): """ Run through the process of assembling the various config layers, falling back to working ones where necessary. As part of this, loads in the required plugins. @@ -327,14 +359,22 @@ def compile_config(confman, default_config_path, layers_disabled): confman.freeze_value("shepherd", "control_server") confman.freeze_value("shepherd", "control_api_key") - # Save current good config + # Save current good local config confman.save_fallback() + # Tuck it away so we can pass the local config to Control + confman.saved_local_config = deepcopy(confman.get_config_bundles()) + + if check_new_run_file(custom_config_path) or new_run: + if new_run: + log.info("'new run' selected, clearing old state...") + control.generate_device_identity(core_conf["root_dir"]) + control.clear_cached_config(core_conf["root_dir"]) # ====Control Remote Config Layer==== # If this fails, maintain current local config. if "control" not in layers_disabled: try: - control_config = control.get_config(core_conf["root_dir"]) + control_config = control.get_cached_config(core_conf["root_dir"]) try: load_config_layer_and_plugins(confman, control_config) log.info(F"Loaded cached Shepherd Control config layer") @@ -383,9 +423,7 @@ def core_confspec(): " errors in validation.")) ]) - confspec.add_spec("control_server", StringSpec()) - confspec.add_spec("control_api_key", StringSpec()) - + confspec.add_spec("control", control.control_confspec()) return confspec diff --git a/shepherd/agent/plugin.py b/shepherd/agent/plugin.py index 1a218b2..43ddbb8 100644 --- a/shepherd/agent/plugin.py +++ b/shepherd/agent/plugin.py @@ -9,6 +9,7 @@ from functools import partial from types import MappingProxyType import pkg_resources from configspec import ConfigSpecification +from . import control from .. import base_plugins @@ -136,6 +137,7 @@ class PluginInterface(): self.config = None self.plugins = None self._plugin_name = "" + self._update_state = control.PluginUpdateState() def _load_pluginclass(self, module): pass @@ -178,6 +180,15 @@ class PluginInterface(): self._functions[ifunc.name] = ifunc + def set_status(self, status): + """ + Set the plugin status, to be sent to Shepherd Control if configured. + + Args: + status: A flat dictionary of fields with string keys. + """ + self._update_state.set_status(status) + @property def confspec(self): return self._confspec @@ -313,6 +324,8 @@ def load_plugin(plugin_name, plugin_dir=None): attr, unbound=True, **attr._shepherd_load_marker._asdict()) interface.register_function(unbound_func) + interface._update_state.set_confspec(interface.confspec) + interface._loaded = True _loaded_plugins[plugin_name] = interface diff --git a/tests/test_control.py b/tests/test_control.py new file mode 100644 index 0000000..5cd9c6e --- /dev/null +++ b/tests/test_control.py @@ -0,0 +1,135 @@ +# pylint: disable=redefined-outer-name +import secrets +from base64 import b64encode +import json +import logging +import pytest +import responses +import statesman + +from shepherd.agent import control + + +def test_device_id(monkeypatch, tmpdir): + with pytest.raises(FileNotFoundError): + control.load_device_identity(tmpdir) + + def fixed_token_hex(_): + return '0123456789abcdef0123456789abcdef' + monkeypatch.setattr(secrets, "token_hex", fixed_token_hex) + + dev_secret, dev_id = control.generate_device_identity(tmpdir) + assert dev_secret == '0123456789abcdef0123456789abcdef' + assert dev_id == '3dead5e4' + + dev_secret, dev_id = control.load_device_identity(tmpdir) + assert dev_secret == '0123456789abcdef0123456789abcdef' + assert dev_id == '3dead5e4' + + +@pytest.fixture +def control_config(): + return {'server': 'api.shepherd.test', 'intro_key': 'abcdefabcdefabcdef'} + + +def test_config(control_config): + control.control_confspec().validate(control_config) + + +def test_url(): + assert control.clean_https_url('api.shepherd.test') == 'https://api.shepherd.test' + assert control.clean_https_url('api.shepherd.test/foo') == 'https://api.shepherd.test/foo' + assert control.clean_https_url('http://api.shepherd.test') == 'https://api.shepherd.test' + + +@responses.activate +def test_control_thread(control_config, tmpdir, caplog): + # Testing threads is a pain, as exceptions (including assertions) thrown in the thread don't + # cause the test to fail. We can cheat a little here, as the 'responses' mock framework will + # throw a requests.exceptions.ConnectionError if the request isn't recognised, and we're + # already logging those in Control. + + responses.add(responses.POST, 'https://api.shepherd.test/agent/update', json={}) + responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_A', json={}) + responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_B', json={}) + + core_update_state = control.CoreUpdateState( + {'the_local_config': 'val'}, {'the_applied_config': 'val'}) + plugin_update_states = {'plugin_A': control.PluginUpdateState(), + 'plugin_B': control.PluginUpdateState()} + + control_thread = control.init_control( + control_config, tmpdir, core_update_state, plugin_update_states) + control.stop() + control_thread.join() + + # Check there were no connection exceptions + for record in caplog.records: + assert record.levelno <= logging.WARNING + + # There is a log line present if the thread stopped properly + assert ("shepherd.agent.control", logging.WARNING, + "Control thread stopping...") in caplog.record_tuples + + +@responses.activate +def test_control(control_config, tmpdir, caplog, monkeypatch): + # Here we skip control_init and just run the update loop directly, to keep things in the same + # thread + + def fixed_token_hex(_): + return '0123456789abcdef0123456789abcdef' + monkeypatch.setattr(secrets, "token_hex", fixed_token_hex) + + core_topic_bundle = statesman.TopicBundle() + + core_topic_bundle.add_reader('status', statesman.StateReader()) + core_topic_bundle.add_reader('config-spec', statesman.StateReader()) + core_topic_bundle.add_reader('device-config', statesman.StateReader()) + core_topic_bundle.add_reader('applied-config', statesman.StateReader()) + + core_callback_count = 0 + + def core_update_callback(request): + nonlocal core_callback_count + core_callback_count += 1 + payload = json.loads(request.body) + assert 'applied-config' in payload + assert 'device-config' in payload + + core_topic_bundle.process_message(payload) + resp_body = core_topic_bundle.get_payload() + + basic_auth = b64encode( + b"0123456789abcdef0123456789abcdef:abcdefabcdefabcdef").decode("ascii") + assert request.headers['authorization'] == F"Basic {basic_auth}" + + return (200, {}, json.dumps(resp_body)) + + responses.add_callback( + responses.POST, 'https://api.shepherd.test/agent/update', + callback=core_update_callback, + content_type='application/json') + + responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_A', json={}) + responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_B', json={}) + + core_update_state = control.CoreUpdateState( + {'the_local_config': 'val'}, {'the_applied_config': 'val'}) + plugin_update_states = {'plugin_A': control.PluginUpdateState(), + 'plugin_B': control.PluginUpdateState()} + plugin_update_states['plugin_A'].set_status({"status1": '1'}) + + # control._stop_event.clear() + control._stop_event.set() + # With the stop event set, the loop should run through and update everything once before + # breaking + control._control_update_loop(control_config, tmpdir, core_update_state, plugin_update_states) + + assert core_callback_count == 1 + + assert not core_update_state.topic_bundle.is_update_required() + + # Check there were no connection exceptions + for record in caplog.records: + assert record.levelno <= logging.WARNING