diff --git a/lazy_player/reactive.py b/lazy_player/reactive.py new file mode 100644 index 0000000..aa3c5e9 --- /dev/null +++ b/lazy_player/reactive.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +from contextvars import ContextVar +from typing import Any, Callable, ClassVar, Generic, Self, TypeVar +from weakref import WeakSet + +__all__ = ["Ref", "Computed", "update_all_computed"] + +T = TypeVar("T") + + +class Ref(Generic[T]): + _users: WeakSet[Computed[Any]] + _value: T + + def __init__(self, initial: T): + self._value = initial + self._users = WeakSet() + + @property + def value(self): + computed = Computed._stack.get() + + if computed is not None: + self._users.add(computed) + + return self._value + + @value.setter + def value(self, new_value: T): + self._value = new_value + self.trigger() + + def trigger(self): + Computed._dirty.update(self._users) + + def __repr__(self): + return f"Ref({self._value!r})" + + def __str__(self): + return repr(self) + + +class Computed(Ref[T]): + _dirty: ClassVar[WeakSet[Self]] = WeakSet() + _stack: ClassVar[ContextVar[Computed[Any] | None]] = ContextVar("stack", default=None) + _update: Callable[[], T] + + def __init__(self, update: Callable[[], T]): + self._update = update + self._users = WeakSet() + self.update() + + def update(self): + token = self._stack.set(self) + + try: + self.value = self._update() + finally: + self._stack.reset(token) + + def __repr__(self): + if hasattr(self, "_value"): + return f"Computed({self._value!r})" + else: + return "Computed(...)" + + +class Watcher: + _watches: set[Computed[Any]] + + def watch(self, handler: Callable[[], Any]): + if not hasattr(self, "_watches"): + self._watches = set() + + def run_handler(): + handler() + + self._watches.add(Computed(run_handler)) + + +def update_all_computed(max_iters: int = 100): + """ + Update all computed values. + If they do not settle in given number of iterations, raise RuntimeError. + """ + + for _ in range(max_iters): + if not Computed._dirty: + return + + dirty = set(Computed._dirty) + + Computed._dirty.clear() + + for computed in dirty: + computed.update() + + raise RuntimeError("Infinite loop in computed values")