From a82e447913624ffd59ee84607f642ee4325ece4d Mon Sep 17 00:00:00 2001 From: novirium Date: Mon, 30 Mar 2020 21:08:29 +0800 Subject: [PATCH] Better context managers and tests --- hold_lock.py | 86 ++++++++++++++++++++++++++++------------------- test_hold_lock.py | 32 ++++++++++++++++-- 2 files changed, 81 insertions(+), 37 deletions(-) diff --git a/hold_lock.py b/hold_lock.py index 18df86d..e26c360 100644 --- a/hold_lock.py +++ b/hold_lock.py @@ -1,10 +1,9 @@ import time import contextlib import threading -from collections import namedtuple -class HoldLock(): +class HoldLock(contextlib.AbstractContextManager): """ A sort-of thread lock, intended to allow one thread to wait until all others are finished using a multi-user resource. @@ -26,24 +25,42 @@ class HoldLock(): `hold()` as the identifier, which then allows the main thread to print a list of things it's waiting for by iterating `waiting_for()`. """ - Holder = namedtuple("Holder", ['identifier', 'expiry']) - class HoldContext(contextlib.AbstractContextManager): - def __init__(self, hold_wait, holder): - self.hold_wait = hold_wait - self.holder = holder + class AnonHolder(): + pass + + class Holder(contextlib.AbstractContextManager): + """ + An object representing something that has a hold on a HoldLock. Can be used as a context + manager. Only intended to be used once. + """ + + def __init__(self, hold_lock, identifier, expiry): + self.hold_lock = hold_lock + self.identifier = identifier + self.expiry = expiry def __enter__(self): - return self.hold_wait + return self def __exit__(self, exc_type, exc_value, traceback): - self.hold_wait.release(self.holder) + self.hold_lock._release(self) + + def release(self): + self.hold_lock._release(self) + + def expired(self): + if self.expiry is not None: + if self.expiry <= self.hold_lock.time_func(): + return True + return False + + def __enter__(self): + self.hold() + return self - def __bool__(self): - if self.holder.expiry is not None: - if self.holder.expiry < self.hold_wait.time_func(): - return False - return True + def __exit__(self, exc_type, exc_value, traceback): + self.release() def __init__(self, time_func=time.monotonic): """ @@ -55,7 +72,7 @@ class HoldLock(): self.time_func = time_func self._closed = False - def hold(self, identifier=None, timeout=None): + def hold(self, identifier=AnonHolder, timeout=None): """ Acquire a hold on this HoldLock, blocking any `wait()` call until all holds are released. Multiple threads may acquire a hold simultaneously, and an identifier may be used more than @@ -78,36 +95,37 @@ class HoldLock(): with self._cv: if self._closed: raise Exception("Cannot get new hold on closed HoldWait instance") - if timeout is not None: - timeout = self.time_func()+timeout - new_holder = self.Holder(identifier, timeout) # technically timeout is expiry here - self._holders.append(new_holder) + new_holder = self.Holder(self, identifier, self.time_func() + + timeout if timeout else None) + self._hold(new_holder) + return new_holder + + def _hold(self, holder): + with self._cv: + self._holders.append(holder) # Sort to make sure earliest expiry is first, with None at the end self._holders.sort(key=lambda holder: (holder.expiry is None, holder.expiry)) self._cv.notify_all() - # cheat a bit by passing the holder as the identifier - this means the context manager - # will always release the relevant one - return self.HoldContext(self, new_holder) - def release(self, identifier=None): + def release(self, identifier=AnonHolder): """ Release a hold on this HoldLock. If there are mutiple holders with the supplied identifier, the one with the earliest timeout will be released. """ with self._cv: - if isinstance(identifier, self.Holder): - matching_holder = identifier - + # _holders is already sorted for us + for holder in self._holders: + if holder.identifier == identifier: + matching_holder = holder + break else: - # _holders is already sorted for us - for holder in self._holders: - if holder.identifier == identifier: - matching_holder = holder - break - else: - raise Exception(F"Release identifier '{identifier}' is not currently held") + raise Exception(F"Release identifier '{identifier}' is not currently held") + + self._release(matching_holder) - self._holders.remove(matching_holder) + def _release(self, holder): + with self._cv: + self._holders.remove(holder) self._cv.notify_all() def close(self): diff --git a/test_hold_lock.py b/test_hold_lock.py index 3f3450b..19ec749 100644 --- a/test_hold_lock.py +++ b/test_hold_lock.py @@ -1,4 +1,4 @@ -from threading import Timer +from threading import Timer, Thread import time from hold_lock import HoldLock @@ -8,9 +8,9 @@ def test_basic_hold_release(): lock = HoldLock() lock.wait() lock.hold() - assert lock.holders == (None,) + assert lock.holders == (HoldLock.AnonHolder,) lock.hold("hold1") - assert lock.holders == (None, "hold1") + assert lock.holders == (HoldLock.AnonHolder, "hold1") assert not lock.wait(0.1) @@ -18,6 +18,10 @@ def test_basic_hold_release(): lock.release() assert lock.wait() + holder = lock.hold("hold1") + holder.release() + assert lock.wait() + def test_thread_release(): lock = HoldLock() @@ -26,6 +30,28 @@ def test_thread_release(): assert lock.wait(1) +def test_context(): + lock = HoldLock() + + def with_func(hold_lock): + with hold_lock: + time.sleep(0.2) + Thread(target=with_func, args=[lock]).start() + assert not lock.wait(0.1) + assert lock.wait(1) + + +def test_hold_context(): + lock = HoldLock() + + def with_func(hold_lock): + with hold_lock.hold("hold1"): + time.sleep(0.2) + Thread(target=with_func, args=[lock]).start() + assert lock.holders == ("hold1",) + assert lock.wait() + + def test_hold_timeout(): lock = HoldLock() lock.hold(timeout=0.1)