from __future__ import annotations

import sys
import traceback
from contextvars import ContextVar
from typing import Any, Callable, ClassVar, Generic, Self, TypeVar
from weakref import WeakSet

__all__ = ["Ref", "Computed", "Watcher", "update_all_computed"]

T = TypeVar("T")


class Ref(Generic[T]):
    _error: Exception | None
    _users: WeakSet[Computed[Any]]
    _value: T

    def __init__(self, initial: T):
        self._error = None
        self._value = initial
        self._users = WeakSet()

    @property
    def value(self):
        computed = Computed._stack.get()

        if computed is not None:
            self._users.add(computed)

        if self._error is not None:
            raise self._error

        return self._value

    @value.setter
    def value(self, new_value: T):
        self._value = new_value
        self.trigger()

    def trigger(self):
        Computed._dirty.update(self._users)

    def __repr__(self):
        return f"Ref({self._value!r})"

    def __str__(self):
        return repr(self)


class Computed(Ref[T]):
    _dirty: ClassVar[WeakSet[Self]] = WeakSet()
    _stack: ClassVar[ContextVar[Computed[Any] | None]] = ContextVar("stack", default=None)
    _update: Callable[[], T]

    def __init__(self, update: Callable[[], T]):
        self._error = None
        self._update = update
        self._users = WeakSet()
        self.update()

    def update(self):
        token = self._stack.set(self)

        try:
            self.value = self._update()
            self._error = None
        except Exception as err:
            self._error = err
            traceback.print_exception(err, file=sys.stderr)
        finally:
            self._stack.reset(token)

    def __repr__(self):
        if hasattr(self, "_value"):
            return f"Computed({self._value!r})"
        else:
            return "Computed(...)"


class Watcher:
    _watches: set[Computed[Any]]

    def watch(self, handler: Callable[[], Any]):
        if not hasattr(self, "_watches"):
            self._watches = set()

        def run_handler():
            handler()

        self._watches.add(Computed(run_handler))


def update_all_computed(max_iters: int = 100):
    """
    Update all computed values.
    If they do not settle in given number of iterations, raise RuntimeError.
    """

    for _ in range(max_iters):
        if not Computed._dirty:
            return

        dirty = set(Computed._dirty)

        Computed._dirty.clear()

        for computed in dirty:
            computed.update()

    raise RuntimeError("Infinite loop in computed values")