diff --git a/hold_lock.py b/hold_lock.py new file mode 100644 index 0000000..18df86d --- /dev/null +++ b/hold_lock.py @@ -0,0 +1,260 @@ +import time +import contextlib +import threading +from collections import namedtuple + + +class HoldLock(): + """ + A sort-of thread lock, intended to allow one thread to wait until all others are finished + using a multi-user resource. + + Once created, threads may call `hold()` on the HoldLock to acquire a hold. If a thread then + calls `wait()` or iterates `waiting_for()`, those calls will block until all holds are + released with `release()`. + + In this simple use case, the HoldLock almost behaves like a reverse semaphore - `hold()` + increases a counter by 1, `release()` reduces it by 1, and calling `wait()` blocks until the + counter comes back down to 0. The closest example of a similar thing I've found is Golang + WaitGroups, which work like this. + + Additionally, the HoldLock allows an identifier to be passed to `hold()`. This same identifier + must be referred to with `release()`, but can be any object - rather than a simple counter, + the HoldLock maintains a list of these identifiers. These only really become useful when the + main waiting thread calls `holders()` or iterates `waiting_for()` - as then it gets access + to these identifiers. The common use case here is to use a string explaining the reason for the + `hold()` as the identifier, which then allows the main thread to print a list of things it's + waiting for by iterating `waiting_for()`. + """ + Holder = namedtuple("Holder", ['identifier', 'expiry']) + + class HoldContext(contextlib.AbstractContextManager): + def __init__(self, hold_wait, holder): + self.hold_wait = hold_wait + self.holder = holder + + def __enter__(self): + return self.hold_wait + + def __exit__(self, exc_type, exc_value, traceback): + self.hold_wait.release(self.holder) + + def __bool__(self): + if self.holder.expiry is not None: + if self.holder.expiry < self.hold_wait.time_func(): + return False + return True + + def __init__(self, time_func=time.monotonic): + """ + Create a HoldLock instance. By default, time.monotonic is used for all timeouts, but this + can be supplied as any function that returns a current absolute time in seconds as a float. + """ + self._holders = [] + self._cv = threading.Condition() + self.time_func = time_func + self._closed = False + + def hold(self, identifier=None, timeout=None): + """ + Acquire a hold on this HoldLock, blocking any `wait()` call until all holds are released. + Multiple threads may acquire a hold simultaneously, and an identifier may be used more than + once. + + The default `None` identifier works like any other, but will result in calls to `holders` + or `waiting_for()` to return a tuple containing None values. + + Can either be called directly or used as a context manager - `with holdlock.hold():` + + The returned object is a context manager, but a bool comparison with it will return False + if the timeout has expired: + + with holdlock.hold(timeout=5) as hold: + while True: + time.sleep(1) + if not hold: + print("Timeout has expired) + """ + with self._cv: + if self._closed: + raise Exception("Cannot get new hold on closed HoldWait instance") + if timeout is not None: + timeout = self.time_func()+timeout + new_holder = self.Holder(identifier, timeout) # technically timeout is expiry here + self._holders.append(new_holder) + # Sort to make sure earliest expiry is first, with None at the end + self._holders.sort(key=lambda holder: (holder.expiry is None, holder.expiry)) + self._cv.notify_all() + # cheat a bit by passing the holder as the identifier - this means the context manager + # will always release the relevant one + return self.HoldContext(self, new_holder) + + def release(self, identifier=None): + """ + Release a hold on this HoldLock. If there are mutiple holders with the supplied identifier, + the one with the earliest timeout will be released. + """ + with self._cv: + if isinstance(identifier, self.Holder): + matching_holder = identifier + + else: + # _holders is already sorted for us + for holder in self._holders: + if holder.identifier == identifier: + matching_holder = holder + break + else: + raise Exception(F"Release identifier '{identifier}' is not currently held") + + self._holders.remove(matching_holder) + self._cv.notify_all() + + def close(self): + """ + Stop any threads from acquiring a new hold on this HoldLock (the will raise an exception) + """ + with self._cv: + self._closed = True + + def reopen(self): + """ + Start allowing threads to get a hold on this HoldLock again (after having called `close()`) + """ + with self._cv: + self._closed = False + + @property + def holders(self): + """ + Return a tuple if current holder identities. The tuple itself is a copy, but the values in + it are the same objects that `hold()` calls have passed in as identifiers. + """ + with self._cv: + return(tuple(holder.identifier for holder in self._holders)) + + @property + def hold_count(self): + """ + Return the current number of holds on this HoldLock + """ + with self._cv: + return len(self._holders) + + def wait(self, timeout=None): + """ + Wait for all threads currently holding this HoldLock to release it, returning True unless + the timeout is hit, where it will return False. + + Note that unless `close()` is called first, _more threads may get a hold_ while waiting. + + If `timeout` is specified, this must be a relative float value in seconds. If + `timeout` is None, `wait()` will block indefinitely for all holds to be released. + """ + expiry = None + if timeout is not None: + expiry = self.time_func()+timeout + + with self._cv: + while len(self._holders) > 0: + cv_timeout = None + now = self.time_func() + + # Pull out any holders that have expired + while (self._holders[0].expiry is not None): + if self._holders[0].expiry <= now: + self._holders.pop(0) + if len(self._holders) == 0: + return True + else: + cv_timeout = self._holders[0].expiry - now + break + + if expiry is not None: + if expiry <= now: + return False + cv_timeout = min(cv_timeout, expiry - now) if cv_timeout else expiry - now + + self._cv.wait(cv_timeout) + return True + + def waiting_for(self, timeout=None, update_period=None): + """ + Behaves the same as `wait()`, but is a generator that will return sequences of remaining + holder identifiers while waiting for all holds to be released. By default, returns a new + sequence of remaining holders whenever it changes, but can also be supplied with + `update_period` to add more intermediate updates. + + When all holds are released, the last returned sequence by the generator will be empty (no + longer waiting on any holds). + If `timeout` is not None and the timeout expires instead, the last sequence returned + will _not_ be empty (was still waiting on holds when the timeout expired). + """ + expiry = None + if timeout is not None: + expiry = self.time_func()+timeout + + with self._cv: + + # We effectively have 2 sections where holders can be released/timed out, and time can + # pass - the wait, and the yield, so things that check for changes + # in those need to be done after both. + + while len(self._holders) > 0: + now = self.time_func() + + # check main timeout + if expiry is not None: + if expiry <= now: + return + + # expire any holders + while (self._holders[0].expiry is not None): + if self._holders[0].expiry <= now: + self._holders.pop(0) + if len(self._holders) == 0: + # Generate empty holder tuple and finish + self._cv.release() + yield tuple() + self._cv.acquire() + return + else: + break + + # Yield holders + yielded_holders = self.holders + self._cv.release() + yield yielded_holders + self._cv.acquire() + + # If holders has changed since before yield, continue (no need to wait for change) + if self.holders != yielded_holders: + continue + + # Holders haven't changed, so we have at least 1 + cv_timeout = update_period + now = self.time_func() + + # Check main timeout again + if expiry is not None: + if expiry <= now: + return + cv_timeout = min(cv_timeout, expiry - now) if cv_timeout else expiry - now + + # Check holder expiry again + if self._holders[0].expiry is not None: + if self._holders[0].expiry <= now: + # next holder has expired, continue and let original check deal with it + continue + else: + holder_timeout = self._holders[0].expiry - now + cv_timeout = min( + holder_timeout, cv_timeout) if cv_timeout else holder_timeout + + self._cv.wait(cv_timeout) + + # Generate empty holder tuple and finish + self._cv.release() + yield tuple() + self._cv.acquire() + return diff --git a/test_hold_lock.py b/test_hold_lock.py new file mode 100644 index 0000000..3f3450b --- /dev/null +++ b/test_hold_lock.py @@ -0,0 +1,103 @@ +from threading import Timer +import time + +from hold_lock import HoldLock + + +def test_basic_hold_release(): + lock = HoldLock() + lock.wait() + lock.hold() + assert lock.holders == (None,) + lock.hold("hold1") + assert lock.holders == (None, "hold1") + + assert not lock.wait(0.1) + + lock.release("hold1") + lock.release() + assert lock.wait() + + +def test_thread_release(): + lock = HoldLock() + lock.hold() + Timer(0.1, lambda: lock.release()).start() + assert lock.wait(1) + + +def test_hold_timeout(): + lock = HoldLock() + lock.hold(timeout=0.1) + assert lock.wait() + + +def test_waiting_for(): + lock = HoldLock() + + # Wait with no holders + res = [] + for holders in lock.waiting_for(): + res.append(holders) + assert res == [()] + + # Release holders while waiting + lock.hold("hold1") + lock.hold("hold2") + Timer(0.1, lambda: lock.release("hold1")).start() + Timer(0.2, lambda: lock.release("hold2")).start() + res = [] + for holders in lock.waiting_for(): + res.append(holders) + assert res == [("hold1", "hold2"), ("hold2",), ()] + + # Release holders during the yield + lock.hold("hold1") + lock.hold("hold2") + Timer(0.1, lambda: lock.release("hold2")).start() + res = [] + for loopnum, holders in enumerate(lock.waiting_for()): + res.append(holders) + if loopnum == 0: + lock.release("hold1") + assert res == [("hold1", "hold2"), ("hold2",), ()] + + +def test_waiting_for_timeout(): + lock = HoldLock() + lock.hold("hold1") + res = [] + for holders in lock.waiting_for(timeout=0.1): + res.append(holders) + assert res == [("hold1",)] + + +def test_waiting_for_timeout_during_yield(): + lock = HoldLock() + lock.hold("hold1") + res = [] + for holders in lock.waiting_for(timeout=0.1): + res.append(holders) + time.sleep(0.2) + assert res == [("hold1",)] + + +def test_waiting_for_hold_timeouts(): + lock = HoldLock() + lock.hold("hold1", timeout=0.1) + lock.hold("hold2", timeout=0.2) + res = [] + for holders in lock.waiting_for(): + res.append(holders) + assert res == [("hold1", "hold2"), ("hold2",), ()] + + +def test_waiting_for_hold_timeout_during_yield(): + lock = HoldLock() + lock.hold("hold1", timeout=0.1) + lock.hold("hold2", timeout=0.3) + res = [] + for holders in lock.waiting_for(): + res.append(holders) + time.sleep(0.2) + assert res == [("hold1", "hold2"), ("hold2",), ()]