import time import itertools import contextlib import threading class HoldLock(contextlib.AbstractContextManager): """ A sort-of thread lock, intended to allow one thread to wait until all others are finished using a multi-user resource. Once created, threads may call `hold()` on the HoldLock to acquire a hold. If a thread then calls `wait()` or iterates `waiting_for()`, those calls will block until all holds are released with `release()`. In this simple use case, the HoldLock almost behaves like a reverse semaphore - `hold()` increases a counter by 1, `release()` reduces it by 1, and calling `wait()` blocks until the counter comes back down to 0. The closest example of a similar thing I've found is Golang WaitGroups, which work like this. Additionally, the HoldLock allows an identifier to be passed to `hold()`. This same identifier must be referred to with `release()`, but can be any object - rather than a simple counter, the HoldLock maintains a list of these identifiers. These only really become useful when the main waiting thread calls `holders()` or iterates `waiting_for()` - as then it gets access to these identifiers. The common use case here is to use a string explaining the reason for the `hold()` as the identifier, which then allows the main thread to print a list of things it's waiting for by iterating `waiting_for()`. By default, the `HoldLock.AnonHolder` identifier is used in all calls, allowing the identifier to be completely ignored if it's not useful. The HoldLock object itself can be used as a context manager in `with` statements, and functions the same as calling `hold()` with defaults. """ class AnonHolder(): pass class Holder(contextlib.AbstractContextManager): """ An object representing something that has a hold on a HoldLock. Can be used as a context manager. Only intended to be used once. """ def __init__(self, hold_lock, identifier, expiry): self.hold_lock = hold_lock self.identifier = identifier self.expiry = expiry def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.hold_lock._release(self) def release(self): self.hold_lock._release(self) def expired(self): if self.expiry is not None: if self.expiry <= self.hold_lock.time_func(): return True return False def __enter__(self): self.hold() return self def __exit__(self, exc_type, exc_value, traceback): self.release() def __init__(self, time_func=time.monotonic): """ Create a HoldLock instance. By default, time.monotonic is used for all timeouts, but this can be supplied as any function that returns a current absolute time in seconds as a float. """ self._holders = [] self._expired_holders = [] self._cv = threading.Condition() self.time_func = time_func self._closed = False def hold(self, identifier=AnonHolder, timeout=None): """ Acquire a hold on this HoldLock, blocking any `wait()` call until all holds are released. Multiple threads may acquire a hold simultaneously, and an identifier may be used more than once. A hold must later be released with `release()`, providing the same identifier. The default `None` identifier works like any other, but will result in calls to `holders` or `waiting_for()` to return a tuple containing None values. Can either be called directly or used as a context manager - `with holdlock.hold():`. The returned Holder object also provides a way to see if the hold has expired (`holder.expired()`) and also provides an alternate way to release it without having to pass the identifier again (`holder.release()`). holder1 = holdlock.hold("annoying to reference identifier") holder1.release() with holdlock.hold(timeout=5) as holder2: while True: time.sleep(1) if holder2.expired(): print("Timeout has expired") """ with self._cv: if self._closed: raise Exception("Cannot get new hold on closed HoldWait instance") new_holder = self.Holder(self, identifier, self.time_func() + timeout if timeout else None) self._hold(new_holder) return new_holder def _hold(self, holder): with self._cv: self._holders.append(holder) # Sort to make sure earliest expiry is first, with None at the end self._holders.sort(key=lambda holder: (holder.expiry is None, holder.expiry)) self._cv.notify_all() def release(self, identifier=AnonHolder): """ Release a hold on this HoldLock. If there are mutiple holders with the supplied identifier, the one with the earliest timeout will be released. Returns False if the hold had expired (technically holds only expire _if_ someone was waiting for it when the timeout was hit), otherwise returns True. """ with self._cv: # _holders is already sorted for us for holder in itertools.chain(self._expired_holders, self._holders): if holder.identifier == identifier: matching_holder = holder break else: raise Exception(F"Release identifier '{identifier}' is not currently held") return self._release(matching_holder) def _release(self, holder): with self._cv: if holder in self._expired_holders: self._expired_holders.remove(holder) else: self._holders.remove(holder) self._cv.notify_all() def close(self): """ Stop any threads from acquiring a new hold on this HoldLock (the will raise an exception) """ with self._cv: self._closed = True def reopen(self): """ Start allowing threads to get a hold on this HoldLock again (after having called `close()`) """ with self._cv: self._closed = False @property def holders(self): """ Return a tuple if current holder identities. The tuple itself is a copy, but the values in it are the same objects that `hold()` calls have passed in as identifiers. """ with self._cv: return(tuple(holder.identifier for holder in self._holders)) @property def hold_count(self): """ Return the current number of holds on this HoldLock """ with self._cv: return len(self._holders) def wait(self, timeout=None): """ Wait for all threads currently holding this HoldLock to release it, returning True unless the timeout is hit, where it will return False. Note that unless `close()` is called first, _more threads may get a hold_ while waiting. If `timeout` is specified, this must be a relative float value in seconds. If `timeout` is None, `wait()` will block indefinitely for all holds to be released. """ expiry = None if timeout is not None: expiry = self.time_func()+timeout with self._cv: while len(self._holders) > 0: cv_timeout = None now = self.time_func() # Pull out any holders that have expired while (self._holders[0].expiry is not None): if self._holders[0].expiry <= now: self._expired_holders.append(self._holders.pop(0)) if len(self._holders) == 0: return True else: cv_timeout = self._holders[0].expiry - now break if expiry is not None: if expiry <= now: return False cv_timeout = min(cv_timeout, expiry - now) if cv_timeout else expiry - now self._cv.wait(cv_timeout) return True def waiting_for(self, timeout=None, update_period=None): """ Behaves the same as `wait()`, but is a generator that will return sequences of remaining holder identifiers while waiting for all holds to be released. By default, returns a new sequence of remaining holders whenever it changes, but can also be supplied with `update_period` to add more intermediate updates. When all holds are released, the last returned sequence by the generator will be empty (no longer waiting on any holds). If `timeout` is not None and the timeout expires instead, the last sequence returned will _not_ be empty (was still waiting on holds when the timeout expired). """ expiry = None if timeout is not None: expiry = self.time_func()+timeout with self._cv: # We effectively have 2 sections where holders can be released/timed out, and time can # pass - the wait, and the yield, so things that check for changes # in those need to be done after both. while len(self._holders) > 0: now = self.time_func() # check main timeout if expiry is not None: if expiry <= now: return # expire any holders while (self._holders[0].expiry is not None): if self._holders[0].expiry <= now: self._holders.pop(0) if len(self._holders) == 0: # Generate empty holder tuple and finish self._cv.release() yield tuple() self._cv.acquire() return else: break # Yield holders yielded_holders = self.holders self._cv.release() yield yielded_holders self._cv.acquire() # If holders has changed since before yield, continue (no need to wait for change) if self.holders != yielded_holders: continue # Holders haven't changed, so we have at least 1 cv_timeout = update_period now = self.time_func() # Check main timeout again if expiry is not None: if expiry <= now: return cv_timeout = min(cv_timeout, expiry - now) if cv_timeout else expiry - now # Check holder expiry again if self._holders[0].expiry is not None: if self._holders[0].expiry <= now: # next holder has expired, continue and let original check deal with it continue else: holder_timeout = self._holders[0].expiry - now cv_timeout = min( holder_timeout, cv_timeout) if cv_timeout else holder_timeout self._cv.wait(cv_timeout) # Generate empty holder tuple and finish self._cv.release() yield tuple() self._cv.acquire() return