Better context managers and tests

master
Tom Wilson 6 years ago
parent 8ef0f5626d
commit a82e447913

@ -1,10 +1,9 @@
import time
import contextlib
import threading
from collections import namedtuple
class HoldLock():
class HoldLock(contextlib.AbstractContextManager):
"""
A sort-of thread lock, intended to allow one thread to wait until all others are finished
using a multi-user resource.
@ -26,24 +25,42 @@ class HoldLock():
`hold()` as the identifier, which then allows the main thread to print a list of things it's
waiting for by iterating `waiting_for()`.
"""
Holder = namedtuple("Holder", ['identifier', 'expiry'])
class HoldContext(contextlib.AbstractContextManager):
def __init__(self, hold_wait, holder):
self.hold_wait = hold_wait
self.holder = holder
class AnonHolder():
pass
class Holder(contextlib.AbstractContextManager):
"""
An object representing something that has a hold on a HoldLock. Can be used as a context
manager. Only intended to be used once.
"""
def __init__(self, hold_lock, identifier, expiry):
self.hold_lock = hold_lock
self.identifier = identifier
self.expiry = expiry
def __enter__(self):
return self.hold_wait
return self
def __exit__(self, exc_type, exc_value, traceback):
self.hold_wait.release(self.holder)
self.hold_lock._release(self)
def release(self):
self.hold_lock._release(self)
def expired(self):
if self.expiry is not None:
if self.expiry <= self.hold_lock.time_func():
return True
return False
def __enter__(self):
self.hold()
return self
def __bool__(self):
if self.holder.expiry is not None:
if self.holder.expiry < self.hold_wait.time_func():
return False
return True
def __exit__(self, exc_type, exc_value, traceback):
self.release()
def __init__(self, time_func=time.monotonic):
"""
@ -55,7 +72,7 @@ class HoldLock():
self.time_func = time_func
self._closed = False
def hold(self, identifier=None, timeout=None):
def hold(self, identifier=AnonHolder, timeout=None):
"""
Acquire a hold on this HoldLock, blocking any `wait()` call until all holds are released.
Multiple threads may acquire a hold simultaneously, and an identifier may be used more than
@ -78,36 +95,37 @@ class HoldLock():
with self._cv:
if self._closed:
raise Exception("Cannot get new hold on closed HoldWait instance")
if timeout is not None:
timeout = self.time_func()+timeout
new_holder = self.Holder(identifier, timeout) # technically timeout is expiry here
self._holders.append(new_holder)
new_holder = self.Holder(self, identifier, self.time_func() +
timeout if timeout else None)
self._hold(new_holder)
return new_holder
def _hold(self, holder):
with self._cv:
self._holders.append(holder)
# Sort to make sure earliest expiry is first, with None at the end
self._holders.sort(key=lambda holder: (holder.expiry is None, holder.expiry))
self._cv.notify_all()
# cheat a bit by passing the holder as the identifier - this means the context manager
# will always release the relevant one
return self.HoldContext(self, new_holder)
def release(self, identifier=None):
def release(self, identifier=AnonHolder):
"""
Release a hold on this HoldLock. If there are mutiple holders with the supplied identifier,
the one with the earliest timeout will be released.
"""
with self._cv:
if isinstance(identifier, self.Holder):
matching_holder = identifier
# _holders is already sorted for us
for holder in self._holders:
if holder.identifier == identifier:
matching_holder = holder
break
else:
# _holders is already sorted for us
for holder in self._holders:
if holder.identifier == identifier:
matching_holder = holder
break
else:
raise Exception(F"Release identifier '{identifier}' is not currently held")
raise Exception(F"Release identifier '{identifier}' is not currently held")
self._release(matching_holder)
self._holders.remove(matching_holder)
def _release(self, holder):
with self._cv:
self._holders.remove(holder)
self._cv.notify_all()
def close(self):

@ -1,4 +1,4 @@
from threading import Timer
from threading import Timer, Thread
import time
from hold_lock import HoldLock
@ -8,9 +8,9 @@ def test_basic_hold_release():
lock = HoldLock()
lock.wait()
lock.hold()
assert lock.holders == (None,)
assert lock.holders == (HoldLock.AnonHolder,)
lock.hold("hold1")
assert lock.holders == (None, "hold1")
assert lock.holders == (HoldLock.AnonHolder, "hold1")
assert not lock.wait(0.1)
@ -18,6 +18,10 @@ def test_basic_hold_release():
lock.release()
assert lock.wait()
holder = lock.hold("hold1")
holder.release()
assert lock.wait()
def test_thread_release():
lock = HoldLock()
@ -26,6 +30,28 @@ def test_thread_release():
assert lock.wait(1)
def test_context():
lock = HoldLock()
def with_func(hold_lock):
with hold_lock:
time.sleep(0.2)
Thread(target=with_func, args=[lock]).start()
assert not lock.wait(0.1)
assert lock.wait(1)
def test_hold_context():
lock = HoldLock()
def with_func(hold_lock):
with hold_lock.hold("hold1"):
time.sleep(0.2)
Thread(target=with_func, args=[lock]).start()
assert lock.holders == ("hold1",)
assert lock.wait()
def test_hold_timeout():
lock = HoldLock()
lock.hold(timeout=0.1)

Loading…
Cancel
Save