from __future__ import annotations import sys import traceback from contextvars import ContextVar from typing import Any, Callable, ClassVar, Generic, Self, TypeVar from weakref import WeakSet __all__ = ["Ref", "Computed", "Watcher", "update_all_computed"] T = TypeVar("T") class Ref(Generic[T]): _error: Exception | None _users: WeakSet[Computed[Any]] _value: T def __init__(self, initial: T): self._error = None self._value = initial self._users = WeakSet() @property def value(self): computed = Computed._stack.get() if computed is not None: self._users.add(computed) if self._error is not None: raise self._error return self._value @value.setter def value(self, new_value: T): self._value = new_value self.trigger() def trigger(self): Computed._dirty.update(self._users) def __repr__(self): return f"Ref({self._value!r})" def __str__(self): return repr(self) class Computed(Ref[T]): _dirty: ClassVar[WeakSet[Self]] = WeakSet() _stack: ClassVar[ContextVar[Computed[Any] | None]] = ContextVar("stack", default=None) _update: Callable[[], T] def __init__(self, update: Callable[[], T]): self._error = None self._update = update self._users = WeakSet() self.update() def update(self): token = self._stack.set(self) try: self.value = self._update() self._error = None except Exception as err: self._error = err traceback.print_exception(err, file=sys.stderr) finally: self._stack.reset(token) def __repr__(self): if hasattr(self, "_value"): return f"Computed({self._value!r})" else: return "Computed(...)" class Watcher: _watches: set[Computed[Any]] def watch(self, handler: Callable[[], Any]): if not hasattr(self, "_watches"): self._watches = set() def run_handler(): handler() self._watches.add(Computed(run_handler)) def update_all_computed(max_iters: int = 100): """ Update all computed values. If they do not settle in given number of iterations, raise RuntimeError. """ for _ in range(max_iters): if not Computed._dirty: return dirty = set(Computed._dirty) Computed._dirty.clear() for computed in dirty: computed.update() raise RuntimeError("Infinite loop in computed values")