You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
shepherd-agent/tests/test_control.py

193 lines
7.1 KiB

# pylint: disable=redefined-outer-name
import secrets
from base64 import b64encode
import json
import logging
import time
import pytest
import responses
import statesman
from collections import namedtuple
from shepherd.agent import control
from shepherd.agent import plugin
def test_device_id(monkeypatch, tmpdir):
with pytest.raises(FileNotFoundError):
control.load_device_identity(tmpdir)
def fixed_token_hex(_):
return '0123456789abcdef0123456789abcdef'
monkeypatch.setattr(secrets, "token_hex", fixed_token_hex)
dev_secret, dev_id = control.generate_device_identity(tmpdir)
assert dev_secret == '0123456789abcdef0123456789abcdef'
assert dev_id == '3dead5e4'
dev_secret, dev_id = control.load_device_identity(tmpdir)
assert dev_secret == '0123456789abcdef0123456789abcdef'
assert dev_id == '3dead5e4'
@pytest.fixture
def control_config():
return {'server': 'api.shepherd.test', 'intro_key': 'abcdefabcdefabcdef'}
def test_config(control_config):
control.control_confspec().validate(control_config)
def test_url():
assert control.clean_https_url('api.shepherd.test') == 'https://api.shepherd.test'
assert control.clean_https_url('api.shepherd.test/foo') == 'https://api.shepherd.test/foo'
assert control.clean_https_url('http://api.shepherd.test') == 'https://api.shepherd.test'
@responses.activate
def test_control_thread(control_config, tmpdir, caplog):
# Testing threads is a pain, as exceptions (including assertions) thrown in the thread don't
# cause the test to fail. We can cheat a little here, as the 'responses' mock framework will
# throw a requests.exceptions.ConnectionError if the request isn't recognised, and we're
# already logging those in Control.
responses.add(responses.POST, 'https://api.shepherd.test/agent/update', json={})
responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_A', json={})
responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_B', json={})
core_update_state = control.CoreUpdateState(
statesman.SequenceReader(), statesman.SequenceWriter())
core_update_state.set_static_state({'the_local_config': 'val'}, {
'the_applied_config': 'val'}, {})
plugin_update_states = {'plugin_A': control.PluginUpdateState(),
'plugin_B': control.PluginUpdateState()}
control_thread = control.start_control(
control_config, tmpdir, core_update_state, plugin_update_states)
control.stop()
control_thread.join()
# Check there were no connection exceptions
for record in caplog.records:
assert record.levelno <= logging.WARNING
# There is a log line present if the thread stopped properly
assert ("shepherd.agent.control", logging.WARNING,
"Control thread stopping...") in caplog.record_tuples
@responses.activate
def test_control(control_config, tmpdir, caplog, monkeypatch):
# Here we skip control_init and just run the update loop directly, to keep things in the same
# thread
def fixed_token_hex(_):
return '0123456789abcdef0123456789abcdef'
monkeypatch.setattr(secrets, "token_hex", fixed_token_hex)
core_topic_bundle = statesman.TopicBundle()
core_topic_bundle.add('status', statesman.StateReader())
core_topic_bundle.add('config-spec', statesman.StateReader())
core_topic_bundle.add('device-config', statesman.StateReader())
core_topic_bundle.add('applied-config', statesman.StateReader())
core_topic_bundle.add('control-commands', statesman.SequenceWriter())
core_topic_bundle.add('command-results', statesman.SequenceReader())
core_callback_count = 0
def core_update_callback(request):
nonlocal core_callback_count
core_callback_count += 1
payload = json.loads(request.body)
assert 'applied-config' in payload
assert 'device-config' in payload
core_topic_bundle.process_message(payload)
resp_body = core_topic_bundle.get_payload()
basic_auth = b64encode(
b"0123456789abcdef0123456789abcdef:abcdefabcdefabcdef").decode("ascii")
assert request.headers['authorization'] == F"Basic {basic_auth}"
return (200, {}, json.dumps(resp_body))
responses.add_callback(
responses.POST, 'https://api.shepherd.test/agent/update',
callback=core_update_callback,
content_type='application/json')
responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_A', json={})
responses.add(responses.POST, 'https://api.shepherd.test/agent/pluginupdate/plugin_B', json={})
core_update_state = control.CoreUpdateState(
statesman.SequenceReader(), statesman.SequenceWriter())
core_update_state.set_static_state({'the_local_config': 'val'}, {
'the_applied_config': 'val'}, {})
plugin_update_states = {'plugin_A': control.PluginUpdateState(),
'plugin_B': control.PluginUpdateState()}
plugin_update_states['plugin_A'].set_status({"status1": '1'})
# control._stop_event.clear()
control._stop_event.set()
# With the stop event set, the loop should run through and update everything once before
# breaking
control._control_update_loop(control_config, tmpdir, core_update_state, plugin_update_states)
assert core_callback_count == 1
assert not core_update_state.topic_bundle.is_update_required()
# Check there were no connection exceptions
for record in caplog.records:
assert record.levelno <= logging.WARNING
def test_command_runner():
func_a_was_called = False
def func_a():
nonlocal func_a_was_called
func_a_was_called = True
test_function_a = plugin.InterfaceFunction(func_a, 'function_a')
func_b_was_called = False
def func_b(arg1):
nonlocal func_b_was_called
func_b_was_called = True
return arg1+1
test_function_b = plugin.InterfaceFunction(func_b, 'function_b')
func_tuple = namedtuple('test_functions', ('function_a', 'function_b')
)(test_function_a, test_function_b)
if_functions = {'test_plugin': func_tuple}
cmd_runner = control.CommandRunner(if_functions)
assert not func_a_was_called
cmd_runner._process_command(10, plugin.InterfaceCall('test_plugin', 'function_a', None))
assert func_a_was_called
assert not func_b_was_called
cmd_runner._process_command(12, plugin.InterfaceCall('test_plugin', 'function_b', {'arg1': 5}))
assert func_b_was_called
# Get most recent writer message
wr_msg = list(cmd_runner.cmd_result_writer._messages.values())[-1]
assert wr_msg == [12, 6]
func_b_was_called = False
cmd_runner.on_new_command_message(
[15, plugin.InterfaceCall('test_plugin', 'function_b', {'arg1': 8})])
while 15 in cmd_runner.current_commands:
time.sleep(0.01)
assert func_b_was_called
wr_msg = list(cmd_runner.cmd_result_writer._messages.values())[-1]
assert wr_msg == [15, 9]
# Control/Plugin integration tests
# Test command_runner with actual plugin