Better context managers and tests

master
Tom Wilson 6 years ago
parent 8ef0f5626d
commit a82e447913

@ -1,10 +1,9 @@
import time import time
import contextlib import contextlib
import threading import threading
from collections import namedtuple
class HoldLock(): class HoldLock(contextlib.AbstractContextManager):
""" """
A sort-of thread lock, intended to allow one thread to wait until all others are finished A sort-of thread lock, intended to allow one thread to wait until all others are finished
using a multi-user resource. using a multi-user resource.
@ -26,24 +25,42 @@ class HoldLock():
`hold()` as the identifier, which then allows the main thread to print a list of things it's `hold()` as the identifier, which then allows the main thread to print a list of things it's
waiting for by iterating `waiting_for()`. waiting for by iterating `waiting_for()`.
""" """
Holder = namedtuple("Holder", ['identifier', 'expiry'])
class HoldContext(contextlib.AbstractContextManager): class AnonHolder():
def __init__(self, hold_wait, holder): pass
self.hold_wait = hold_wait
self.holder = holder class Holder(contextlib.AbstractContextManager):
"""
An object representing something that has a hold on a HoldLock. Can be used as a context
manager. Only intended to be used once.
"""
def __init__(self, hold_lock, identifier, expiry):
self.hold_lock = hold_lock
self.identifier = identifier
self.expiry = expiry
def __enter__(self): def __enter__(self):
return self.hold_wait return self
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.hold_wait.release(self.holder) self.hold_lock._release(self)
def release(self):
self.hold_lock._release(self)
def expired(self):
if self.expiry is not None:
if self.expiry <= self.hold_lock.time_func():
return True
return False
def __enter__(self):
self.hold()
return self
def __bool__(self): def __exit__(self, exc_type, exc_value, traceback):
if self.holder.expiry is not None: self.release()
if self.holder.expiry < self.hold_wait.time_func():
return False
return True
def __init__(self, time_func=time.monotonic): def __init__(self, time_func=time.monotonic):
""" """
@ -55,7 +72,7 @@ class HoldLock():
self.time_func = time_func self.time_func = time_func
self._closed = False self._closed = False
def hold(self, identifier=None, timeout=None): def hold(self, identifier=AnonHolder, timeout=None):
""" """
Acquire a hold on this HoldLock, blocking any `wait()` call until all holds are released. Acquire a hold on this HoldLock, blocking any `wait()` call until all holds are released.
Multiple threads may acquire a hold simultaneously, and an identifier may be used more than Multiple threads may acquire a hold simultaneously, and an identifier may be used more than
@ -78,36 +95,37 @@ class HoldLock():
with self._cv: with self._cv:
if self._closed: if self._closed:
raise Exception("Cannot get new hold on closed HoldWait instance") raise Exception("Cannot get new hold on closed HoldWait instance")
if timeout is not None: new_holder = self.Holder(self, identifier, self.time_func() +
timeout = self.time_func()+timeout timeout if timeout else None)
new_holder = self.Holder(identifier, timeout) # technically timeout is expiry here self._hold(new_holder)
self._holders.append(new_holder) return new_holder
def _hold(self, holder):
with self._cv:
self._holders.append(holder)
# Sort to make sure earliest expiry is first, with None at the end # Sort to make sure earliest expiry is first, with None at the end
self._holders.sort(key=lambda holder: (holder.expiry is None, holder.expiry)) self._holders.sort(key=lambda holder: (holder.expiry is None, holder.expiry))
self._cv.notify_all() self._cv.notify_all()
# cheat a bit by passing the holder as the identifier - this means the context manager
# will always release the relevant one
return self.HoldContext(self, new_holder)
def release(self, identifier=None): def release(self, identifier=AnonHolder):
""" """
Release a hold on this HoldLock. If there are mutiple holders with the supplied identifier, Release a hold on this HoldLock. If there are mutiple holders with the supplied identifier,
the one with the earliest timeout will be released. the one with the earliest timeout will be released.
""" """
with self._cv: with self._cv:
if isinstance(identifier, self.Holder): # _holders is already sorted for us
matching_holder = identifier for holder in self._holders:
if holder.identifier == identifier:
matching_holder = holder
break
else: else:
# _holders is already sorted for us raise Exception(F"Release identifier '{identifier}' is not currently held")
for holder in self._holders:
if holder.identifier == identifier: self._release(matching_holder)
matching_holder = holder
break
else:
raise Exception(F"Release identifier '{identifier}' is not currently held")
self._holders.remove(matching_holder) def _release(self, holder):
with self._cv:
self._holders.remove(holder)
self._cv.notify_all() self._cv.notify_all()
def close(self): def close(self):

@ -1,4 +1,4 @@
from threading import Timer from threading import Timer, Thread
import time import time
from hold_lock import HoldLock from hold_lock import HoldLock
@ -8,9 +8,9 @@ def test_basic_hold_release():
lock = HoldLock() lock = HoldLock()
lock.wait() lock.wait()
lock.hold() lock.hold()
assert lock.holders == (None,) assert lock.holders == (HoldLock.AnonHolder,)
lock.hold("hold1") lock.hold("hold1")
assert lock.holders == (None, "hold1") assert lock.holders == (HoldLock.AnonHolder, "hold1")
assert not lock.wait(0.1) assert not lock.wait(0.1)
@ -18,6 +18,10 @@ def test_basic_hold_release():
lock.release() lock.release()
assert lock.wait() assert lock.wait()
holder = lock.hold("hold1")
holder.release()
assert lock.wait()
def test_thread_release(): def test_thread_release():
lock = HoldLock() lock = HoldLock()
@ -26,6 +30,28 @@ def test_thread_release():
assert lock.wait(1) assert lock.wait(1)
def test_context():
lock = HoldLock()
def with_func(hold_lock):
with hold_lock:
time.sleep(0.2)
Thread(target=with_func, args=[lock]).start()
assert not lock.wait(0.1)
assert lock.wait(1)
def test_hold_context():
lock = HoldLock()
def with_func(hold_lock):
with hold_lock.hold("hold1"):
time.sleep(0.2)
Thread(target=with_func, args=[lock]).start()
assert lock.holders == ("hold1",)
assert lock.wait()
def test_hold_timeout(): def test_hold_timeout():
lock = HoldLock() lock = HoldLock()
lock.hold(timeout=0.1) lock.hold(timeout=0.1)

Loading…
Cancel
Save