parent
17225a1b39
commit
199acb7e3b
@ -1,132 +1,291 @@
|
||||
import os
|
||||
import uuid
|
||||
import subprocess
|
||||
import requests
|
||||
import threading
|
||||
import json
|
||||
import secrets
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse, urlunparse, urljoin
|
||||
from collections import namedtuple
|
||||
from hashlib import blake2b
|
||||
import time
|
||||
import logging
|
||||
import toml
|
||||
import requests
|
||||
from configspec import *
|
||||
import statesman
|
||||
|
||||
from . import plugin
|
||||
# Check for shepherd.new file in edit conf dir. If there,
|
||||
# or if no shepherd.id file can be found, generate a new one.
|
||||
# For now, also attempt to delete /var/lib/zerotier-one/identity.public and identity.secret
|
||||
# Once generated, if it was due to shepherd.new file, delete it.
|
||||
log = logging.getLogger("shepherd.agent.control")
|
||||
|
||||
_control_update_required = threading.Condition()
|
||||
|
||||
#Start new thread, and push ID and core config to api.shepherd.distreon.net/client/update
|
||||
|
||||
class UpdateManager():
|
||||
def __init__(self):
|
||||
pass
|
||||
def _update_required_callback():
|
||||
with _control_update_required:
|
||||
_control_update_required.notify()
|
||||
|
||||
|
||||
def control_confspec():
|
||||
"""
|
||||
Returns the control config specification
|
||||
"""
|
||||
confspec = ConfigSpecification(optional=True)
|
||||
confspec.add_spec("server", StringSpec())
|
||||
confspec.add_spec("intro_key", StringSpec())
|
||||
|
||||
return confspec
|
||||
|
||||
|
||||
class CoreUpdateState():
|
||||
def __init__(self, local_config, applied_config):
|
||||
self.topic_bundle = statesman.TopicBundle()
|
||||
|
||||
class SequenceUpdate():
|
||||
Item = namedtuple('Item', ['sequence_number', 'content'])
|
||||
self.topic_bundle.add_writer('status', statesman.StateWriter())
|
||||
self.topic_bundle.add_writer('config-spec', statesman.StateWriter())
|
||||
self.topic_bundle.add_writer('device-config', statesman.StateWriter())
|
||||
self.topic_bundle.add_writer('applied-config', statesman.StateWriter())
|
||||
|
||||
self.topic_bundle.set_update_required_callback(_update_required_callback)
|
||||
|
||||
# These should all effectively be static
|
||||
self.topic_bundle['device-config'].set_state(local_config)
|
||||
self.topic_bundle['applied-config'].set_state(applied_config)
|
||||
|
||||
def set_status(self, status_dict):
|
||||
self.topic_bundle['status'].set_state(status_dict)
|
||||
|
||||
|
||||
class PluginUpdateState():
|
||||
def __init__(self):
|
||||
self.items = []
|
||||
self._sequence_count = 0
|
||||
self._dirty = False
|
||||
|
||||
def _next_sequence_number(self):
|
||||
# TODO: need to establish a max sequence number, so that it can be compared to half
|
||||
# that range and wrap around.
|
||||
self._sequence_count +=1
|
||||
return self._sequence_count
|
||||
|
||||
def mark_as_dirty(self):
|
||||
self._dirty = True
|
||||
|
||||
def add_item(self, item):
|
||||
self.items.append(self.Item(self._next_sequence_number(), item))
|
||||
self.mark_as_dirty()
|
||||
|
||||
def get_payload():
|
||||
pass
|
||||
def process_ack():
|
||||
pass
|
||||
|
||||
client_id = None
|
||||
control_url = None
|
||||
api_key = None
|
||||
|
||||
def _update_job(core_config, plugin_config):
|
||||
payload = {"client_id":client_id, "core_config":core_config,"plugin_config":plugin_config}
|
||||
#json_string = json.dumps(payload)
|
||||
try:
|
||||
# Using the json arg rather than json.dumps ourselves automatically sets the Content-Type
|
||||
# header to application/json, which Flask expects to work correctly
|
||||
r = requests.post(control_url, json=payload, auth=(client_id, api_key))
|
||||
except requests.exceptions.ConnectionError:
|
||||
raise
|
||||
|
||||
def generate_new_id(root_dir):
|
||||
global client_id
|
||||
with open(os.path.join(root_dir, "shepherd.id"), 'w+') as f:
|
||||
new_id = uuid.uuid1()
|
||||
client_id = str(new_id)
|
||||
f.write(client_id)
|
||||
generate_new_zerotier_id()
|
||||
|
||||
def get_config(config_dir):
|
||||
return {}
|
||||
self.topic_bundle = statesman.TopicBundle()
|
||||
|
||||
# config-spec should be static, but isn't known yet when this is created
|
||||
self.topic_bundle.add_writer('status', statesman.StateWriter())
|
||||
self.topic_bundle.add_writer('config-spec', statesman.StateWriter())
|
||||
|
||||
def init_control(core_config, plugin_config):
|
||||
global client_id
|
||||
global control_url
|
||||
global api_key
|
||||
self.topic_bundle.set_update_required_callback(_update_required_callback)
|
||||
|
||||
# On init, need to be able to quickly return the cached shepherd control layer if necessary.
|
||||
def set_status(self, status_dict):
|
||||
self.topic_bundle['status'].set_state(status_dict)
|
||||
|
||||
# Create the /update endpoint structure
|
||||
def set_confspec(self, config_spec):
|
||||
self.topic_bundle['config-spec'].set_state(config_spec)
|
||||
|
||||
root_dir = os.path.expanduser(core_config["root_dir"])
|
||||
editconf_dir = os.path.dirname(os.path.expanduser(core_config["conf_edit_path"]))
|
||||
|
||||
#Some weirdness with URL parsing means that by default most urls (like www.google.com)
|
||||
def clean_https_url(dirty_url):
|
||||
"""
|
||||
Take a url with or without the leading "https://" scheme, and convert it to one that
|
||||
does. Change HTTP to HTTPS if present.
|
||||
"""
|
||||
# Some weirdness with URL parsing means that by default most urls (like www.google.com)
|
||||
# get treated as relative
|
||||
# https://stackoverflow.com/questions/53816559/python-3-netloc-value-in-urllib-parse-is-empty-if-url-doesnt-have
|
||||
|
||||
control_url = core_config["control_server"]
|
||||
if "//" not in control_url:
|
||||
control_url = "//"+control_url
|
||||
control_url = urlunparse(urlparse(control_url)._replace(scheme="https"))
|
||||
control_url = urljoin(control_url, "/client/update")
|
||||
print(F"Control url is: {control_url}")
|
||||
|
||||
api_key = core_config["api_key"]
|
||||
|
||||
if os.path.isfile(os.path.join(editconf_dir, "shepherd.new")):
|
||||
generate_new_id(root_dir)
|
||||
os.remove(os.path.join(editconf_dir, "shepherd.new"))
|
||||
print(F"Config hostname: {core_config['hostname']}")
|
||||
if not (core_config["hostname"] == ""):
|
||||
print("Attempting to change hostname")
|
||||
subprocess.run(["raspi-config", "nonint", "do_hostname", core_config["hostname"]])
|
||||
elif not os.path.isfile(os.path.join(root_dir, "shepherd.id")):
|
||||
generate_new_id(root_dir)
|
||||
else:
|
||||
with open(os.path.join(root_dir, "shepherd.id"), 'r') as id_file:
|
||||
client_id = id_file.readline().strip()
|
||||
|
||||
print(F"Client ID is: {client_id}")
|
||||
|
||||
control_thread = threading.Thread(target=_update_job, args=(core_config,plugin_config))
|
||||
if "//" not in dirty_url:
|
||||
dirty_url = "//"+dirty_url
|
||||
return urlunparse(urlparse(dirty_url)._replace(scheme="https"))
|
||||
|
||||
|
||||
def load_device_identity(root_dir):
|
||||
"""
|
||||
Attempt to load the device identity from the shepherd.identity file. Will throw exceptions if
|
||||
this fails. Returns a tuple of (device_secret, device_id)
|
||||
"""
|
||||
identity_filepath = Path(root_dir, 'shepherd.identity')
|
||||
if not identity_filepath.exists():
|
||||
log.warning(F"{identity_filepath} file does not exist")
|
||||
raise FileNotFoundError()
|
||||
|
||||
with identity_filepath.open() as identity_file:
|
||||
identity_dict = toml.load(identity_file)
|
||||
|
||||
dev_secret = identity_dict["device_secret"]
|
||||
|
||||
dev_secret_bytes = bytes.fromhex(dev_secret)
|
||||
|
||||
if len(dev_secret_bytes) != 16:
|
||||
log.error(F"Device secret loaded from file {identity_filepath} does not contain the "
|
||||
"required 16 bytes")
|
||||
raise ValueError()
|
||||
|
||||
secret_hash = blake2b(dev_secret_bytes, digest_size=16).hexdigest()
|
||||
dev_id = secret_hash[:8]
|
||||
log.info(F"Loaded device identity. ID: {dev_id}")
|
||||
return (dev_secret, dev_id)
|
||||
|
||||
|
||||
def generate_device_identity(root_dir):
|
||||
"""
|
||||
Generate a new device identity and save it to the shepherd.identity file.
|
||||
Returns a tuple of (device_secret, device_id).
|
||||
"""
|
||||
|
||||
dev_secret = secrets.token_hex(16)
|
||||
|
||||
identity_dict = {}
|
||||
identity_dict['device_secret'] = dev_secret
|
||||
|
||||
identity_filepath = Path(root_dir, 'shepherd.identity')
|
||||
with identity_filepath.open('w+') as identity_file:
|
||||
toml.dump(identity_dict, identity_file)
|
||||
|
||||
dev_secret_bytes = bytes.fromhex(dev_secret)
|
||||
secret_hash = blake2b(dev_secret_bytes, digest_size=16).hexdigest()
|
||||
dev_id = secret_hash[:8]
|
||||
|
||||
log.info(F"Generated new device identity. ID: {dev_id}")
|
||||
return (dev_secret, dev_id)
|
||||
|
||||
|
||||
_update_thread_init_done = threading.Event()
|
||||
|
||||
_stop_event = threading.Event()
|
||||
|
||||
|
||||
def stop():
|
||||
_stop_event.set()
|
||||
_update_required_callback()
|
||||
log.info("Control thread stop requested.")
|
||||
|
||||
|
||||
def init_control(config, root_dir, core_update_state, plugin_update_states):
|
||||
"""
|
||||
Start the Control update thread and initialise the Shepherd Control systems.
|
||||
"""
|
||||
_stop_event.clear()
|
||||
_update_thread_init_done.clear()
|
||||
control_thread = threading.Thread(target=_control_update_loop, args=(
|
||||
config, root_dir, core_update_state, plugin_update_states))
|
||||
control_thread.start()
|
||||
|
||||
# Wait for init so our log makes sense
|
||||
_update_thread_init_done.wait()
|
||||
|
||||
def _post_logs_job():
|
||||
logs = shepherd.plugin.plugin_functions["scout"].get_logs()
|
||||
measurements = shepherd.plugin.plugin_functions["scout"].get_measurements()
|
||||
return control_thread
|
||||
|
||||
|
||||
payload = {"client_id":client_id, "logs":logs, "measurements":measurements}
|
||||
def _control_update_loop(config, root_dir, core_update_state, plugin_update_states):
|
||||
control_api_url = urljoin(clean_https_url(config["server"]), "/agent")
|
||||
log.info(F"Control server API endpoint is {control_api_url}")
|
||||
intro_key = config["intro_key"]
|
||||
log.info(F"Using intro key: {intro_key}")
|
||||
|
||||
try:
|
||||
r = requests.post(control_url, json=payload, auth=(client_id, api_key))
|
||||
except requests.exceptions.ConnectionError:
|
||||
pass
|
||||
device_secret, device_id = load_device_identity(root_dir)
|
||||
except Exception:
|
||||
log.warning("Could not load device identity from shepherd.identity file")
|
||||
device_secret, device_id = generate_device_identity(root_dir)
|
||||
|
||||
_update_thread_init_done.set()
|
||||
|
||||
update_rate_limiter = SmoothTokenBucketLimit(10, 10*60, 3, time.monotonic())
|
||||
|
||||
session = requests.Session()
|
||||
# r=session.post('https://api.shepherd.test/agent/update')
|
||||
while True:
|
||||
# Spin here until something needs updating
|
||||
with _control_update_required:
|
||||
|
||||
new_endpoint_updates = {} # a dict of url:topic_bundle pairs
|
||||
while True:
|
||||
if core_update_state.topic_bundle.is_update_required():
|
||||
new_endpoint_updates['/update'] = core_update_state.topic_bundle
|
||||
|
||||
for plugin_name, state in plugin_update_states.items():
|
||||
if state.topic_bundle.is_update_required():
|
||||
new_endpoint_updates[f"/pluginupdate/{plugin_name}"] = state.topic_bundle
|
||||
|
||||
if (len(new_endpoint_updates) > 0) or _stop_event.is_set():
|
||||
break
|
||||
|
||||
_control_update_required.wait()
|
||||
|
||||
for endpoint, topic_bundle in new_endpoint_updates.items():
|
||||
try:
|
||||
r = session.post(control_api_url+endpoint,
|
||||
json=topic_bundle.get_payload(),
|
||||
auth=(device_secret, intro_key))
|
||||
|
||||
if r.status_code == requests.codes['conflict']:
|
||||
# Server replies with this when trying to add our device ID and failing
|
||||
# due to it already existing (device secret hash is a mismatch). We need to
|
||||
# regenerate our ID
|
||||
log.info(F"Control server has indicated that device ID {device_id} already"
|
||||
" exists. Generating new one...")
|
||||
device_secret, device_id = generate_device_identity(root_dir)
|
||||
elif r.status_code == requests.codes['ok']:
|
||||
topic_bundle.process_message(r.json())
|
||||
except requests.exceptions.RequestException:
|
||||
log.exception("Failed to make Shepherd Control request")
|
||||
|
||||
if _stop_event.is_set():
|
||||
# Breaking here is a clean way of killing any delay and allowing a final update before
|
||||
# the thread ends.
|
||||
log.warning("Control thread stopping...")
|
||||
break
|
||||
|
||||
delay = update_rate_limiter.new_event(time.monotonic())
|
||||
_stop_event.wait(delay)
|
||||
|
||||
_update_thread_init_done.clear()
|
||||
|
||||
|
||||
def get_cached_config(config_dir):
|
||||
return {}
|
||||
|
||||
|
||||
def clear_cached_config(config_dir):
|
||||
pass
|
||||
|
||||
|
||||
class SmoothTokenBucketLimit():
|
||||
"""
|
||||
Event rate limiter implementing a modified Token Bucket algorithm. Delay returned ramps
|
||||
up as the bucket empties.
|
||||
"""
|
||||
|
||||
def __init__(self, allowed_events, period, allowed_burst, initial_time):
|
||||
self.allowed_events = allowed_events
|
||||
self.period = period
|
||||
self.allowed_burst = allowed_burst
|
||||
self.last_token_timestamp = initial_time
|
||||
self.tokens = allowed_events
|
||||
self._is_saturated = False
|
||||
|
||||
def new_event(self, time_now):
|
||||
"""
|
||||
Register a new event for the rate limiter. Return a required delay to ignore future
|
||||
events for in seconds. Conceptually, the "token" we're grabbing here is for the _next_
|
||||
event.
|
||||
"""
|
||||
if self.tokens < self.allowed_events:
|
||||
time_since_last_token = time_now - self.last_token_timestamp
|
||||
tokens_added = int(time_since_last_token/(self.period/self.allowed_events))
|
||||
|
||||
self.tokens = self.tokens + tokens_added
|
||||
self.last_token_timestamp = self.last_token_timestamp + \
|
||||
(self.period/self.allowed_events)*tokens_added
|
||||
if self.tokens >= self.allowed_events:
|
||||
self.last_token_timestamp = time_now
|
||||
|
||||
if self.tokens > 0:
|
||||
self.tokens = self.tokens - 1
|
||||
# Add a delay that ramps from 0 when the bucket is allowed_burst from full to p/x when
|
||||
# it is empty
|
||||
ramp_token_count = self.allowed_events-self.allowed_burst
|
||||
if self.tokens > ramp_token_count:
|
||||
delay = 0
|
||||
else:
|
||||
delay = ((self.period/self.allowed_events)/ramp_token_count) * \
|
||||
(ramp_token_count-self.tokens)
|
||||
|
||||
self._is_saturated = False
|
||||
else:
|
||||
delay = (self.period/self.allowed_events) - (time_now-self.last_token_timestamp)
|
||||
self.last_token_timestamp = time_now+delay
|
||||
# This delay makes is set to account for the next token that would otherwise be added,
|
||||
# without relying on the returned delay _actually_ occurring exactly.
|
||||
self._is_saturated = True
|
||||
|
||||
def post_logs():
|
||||
post_logs_thread = threading.Thread(target=_post_logs_job, args=())
|
||||
post_logs_thread.start()
|
||||
return delay
|
||||
|
||||
def is_saturated(self):
|
||||
"""
|
||||
Returns true if the rate limiter delay is at it's maximum value
|
||||
"""
|
||||
return self._is_saturated
|
||||
|
||||
@ -0,0 +1,135 @@
|
||||
# pylint: disable=redefined-outer-name
|
||||
import secrets
|
||||
from base64 import b64encode
|
||||
import json
|
||||
import logging
|
||||
import pytest
|
||||
import responses
|
||||
import statesman
|
||||
|
||||
from shepherd.agent import control
|
||||
|
||||
|
||||
def test_device_id(monkeypatch, tmpdir):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
control.load_device_identity(tmpdir)
|
||||
|
||||
def fixed_token_hex(_):
|
||||
return '0123456789abcdef0123456789abcdef'
|
||||
monkeypatch.setattr(secrets, "token_hex", fixed_token_hex)
|
||||
|
||||
dev_secret, dev_id = control.generate_device_identity(tmpdir)
|
||||
assert dev_secret == '0123456789abcdef0123456789abcdef'
|
||||
assert dev_id == '3dead5e4'
|
||||
|
||||
dev_secret, dev_id = control.load_device_identity(tmpdir)
|
||||
assert dev_secret == '0123456789abcdef0123456789abcdef'
|
||||
assert dev_id == '3dead5e4'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def control_config():
|
||||
return {'server': 'api.shepherd.test', 'intro_key': 'abcdefabcdefabcdef'}
|
||||
|
||||
|
||||
def test_config(control_config):
|
||||
control.control_confspec().validate(control_config)
|
||||
|
||||
|
||||
def test_url():
|
||||
assert control.clean_https_url('api.shepherd.test') == 'https://api.shepherd.test'
|
||||
assert control.clean_https_url('api.shepherd.test/foo') == 'https://api.shepherd.test/foo'
|
||||
assert control.clean_https_url('http://api.shepherd.test') == 'https://api.shepherd.test'
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_control_thread(control_config, tmpdir, caplog):
|
||||
# Testing threads is a pain, as exceptions (including assertions) thrown in the thread don't
|
||||
# cause the test to fail. We can cheat a little here, as the 'responses' mock framework will
|
||||
# throw a requests.exceptions.ConnectionError if the request isn't recognised, and we're
|
||||
# already logging those in Control.
|
||||
|
||||
responses.add(responses.POST, 'https://api.shepherd.test/agent/update', json={})
|
||||
responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_A', json={})
|
||||
responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_B', json={})
|
||||
|
||||
core_update_state = control.CoreUpdateState(
|
||||
{'the_local_config': 'val'}, {'the_applied_config': 'val'})
|
||||
plugin_update_states = {'plugin_A': control.PluginUpdateState(),
|
||||
'plugin_B': control.PluginUpdateState()}
|
||||
|
||||
control_thread = control.init_control(
|
||||
control_config, tmpdir, core_update_state, plugin_update_states)
|
||||
control.stop()
|
||||
control_thread.join()
|
||||
|
||||
# Check there were no connection exceptions
|
||||
for record in caplog.records:
|
||||
assert record.levelno <= logging.WARNING
|
||||
|
||||
# There is a log line present if the thread stopped properly
|
||||
assert ("shepherd.agent.control", logging.WARNING,
|
||||
"Control thread stopping...") in caplog.record_tuples
|
||||
|
||||
|
||||
@responses.activate
|
||||
def test_control(control_config, tmpdir, caplog, monkeypatch):
|
||||
# Here we skip control_init and just run the update loop directly, to keep things in the same
|
||||
# thread
|
||||
|
||||
def fixed_token_hex(_):
|
||||
return '0123456789abcdef0123456789abcdef'
|
||||
monkeypatch.setattr(secrets, "token_hex", fixed_token_hex)
|
||||
|
||||
core_topic_bundle = statesman.TopicBundle()
|
||||
|
||||
core_topic_bundle.add_reader('status', statesman.StateReader())
|
||||
core_topic_bundle.add_reader('config-spec', statesman.StateReader())
|
||||
core_topic_bundle.add_reader('device-config', statesman.StateReader())
|
||||
core_topic_bundle.add_reader('applied-config', statesman.StateReader())
|
||||
|
||||
core_callback_count = 0
|
||||
|
||||
def core_update_callback(request):
|
||||
nonlocal core_callback_count
|
||||
core_callback_count += 1
|
||||
payload = json.loads(request.body)
|
||||
assert 'applied-config' in payload
|
||||
assert 'device-config' in payload
|
||||
|
||||
core_topic_bundle.process_message(payload)
|
||||
resp_body = core_topic_bundle.get_payload()
|
||||
|
||||
basic_auth = b64encode(
|
||||
b"0123456789abcdef0123456789abcdef:abcdefabcdefabcdef").decode("ascii")
|
||||
assert request.headers['authorization'] == F"Basic {basic_auth}"
|
||||
|
||||
return (200, {}, json.dumps(resp_body))
|
||||
|
||||
responses.add_callback(
|
||||
responses.POST, 'https://api.shepherd.test/agent/update',
|
||||
callback=core_update_callback,
|
||||
content_type='application/json')
|
||||
|
||||
responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_A', json={})
|
||||
responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_B', json={})
|
||||
|
||||
core_update_state = control.CoreUpdateState(
|
||||
{'the_local_config': 'val'}, {'the_applied_config': 'val'})
|
||||
plugin_update_states = {'plugin_A': control.PluginUpdateState(),
|
||||
'plugin_B': control.PluginUpdateState()}
|
||||
plugin_update_states['plugin_A'].set_status({"status1": '1'})
|
||||
|
||||
# control._stop_event.clear()
|
||||
control._stop_event.set()
|
||||
# With the stop event set, the loop should run through and update everything once before
|
||||
# breaking
|
||||
control._control_update_loop(control_config, tmpdir, core_update_state, plugin_update_states)
|
||||
|
||||
assert core_callback_count == 1
|
||||
|
||||
assert not core_update_state.topic_bundle.is_update_required()
|
||||
|
||||
# Check there were no connection exceptions
|
||||
for record in caplog.records:
|
||||
assert record.levelno <= logging.WARNING
|
||||
Loading…
Reference in new issue