Add HoldLock

master
Tom Wilson 6 years ago
parent 9fc3788266
commit 8ef0f5626d

@ -0,0 +1,260 @@
import time
import contextlib
import threading
from collections import namedtuple
class HoldLock():
"""
A sort-of thread lock, intended to allow one thread to wait until all others are finished
using a multi-user resource.
Once created, threads may call `hold()` on the HoldLock to acquire a hold. If a thread then
calls `wait()` or iterates `waiting_for()`, those calls will block until all holds are
released with `release()`.
In this simple use case, the HoldLock almost behaves like a reverse semaphore - `hold()`
increases a counter by 1, `release()` reduces it by 1, and calling `wait()` blocks until the
counter comes back down to 0. The closest example of a similar thing I've found is Golang
WaitGroups, which work like this.
Additionally, the HoldLock allows an identifier to be passed to `hold()`. This same identifier
must be referred to with `release()`, but can be any object - rather than a simple counter,
the HoldLock maintains a list of these identifiers. These only really become useful when the
main waiting thread calls `holders()` or iterates `waiting_for()` - as then it gets access
to these identifiers. The common use case here is to use a string explaining the reason for the
`hold()` as the identifier, which then allows the main thread to print a list of things it's
waiting for by iterating `waiting_for()`.
"""
Holder = namedtuple("Holder", ['identifier', 'expiry'])
class HoldContext(contextlib.AbstractContextManager):
def __init__(self, hold_wait, holder):
self.hold_wait = hold_wait
self.holder = holder
def __enter__(self):
return self.hold_wait
def __exit__(self, exc_type, exc_value, traceback):
self.hold_wait.release(self.holder)
def __bool__(self):
if self.holder.expiry is not None:
if self.holder.expiry < self.hold_wait.time_func():
return False
return True
def __init__(self, time_func=time.monotonic):
"""
Create a HoldLock instance. By default, time.monotonic is used for all timeouts, but this
can be supplied as any function that returns a current absolute time in seconds as a float.
"""
self._holders = []
self._cv = threading.Condition()
self.time_func = time_func
self._closed = False
def hold(self, identifier=None, timeout=None):
"""
Acquire a hold on this HoldLock, blocking any `wait()` call until all holds are released.
Multiple threads may acquire a hold simultaneously, and an identifier may be used more than
once.
The default `None` identifier works like any other, but will result in calls to `holders`
or `waiting_for()` to return a tuple containing None values.
Can either be called directly or used as a context manager - `with holdlock.hold():`
The returned object is a context manager, but a bool comparison with it will return False
if the timeout has expired:
with holdlock.hold(timeout=5) as hold:
while True:
time.sleep(1)
if not hold:
print("Timeout has expired)
"""
with self._cv:
if self._closed:
raise Exception("Cannot get new hold on closed HoldWait instance")
if timeout is not None:
timeout = self.time_func()+timeout
new_holder = self.Holder(identifier, timeout) # technically timeout is expiry here
self._holders.append(new_holder)
# Sort to make sure earliest expiry is first, with None at the end
self._holders.sort(key=lambda holder: (holder.expiry is None, holder.expiry))
self._cv.notify_all()
# cheat a bit by passing the holder as the identifier - this means the context manager
# will always release the relevant one
return self.HoldContext(self, new_holder)
def release(self, identifier=None):
"""
Release a hold on this HoldLock. If there are mutiple holders with the supplied identifier,
the one with the earliest timeout will be released.
"""
with self._cv:
if isinstance(identifier, self.Holder):
matching_holder = identifier
else:
# _holders is already sorted for us
for holder in self._holders:
if holder.identifier == identifier:
matching_holder = holder
break
else:
raise Exception(F"Release identifier '{identifier}' is not currently held")
self._holders.remove(matching_holder)
self._cv.notify_all()
def close(self):
"""
Stop any threads from acquiring a new hold on this HoldLock (the will raise an exception)
"""
with self._cv:
self._closed = True
def reopen(self):
"""
Start allowing threads to get a hold on this HoldLock again (after having called `close()`)
"""
with self._cv:
self._closed = False
@property
def holders(self):
"""
Return a tuple if current holder identities. The tuple itself is a copy, but the values in
it are the same objects that `hold()` calls have passed in as identifiers.
"""
with self._cv:
return(tuple(holder.identifier for holder in self._holders))
@property
def hold_count(self):
"""
Return the current number of holds on this HoldLock
"""
with self._cv:
return len(self._holders)
def wait(self, timeout=None):
"""
Wait for all threads currently holding this HoldLock to release it, returning True unless
the timeout is hit, where it will return False.
Note that unless `close()` is called first, _more threads may get a hold_ while waiting.
If `timeout` is specified, this must be a relative float value in seconds. If
`timeout` is None, `wait()` will block indefinitely for all holds to be released.
"""
expiry = None
if timeout is not None:
expiry = self.time_func()+timeout
with self._cv:
while len(self._holders) > 0:
cv_timeout = None
now = self.time_func()
# Pull out any holders that have expired
while (self._holders[0].expiry is not None):
if self._holders[0].expiry <= now:
self._holders.pop(0)
if len(self._holders) == 0:
return True
else:
cv_timeout = self._holders[0].expiry - now
break
if expiry is not None:
if expiry <= now:
return False
cv_timeout = min(cv_timeout, expiry - now) if cv_timeout else expiry - now
self._cv.wait(cv_timeout)
return True
def waiting_for(self, timeout=None, update_period=None):
"""
Behaves the same as `wait()`, but is a generator that will return sequences of remaining
holder identifiers while waiting for all holds to be released. By default, returns a new
sequence of remaining holders whenever it changes, but can also be supplied with
`update_period` to add more intermediate updates.
When all holds are released, the last returned sequence by the generator will be empty (no
longer waiting on any holds).
If `timeout` is not None and the timeout expires instead, the last sequence returned
will _not_ be empty (was still waiting on holds when the timeout expired).
"""
expiry = None
if timeout is not None:
expiry = self.time_func()+timeout
with self._cv:
# We effectively have 2 sections where holders can be released/timed out, and time can
# pass - the wait, and the yield, so things that check for changes
# in those need to be done after both.
while len(self._holders) > 0:
now = self.time_func()
# check main timeout
if expiry is not None:
if expiry <= now:
return
# expire any holders
while (self._holders[0].expiry is not None):
if self._holders[0].expiry <= now:
self._holders.pop(0)
if len(self._holders) == 0:
# Generate empty holder tuple and finish
self._cv.release()
yield tuple()
self._cv.acquire()
return
else:
break
# Yield holders
yielded_holders = self.holders
self._cv.release()
yield yielded_holders
self._cv.acquire()
# If holders has changed since before yield, continue (no need to wait for change)
if self.holders != yielded_holders:
continue
# Holders haven't changed, so we have at least 1
cv_timeout = update_period
now = self.time_func()
# Check main timeout again
if expiry is not None:
if expiry <= now:
return
cv_timeout = min(cv_timeout, expiry - now) if cv_timeout else expiry - now
# Check holder expiry again
if self._holders[0].expiry is not None:
if self._holders[0].expiry <= now:
# next holder has expired, continue and let original check deal with it
continue
else:
holder_timeout = self._holders[0].expiry - now
cv_timeout = min(
holder_timeout, cv_timeout) if cv_timeout else holder_timeout
self._cv.wait(cv_timeout)
# Generate empty holder tuple and finish
self._cv.release()
yield tuple()
self._cv.acquire()
return

@ -0,0 +1,103 @@
from threading import Timer
import time
from hold_lock import HoldLock
def test_basic_hold_release():
lock = HoldLock()
lock.wait()
lock.hold()
assert lock.holders == (None,)
lock.hold("hold1")
assert lock.holders == (None, "hold1")
assert not lock.wait(0.1)
lock.release("hold1")
lock.release()
assert lock.wait()
def test_thread_release():
lock = HoldLock()
lock.hold()
Timer(0.1, lambda: lock.release()).start()
assert lock.wait(1)
def test_hold_timeout():
lock = HoldLock()
lock.hold(timeout=0.1)
assert lock.wait()
def test_waiting_for():
lock = HoldLock()
# Wait with no holders
res = []
for holders in lock.waiting_for():
res.append(holders)
assert res == [()]
# Release holders while waiting
lock.hold("hold1")
lock.hold("hold2")
Timer(0.1, lambda: lock.release("hold1")).start()
Timer(0.2, lambda: lock.release("hold2")).start()
res = []
for holders in lock.waiting_for():
res.append(holders)
assert res == [("hold1", "hold2"), ("hold2",), ()]
# Release holders during the yield
lock.hold("hold1")
lock.hold("hold2")
Timer(0.1, lambda: lock.release("hold2")).start()
res = []
for loopnum, holders in enumerate(lock.waiting_for()):
res.append(holders)
if loopnum == 0:
lock.release("hold1")
assert res == [("hold1", "hold2"), ("hold2",), ()]
def test_waiting_for_timeout():
lock = HoldLock()
lock.hold("hold1")
res = []
for holders in lock.waiting_for(timeout=0.1):
res.append(holders)
assert res == [("hold1",)]
def test_waiting_for_timeout_during_yield():
lock = HoldLock()
lock.hold("hold1")
res = []
for holders in lock.waiting_for(timeout=0.1):
res.append(holders)
time.sleep(0.2)
assert res == [("hold1",)]
def test_waiting_for_hold_timeouts():
lock = HoldLock()
lock.hold("hold1", timeout=0.1)
lock.hold("hold2", timeout=0.2)
res = []
for holders in lock.waiting_for():
res.append(holders)
assert res == [("hold1", "hold2"), ("hold2",), ()]
def test_waiting_for_hold_timeout_during_yield():
lock = HoldLock()
lock.hold("hold1", timeout=0.1)
lock.hold("hold2", timeout=0.3)
res = []
for holders in lock.waiting_for():
res.append(holders)
time.sleep(0.2)
assert res == [("hold1", "hold2"), ("hold2",), ()]
Loading…
Cancel
Save