You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
353 lines
13 KiB
353 lines
13 KiB
import threading
|
|
import secrets
|
|
from types import SimpleNamespace
|
|
from pathlib import Path
|
|
from urllib.parse import urlparse, urlunparse, urljoin
|
|
from hashlib import blake2b
|
|
import time
|
|
import logging
|
|
import toml
|
|
import requests
|
|
from configspec import *
|
|
import statesman
|
|
|
|
|
|
# Namespace of types intended for server-side use.
|
|
def get_export():
|
|
from . import plugin
|
|
export = SimpleNamespace()
|
|
export.InterfaceCall = plugin.InterfaceCall
|
|
return export
|
|
|
|
|
|
log = logging.getLogger("shepherd.agent.control")
|
|
|
|
_control_update_required = threading.Condition()
|
|
|
|
|
|
def _update_required_callback():
|
|
with _control_update_required:
|
|
_control_update_required.notify()
|
|
|
|
|
|
def register_on(core_interface):
|
|
"""
|
|
Register the control confspec on the core interface.
|
|
"""
|
|
confspec = ConfigSpecification()
|
|
confspec.add_spec("server", StringSpec())
|
|
confspec.add_spec("intro_key", StringSpec())
|
|
|
|
core_interface.confspec.add_spec("control", confspec, optional=True)
|
|
|
|
|
|
class CoreUpdateState():
|
|
"""
|
|
A container for all state that might need communicating remotely to Control. Abstracts the
|
|
Statesman topics away from other parts of the Agent.
|
|
"""
|
|
|
|
def __init__(self, cmd_reader, cmd_result_writer):
|
|
"""
|
|
Control update handler for the `/update` core endpoint. Needs a reference to the CommandRunner
|
|
"""
|
|
self.topic_bundle = statesman.TopicBundle({
|
|
'status': statesman.StateWriter(),
|
|
'config-spec': statesman.StateWriter(),
|
|
'device-config': statesman.StateWriter(),
|
|
'applied-config': statesman.StateWriter(),
|
|
'control-commands': cmd_reader,
|
|
'command-results': cmd_result_writer})
|
|
|
|
self.topic_bundle.set_update_required_callback(_update_required_callback)
|
|
|
|
def set_static_state(self, local_config, applied_config, confspec):
|
|
# These should all effectively be static
|
|
self.topic_bundle['device-config'].set_state(local_config)
|
|
self.topic_bundle['applied-config'].set_state(applied_config)
|
|
self.topic_bundle['config-spec'].set_state(confspec)
|
|
|
|
def set_status(self, status_dict):
|
|
self.topic_bundle['status'].set_state(status_dict)
|
|
|
|
|
|
class CommandRunner():
|
|
def __init__(self, interface_functions):
|
|
self.cmd_reader = statesman.SequenceReader(
|
|
new_message_callback=self.on_new_command_message)
|
|
self.cmd_result_writer = statesman.SequenceWriter()
|
|
self._functions = interface_functions
|
|
self.current_commands = {}
|
|
|
|
def on_new_command_message(self, message):
|
|
# This should be a single list, where the first value is the command ID and the second
|
|
# value is a plugin.FunctionCall
|
|
commandID = message[0]
|
|
command_call = message[1]
|
|
|
|
command_thread = threading.Thread(target=self._process_command,
|
|
args=(commandID, command_call))
|
|
command_thread.start()
|
|
|
|
def _process_command(self, commandID, command_call):
|
|
if commandID in self.current_commands:
|
|
raise ValueError(F"Already running a command with ID {commandID}")
|
|
self.current_commands[commandID] = threading.current_thread()
|
|
|
|
try:
|
|
command_call.resolve(self._functions)
|
|
result = command_call.call()
|
|
|
|
self.cmd_result_writer.add_message([commandID, result])
|
|
finally:
|
|
self.current_commands.pop(commandID)
|
|
|
|
|
|
class PluginUpdateState():
|
|
def __init__(self):
|
|
self.topic_bundle = statesman.TopicBundle()
|
|
|
|
# config-spec should be static, but isn't known yet when this is created
|
|
self.topic_bundle.add('status', statesman.StateWriter())
|
|
self.topic_bundle.add('config-spec', statesman.StateWriter())
|
|
self.topic_bundle.add('command-spec', statesman.StateWriter())
|
|
# Why is config split out into plugins? Just like the device config and applied config,
|
|
# it's only loaded once at the start. Is this purely because it's easy to get at from the
|
|
# PluginInterface where this object is created?
|
|
|
|
self.topic_bundle.set_update_required_callback(_update_required_callback)
|
|
|
|
def set_status(self, status_dict):
|
|
self.topic_bundle['status'].set_state(status_dict)
|
|
|
|
def set_confspec(self, config_spec):
|
|
self.topic_bundle['config-spec'].set_state(config_spec)
|
|
|
|
def set_commandspec(self, command_spec):
|
|
self.topic_bundle['command-spec'].set_state(command_spec)
|
|
|
|
|
|
def clean_https_url(dirty_url):
|
|
"""
|
|
Take a url with or without the leading "https://" scheme, and convert it to one that
|
|
does. Change HTTP to HTTPS if present.
|
|
"""
|
|
# Some weirdness with URL parsing means that by default most urls (like www.google.com)
|
|
# get treated as relative
|
|
# https://stackoverflow.com/questions/53816559/python-3-netloc-value-in-urllib-parse-is-empty-if-url-doesnt-have
|
|
|
|
if "//" not in dirty_url:
|
|
dirty_url = "//"+dirty_url
|
|
return urlunparse(urlparse(dirty_url)._replace(scheme="https"))
|
|
|
|
|
|
def load_device_identity(root_dir):
|
|
"""
|
|
Attempt to load the device identity from the shepherd.identity file. Will throw exceptions if
|
|
this fails. Returns a tuple of (device_secret, device_id)
|
|
"""
|
|
identity_filepath = Path(root_dir, 'shepherd.identity')
|
|
if not identity_filepath.exists():
|
|
log.warning(F"{identity_filepath} file does not exist")
|
|
raise FileNotFoundError()
|
|
|
|
with identity_filepath.open() as identity_file:
|
|
identity_dict = toml.load(identity_file)
|
|
|
|
dev_secret = identity_dict["device_secret"]
|
|
|
|
dev_secret_bytes = bytes.fromhex(dev_secret)
|
|
|
|
if len(dev_secret_bytes) != 16:
|
|
log.error(F"Device secret loaded from file {identity_filepath} does not contain the "
|
|
"required 16 bytes")
|
|
raise ValueError()
|
|
|
|
secret_hash = blake2b(dev_secret_bytes, digest_size=16).hexdigest()
|
|
dev_id = secret_hash[:8]
|
|
log.info(F"Loaded device identity. ID: {dev_id}")
|
|
return (dev_secret, dev_id)
|
|
|
|
|
|
def generate_device_identity(root_dir):
|
|
"""
|
|
Generate a new device identity and save it to the shepherd.identity file.
|
|
Returns a tuple of (device_secret, device_id).
|
|
"""
|
|
|
|
dev_secret = secrets.token_hex(16)
|
|
|
|
identity_dict = {}
|
|
identity_dict['device_secret'] = dev_secret
|
|
|
|
identity_filepath = Path(root_dir, 'shepherd.identity')
|
|
with identity_filepath.open('w+') as identity_file:
|
|
toml.dump(identity_dict, identity_file)
|
|
|
|
dev_secret_bytes = bytes.fromhex(dev_secret)
|
|
secret_hash = blake2b(dev_secret_bytes, digest_size=16).hexdigest()
|
|
dev_id = secret_hash[:8]
|
|
|
|
log.info(F"Generated new device identity. ID: {dev_id}")
|
|
return (dev_secret, dev_id)
|
|
|
|
|
|
_update_thread_init_done = threading.Event()
|
|
|
|
_stop_event = threading.Event()
|
|
|
|
|
|
def stop():
|
|
_stop_event.set()
|
|
_update_required_callback()
|
|
log.info("Control thread stop requested.")
|
|
|
|
|
|
def start_control(config, root_dir, core_update_state, plugin_update_states):
|
|
"""
|
|
Start the Control update thread and initialise the Shepherd Control systems.
|
|
"""
|
|
_stop_event.clear()
|
|
_update_thread_init_done.clear()
|
|
control_thread = threading.Thread(target=_control_update_loop, args=(
|
|
config, root_dir, core_update_state, plugin_update_states))
|
|
control_thread.start()
|
|
|
|
# Wait for init so our log makes sense
|
|
_update_thread_init_done.wait()
|
|
|
|
return control_thread
|
|
|
|
|
|
def _control_update_loop(config, root_dir, core_update_state, plugin_update_states):
|
|
control_api_url = urljoin(clean_https_url(config["server"]), "/agent")
|
|
log.info(F"Control server API endpoint is {control_api_url}")
|
|
intro_key = config["intro_key"]
|
|
log.info(F"Using intro key: {intro_key}")
|
|
|
|
try:
|
|
device_secret, device_id = load_device_identity(root_dir)
|
|
except Exception:
|
|
log.warning("Could not load device identity from shepherd.identity file")
|
|
device_secret, device_id = generate_device_identity(root_dir)
|
|
|
|
_update_thread_init_done.set()
|
|
|
|
update_rate_limiter = SmoothTokenBucketLimit(10, 10*60, 3, time.monotonic())
|
|
|
|
session = requests.Session()
|
|
# r=session.post('https://api.shepherd.test/agent/update')
|
|
while True:
|
|
# Spin here until something needs updating
|
|
with _control_update_required:
|
|
|
|
new_endpoint_updates = {} # a dict of url:topic_bundle pairs
|
|
while True:
|
|
if core_update_state.topic_bundle.is_update_required():
|
|
new_endpoint_updates['/update'] = core_update_state.topic_bundle
|
|
|
|
for plugin_name, state in plugin_update_states.items():
|
|
if state.topic_bundle.is_update_required():
|
|
new_endpoint_updates[f"/pluginupdate/{plugin_name}"] = state.topic_bundle
|
|
|
|
if (len(new_endpoint_updates) > 0) or _stop_event.is_set():
|
|
break
|
|
|
|
_control_update_required.wait()
|
|
|
|
for endpoint, topic_bundle in new_endpoint_updates.items():
|
|
try:
|
|
r = session.post(control_api_url+endpoint,
|
|
json=topic_bundle.get_payload(),
|
|
auth=(device_secret, intro_key))
|
|
|
|
if r.status_code == requests.codes['conflict']:
|
|
# Server replies with this when trying to add our device ID and failing
|
|
# due to it already existing (device secret hash is a mismatch). We need to
|
|
# regenerate our ID
|
|
log.info(F"Control server has indicated that device ID {device_id} already"
|
|
" exists. Generating new one...")
|
|
device_secret, device_id = generate_device_identity(root_dir)
|
|
elif r.status_code == requests.codes['ok']:
|
|
topic_bundle.process_message(r.json())
|
|
except requests.exceptions.RequestException:
|
|
log.exception("Failed to make Shepherd Control request")
|
|
|
|
if _stop_event.is_set():
|
|
# Breaking here is a clean way of killing any delay and allowing a final update before
|
|
# the thread ends.
|
|
log.warning("Control thread stopping...")
|
|
_stop_event.clear()
|
|
break
|
|
|
|
delay = update_rate_limiter.new_event(time.monotonic())
|
|
_stop_event.wait(delay)
|
|
|
|
_update_thread_init_done.clear()
|
|
|
|
|
|
def get_cached_config(config_dir):
|
|
return {}
|
|
|
|
|
|
def clear_cached_config(config_dir):
|
|
pass
|
|
|
|
|
|
class SmoothTokenBucketLimit():
|
|
"""
|
|
Event rate limiter implementing a modified Token Bucket algorithm. Delay returned ramps
|
|
up as the bucket empties.
|
|
"""
|
|
|
|
def __init__(self, allowed_events, period, allowed_burst, initial_time):
|
|
self.allowed_events = allowed_events
|
|
self.period = period
|
|
self.allowed_burst = allowed_burst
|
|
self.last_token_timestamp = initial_time
|
|
self.tokens = allowed_events
|
|
self._is_saturated = False
|
|
|
|
def new_event(self, time_now):
|
|
"""
|
|
Register a new event for the rate limiter. Return a required delay to ignore future
|
|
events for in seconds. Conceptually, the "token" we're grabbing here is for the _next_
|
|
event.
|
|
"""
|
|
if self.tokens < self.allowed_events:
|
|
time_since_last_token = time_now - self.last_token_timestamp
|
|
tokens_added = int(time_since_last_token/(self.period/self.allowed_events))
|
|
|
|
self.tokens = self.tokens + tokens_added
|
|
self.last_token_timestamp = self.last_token_timestamp + \
|
|
(self.period/self.allowed_events)*tokens_added
|
|
if self.tokens >= self.allowed_events:
|
|
self.last_token_timestamp = time_now
|
|
|
|
if self.tokens > 0:
|
|
self.tokens = self.tokens - 1
|
|
# Add a delay that ramps from 0 when the bucket is allowed_burst from full to p/x when
|
|
# it is empty
|
|
ramp_token_count = self.allowed_events-self.allowed_burst
|
|
if self.tokens > ramp_token_count:
|
|
delay = 0
|
|
else:
|
|
delay = ((self.period/self.allowed_events)/ramp_token_count) * \
|
|
(ramp_token_count-self.tokens)
|
|
|
|
self._is_saturated = False
|
|
else:
|
|
delay = (self.period/self.allowed_events) - (time_now-self.last_token_timestamp)
|
|
self.last_token_timestamp = time_now+delay
|
|
# This delay makes is set to account for the next token that would otherwise be added,
|
|
# without relying on the returned delay _actually_ occurring exactly.
|
|
self._is_saturated = True
|
|
|
|
return delay
|
|
|
|
def is_saturated(self):
|
|
"""
|
|
Returns true if the rate limiter delay is at it's maximum value
|
|
"""
|
|
return self._is_saturated
|