parent
9fc3788266
commit
8ef0f5626d
@ -0,0 +1,260 @@
|
||||
import time
|
||||
import contextlib
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
|
||||
|
||||
class HoldLock():
|
||||
"""
|
||||
A sort-of thread lock, intended to allow one thread to wait until all others are finished
|
||||
using a multi-user resource.
|
||||
|
||||
Once created, threads may call `hold()` on the HoldLock to acquire a hold. If a thread then
|
||||
calls `wait()` or iterates `waiting_for()`, those calls will block until all holds are
|
||||
released with `release()`.
|
||||
|
||||
In this simple use case, the HoldLock almost behaves like a reverse semaphore - `hold()`
|
||||
increases a counter by 1, `release()` reduces it by 1, and calling `wait()` blocks until the
|
||||
counter comes back down to 0. The closest example of a similar thing I've found is Golang
|
||||
WaitGroups, which work like this.
|
||||
|
||||
Additionally, the HoldLock allows an identifier to be passed to `hold()`. This same identifier
|
||||
must be referred to with `release()`, but can be any object - rather than a simple counter,
|
||||
the HoldLock maintains a list of these identifiers. These only really become useful when the
|
||||
main waiting thread calls `holders()` or iterates `waiting_for()` - as then it gets access
|
||||
to these identifiers. The common use case here is to use a string explaining the reason for the
|
||||
`hold()` as the identifier, which then allows the main thread to print a list of things it's
|
||||
waiting for by iterating `waiting_for()`.
|
||||
"""
|
||||
Holder = namedtuple("Holder", ['identifier', 'expiry'])
|
||||
|
||||
class HoldContext(contextlib.AbstractContextManager):
|
||||
def __init__(self, hold_wait, holder):
|
||||
self.hold_wait = hold_wait
|
||||
self.holder = holder
|
||||
|
||||
def __enter__(self):
|
||||
return self.hold_wait
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.hold_wait.release(self.holder)
|
||||
|
||||
def __bool__(self):
|
||||
if self.holder.expiry is not None:
|
||||
if self.holder.expiry < self.hold_wait.time_func():
|
||||
return False
|
||||
return True
|
||||
|
||||
def __init__(self, time_func=time.monotonic):
|
||||
"""
|
||||
Create a HoldLock instance. By default, time.monotonic is used for all timeouts, but this
|
||||
can be supplied as any function that returns a current absolute time in seconds as a float.
|
||||
"""
|
||||
self._holders = []
|
||||
self._cv = threading.Condition()
|
||||
self.time_func = time_func
|
||||
self._closed = False
|
||||
|
||||
def hold(self, identifier=None, timeout=None):
|
||||
"""
|
||||
Acquire a hold on this HoldLock, blocking any `wait()` call until all holds are released.
|
||||
Multiple threads may acquire a hold simultaneously, and an identifier may be used more than
|
||||
once.
|
||||
|
||||
The default `None` identifier works like any other, but will result in calls to `holders`
|
||||
or `waiting_for()` to return a tuple containing None values.
|
||||
|
||||
Can either be called directly or used as a context manager - `with holdlock.hold():`
|
||||
|
||||
The returned object is a context manager, but a bool comparison with it will return False
|
||||
if the timeout has expired:
|
||||
|
||||
with holdlock.hold(timeout=5) as hold:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
if not hold:
|
||||
print("Timeout has expired)
|
||||
"""
|
||||
with self._cv:
|
||||
if self._closed:
|
||||
raise Exception("Cannot get new hold on closed HoldWait instance")
|
||||
if timeout is not None:
|
||||
timeout = self.time_func()+timeout
|
||||
new_holder = self.Holder(identifier, timeout) # technically timeout is expiry here
|
||||
self._holders.append(new_holder)
|
||||
# Sort to make sure earliest expiry is first, with None at the end
|
||||
self._holders.sort(key=lambda holder: (holder.expiry is None, holder.expiry))
|
||||
self._cv.notify_all()
|
||||
# cheat a bit by passing the holder as the identifier - this means the context manager
|
||||
# will always release the relevant one
|
||||
return self.HoldContext(self, new_holder)
|
||||
|
||||
def release(self, identifier=None):
|
||||
"""
|
||||
Release a hold on this HoldLock. If there are mutiple holders with the supplied identifier,
|
||||
the one with the earliest timeout will be released.
|
||||
"""
|
||||
with self._cv:
|
||||
if isinstance(identifier, self.Holder):
|
||||
matching_holder = identifier
|
||||
|
||||
else:
|
||||
# _holders is already sorted for us
|
||||
for holder in self._holders:
|
||||
if holder.identifier == identifier:
|
||||
matching_holder = holder
|
||||
break
|
||||
else:
|
||||
raise Exception(F"Release identifier '{identifier}' is not currently held")
|
||||
|
||||
self._holders.remove(matching_holder)
|
||||
self._cv.notify_all()
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
Stop any threads from acquiring a new hold on this HoldLock (the will raise an exception)
|
||||
"""
|
||||
with self._cv:
|
||||
self._closed = True
|
||||
|
||||
def reopen(self):
|
||||
"""
|
||||
Start allowing threads to get a hold on this HoldLock again (after having called `close()`)
|
||||
"""
|
||||
with self._cv:
|
||||
self._closed = False
|
||||
|
||||
@property
|
||||
def holders(self):
|
||||
"""
|
||||
Return a tuple if current holder identities. The tuple itself is a copy, but the values in
|
||||
it are the same objects that `hold()` calls have passed in as identifiers.
|
||||
"""
|
||||
with self._cv:
|
||||
return(tuple(holder.identifier for holder in self._holders))
|
||||
|
||||
@property
|
||||
def hold_count(self):
|
||||
"""
|
||||
Return the current number of holds on this HoldLock
|
||||
"""
|
||||
with self._cv:
|
||||
return len(self._holders)
|
||||
|
||||
def wait(self, timeout=None):
|
||||
"""
|
||||
Wait for all threads currently holding this HoldLock to release it, returning True unless
|
||||
the timeout is hit, where it will return False.
|
||||
|
||||
Note that unless `close()` is called first, _more threads may get a hold_ while waiting.
|
||||
|
||||
If `timeout` is specified, this must be a relative float value in seconds. If
|
||||
`timeout` is None, `wait()` will block indefinitely for all holds to be released.
|
||||
"""
|
||||
expiry = None
|
||||
if timeout is not None:
|
||||
expiry = self.time_func()+timeout
|
||||
|
||||
with self._cv:
|
||||
while len(self._holders) > 0:
|
||||
cv_timeout = None
|
||||
now = self.time_func()
|
||||
|
||||
# Pull out any holders that have expired
|
||||
while (self._holders[0].expiry is not None):
|
||||
if self._holders[0].expiry <= now:
|
||||
self._holders.pop(0)
|
||||
if len(self._holders) == 0:
|
||||
return True
|
||||
else:
|
||||
cv_timeout = self._holders[0].expiry - now
|
||||
break
|
||||
|
||||
if expiry is not None:
|
||||
if expiry <= now:
|
||||
return False
|
||||
cv_timeout = min(cv_timeout, expiry - now) if cv_timeout else expiry - now
|
||||
|
||||
self._cv.wait(cv_timeout)
|
||||
return True
|
||||
|
||||
def waiting_for(self, timeout=None, update_period=None):
|
||||
"""
|
||||
Behaves the same as `wait()`, but is a generator that will return sequences of remaining
|
||||
holder identifiers while waiting for all holds to be released. By default, returns a new
|
||||
sequence of remaining holders whenever it changes, but can also be supplied with
|
||||
`update_period` to add more intermediate updates.
|
||||
|
||||
When all holds are released, the last returned sequence by the generator will be empty (no
|
||||
longer waiting on any holds).
|
||||
If `timeout` is not None and the timeout expires instead, the last sequence returned
|
||||
will _not_ be empty (was still waiting on holds when the timeout expired).
|
||||
"""
|
||||
expiry = None
|
||||
if timeout is not None:
|
||||
expiry = self.time_func()+timeout
|
||||
|
||||
with self._cv:
|
||||
|
||||
# We effectively have 2 sections where holders can be released/timed out, and time can
|
||||
# pass - the wait, and the yield, so things that check for changes
|
||||
# in those need to be done after both.
|
||||
|
||||
while len(self._holders) > 0:
|
||||
now = self.time_func()
|
||||
|
||||
# check main timeout
|
||||
if expiry is not None:
|
||||
if expiry <= now:
|
||||
return
|
||||
|
||||
# expire any holders
|
||||
while (self._holders[0].expiry is not None):
|
||||
if self._holders[0].expiry <= now:
|
||||
self._holders.pop(0)
|
||||
if len(self._holders) == 0:
|
||||
# Generate empty holder tuple and finish
|
||||
self._cv.release()
|
||||
yield tuple()
|
||||
self._cv.acquire()
|
||||
return
|
||||
else:
|
||||
break
|
||||
|
||||
# Yield holders
|
||||
yielded_holders = self.holders
|
||||
self._cv.release()
|
||||
yield yielded_holders
|
||||
self._cv.acquire()
|
||||
|
||||
# If holders has changed since before yield, continue (no need to wait for change)
|
||||
if self.holders != yielded_holders:
|
||||
continue
|
||||
|
||||
# Holders haven't changed, so we have at least 1
|
||||
cv_timeout = update_period
|
||||
now = self.time_func()
|
||||
|
||||
# Check main timeout again
|
||||
if expiry is not None:
|
||||
if expiry <= now:
|
||||
return
|
||||
cv_timeout = min(cv_timeout, expiry - now) if cv_timeout else expiry - now
|
||||
|
||||
# Check holder expiry again
|
||||
if self._holders[0].expiry is not None:
|
||||
if self._holders[0].expiry <= now:
|
||||
# next holder has expired, continue and let original check deal with it
|
||||
continue
|
||||
else:
|
||||
holder_timeout = self._holders[0].expiry - now
|
||||
cv_timeout = min(
|
||||
holder_timeout, cv_timeout) if cv_timeout else holder_timeout
|
||||
|
||||
self._cv.wait(cv_timeout)
|
||||
|
||||
# Generate empty holder tuple and finish
|
||||
self._cv.release()
|
||||
yield tuple()
|
||||
self._cv.acquire()
|
||||
return
|
||||
@ -0,0 +1,103 @@
|
||||
from threading import Timer
|
||||
import time
|
||||
|
||||
from hold_lock import HoldLock
|
||||
|
||||
|
||||
def test_basic_hold_release():
|
||||
lock = HoldLock()
|
||||
lock.wait()
|
||||
lock.hold()
|
||||
assert lock.holders == (None,)
|
||||
lock.hold("hold1")
|
||||
assert lock.holders == (None, "hold1")
|
||||
|
||||
assert not lock.wait(0.1)
|
||||
|
||||
lock.release("hold1")
|
||||
lock.release()
|
||||
assert lock.wait()
|
||||
|
||||
|
||||
def test_thread_release():
|
||||
lock = HoldLock()
|
||||
lock.hold()
|
||||
Timer(0.1, lambda: lock.release()).start()
|
||||
assert lock.wait(1)
|
||||
|
||||
|
||||
def test_hold_timeout():
|
||||
lock = HoldLock()
|
||||
lock.hold(timeout=0.1)
|
||||
assert lock.wait()
|
||||
|
||||
|
||||
def test_waiting_for():
|
||||
lock = HoldLock()
|
||||
|
||||
# Wait with no holders
|
||||
res = []
|
||||
for holders in lock.waiting_for():
|
||||
res.append(holders)
|
||||
assert res == [()]
|
||||
|
||||
# Release holders while waiting
|
||||
lock.hold("hold1")
|
||||
lock.hold("hold2")
|
||||
Timer(0.1, lambda: lock.release("hold1")).start()
|
||||
Timer(0.2, lambda: lock.release("hold2")).start()
|
||||
res = []
|
||||
for holders in lock.waiting_for():
|
||||
res.append(holders)
|
||||
assert res == [("hold1", "hold2"), ("hold2",), ()]
|
||||
|
||||
# Release holders during the yield
|
||||
lock.hold("hold1")
|
||||
lock.hold("hold2")
|
||||
Timer(0.1, lambda: lock.release("hold2")).start()
|
||||
res = []
|
||||
for loopnum, holders in enumerate(lock.waiting_for()):
|
||||
res.append(holders)
|
||||
if loopnum == 0:
|
||||
lock.release("hold1")
|
||||
assert res == [("hold1", "hold2"), ("hold2",), ()]
|
||||
|
||||
|
||||
def test_waiting_for_timeout():
|
||||
lock = HoldLock()
|
||||
lock.hold("hold1")
|
||||
res = []
|
||||
for holders in lock.waiting_for(timeout=0.1):
|
||||
res.append(holders)
|
||||
assert res == [("hold1",)]
|
||||
|
||||
|
||||
def test_waiting_for_timeout_during_yield():
|
||||
lock = HoldLock()
|
||||
lock.hold("hold1")
|
||||
res = []
|
||||
for holders in lock.waiting_for(timeout=0.1):
|
||||
res.append(holders)
|
||||
time.sleep(0.2)
|
||||
assert res == [("hold1",)]
|
||||
|
||||
|
||||
def test_waiting_for_hold_timeouts():
|
||||
lock = HoldLock()
|
||||
lock.hold("hold1", timeout=0.1)
|
||||
lock.hold("hold2", timeout=0.2)
|
||||
res = []
|
||||
for holders in lock.waiting_for():
|
||||
res.append(holders)
|
||||
assert res == [("hold1", "hold2"), ("hold2",), ()]
|
||||
|
||||
|
||||
def test_waiting_for_hold_timeout_during_yield():
|
||||
lock = HoldLock()
|
||||
lock.hold("hold1", timeout=0.1)
|
||||
lock.hold("hold2", timeout=0.3)
|
||||
res = []
|
||||
for holders in lock.waiting_for():
|
||||
res.append(holders)
|
||||
time.sleep(0.2)
|
||||
assert res == [("hold1", "hold2"), ("hold2",), ()]
|
||||
Loading…
Reference in new issue