From 14d8d30a7d79b439ca77ca594bc9a993a7858400 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Sun, 8 Mar 2026 01:51:14 -0800 Subject: [PATCH 01/81] Add StateToken[TOKEN_TYPE] for flexible state manager Special case access patterns for BaseState retrieval to allow for other types of state that have different semantics. --- reflex/app.py | 35 ++++--- reflex/event.py | 12 ++- reflex/istate/manager/__init__.py | 38 ++++--- reflex/istate/manager/disk.py | 127 +++++++++++----------- reflex/istate/manager/token.py | 169 ++++++++++++++++++++++++++++++ reflex/istate/proxy.py | 9 +- reflex/istate/shared.py | 16 ++- reflex/istate/wrappers.py | 7 +- reflex/plugins/_screenshot.py | 10 +- reflex/state.py | 3 +- reflex/testing.py | 16 ++- reflex/utils/token_manager.py | 6 +- 12 files changed, 325 insertions(+), 123 deletions(-) create mode 100644 reflex/istate/manager/token.py diff --git a/reflex/app.py b/reflex/app.py index 54682543a7d..73f3cedbc6e 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -83,6 +83,7 @@ get_hydrate_event, noop, ) +from reflex.istate.manager.token import BaseStateToken from reflex.istate.proxy import StateProxy from reflex.page import DECORATED_PAGES from reflex.route import ( @@ -96,8 +97,6 @@ State, StateManager, StateUpdate, - _split_substate_key, - _substate_key, all_base_state_classes, code_uses_state_contexts, ) @@ -525,7 +524,7 @@ def _setup_state(self) -> None: config = get_config() # Set up the state manager. - self._state_manager = StateManager.create(state=self._state) + self._state_manager = StateManager.create() # Set up the Socket.IO AsyncServer. if not self.sio: @@ -1568,7 +1567,7 @@ def all_routes(_request: Request) -> Response: @contextlib.asynccontextmanager async def modify_state( self, - token: str, + token: BaseStateToken, background: bool = False, previous_dirty_vars: dict[str, set[str]] | None = None, ) -> AsyncIterator[BaseState]: @@ -1604,7 +1603,7 @@ async def modify_state( delta=delta, final=True if not background else None, ), - token=token, + token=token.ident, ) def _process_background( @@ -1976,8 +1975,12 @@ async def _create_upload_event() -> Event: ) # Get the state for the session. - substate_token = _substate_key(token, handler.rpartition(".")[0]) - state = await app.state_manager.get_state(substate_token) + if app._state is None: + msg = "Upload failed, app has no state defined." + raise UploadValueError(msg) + state = await app.state_manager.get_state( + BaseStateToken(ident=token, cls=app._state) + ) handler_upload_param = () @@ -2157,17 +2160,14 @@ async def emit_update(self, update: StateUpdate, token: str) -> None: update: The state update to send. token: The client token (tab) associated with the event. """ - client_token, _ = _split_substate_key(token) - socket_record = self._token_manager.token_to_socket.get(client_token) + socket_record = self._token_manager.token_to_socket.get(token) if ( socket_record is None or socket_record.instance_id != self._token_manager.instance_id ): if isinstance(self._token_manager, RedisTokenManager): # The socket belongs to another instance of the app, send it to the lost and found. - if not await self._token_manager.emit_lost_and_found( - client_token, update - ): + if not await self._token_manager.emit_lost_and_found(token, update): console.warn( f"Failed to send delta to lost and found for client {token!r}" ) @@ -2294,8 +2294,9 @@ async def link_token_to_sid(self, sid: str, token: str): await self.emit("new_token", new_token, to=sid) # Update client state to apply new sid/token for running background tasks. - async with self.app.state_manager.modify_state( - _substate_key(new_token or token, self.app.state_manager.state) - ) as state: - state.router_data[constants.RouteVar.SESSION_ID] = sid - state.router = RouterData.from_router_data(state.router_data) + if self.app._state is not None: + async with self.app.state_manager.modify_state( + BaseStateToken(ident=new_token or token, cls=self.app._state) + ) as state: + state.router_data[constants.RouteVar.SESSION_ID] = sid + state.router = RouterData.from_router_data(state.router_data) diff --git a/reflex/event.py b/reflex/event.py index ff75e3bd3cb..730ecd81b84 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -55,6 +55,9 @@ ) from reflex.vars.object import ObjectVar +if TYPE_CHECKING: + from reflex.istate.manager.token import BaseStateToken + @dataclasses.dataclass( init=True, @@ -76,14 +79,19 @@ class Event: payload: dict[str, Any] = dataclasses.field(default_factory=dict) @property - def substate_token(self) -> str: + def substate_token(self) -> BaseStateToken: """Get the substate token for the event. Returns: The substate token. """ + from reflex.istate.manager.token import BaseStateToken + from reflex.state import State + substate = self.name.rpartition(".")[0] - return f"{self.token}_{substate}" + return BaseStateToken( + ident=self.token, cls=State.get_class_substate(tuple(substate.split("."))) + ) _EVENT_FIELDS: set[str] = {f.name for f in dataclasses.fields(Event)} diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index 1eae7550de3..0f8f42be3a7 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -11,7 +11,7 @@ from reflex import constants from reflex.config import get_config from reflex.event import Event -from reflex.state import BaseState +from reflex.istate.manager.token import TOKEN_TYPE, StateToken from reflex.utils import console, prerequisites from reflex.utils.exceptions import InvalidStateManagerModeError @@ -29,16 +29,10 @@ class StateModificationContext(TypedDict, total=False): class StateManager(ABC): """A class to manage many client states.""" - # The state class to use. - state: type[BaseState] - @classmethod - def create(cls, state: type[BaseState]): + def create(cls): """Create a new state manager. - Args: - state: The state class to use. - Raises: InvalidStateManagerModeError: If the state manager mode is invalid. @@ -51,11 +45,11 @@ def create(cls, state: type[BaseState]): if config.state_manager_mode == constants.StateManagerMode.MEMORY: from reflex.istate.manager.memory import StateManagerMemory - return StateManagerMemory(state=state) + return StateManagerMemory() if config.state_manager_mode == constants.StateManagerMode.DISK: from reflex.istate.manager.disk import StateManagerDisk - return StateManagerDisk(state=state) + return StateManagerDisk() if config.state_manager_mode == constants.StateManagerMode.REDIS: redis = prerequisites.get_redis() if redis is not None: @@ -63,7 +57,6 @@ def create(cls, state: type[BaseState]): # make sure expiration values are obtained only from the config object on creation return StateManagerRedis( - state=state, redis=redis, token_expiration=config.redis_token_expiration, lock_expiration=config.redis_lock_expiration, @@ -73,7 +66,7 @@ def create(cls, state: type[BaseState]): raise InvalidStateManagerModeError(msg) @abstractmethod - async def get_state(self, token: str) -> BaseState: + async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: """Get the state for a token. Args: @@ -86,8 +79,8 @@ async def get_state(self, token: str) -> BaseState: @abstractmethod async def set_state( self, - token: str, - state: BaseState, + token: StateToken[TOKEN_TYPE], + state: TOKEN_TYPE, **context: Unpack[StateModificationContext], ): """Set the state for a token. @@ -101,8 +94,8 @@ async def set_state( @abstractmethod @contextlib.asynccontextmanager async def modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> AsyncIterator[BaseState]: + self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] + ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. Args: @@ -112,15 +105,15 @@ async def modify_state( Yields: The state for the token. """ - yield self.state() + yield # pyright: ignore[reportReturnType] @contextlib.asynccontextmanager async def modify_state_with_links( self, - token: str, + token: StateToken[TOKEN_TYPE], previous_dirty_vars: dict[str, set[str]] | None = None, **context: Unpack[StateModificationContext], - ) -> AsyncIterator[BaseState]: + ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token, including linked substates, while holding exclusive lock. Args: @@ -131,8 +124,13 @@ async def modify_state_with_links( Yields: The state for the token with linked states patched in. """ + from reflex.state import BaseState + async with self.modify_state(token, **context) as root_state: - if getattr(root_state, "_reflex_internal_links", None) is not None: + if ( + isinstance(root_state, BaseState) + and getattr(root_state, "_reflex_internal_links", None) is not None + ): from reflex.istate.shared import SharedStateBaseInternal shared_state = await root_state.get_state(SharedStateBaseInternal) diff --git a/reflex/istate/manager/disk.py b/reflex/istate/manager/disk.py index 30db2f25481..5a6ee7ecef9 100644 --- a/reflex/istate/manager/disk.py +++ b/reflex/istate/manager/disk.py @@ -8,6 +8,7 @@ from collections.abc import AsyncIterator from hashlib import md5 from pathlib import Path +from typing import Any, Generic, cast from typing_extensions import Unpack, override @@ -17,17 +18,18 @@ StateModificationContext, _default_token_expiration, ) -from reflex.state import BaseState, _split_substate_key, _substate_key +from reflex.istate.manager.token import TOKEN_TYPE, BaseStateToken, StateToken +from reflex.state import BaseState from reflex.utils import console, path_ops, prerequisites from reflex.utils.misc import run_in_thread @dataclasses.dataclass(frozen=True) -class QueueItem: +class QueueItem(Generic[TOKEN_TYPE]): """An item in the write queue.""" - token: str - state: BaseState + token: StateToken[TOKEN_TYPE] + state: TOKEN_TYPE timestamp: float @@ -36,7 +38,7 @@ class StateManagerDisk(StateManager): """A state manager that stores states on disk.""" # The mapping of client ids to states. - states: dict[str, BaseState] = dataclasses.field(default_factory=dict) + states: dict[str, Any] = dataclasses.field(default_factory=dict) # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) @@ -57,7 +59,7 @@ class StateManagerDisk(StateManager): ) # Pending writes - _write_queue: dict[str, QueueItem] = dataclasses.field( + _write_queue: dict[StateToken, QueueItem] = dataclasses.field( default_factory=dict, init=False, ) @@ -96,7 +98,7 @@ def _purge_expired_states(self): # remove the file path.unlink() - def token_path(self, token: str) -> Path: + def token_path(self, token: StateToken) -> Path: """Get the path for a token. Args: @@ -106,10 +108,10 @@ def token_path(self, token: str) -> Path: The path for the token. """ return ( - self.states_directory / f"{md5(token.encode()).hexdigest()}.pkl" + self.states_directory / f"{md5(str(token).encode()).hexdigest()}.pkl" ).absolute() - async def load_state(self, token: str) -> BaseState | None: + async def load_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE | None: """Load a state object based on the provided token. Args: @@ -123,23 +125,23 @@ async def load_state(self, token: str) -> BaseState | None: if token_path.exists(): try: with token_path.open(mode="rb") as file: - return BaseState._deserialize(fp=file) + return token.deserialize(fp=file) except Exception: pass return None async def populate_substates( - self, client_token: str, state: BaseState, root_state: BaseState + self, token: BaseStateToken, state: BaseState, root_state: BaseState ): """Populate the substates of a state object. Args: - client_token: The client token. + token: The token used to identify the state object. state: The state object to populate. root_state: The root state object. """ for substate in state.get_substates(): - substate_token = _substate_key(client_token, substate) + substate_token = token.with_cls(substate) fresh_instance = await root_state.get_state(substate) instance = await self.load_state(substate_token) @@ -151,13 +153,13 @@ async def populate_substates( state.substates[substate.get_name()] = instance instance.parent_state = state - await self.populate_substates(client_token, instance, root_state) + await self.populate_substates(token, instance, root_state) @override async def get_state( self, - token: str, - ) -> BaseState: + token: StateToken[TOKEN_TYPE], + ) -> TOKEN_TYPE: """Get the state for a token. Args: @@ -166,38 +168,46 @@ async def get_state( Returns: The state for the token. """ - client_token = _split_substate_key(token)[0] - self._token_last_touched[client_token] = time.time() - root_state = self.states.get(client_token) + self._token_last_touched[token.ident] = time.time() + root_state = self.states.get(token.ident) if root_state is not None: # Retrieved state from memory. return root_state # Deserialize root state from disk. - root_state = await self.load_state(_substate_key(client_token, self.state)) - # Create a new root state tree with all substates instantiated. - fresh_root_state = self.state(_reflex_internal_init=True) - if root_state is None: - root_state = fresh_root_state + root_state = await self.load_state(token) + if isinstance(token, BaseStateToken): + # Create a new root state tree with all substates instantiated. + fresh_root_state = token.cls(_reflex_internal_init=True) + if root_state is None: + root_state = fresh_root_state + elif not isinstance(root_state, BaseState): + msg = "Deserialized state is not an instance of BaseState, cannot populate substates." + raise TypeError(msg) + else: + # Ensure all substates exist, even if they were not serialized previously. + root_state.substates = fresh_root_state.substates + await self.populate_substates(token, root_state, root_state) else: - # Ensure all substates exist, even if they were not serialized previously. - root_state.substates = fresh_root_state.substates - self.states[client_token] = root_state - await self.populate_substates(client_token, root_state, root_state) - return root_state - - async def set_state_for_substate(self, client_token: str, substate: BaseState): + # For non-BaseState tokens, if the deserialized state is None, we create a new instance using the token's cls. + if root_state is None: + root_state = token.cls() + self.states[token.ident] = root_state + return cast(TOKEN_TYPE, root_state) + + async def set_state_for_substate( + self, token: StateToken[TOKEN_TYPE], substate: TOKEN_TYPE + ): """Set the state for a substate. Args: - client_token: The client token. + token: The token used to identify the state object. substate: The substate to set. """ - substate_token = _substate_key(client_token, substate) + substate_token = token.with_cls(type(substate)) - if substate._get_was_touched(): - substate._was_touched = False # Reset the touched flag after serializing. - pickle_state = substate._serialize() + if token.get_and_reset_touched_state(substate): + pickle_state = token.serialize(substate) if pickle_state: if not self.states_directory.exists(): self.states_directory.mkdir(parents=True, exist_ok=True) @@ -205,8 +215,9 @@ async def set_state_for_substate(self, client_token: str, substate: BaseState): lambda: self.token_path(substate_token).write_bytes(pickle_state), ) - for substate_substate in substate.substates.values(): - await self.set_state_for_substate(client_token, substate_substate) + if isinstance(token, BaseStateToken) and isinstance(substate, BaseState): + for substate_substate in substate.substates.values(): + await self.set_state_for_substate(token, substate_substate) async def _process_write_queue_delay(self): """Wait for the debounce period before processing the write queue again.""" @@ -252,15 +263,14 @@ async def _process_write_queue(self): ) for item in items_to_write: token = item.token - client_token, _ = _split_substate_key(token) await self.set_state_for_substate( - client_token, self._write_queue.pop(token).state + token, self._write_queue.pop(token).state ) # Check for expired states to purge. - for token, last_touched in list(self._token_last_touched.items()): + for token_ident, last_touched in list(self._token_last_touched.items()): if now - last_touched > self.token_expiration: - self._token_last_touched.pop(token) - self.states.pop(token, None) + self._token_last_touched.pop(token_ident) + self.states.pop(token_ident, None) await run_in_thread(self._purge_expired_states) await self._process_write_queue_delay() except asyncio.CancelledError: # noqa: PERF203 @@ -283,10 +293,8 @@ async def _flush_write_queue(self): f"StateManagerDisk._flush_write_queue: writing {n_outstanding_items} remaining items to disk" ) for item in outstanding_items: - token = item.token - client_token, _ = _split_substate_key(token) await self.set_state_for_substate( - client_token, + item.token, item.state, ) console.debug( @@ -306,7 +314,10 @@ async def _schedule_process_write_queue(self): @override async def set_state( - self, token: str, state: BaseState, **context: Unpack[StateModificationContext] + self, + token: StateToken[TOKEN_TYPE], + state: TOKEN_TYPE, + **context: Unpack[StateModificationContext], ): """Set the state for a token. @@ -315,26 +326,25 @@ async def set_state( state: The state to set. context: The state modification context. """ - client_token, _ = _split_substate_key(token) if self._write_debounce_seconds > 0: # Deferred write to reduce disk IO overhead. - if client_token not in self._write_queue: - self._write_queue[client_token] = QueueItem( - token=client_token, + if token not in self._write_queue: + self._write_queue[token] = QueueItem( + token=token, state=state, timestamp=time.time(), ) else: # Immediate write to disk. - await self.set_state_for_substate(client_token, state) + await self.set_state_for_substate(token, state) # Ensure the processing task is scheduled to handle expirations and any deferred writes. await self._schedule_process_write_queue() @override @contextlib.asynccontextmanager async def modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> AsyncIterator[BaseState]: + self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] + ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. Args: @@ -345,13 +355,12 @@ async def modify_state( The state for the token. """ # Disk state manager ignores the substate suffix and always returns the top-level state. - client_token, _ = _split_substate_key(token) - if client_token not in self._states_locks: + if token.ident not in self._states_locks: async with self._state_manager_lock: - if client_token not in self._states_locks: - self._states_locks[client_token] = asyncio.Lock() + if token.ident not in self._states_locks: + self._states_locks[token.ident] = asyncio.Lock() - async with self._states_locks[client_token]: + async with self._states_locks[token.ident]: state = await self.get_state(token) yield state await self.set_state(token, state, **context) diff --git a/reflex/istate/manager/token.py b/reflex/istate/manager/token.py new file mode 100644 index 00000000000..f9eb85b1a12 --- /dev/null +++ b/reflex/istate/manager/token.py @@ -0,0 +1,169 @@ +"""Representation of a StateManager token.""" + +import dataclasses +import pickle +from typing import TYPE_CHECKING, BinaryIO, Generic, Self, TypeVar + +if TYPE_CHECKING: + from reflex.state import BaseState + +TOKEN_TYPE = TypeVar("TOKEN_TYPE") + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class StateToken(Generic[TOKEN_TYPE]): + """Token for looking referencing a state instance in the StateManager.""" + + # Identifier, usually the client_token, but could be a linked / shared token. + ident: str + + # The class associated with the state instance. + cls: type[TOKEN_TYPE] + + def with_cls(self, cls: type[TOKEN_TYPE]) -> Self: + """Return a new token with the cls field updated to the provided class. + + Args: + cls: The class to update the cls field to. + + Returns: + A new StateToken instance with the updated cls field. + """ + return dataclasses.replace(self, cls=cls) + + def __str__(self) -> str: + """The key used in the underlying StateManager store. + + Returns: + A string representation of the token, which is a combination of the ident and cls name. + """ + # urlencode the redis token to escape the slash delimiter. + clean_ident = self.ident.replace("/", "%2F") + clean_cls_name = self.cls.__name__.replace("/", "%2F") + return f"{clean_ident}/{clean_cls_name}" + + @classmethod + def serialize(cls, state: TOKEN_TYPE) -> bytes: + """Serialize the state for redis/disk storage. + + Args: + state: The state to serialize. + + Returns: + The serialized state. + """ + return pickle.dumps(state) + + @classmethod + def deserialize( + cls, data: bytes | None = None, fp: BinaryIO | None = None + ) -> TOKEN_TYPE: + """Deserialize the state from redis/disk. + + data and fp are mutually exclusive, but one must be provided. + + Args: + data: The serialized state data. + fp: The file pointer to the serialized state data. + + Returns: + The raw deserialized state ("should match the token type"). + + Raises: + ValueError: If both data and fp are provided, or neither are provided. + """ + if data is not None and fp is None: + return pickle.loads(data) + if fp is not None: + return pickle.load(fp) + msg = "Only one of `data` or `fp` must be provided" + raise ValueError(msg) + + @classmethod + def get_and_reset_touched_state(cls, state: TOKEN_TYPE) -> bool: + """Get the touched state and reset the touched flag. + + This is used to determine if a state has been modified since it was last serialized. + + Args: + state: The state to check for modifications. + + Returns: + The touched state of the state. + """ + # Default implementation is always to write the state. + return True + + +class BaseStateToken(StateToken["BaseState"]): + """A token for the accessing reflex BaseState instances. + + This token type implies subtree hierarchy population and other semantic checks. + """ + + def with_cls(self, cls: type["BaseState"]) -> Self: + """Return a new token with the cls field updated to the provided class. + + Args: + cls: The class to update the cls field to. + + Returns: + A new StateToken instance with the updated cls field. + """ + return super().with_cls(cls) + + def __str__(self) -> str: + """The key used in the underlying StateManager store. + + Returns: + A string representation of the token, which is a combination of the ident and cls name. + """ + # urlencode the redis token to escape the slash delimiter. + return f"{self.ident}_{self.cls.get_full_name()}" + + @classmethod + def serialize(cls, state: BaseState) -> bytes: + """Serialize the BaseState for redis/disk storage. + + Args: + state: The BaseState to serialize. + + Returns: + The serialized state. + """ + return state._serialize() + + @classmethod + def deserialize( + cls, data: bytes | None = None, fp: BinaryIO | None = None + ) -> BaseState: + """Deserialize the BaseState from redis/disk. + + data and fp are mutually exclusive, but one must be provided. + + Args: + data: The serialized state data. + fp: The file pointer to the serialized state data. + + Returns: + The deserialized BaseState instance. + """ + from reflex.state import BaseState + + return BaseState._deserialize(data, fp) + + @classmethod + def get_and_reset_touched_state(cls, state: BaseState) -> bool: + """Get the touched state and reset the touched flag. + + This is used to determine if a state has been modified since it was last serialized. + + Args: + state: The BaseState to check for modifications. + + Returns: + The touched state of the BaseState. + """ + was_touched = state._get_was_touched() + state._was_touched = False # Reset the touched flag after serializing. + return was_touched diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index 90231eeb7fc..1ba14da5c4c 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -18,6 +18,7 @@ from typing_extensions import Self from reflex.base import Base +from reflex.istate.manager.token import BaseStateToken from reflex.utils import prerequisites from reflex.utils.exceptions import ImmutableStateError from reflex.utils.serializers import can_serialize, serialize, serializer @@ -71,14 +72,12 @@ def __init__( state_instance: The state instance to proxy. parent_state_proxy: The parent state proxy, for linked mutability and context tracking. """ - from reflex.state import _substate_key - super().__init__(state_instance) self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) - self._self_substate_token = _substate_key( - state_instance.router.session.client_token, - self._self_substate_path, + self._self_substate_token = BaseStateToken( + ident=state_instance.router.session.client_token, + cls=state_instance.__class__, ) self._self_actx = None self._self_mutable = False diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py index 01bf154d913..3fa2f8ffa37 100644 --- a/reflex/istate/shared.py +++ b/reflex/istate/shared.py @@ -7,7 +7,8 @@ from reflex.constants import ROUTER_DATA from reflex.event import Event, get_hydrate_event -from reflex.state import BaseState, State, _override_base_method, _substate_key +from reflex.istate.manager.token import BaseStateToken +from reflex.state import BaseState, State, _override_base_method from reflex.utils import console from reflex.utils.exceptions import ReflexRuntimeError @@ -52,7 +53,7 @@ def _do_update_other_tokens( async def _update_client(token: str): async with app.modify_state( - _substate_key(token, state_type), + BaseStateToken(ident=token, cls=state_type), previous_dirty_vars=previous_dirty_vars, ): pass @@ -236,7 +237,10 @@ async def _unlink(self): # Patch in the original state, apply updates, then rehydrate. private_root_state = await get_state_manager().get_state( - _substate_key(self.router.session.client_token, type(self)) + BaseStateToken( + ident=self.router.session.client_token, + cls=type(self), + ) ) private_state = await private_root_state.get_state(type(self)) async with _patch_state( @@ -271,12 +275,14 @@ async def _internal_patch_linked_state( # Get the newly linked state and update pointers/delta for subsequent events. if token not in self._held_locks: linked_root_state = await self._exit_stack.enter_async_context( - get_state_manager().modify_state(_substate_key(token, type(self))) + get_state_manager().modify_state( + BaseStateToken(ident=token, cls=type(self)) + ) ) self._held_locks.setdefault(token, {}) else: linked_root_state = await get_state_manager().get_state( - _substate_key(token, type(self)) + BaseStateToken(ident=token, cls=type(self)) ) linked_state = await linked_root_state.get_state(type(self)) if not isinstance(linked_state, SharedState): diff --git a/reflex/istate/wrappers.py b/reflex/istate/wrappers.py index 865bd6c6383..dada8cf9e3a 100644 --- a/reflex/istate/wrappers.py +++ b/reflex/istate/wrappers.py @@ -2,8 +2,9 @@ from typing import Any +from reflex.istate.manager.token import BaseStateToken from reflex.istate.proxy import ReadOnlyStateProxy -from reflex.state import _split_substate_key, _substate_key, get_state_manager +from reflex.state import State, _split_substate_key, get_state_manager async def get_state(token: str, state_cls: Any | None = None) -> ReadOnlyStateProxy: @@ -18,9 +19,9 @@ async def get_state(token: str, state_cls: Any | None = None) -> ReadOnlyStatePr """ mng = get_state_manager() if state_cls is not None: - root_state = await mng.get_state(_substate_key(token, state_cls)) + root_state = await mng.get_state(BaseStateToken(ident=token, cls=state_cls)) else: - root_state = await mng.get_state(token) + root_state = await mng.get_state(BaseStateToken(ident=token, cls=State)) _, state_path = _split_substate_key(token) state_cls = root_state.get_class_substate(tuple(state_path.split("."))) instance = await root_state.get_state(state_cls) diff --git a/reflex/plugins/_screenshot.py b/reflex/plugins/_screenshot.py index f0741c9fd0e..53d3257c832 100644 --- a/reflex/plugins/_screenshot.py +++ b/reflex/plugins/_screenshot.py @@ -97,7 +97,8 @@ async def clone_state(request: "Request") -> "Response": from starlette.responses import JSONResponse - from reflex.state import _substate_key + from reflex.istate.manager.token import BaseStateToken + from reflex.state import State if not app.event_namespace: return JSONResponse({}) @@ -109,7 +110,9 @@ async def clone_state(request: "Request") -> "Response": {"error": "Token to clone must be a string."}, status_code=400 ) - old_state = await app.state_manager.get_state(token_to_clone) + old_state = await app.state_manager.get_state( + BaseStateToken(ident=token_to_clone, cls=State), + ) new_state = _deep_copy(old_state) @@ -132,7 +135,8 @@ async def clone_state(request: "Request") -> "Response": found_new = True await app.state_manager.set_state( - _substate_key(new_token, new_state), new_state + BaseStateToken(ident=new_token, cls=type(new_state)), + new_state, ) return JSONResponse(new_token) diff --git a/reflex/state.py b/reflex/state.py index 49c6703088b..47edd872d8f 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1609,6 +1609,7 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: StateMismatchError: If the state instance is not of the expected type. """ from reflex.istate.manager.redis import StateManagerRedis + from reflex.istate.manager.token import BaseStateToken # Then get the target state and all its substates. state_manager = get_state_manager() @@ -1619,7 +1620,7 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: ) raise RuntimeError(msg) state_in_redis = await state_manager.get_state( - token=_substate_key(self.router.session.client_token, state_cls), + token=BaseStateToken(ident=self.router.session.client_token, cls=state_cls), top_level=False, for_state_instance=self, ) diff --git a/reflex/testing.py b/reflex/testing.py index 4ab72602334..46792d06c58 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -40,8 +40,10 @@ from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis +from reflex.istate.manager.token import BaseStateToken from reflex.state import ( BaseState, + State, StateManager, _split_substate_key, reload_state_module, @@ -738,7 +740,9 @@ async def get_state(self, token: str) -> BaseState: client_token, _ = _split_substate_key(token) self.state_manager.states.pop(client_token, None) try: - return await self.state_manager.get_state(token) + return await self.state_manager.get_state( + BaseStateToken(ident=token, cls=State) + ) finally: await self.state_manager.close() @@ -759,7 +763,9 @@ async def set_state(self, token: str, **kwargs) -> None: for key, value in kwargs.items(): setattr(state, key, value) try: - await self.state_manager.set_state(token, state) + await self.state_manager.set_state( + BaseStateToken(ident=token, cls=type(state)), state + ) finally: if self.app_instance is not None and isinstance( self.app_instance.state_manager, StateManagerDisk @@ -785,7 +791,7 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: if self.state_manager is None: msg = "state_manager is not set." raise RuntimeError(msg) - if self.app_instance is None: + if self.app_instance is None or self.app_instance._state is None: msg = "App is not running." raise RuntimeError(msg) app_state_manager = self.app_instance.state_manager @@ -794,7 +800,9 @@ async def modify_state(self, token: str) -> AsyncIterator[BaseState]: # the redis/disk connection is on the backend_thread event loop self.app_instance._state_manager = self.state_manager try: - async with self.app_instance.modify_state(token) as state: + async with self.app_instance.modify_state( + BaseStateToken(ident=token, cls=self.app_instance._state) + ) as state: yield state finally: if isinstance(app_state_manager, StateManagerDisk): diff --git a/reflex/utils/token_manager.py b/reflex/utils/token_manager.py index 514d641cea6..f16a2ff386a 100644 --- a/reflex/utils/token_manager.py +++ b/reflex/utils/token_manager.py @@ -12,7 +12,7 @@ from typing import TYPE_CHECKING, ClassVar from reflex.istate.manager.redis import StateManagerRedis -from reflex.state import BaseState, StateUpdate +from reflex.state import StateUpdate from reflex.utils import console, prerequisites from reflex.utils.tasks import ensure_task @@ -248,9 +248,7 @@ async def _handle_socket_record_del( async def _subscribe_socket_record_updates(self) -> None: """Subscribe to Redis keyspace notifications for socket record updates.""" - await StateManagerRedis( - state=BaseState, redis=self.redis - )._enable_keyspace_notifications() + await StateManagerRedis(redis=self.redis)._enable_keyspace_notifications() redis_db = self.redis.get_connection_kwargs().get("db", 0) async with self.redis.pubsub() as pubsub: From c4b191f426ecc60db1e972c0ce4e15f7675dc523 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Sun, 8 Mar 2026 05:00:08 -0700 Subject: [PATCH 02/81] Use StateToken with redis and memory managers Update all unit test cases to use the new StateToken / BaseStateToken API --- reflex/event.py | 8 +- reflex/istate/manager/disk.py | 28 +++-- reflex/istate/manager/memory.py | 44 ++++---- reflex/istate/manager/redis.py | 137 +++++++++++++---------- reflex/testing.py | 8 +- tests/integration/test_client_storage.py | 2 +- tests/units/istate/manager/test_redis.py | 89 ++++++++------- tests/units/test_app.py | 58 ++++++---- tests/units/test_state.py | 127 ++++++++++++--------- tests/units/test_state_tree.py | 7 +- 10 files changed, 294 insertions(+), 214 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index 730ecd81b84..5dab0461e47 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -87,10 +87,16 @@ def substate_token(self) -> BaseStateToken: """ from reflex.istate.manager.token import BaseStateToken from reflex.state import State + from reflex.utils.prerequisites import get_app + + app = get_app().app + + root_state = State if app._state is None else app._state substate = self.name.rpartition(".")[0] return BaseStateToken( - ident=self.token, cls=State.get_class_substate(tuple(substate.split("."))) + ident=self.token, + cls=root_state.get_class_substate(tuple(substate.split("."))), ) diff --git a/reflex/istate/manager/disk.py b/reflex/istate/manager/disk.py index 5a6ee7ecef9..d57068e96f1 100644 --- a/reflex/istate/manager/disk.py +++ b/reflex/istate/manager/disk.py @@ -168,17 +168,23 @@ async def get_state( Returns: The state for the token. """ - self._token_last_touched[token.ident] = time.time() - root_state = self.states.get(token.ident) + if isinstance(token, BaseStateToken): + root_state = self.states.get(token.ident) + self._token_last_touched[token.ident] = time.time() + else: + root_state = self.states.get(str(token)) + self._token_last_touched[str(token)] = time.time() if root_state is not None: # Retrieved state from memory. return root_state # Deserialize root state from disk. - root_state = await self.load_state(token) if isinstance(token, BaseStateToken): + # Find the root state + root_state_cls = token.cls.get_root_state() + root_state = await self.load_state(token.with_cls(root_state_cls)) # Create a new root state tree with all substates instantiated. - fresh_root_state = token.cls(_reflex_internal_init=True) + fresh_root_state = root_state_cls(_reflex_internal_init=True) if root_state is None: root_state = fresh_root_state elif not isinstance(root_state, BaseState): @@ -188,12 +194,14 @@ async def get_state( # Ensure all substates exist, even if they were not serialized previously. root_state.substates = fresh_root_state.substates await self.populate_substates(token, root_state, root_state) - else: - # For non-BaseState tokens, if the deserialized state is None, we create a new instance using the token's cls. - if root_state is None: - root_state = token.cls() - self.states[token.ident] = root_state - return cast(TOKEN_TYPE, root_state) + self.states[token.ident] = root_state + return cast(TOKEN_TYPE, root_state) + # For non-BaseState tokens, if the deserialized state is None, we create a new instance using the token's cls. + state = await self.load_state(token) + if state is None: + state = token.cls() + self.states[str(token)] = state + return cast(TOKEN_TYPE, state) async def set_state_for_substate( self, token: StateToken[TOKEN_TYPE], substate: TOKEN_TYPE diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index d1040470a88..0f567749d60 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -4,11 +4,12 @@ import contextlib import dataclasses from collections.abc import AsyncIterator +from typing import Any, cast from typing_extensions import Unpack, override from reflex.istate.manager import StateManager, StateModificationContext -from reflex.state import BaseState, _split_substate_key +from reflex.istate.manager.token import TOKEN_TYPE, BaseStateToken, StateToken @dataclasses.dataclass @@ -16,7 +17,7 @@ class StateManagerMemory(StateManager): """A state manager that stores states in memory.""" # The mapping of client ids to states. - states: dict[str, BaseState] = dataclasses.field(default_factory=dict) + states: dict[str, Any] = dataclasses.field(default_factory=dict) # The mutex ensures the dict of mutexes is updated exclusively _state_manager_lock: asyncio.Lock = dataclasses.field(default=asyncio.Lock()) @@ -27,7 +28,7 @@ class StateManagerMemory(StateManager): ) @override - async def get_state(self, token: str) -> BaseState: + async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: """Get the state for a token. Args: @@ -36,17 +37,21 @@ async def get_state(self, token: str) -> BaseState: Returns: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] - if token not in self.states: - self.states[token] = self.state(_reflex_internal_init=True) - return self.states[token] + key = token.ident if isinstance(token, BaseStateToken) else str(token) + if key not in self.states: + if isinstance(token, BaseStateToken): + self.states[key] = token.cls.get_root_state()( + _reflex_internal_init=True + ) + else: + self.states[key] = token.cls() + return cast(TOKEN_TYPE, self.states[key]) @override async def set_state( self, - token: str, - state: BaseState, + token: StateToken[TOKEN_TYPE], + state: TOKEN_TYPE, **context: Unpack[StateModificationContext], ): """Set the state for a token. @@ -56,14 +61,14 @@ async def set_state( state: The state to set. context: The state modification context. """ - token = _split_substate_key(token)[0] - self.states[token] = state + key = token.ident if isinstance(token, BaseStateToken) else str(token) + self.states[key] = state @override @contextlib.asynccontextmanager async def modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> AsyncIterator[BaseState]: + self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] + ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. Args: @@ -73,13 +78,12 @@ async def modify_state( Yields: The state for the token. """ - # Memory state manager ignores the substate suffix and always returns the top-level state. - token = _split_substate_key(token)[0] - if token not in self._states_locks: + if token.ident not in self._states_locks: async with self._state_manager_lock: - if token not in self._states_locks: - self._states_locks[token] = asyncio.Lock() + if token.ident not in self._states_locks: + self._states_locks[token.ident] = asyncio.Lock() - async with self._states_locks[token]: + async with self._states_locks[token.ident]: state = await self.get_state(token) yield state + await self.set_state(token, state, **context) diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index 63fc90586aa..e7dd95857d2 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -9,7 +9,7 @@ import time import uuid from collections.abc import AsyncIterator -from typing import TypedDict +from typing import Any, TypedDict, cast from redis import ResponseError from redis.asyncio import Redis @@ -22,7 +22,8 @@ StateModificationContext, _default_token_expiration, ) -from reflex.state import BaseState, _split_substate_key, _substate_key +from reflex.istate.manager.token import TOKEN_TYPE, BaseStateToken, StateToken +from reflex.state import BaseState from reflex.utils import console from reflex.utils.exceptions import ( InvalidLockWarningThresholdError, @@ -134,9 +135,7 @@ class StateManagerRedis(StateManager): ) # Cached states - _cached_states: dict[str, BaseState] = dataclasses.field( - default_factory=dict, init=False - ) + _cached_states: dict[str, Any] = dataclasses.field(default_factory=dict, init=False) _cached_states_locks: dict[str, asyncio.Lock] = dataclasses.field( default_factory=dict, init=False ) @@ -257,32 +256,31 @@ def _get_populated_states( @override async def get_state( self, - token: str, + token: StateToken[TOKEN_TYPE], top_level: bool = True, for_state_instance: BaseState | None = None, - ) -> BaseState: + ) -> TOKEN_TYPE: """Get the state for a token. Args: token: The token to get the state for. - top_level: If true, return an instance of the top-level state (self.state). + top_level: If true, return the top-level root state. for_state_instance: If provided, attach the requested states to this existing state tree. Returns: The state for the token. Raises: - RuntimeError: when the state_cls is not specified in the token, or when the parent state for a - requested state was not fetched. + RuntimeError: when the parent state for a requested state was not fetched. """ - # Split the actual token from the fully qualified substate name. - token, state_path = _split_substate_key(token) - if state_path: - # Get the State class associated with the given path. - requested_state_cls = self.state.get_class_substate(state_path) - else: - msg = f"StateManagerRedis requires token to be specified in the form of {{token}}_{{state_full_name}}, but got {token}" - raise RuntimeError(msg) + if not isinstance(token, BaseStateToken): + # Non-BaseState token: simple single-key fetch. + redis_data = await self.redis.get(str(token)) + if redis_data is not None: + return token.deserialize(data=redis_data) + return token.cls() + + requested_state_cls = token.cls # Determine which states we already have. flat_state_tree: dict[str, BaseState] = ( @@ -298,7 +296,7 @@ async def get_state( redis_pipeline = self.redis.pipeline() for state_cls in required_state_classes: - redis_pipeline.get(_substate_key(token, state_cls)) + redis_pipeline.get(str(token.with_cls(state_cls))) for state_cls, redis_state in zip( required_state_classes, @@ -336,14 +334,17 @@ async def get_state( # To retain compatibility with previous implementation, by default, we return # the top-level state which should always be fetched or already cached. if top_level: - return flat_state_tree[self.state.get_full_name()] - return flat_state_tree[requested_state_cls.get_full_name()] + return cast( + TOKEN_TYPE, + flat_state_tree[requested_state_cls.get_root_state().get_full_name()], + ) + return cast(TOKEN_TYPE, flat_state_tree[requested_state_cls.get_full_name()]) @override async def set_state( self, - token: str, - state: BaseState, + token: StateToken[TOKEN_TYPE], + state: TOKEN_TYPE, *, lock_id: bytes | None = None, **context: Unpack[StateModificationContext], @@ -377,7 +378,17 @@ async def set_state( ) raise LockExpiredError(msg) - client_token, substate_name = _split_substate_key(token) + if not isinstance(token, BaseStateToken): + # Non-BaseState token: simple single-key write. + pickle_state = token.serialize(state) + if pickle_state: + await self.redis.set(str(token), pickle_state, ex=self.token_expiration) + return + + base_state = cast(BaseState, state) + + client_token = token.ident + substate_name = token.cls.get_full_name() if lock_id is not None and client_token not in self._local_leases: time_taken = ( @@ -396,29 +407,32 @@ async def set_state( ) # If the substate name on the token doesn't match the instance name, it cannot have a parent. - if state.parent_state is not None and state.get_full_name() != substate_name: - msg = f"Cannot `set_state` with mismatching token {token} and substate {state.get_full_name()}." + if ( + base_state.parent_state is not None + and base_state.get_full_name() != substate_name + ): + msg = f"Cannot `set_state` with mismatching token {token} and substate {base_state.get_full_name()}." raise RuntimeError(msg) # Recursively set_state on all known substates. tasks = [ asyncio.create_task( self.set_state( - _substate_key(client_token, substate), + token.with_cls(type(substate)), substate, lock_id=lock_id, **context, ), name=f"reflex_set_state|{client_token}|{substate.get_full_name()}", ) - for substate in state.substates.values() + for substate in base_state.substates.values() ] # Persist only the given state (parents or substates are excluded by BaseState.__getstate__). - if state._get_was_touched(): - pickle_state = state._serialize() + if base_state._get_was_touched(): + pickle_state = base_state._serialize() if pickle_state: await self.redis.set( - _substate_key(client_token, state), + str(token), pickle_state, ex=self.token_expiration, ) @@ -429,8 +443,8 @@ async def set_state( @contextlib.asynccontextmanager async def _try_modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> AsyncIterator[BaseState | None]: + self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] + ) -> AsyncIterator[TOKEN_TYPE | None]: """Modify the state for a token while holding exclusive lock. Args: @@ -456,7 +470,7 @@ async def _try_modify_state( return # Opportunistic locking is enabled, so try to hold the lock across multiple calls. - client_token, _ = _split_substate_key(token) + client_token = token.ident lock_held_ctx = contextlib.AsyncExitStack() try: lock_id = await lock_held_ctx.enter_async_context(self._lock(token)) @@ -514,8 +528,8 @@ async def _try_modify_state( @override @contextlib.asynccontextmanager async def modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> AsyncIterator[BaseState]: + self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] + ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. Args: @@ -528,11 +542,13 @@ async def modify_state( while True: async with self._try_modify_state(token, **context) as state_instance: if state_instance is not None: - yield state_instance + yield cast(TOKEN_TYPE, state_instance) return @contextlib.asynccontextmanager - async def _get_state_cached(self, token: str) -> AsyncIterator[BaseState | None]: + async def _get_state_cached( + self, token: StateToken[TOKEN_TYPE] + ) -> AsyncIterator[TOKEN_TYPE | None]: """Get the cached state for a token, while holding the local lease lock. Args: @@ -540,11 +556,8 @@ async def _get_state_cached(self, token: str) -> AsyncIterator[BaseState | None] Yields: The cached state for the token, or None if not cached/uncachable. - - Raises: - RuntimeError: when the state_cls is not specified in the token. """ - client_token, state_path = _split_substate_key(token) + client_token = token.ident # Opportunistically reuse existing lock. if ( client_token in self._local_leases @@ -555,17 +568,23 @@ async def _get_state_cached(self, token: str) -> AsyncIterator[BaseState | None] if ( cached_state := self._cached_states.get(client_token) ) is not None: - # Make sure we have the substate cached (or fetch it from redis). - try: - substate = cached_state.get_substate(state_path.split(".")) - if len(substate.substates) != len( - type(substate).get_substates() - ): - # If the substate is missing substates, we need to refetch it. - raise ValueError # noqa: TRY301 - except ValueError: - await self.get_state(token, for_state_instance=cached_state) - yield cached_state + if isinstance(token, BaseStateToken): + # Make sure we have the substate cached (or fetch it from redis). + state_path = token.cls.get_full_name() + try: + substate = cached_state.get_substate( + state_path.split(".") + ) + if len(substate.substates) != len( + type(substate).get_substates() + ): + # If the substate is missing substates, we need to refetch it. + raise ValueError # noqa: TRY301 + except ValueError: + await self.get_state( + token, for_state_instance=cached_state + ) + yield cast(TOKEN_TYPE, cached_state) return elif self._debug_enabled: console.debug( @@ -595,7 +614,7 @@ def _notify_next_waiter(self, key: bytes): async def _create_lease_break_task( self, - token: str, + token: StateToken[TOKEN_TYPE], lock_id: bytes, cleanup_ctx: contextlib.AsyncExitStack, **context: Unpack[StateModificationContext], @@ -613,7 +632,7 @@ async def _create_lease_break_task( """ self._ensure_lock_task() - client_token, _ = _split_substate_key(token) + client_token = token.ident async def do_flush() -> None: if (state_lock := self._cached_states_locks.get(client_token)) is None: @@ -709,7 +728,7 @@ async def lease_breaker(): return None @staticmethod - def _lock_key(token: str) -> bytes: + def _lock_key(token: StateToken[Any]) -> bytes: """Get the redis key for a token's lock. Args: @@ -718,9 +737,7 @@ def _lock_key(token: str) -> bytes: Returns: The redis lock key for the token. """ - # All substates share the same lock domain, so ignore any substate path suffix. - client_token = _split_substate_key(token)[0] - return f"{client_token}_lock".encode() + return f"{token.ident}_lock".encode() async def _try_extend_lock(self, lock_key: bytes) -> bool | None: """Extends the current lock for another lock_expiration period. @@ -1000,7 +1017,7 @@ async def _wait_lock(self, lock_key: bytes, lock_id: bytes) -> None: ) @contextlib.asynccontextmanager - async def _lock(self, token: str): + async def _lock(self, token: StateToken[Any]): """Obtain a redis lock for a token. Args: diff --git a/reflex/testing.py b/reflex/testing.py index 46792d06c58..6d384c54a0d 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -289,9 +289,9 @@ def _initialize_app(self): raise RuntimeError(msg) if isinstance(self.app_instance._state_manager, StateManagerRedis): # Create our own redis connection for testing. - self.state_manager = StateManagerRedis.create(self.app_instance._state) + self.state_manager = StateManagerRedis.create() elif isinstance(self.app_instance._state_manager, StateManagerDisk): - self.state_manager = StateManagerDisk.create(self.app_instance._state) + self.state_manager = StateManagerDisk.create() if self.state_manager is None: self.state_manager = ( self.app_instance._state_manager if self.app_instance else None @@ -378,9 +378,7 @@ async def _reset_backend_state_manager(self): ) and self.app_instance._state is not None ): - self.app_instance._state_manager = StateManagerRedis.create( - state=self.app_instance._state, - ) + self.app_instance._state_manager = StateManagerRedis.create() if not isinstance(self.app_instance.state_manager, StateManagerRedis): msg = "Failed to reset state manager." raise RuntimeError(msg) diff --git a/tests/integration/test_client_storage.py b/tests/integration/test_client_storage.py index 80766407d31..01c2239c768 100644 --- a/tests/integration/test_client_storage.py +++ b/tests/integration/test_client_storage.py @@ -660,7 +660,7 @@ def set_sub_sub(var: str, value: str): ): # Purge the backend's disk manager app_state_manager.states.pop(token, None) - app_state_manager._write_queue.pop(token, None) + app_state_manager._write_queue.clear() og_token_expiration = app_state_manager.token_expiration app_state_manager.token_expiration = 0 app_state_manager._purge_expired_states() diff --git a/tests/units/istate/manager/test_redis.py b/tests/units/istate/manager/test_redis.py index 076268ca9c6..f93bf955d68 100644 --- a/tests/units/istate/manager/test_redis.py +++ b/tests/units/istate/manager/test_redis.py @@ -11,7 +11,8 @@ import pytest_asyncio from reflex.istate.manager.redis import StateManagerRedis -from reflex.state import BaseState, _substate_key +from reflex.istate.manager.token import BaseStateToken +from reflex.state import BaseState from tests.units.mock_redis import mock_redis, real_redis @@ -43,7 +44,7 @@ async def state_manager_redis( async with real_redis() as redis: if redis is None: redis = mock_redis() - state_manager = StateManagerRedis(state=root_state, redis=redis) + state_manager = StateManagerRedis(redis=redis) test_start = time.monotonic() yield state_manager # None of the tests should have triggered a lock expiration. @@ -80,10 +81,14 @@ async def test_basic_get_set( token = str(uuid.uuid4()) - fresh_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + fresh_state = await state_manager_redis.get_state( + BaseStateToken(ident=token, cls=root_state) + ) fresh_state.foo = "baz" fresh_state.count = 42 - await state_manager_redis.set_state(_substate_key(token, root_state), fresh_state) + await state_manager_redis.set_state( + BaseStateToken(ident=token, cls=root_state), fresh_state + ) async def test_modify( @@ -102,19 +107,21 @@ async def test_modify( # Initial modify should set count to 1 async with state_manager_redis.modify_state( - _substate_key(token, root_state) + BaseStateToken(ident=token, cls=root_state) ) as new_state: new_state.count = 1 # Subsequent modify should set count to 2 async with state_manager_redis.modify_state( - _substate_key(token, root_state) + BaseStateToken(ident=token, cls=root_state) ) as new_state: assert isinstance(new_state, root_state) assert new_state.count == 1 new_state.count += 2 - final_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + final_state = await state_manager_redis.get_state( + BaseStateToken(ident=token, cls=root_state) + ) assert isinstance(final_state, root_state) assert final_state.count == 3 @@ -136,16 +143,14 @@ async def test_modify_oplock( state_manager_redis._debug_enabled = True state_manager_redis._oplock_enabled = True - state_manager_2 = StateManagerRedis( - state=root_state, redis=state_manager_redis.redis - ) + state_manager_2 = StateManagerRedis(redis=state_manager_redis.redis) state_manager_2._debug_enabled = True state_manager_2._oplock_enabled = True # Initial modify should set count to 1 async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: new_state.count = 1 @@ -168,7 +173,7 @@ async def test_modify_oplock( # The second modify should NOT trigger another redis lock async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: new_state.count = 2 assert state_lock_1.locked() @@ -183,7 +188,7 @@ async def test_modify_oplock( # Contend the lock from another state manager async with state_manager_2.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: new_state.count = 3 state_lock_2 = state_manager_2._cached_states_locks.get(token) @@ -265,9 +270,7 @@ async def test_oplock_contention_queue( state_manager_redis._debug_enabled = True state_manager_redis._oplock_enabled = True - state_manager_2 = StateManagerRedis( - state=root_state, redis=state_manager_redis.redis - ) + state_manager_2 = StateManagerRedis(redis=state_manager_redis.redis) state_manager_2._debug_enabled = True state_manager_2._oplock_enabled = True @@ -279,7 +282,7 @@ async def test_oplock_contention_queue( async def modify_1(): async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -290,7 +293,7 @@ async def modify_2(): await modify_started.wait() modify_2_started.set() async with state_manager_2.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -300,7 +303,7 @@ async def modify_3(): await modify_started.wait() modify_2_started.set() async with state_manager_2.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -323,14 +326,16 @@ async def modify_3(): await task_3 interim_state = await state_manager_redis.get_state( - _substate_key(token, root_state) + BaseStateToken(ident=token, cls=root_state) ) assert isinstance(interim_state, root_state) assert interim_state.count == 1 await state_manager_2.close() - final_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + final_state = await state_manager_redis.get_state( + BaseStateToken(ident=token, cls=root_state) + ) assert isinstance(final_state, root_state) assert final_state.count == 3 @@ -360,16 +365,12 @@ async def test_oplock_contention_no_lease( state_manager_redis._debug_enabled = True state_manager_redis._oplock_enabled = True - state_manager_2 = StateManagerRedis( - state=root_state, redis=state_manager_redis.redis - ) + state_manager_2 = StateManagerRedis(redis=state_manager_redis.redis) state_manager_2._debug_enabled = True state_manager_2._oplock_enabled = True - state_manager_3 = StateManagerRedis( - state=root_state, redis=state_manager_redis.redis - ) + state_manager_3 = StateManagerRedis(redis=state_manager_redis.redis) state_manager_3._debug_enabled = True state_manager_3._oplock_enabled = True @@ -380,7 +381,7 @@ async def test_oplock_contention_no_lease( async def modify_1(): async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -391,7 +392,7 @@ async def modify_2(): await modify_started.wait() modify_2_started.set() async with state_manager_2.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -401,7 +402,7 @@ async def modify_3(): await modify_started.wait() modify_2_started.set() async with state_manager_3.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -437,7 +438,9 @@ async def modify_3(): await state_manager_2.close() await state_manager_3.close() - final_state = await state_manager_2.get_state(_substate_key(token, root_state)) + final_state = await state_manager_2.get_state( + BaseStateToken(ident=token, cls=root_state) + ) assert isinstance(final_state, root_state) assert final_state.count == 3 @@ -469,9 +472,7 @@ async def test_oplock_contention_racers( state_manager_redis._debug_enabled = True state_manager_redis._oplock_enabled = True - state_manager_2 = StateManagerRedis( - state=root_state, redis=state_manager_redis.redis - ) + state_manager_2 = StateManagerRedis(redis=state_manager_redis.redis) state_manager_2._debug_enabled = True state_manager_2._oplock_enabled = True lease_1 = None @@ -480,7 +481,7 @@ async def test_oplock_contention_racers( async def modify_1(): nonlocal lease_1 async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: lease_1 = await state_manager_redis._get_local_lease(token) assert isinstance(new_state, root_state) @@ -491,7 +492,7 @@ async def modify_2(): await asyncio.sleep(racer_delay) nonlocal lease_2 async with state_manager_2.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: lease_2 = await state_manager_2._get_local_lease(token) assert isinstance(new_state, root_state) @@ -540,7 +541,7 @@ async def canceller(): task = asyncio.create_task(canceller()) async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert await state_manager_redis._get_local_lease(token) is None assert isinstance(new_state, root_state) @@ -575,20 +576,20 @@ class SubState2(root_state): state_manager_redis._oplock_enabled = True async with state_manager_redis.modify_state( - _substate_key(token, SubState1), + BaseStateToken(ident=token, cls=SubState1), ) as new_state: assert SubState1.get_name() in new_state.substates assert SubState2.get_name() not in new_state.substates async with state_manager_redis.modify_state( - _substate_key(token, SubState2), + BaseStateToken(ident=token, cls=SubState2), ) as new_state: # Both substates should be fetched and cached. assert SubState1.get_name() in new_state.substates assert SubState2.get_name() in new_state.substates async with state_manager_redis.modify_state( - _substate_key(token, SubState1), + BaseStateToken(ident=token, cls=SubState1), ) as new_state: # Both substates should be fetched and cached now. assert SubState1.get_name() in new_state.substates @@ -648,7 +649,7 @@ async def test_oplock_hold_oplock_after_cancel( async def modify(): async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: modify_started.set() assert isinstance(new_state, root_state) @@ -684,7 +685,7 @@ async def modify(): # Modify the state again, this should get a new lock and lease async with state_manager_redis.modify_state( - _substate_key(token, root_state), + BaseStateToken(ident=token, cls=root_state), ) as new_state: assert isinstance(new_state, root_state) new_state.count += 1 @@ -700,6 +701,8 @@ async def modify(): await state_manager_redis.close() # Both increments should be present. - final_state = await state_manager_redis.get_state(_substate_key(token, root_state)) + final_state = await state_manager_redis.get_state( + BaseStateToken(ident=token, cls=root_state) + ) assert isinstance(final_state, root_state) assert final_state.count == 2 diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 25c71c0d17e..c61a4a1f06b 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -39,16 +39,10 @@ from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis +from reflex.istate.manager.token import BaseStateToken from reflex.middleware import HydrateMiddleware from reflex.model import Model -from reflex.state import ( - BaseState, - OnLoadInternalState, - RouterData, - State, - StateUpdate, - _substate_key, -) +from reflex.state import BaseState, OnLoadInternalState, RouterData, State, StateUpdate from reflex.style import Style from reflex.utils import console, exceptions, format from reflex.vars.base import computed_var @@ -450,7 +444,9 @@ async def test_initialize_with_state(test_state: type[ATestState], token: str): assert app._state == test_state # Get a state for a given token. - state = await app.state_manager.get_state(_substate_key(token, test_state)) + state = await app.state_manager.get_state( + BaseStateToken(ident=token, cls=test_state) + ) assert isinstance(state, test_state) assert state.var == 0 @@ -467,8 +463,8 @@ async def test_set_and_get_state(test_state: type[ATestState]): app = App(_state=test_state) # Create two tokens. - token1 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}" - token2 = str(uuid.uuid4()) + f"_{test_state.get_full_name()}" + token1 = BaseStateToken(ident=str(uuid.uuid4()), cls=test_state) + token2 = BaseStateToken(ident=str(uuid.uuid4()), cls=test_state) # Get the default state for each token. state1 = await app.state_manager.get_state(token1) @@ -935,7 +931,14 @@ async def test_dict_mutation_detection__plain_list( ), ], ) -async def test_upload_file(tmp_path, state, delta, token: str, mocker: MockerFixture): +async def test_upload_file( + tmp_path, + state, + delta, + token: str, + mocker: MockerFixture, + app_module_mock: unittest.mock.Mock, +): """Test that file upload works correctly. Args: @@ -944,15 +947,16 @@ async def test_upload_file(tmp_path, state, delta, token: str, mocker: MockerFix delta: Expected delta after processing all files. token: a Token. mocker: pytest mocker object. + app_module_mock: The mock for the app module, used to patch the app instance. """ mocker.patch( "reflex.state.State.class_subclasses", {state if state is FileUploadState else FileStateBase1}, ) # The App state must be the "root" of the state tree - app = App() + app = app_module_mock.app = App() app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] - async with app.modify_state(_substate_key(token, state)) as root_state: + async with app.modify_state(BaseStateToken(ident=token, cls=state)) as root_state: root_state.get_substate(state.get_full_name().split("."))._tmp_path = tmp_path data = b"This is binary data" @@ -999,6 +1003,7 @@ async def test_upload_file_keeps_form_open_until_stream_completes( tmp_path, token: str, mocker: MockerFixture, + app_module_mock: unittest.mock.Mock, ): """Test that upload files are not eagerly copied into memory. @@ -1010,16 +1015,19 @@ async def test_upload_file_keeps_form_open_until_stream_completes( tmp_path: Temporary path. token: A token. mocker: pytest mocker object. + app_module_mock: The mock for the app module, used to patch the app instance. """ mocker.patch( "reflex.state.State.class_subclasses", {FileUploadState}, ) - app = App() + app = app_module_mock.app = App() app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] # Set _tmp_path via modify_state instead of setting class attribute directly. - async with app.modify_state(_substate_key(token, FileUploadState)) as root_state: + async with app.modify_state( + BaseStateToken(ident=token, cls=FileUploadState) + ) as root_state: root_state.get_substate( FileUploadState.get_full_name().split(".") )._tmp_path = tmp_path @@ -1133,16 +1141,19 @@ async def test_upload_file_closes_form_if_response_cancelled_before_stream_start tmp_path, token: str, mocker: MockerFixture, + app_module_mock: unittest.mock.Mock, ): """Test that response cancellation before iteration still closes form data.""" mocker.patch( "reflex.state.State.class_subclasses", {FileUploadState}, ) - app = App() + app = app_module_mock.app = App() app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] - async with app.modify_state(_substate_key(token, FileUploadState)) as root_state: + async with app.modify_state( + BaseStateToken(ident=token, cls=FileUploadState) + ) as root_state: root_state.get_substate( FileUploadState.get_full_name().split(".") )._tmp_path = tmp_path @@ -1390,7 +1401,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( } assert constants.ROUTER in app._state()._var_dependencies - substate_token = _substate_key(token, DynamicState) + substate_token = BaseStateToken(ident=token, cls=DynamicState) sid = "mock_sid" client_ip = "127.0.0.1" async with app.state_manager.modify_state(substate_token) as state: @@ -1557,7 +1568,9 @@ def _dynamic_state_event(name, val, **kwargs): @pytest.mark.asyncio -async def test_process_events(mocker: MockerFixture, token: str): +async def test_process_events( + mocker: MockerFixture, token: str, app_module_mock: unittest.mock.Mock +): """Test that an event is processed properly and that it is postprocessed n+1 times. Also check that the processing flag of the last stateupdate is set to False. @@ -1565,6 +1578,7 @@ async def test_process_events(mocker: MockerFixture, token: str): Args: mocker: mocker object. token: a Token. + app_module_mock: The mock for the app module, used to patch the app instance. """ router_data = { "pathname": "/", @@ -1574,7 +1588,7 @@ async def test_process_events(mocker: MockerFixture, token: str): "headers": {}, "ip": "127.0.0.1", } - app = App(_state=GenState) + app = app_module_mock.app = App(_state=GenState) mocker.patch.object(app, "_postprocess", AsyncMock()) event = Event( @@ -2153,7 +2167,7 @@ class Sub(Base): app._event_namespace = AsyncMock() async with app.modify_state( - token=_substate_key(token, Sub.get_name()) + token=BaseStateToken(ident=token, cls=Sub) ) as root_state: sub = root_state.substates[Sub.get_name()] if substate: diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 0775dd000e5..303dfa91fbf 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -34,6 +34,7 @@ from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis +from reflex.istate.manager.token import BaseStateToken from reflex.state import ( BaseState, ImmutableMutableProxy, @@ -44,7 +45,6 @@ State, StateProxy, StateUpdate, - _substate_key, ) from reflex.testing import chdir from reflex.utils import format, prerequisites, types @@ -1698,16 +1698,16 @@ async def state_manager(request) -> AsyncGenerator[StateManager, None]: Yields: A state manager instance """ - state_manager = StateManager.create(state=TestState) + state_manager = StateManager.create() if request.param == "redis": if not isinstance(state_manager, StateManagerRedis): - state_manager = StateManagerRedis(state=TestState, redis=mock_redis()) + state_manager = StateManagerRedis(redis=mock_redis()) elif request.param == "disk": # explicitly NOT using redis - state_manager = StateManagerDisk(state=TestState) + state_manager = StateManagerDisk() assert not state_manager._states_locks else: - state_manager = StateManagerMemory(state=TestState) + state_manager = StateManagerMemory() assert not state_manager._states_locks yield state_manager @@ -1716,7 +1716,7 @@ async def state_manager(request) -> AsyncGenerator[StateManager, None]: @pytest.fixture -def substate_token(state_manager, token) -> str: +def substate_token(state_manager, token) -> BaseStateToken: """A token + substate name for looking up in state manager. Args: @@ -1726,12 +1726,12 @@ def substate_token(state_manager, token) -> str: Returns: Token concatenated with the state_manager's state full_name. """ - return _substate_key(token, state_manager.state) + return BaseStateToken(ident=token, cls=TestState) @pytest.mark.asyncio async def test_state_manager_modify_state( - state_manager: StateManager, token: str, substate_token: str + state_manager: StateManager, token: str, substate_token: BaseStateToken ): """Test that the state manager can modify a state exclusively. @@ -1762,7 +1762,7 @@ async def test_state_manager_modify_state( assert not state_manager._states_locks[token].locked() # separate instances should NOT share locks - sm2 = type(state_manager)(state=TestState) + sm2 = type(state_manager)() assert sm2._state_manager_lock is state_manager._state_manager_lock assert not sm2._states_locks if state_manager._states_locks: @@ -1773,7 +1773,7 @@ async def test_state_manager_modify_state( @pytest.mark.asyncio async def test_state_manager_contend( - state_manager: StateManager, token: str, substate_token: str + state_manager: StateManager, token: str, substate_token: BaseStateToken ): """Multiple coroutines attempting to access the same state. @@ -1820,11 +1820,11 @@ async def state_manager_redis() -> AsyncGenerator[StateManager, None]: Yields: A state manager instance """ - state_manager = StateManager.create(TestState) + state_manager = StateManager.create() if not isinstance(state_manager, StateManagerRedis): # Create a mocked redis client instead of skipping. - state_manager = StateManagerRedis(state=TestState, redis=mock_redis()) + state_manager = StateManagerRedis(redis=mock_redis()) yield state_manager @@ -1842,12 +1842,14 @@ def substate_token_redis(state_manager_redis, token): Returns: Token concatenated with the state_manager's state full_name. """ - return _substate_key(token, state_manager_redis.state) + return BaseStateToken(ident=token, cls=TestState) @pytest.mark.asyncio async def test_state_manager_lock_expire( - state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str + state_manager_redis: StateManagerRedis, + token: str, + substate_token_redis: BaseStateToken, ): """Test that the state manager lock expires and raises exception exiting context. @@ -1892,7 +1894,9 @@ def loop_exception_handler(loop, context): @pytest.mark.asyncio async def test_state_manager_lock_expire_contend( - state_manager_redis: StateManagerRedis, token: str, substate_token_redis: str + state_manager_redis: StateManagerRedis, + token: str, + substate_token_redis: BaseStateToken, ): """Test that the state manager lock expires and queued waiters proceed. @@ -1968,7 +1972,7 @@ async def _coro_waiter(): async def test_state_manager_lock_warning_threshold_contend( state_manager_redis: StateManagerRedis, token: str, - substate_token_redis: str, + substate_token_redis: BaseStateToken, mocker: MockerFixture, ): """Test that the state manager triggers a warning when lock contention exceeds the warning threshold. @@ -2150,7 +2154,12 @@ async def test_state_proxy( pickle_state = parent_state._serialize() if pickle_state: await mock_app.state_manager.redis.set( - _substate_key(parent_state.router.session.client_token, parent_state), + str( + BaseStateToken( + ident=parent_state.router.session.client_token, + cls=type(parent_state), + ) + ), pickle_state, ex=mock_app.state_manager.token_expiration, ) @@ -2205,7 +2214,10 @@ async def test_state_proxy( # Get the state from the state manager directly and check that the value is updated gotten_state = await mock_app.state_manager.get_state( - _substate_key(grandchild_state.router.session.client_token, grandchild_state) + BaseStateToken( + ident=grandchild_state.router.session.client_token, + cls=type(grandchild_state), + ) ) if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): # For in-process store, only one instance of the state exists @@ -2360,7 +2372,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): namespace._token_manager.token_to_socket[token] = SocketRecord( instance_id="mock", sid=sid ) - mock_app.state_manager.state = mock_app._state = BackgroundTaskState + mock_app._state = BackgroundTaskState async for update in rx.app.process( mock_app, Event( @@ -2426,7 +2438,7 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): ] background_task_state = await mock_app.state_manager.get_state( - _substate_key(token, BackgroundTaskState) + BaseStateToken(ident=token, cls=BackgroundTaskState) ) assert isinstance(background_task_state, BackgroundTaskState) assert background_task_state.order == exp_order @@ -2491,7 +2503,7 @@ async def test_background_task_reset(mock_app: rx.App, token: str): token: A token. """ router_data = {"query": {}} - mock_app.state_manager.state = mock_app._state = BackgroundTaskState + mock_app._state = BackgroundTaskState async for update in rx.app.process( mock_app, Event( @@ -2516,7 +2528,7 @@ async def test_background_task_reset(mock_app: rx.App, token: str): await mock_app.state_manager.close() background_task_state = await mock_app.state_manager.get_state( - _substate_key(token, BackgroundTaskState) + BaseStateToken(ident=token, cls=BackgroundTaskState) ) assert isinstance(background_task_state, BackgroundTaskState) assert background_task_state.order == ["reset"] @@ -3087,7 +3099,9 @@ def index(): app.add_page(index, on_load=test_state.test_handler) app._compile_page("index") - async with app.state_manager.modify_state(_substate_key(token, State)) as state: + async with app.state_manager.modify_state( + BaseStateToken(ident=token, cls=State) + ) as state: state.router_data = {"simulate": "hydrate"} updates = [] @@ -3140,7 +3154,9 @@ def index(): app.add_page(index, on_load=[OnLoadState.test_handler, OnLoadState.test_handler]) app._compile_page("index") - async with app.state_manager.modify_state(_substate_key(token, State)) as state: + async with app.state_manager.modify_state( + BaseStateToken(ident=token, cls=State) + ) as state: state.router_data = {"simulate": "hydrate"} updates = [] @@ -3181,11 +3197,11 @@ async def test_get_state(mock_app: rx.App, token: str): mock_app: An app that will be returned by `get_app()` token: A token. """ - mock_app.state_manager.state = mock_app._state = TestState + mock_app._state = TestState # Get instance of ChildState2. test_state = await mock_app.state_manager.get_state( - _substate_key(token, ChildState2) + BaseStateToken(ident=token, cls=ChildState2) ) assert isinstance(test_state, TestState) if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): @@ -3249,7 +3265,7 @@ async def test_get_state(mock_app: rx.App, token: str): # Get a fresh instance new_test_state = await mock_app.state_manager.get_state( - _substate_key(token, ChildState2) + BaseStateToken(ident=token, cls=ChildState2) ) assert isinstance(new_test_state, TestState) if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): @@ -3342,10 +3358,12 @@ class GreatGrandchild3(Grandchild3): has a computed var. """ - mock_app.state_manager.state = mock_app._state = Parent + mock_app._state = Parent # Get the top level state via unconnected sibling. - root = await mock_app.state_manager.get_state(_substate_key(token, Child)) + root = await mock_app.state_manager.get_state( + BaseStateToken(ident=token, cls=Child) + ) # Set value in parent_var to assert it does not get refetched later. root.parent_var = 1 @@ -3428,8 +3446,7 @@ def foo(self) -> str: ] # Get state from state manager. - state_manager.state = State - rx_state = await state_manager.get_state(_substate_key(token, State)) + rx_state = await state_manager.get_state(BaseStateToken(ident=token, cls=State)) assert RouterVarParentState.get_name() in rx_state.substates parent_state = rx_state.substates[RouterVarParentState.get_name()] assert RouterVarDepState.get_name() in parent_state.substates @@ -3452,7 +3469,9 @@ async def test_setvar(mock_app: rx.App, token: str): mock_app: An app that will be returned by `get_app()` token: A token. """ - state = await mock_app.state_manager.get_state(_substate_key(token, TestState)) + state = await mock_app.state_manager.get_state( + BaseStateToken(ident=token, cls=TestState) + ) assert isinstance(state, TestState) # Set Var in same state (with Var type casting) @@ -3552,9 +3571,8 @@ def test_redis_state_manager_config_knobs(tmp_path, expiration_kwargs, expected_ with chdir(proj_root): # reload config for each parameter to avoid stale values reflex.config.get_config(reload=True) - from reflex.state import State - state_manager = StateManagerRedis(state=State, redis=mock_redis()) + state_manager = StateManagerRedis(redis=mock_redis()) assert state_manager.lock_expiration == expected_values[0] # pyright: ignore [reportAttributeAccessIssue] assert state_manager.token_expiration == expected_values[1] # pyright: ignore [reportAttributeAccessIssue] assert state_manager.lock_warning_threshold == expected_values[2] # pyright: ignore [reportAttributeAccessIssue] @@ -3589,10 +3607,9 @@ def test_redis_state_manager_config_knobs_invalid_lock_warning_threshold( with chdir(proj_root): # reload config for each parameter to avoid stale values reflex.config.get_config(reload=True) - from reflex.state import State with pytest.raises(InvalidLockWarningThresholdError): - StateManagerRedis(state=State, redis=mock_redis()) + StateManagerRedis(redis=mock_redis()) del sys.modules[constants.Config.MODULE] @@ -3812,8 +3829,10 @@ class State(Root): class Child(State): foo: str = "bar" - dsm = StateManagerDisk(state=Root) - async with dsm.modify_state(token) as root: + bs_token = BaseStateToken(ident=token, cls=Root) + + dsm = StateManagerDisk() + async with dsm.modify_state(bs_token) as root: s = await root.get_state(State) s.num += 1 c = await root.get_state(Child) @@ -3821,8 +3840,8 @@ class Child(State): assert not c._get_was_touched() await dsm.close() - dsm2 = StateManagerDisk(state=Root) - root = await dsm2.get_state(token) + dsm2 = StateManagerDisk() + root = await dsm2.get_state(bs_token) s = await root.get_state(State) assert s.num == 43 c = await root.get_state(Child) @@ -4197,7 +4216,9 @@ async def test_upcast_event_handler_arg(handler, payload): @pytest.mark.asyncio -async def test_get_var_value(state_manager: StateManager, substate_token: str): +async def test_get_var_value( + state_manager: StateManager, substate_token: BaseStateToken +): """Test that get_var_value works correctly. Args: @@ -4270,10 +4291,12 @@ async def v(self) -> int: child3 = await self.get_state(Child3) return child3.child3_var + p.parent_var - mock_app.state_manager.state = mock_app._state = Parent + mock_app._state = Parent # Get the top level state via unconnected sibling. - root = await mock_app.state_manager.get_state(_substate_key(token, Child)) + root = await mock_app.state_manager.get_state( + BaseStateToken(ident=token, cls=Child) + ) # Set value in parent_var to assert it does not get refetched later. root.parent_var = 1 @@ -4352,9 +4375,11 @@ class OtherState(rx.State): data: list[dict[str, Any]] = [{"foo": "bar"}] - mock_app.state_manager.state = mock_app._state = rx.State + mock_app._state = rx.State comp = Table.create(data=OtherState.data) - state = await mock_app.state_manager.get_state(_substate_key(token, OtherState)) + state = await mock_app.state_manager.get_state( + BaseStateToken(ident=token, cls=OtherState) + ) other_state = await state.get_state(OtherState) assert comp.State is not None # The state should have been pre-cached from the dependency. @@ -4412,8 +4437,10 @@ class OtherState(rx.State): async def fetch_data_state(self) -> None: print(await self.get_state(DataState)) - mock_app.state_manager.state = mock_app._state = rx.State - state = await mock_app.state_manager.get_state(_substate_key(token, OtherState)) + mock_app._state = rx.State + state = await mock_app.state_manager.get_state( + BaseStateToken(ident=token, cls=OtherState) + ) other_state = await state.get_state(OtherState) await other_state.fetch_data_state() # Should not raise exception. @@ -4427,9 +4454,9 @@ class MutableProxyState(BaseState): @pytest.mark.asyncio async def test_rebind_mutable_proxy(mock_app: rx.App, token: str) -> None: """Test that previously bound MutableProxy instances can be rebound correctly.""" - mock_app.state_manager.state = mock_app._state = MutableProxyState + mock_app._state = MutableProxyState async with mock_app.state_manager.modify_state( - _substate_key(token, MutableProxyState) + BaseStateToken(ident=token, cls=MutableProxyState) ) as state: state.router = RouterData.from_router_data({ "query": {}, @@ -4465,7 +4492,7 @@ async def test_rebind_mutable_proxy(mock_app: rx.App, token: str) -> None: state_proxy.data["a"].append(3) async with mock_app.state_manager.modify_state( - _substate_key(token, MutableProxyState) + BaseStateToken(ident=token, cls=MutableProxyState) ) as state: assert isinstance(state, MutableProxyState) assert state.data["a"] == [2, 3] diff --git a/tests/units/test_state_tree.py b/tests/units/test_state_tree.py index 7ed19500cc2..28bd4acf3a1 100644 --- a/tests/units/test_state_tree.py +++ b/tests/units/test_state_tree.py @@ -8,7 +8,8 @@ import reflex as rx from reflex.constants.state import FIELD_MARKER from reflex.istate.manager.redis import StateManagerRedis -from reflex.state import BaseState, StateManager, _substate_key +from reflex.istate.manager.token import BaseStateToken +from reflex.state import BaseState, StateManager class Root(BaseState): @@ -371,7 +372,9 @@ async def test_get_state_tree( exp_root_substates: The expected substates of the root state. exp_root_dict_keys: The expected keys of the root state dict. """ - state = await state_manager_redis.get_state(_substate_key(token, substate_cls)) + state = await state_manager_redis.get_state( + BaseStateToken(ident=token, cls=substate_cls) + ) assert isinstance(state, Root) assert sorted(state.substates) == sorted(exp_root_substates) From cdacd2291fd57b6f6eebcb28a8ba58d9a626c33f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Sun, 8 Mar 2026 05:43:14 -0700 Subject: [PATCH 03/81] disambiguate class name --- reflex/istate/manager/token.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/reflex/istate/manager/token.py b/reflex/istate/manager/token.py index f9eb85b1a12..ae98f405ec7 100644 --- a/reflex/istate/manager/token.py +++ b/reflex/istate/manager/token.py @@ -39,7 +39,9 @@ def __str__(self) -> str: """ # urlencode the redis token to escape the slash delimiter. clean_ident = self.ident.replace("/", "%2F") - clean_cls_name = self.cls.__name__.replace("/", "%2F") + clean_cls_name = f"{self.cls.__module__}.{self.cls.__name__}".replace( + "/", "%2F" + ) return f"{clean_ident}/{clean_cls_name}" @classmethod From d463ef8af8807690d51e768a7cd778644f8296dc Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 17 Mar 2026 15:25:14 -0700 Subject: [PATCH 04/81] handle legacy tokens passed to app.modify_state --- reflex/__init__.py | 1 + reflex/app.py | 33 ++++++++++++++++++++++++++-- reflex/istate/manager/token.py | 39 ++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/reflex/__init__.py b/reflex/__init__.py index 066df110f02..d7e6b03feb2 100644 --- a/reflex/__init__.py +++ b/reflex/__init__.py @@ -327,6 +327,7 @@ "LocalStorage", "SessionStorage", ], + "istate.manager.token": ["StateToken", "BaseStateToken"], "middleware": ["middleware", "Middleware"], "model": ["asession", "session", "Model", "ModelRegistry"], "page": ["page"], diff --git a/reflex/app.py b/reflex/app.py index 73f3cedbc6e..fd0b3a048ea 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -29,7 +29,16 @@ from pathlib import Path from timeit import default_timer as timer from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, BinaryIO, ParamSpec, get_args, get_type_hints +from typing import ( + TYPE_CHECKING, + Any, + BinaryIO, + ParamSpec, + get_args, + get_type_hints, + overload, +) +from warnings import deprecated from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp as EngineIOApp @@ -1564,10 +1573,27 @@ def all_routes(_request: Request) -> Response: str(constants.Endpoint.ALL_ROUTES), all_routes, methods=["GET"] ) + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + def modify_state( + self, + token: str, + background: bool = False, + previous_dirty_vars: dict[str, set[str]] | None = None, + ) -> contextlib.AbstractAsyncContextManager[BaseState]: ... + + @overload + def modify_state( + self, + token: BaseStateToken, + background: bool = False, + previous_dirty_vars: dict[str, set[str]] | None = None, + ) -> contextlib.AbstractAsyncContextManager[BaseState]: ... + @contextlib.asynccontextmanager async def modify_state( self, - token: BaseStateToken, + token: BaseStateToken | str, background: bool = False, previous_dirty_vars: dict[str, set[str]] | None = None, ) -> AsyncIterator[BaseState]: @@ -1588,6 +1614,9 @@ async def modify_state( msg = "App has not been initialized yet." raise RuntimeError(msg) + if isinstance(token, str): + token = BaseStateToken.from_legacy_token(token, root_state=self._state) + # Get exclusive access to the state. async with self.state_manager.modify_state_with_links( token, previous_dirty_vars=previous_dirty_vars diff --git a/reflex/istate/manager/token.py b/reflex/istate/manager/token.py index ae98f405ec7..b9b4b75cfd9 100644 --- a/reflex/istate/manager/token.py +++ b/reflex/istate/manager/token.py @@ -4,6 +4,8 @@ import pickle from typing import TYPE_CHECKING, BinaryIO, Generic, Self, TypeVar +from reflex.utils import console + if TYPE_CHECKING: from reflex.state import BaseState @@ -169,3 +171,40 @@ def get_and_reset_touched_state(cls, state: BaseState) -> bool: was_touched = state._get_was_touched() state._was_touched = False # Reset the touched flag after serializing. return was_touched + + @classmethod + def from_legacy_token( + cls, legacy_token: str, root_state: "type[BaseState] | None" + ) -> Self: + """Create a BaseStateToken from a legacy token string. + + The legacy token format is "{ident}_{module_path}.{class_name}". + + Args: + legacy_token: The legacy token string to convert. + root_state: The root state instance. + + Returns: + A BaseStateToken instance created from the legacy token. + + Raises: + ValueError: If the legacy token format is invalid or if the state class cannot be found + """ + from reflex.state import _split_substate_key + + if root_state is None: + msg = ( + "Root state must be provided to convert legacy token to BaseStateToken." + ) + raise ValueError(msg) + + console.deprecate( + feature_name="Passing a string to modify_state", + reason="Use rx.BaseStateToken(token, state_cls) instead of the legacy string format", + deprecation_version="0.9.0", + removal_version="1.0", + ) + + client_token, state_path = _split_substate_key(legacy_token) + state_cls = root_state.get_class_substate(tuple(state_path.split("."))) # type: ignore[union-attr] + return cls(ident=client_token, cls=state_cls) From b715c27739b9aec116bc7b9b1ef8abb356666369 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 17 Mar 2026 16:14:14 -0700 Subject: [PATCH 05/81] Add EventContext module --- reflex/ievent/__init__.py | 1 + reflex/ievent/context.py | 78 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 reflex/ievent/__init__.py create mode 100644 reflex/ievent/context.py diff --git a/reflex/ievent/__init__.py b/reflex/ievent/__init__.py new file mode 100644 index 00000000000..b2e670ff4b3 --- /dev/null +++ b/reflex/ievent/__init__.py @@ -0,0 +1 @@ +"""Internal event processing.""" diff --git a/reflex/ievent/context.py b/reflex/ievent/context.py new file mode 100644 index 00000000000..bf8e0e28e6f --- /dev/null +++ b/reflex/ievent/context.py @@ -0,0 +1,78 @@ +"""The context and associated metadata for handling an event.""" + +import dataclasses +import functools +import uuid +from collections.abc import Callable, Mapping +from contextvars import ContextVar +from typing import TYPE_CHECKING, Any, Protocol + +from reflex.istate.manager import StateManager +from reflex.utils.format import to_snake_case + +if TYPE_CHECKING: + from reflex.event import Event + + +@functools.lru_cache +def get_name(cls: type | Callable) -> str: + """Get the name of the state/func. + + Returns: + The name of the state/func. + """ + module = cls.__module__.replace(".", "___") + qualname = getattr(cls, "__qualname__", cls.__name__).replace(".", "___") + return to_snake_case(f"{module}___{qualname}") + + +class EnqueueProtocol(Protocol): + """Protocol for the enqueue function in the event context.""" + + async def __call__(self, event: Event) -> None: + """Enqueue an event handler to be executed. + + Args: + event: The event to enqueue. + """ + ... + + +class EmitDeltaProtocol(Protocol): + """Protocol for the emit_delta function in the event context.""" + + async def __call__( + self, + deltas: Mapping[str, Mapping[str, Any]], + ) -> None: + """Emit a delta to the frontend. + + Args: + deltas: The deltas to emit, mapping client tokens to variable updates. + """ + ... + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class EventContext: + """The context for an event.""" + + # Identifies the client session. + token: str + + # Manages persistence of state across events. + state_manager: StateManager + + # Function responsible for enqueuing an event handler to be executed. + enqueue: EnqueueProtocol + + # Each event is associated with a top-level transaction id. + txid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex[:12]) + # The txid of another EventContext that enqueued this context's event. + parent_txid: str | None = None + + emit_delta: EmitDeltaProtocol | None = None + cached_states: dict[type, Any] = dataclasses.field(default_factory=dict, init=False) + + +event_context: ContextVar[EventContext] = ContextVar("event_context") From 57d0d01cabd2fcfb957269626f8f215735eecc32 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 19 Mar 2026 23:05:25 -0700 Subject: [PATCH 06/81] Move BaseState event processing to reflex.ievent.processor package Create the EventProcessor class to manage the backend event queue. Move event processing logic out of the BaseState and into a separate module. --- reflex/app.py | 211 +++------- reflex/app_mixins/middleware.py | 5 +- reflex/ievent/context.py | 49 ++- reflex/ievent/processor/__init__.py | 10 + .../ievent/processor/base_state_processor.py | 261 ++++++++++++ reflex/ievent/processor/event_processor.py | 379 ++++++++++++++++++ reflex/ievent/registry.py | 28 ++ reflex/istate/manager/redis.py | 13 +- reflex/state.py | 344 ++-------------- reflex/utils/format.py | 4 +- reflex/utils/serializers.py | 11 + reflex/utils/tasks.py | 28 +- tests/integration/test_background_task.py | 2 +- 13 files changed, 840 insertions(+), 505 deletions(-) create mode 100644 reflex/ievent/processor/__init__.py create mode 100644 reflex/ievent/processor/base_state_processor.py create mode 100644 reflex/ievent/processor/event_processor.py create mode 100644 reflex/ievent/registry.py diff --git a/reflex/app.py b/reflex/app.py index fd0b3a048ea..d9293512050 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -16,7 +16,6 @@ import traceback import urllib.parse from collections.abc import ( - AsyncGenerator, AsyncIterator, Awaitable, Callable, @@ -89,11 +88,10 @@ EventSpec, EventType, IndividualEventType, - get_hydrate_event, noop, ) +from reflex.ievent.processor import EventProcessor from reflex.istate.manager.token import BaseStateToken -from reflex.istate.proxy import StateProxy from reflex.page import DECORATED_PAGES from reflex.route import ( get_route_args, @@ -164,24 +162,26 @@ def default_backend_exception_handler(exception: Exception) -> EventSpec: """ from reflex.components.sonner.toast import toast - error = traceback.format_exc() + error = traceback.format_exception( + type(exception), exception, exception.__traceback__ + ) - console.error(f"[Reflex Backend Exception]\n {error}\n") + console.error(f"[Reflex Backend Exception]\n {''.join(error)}\n") error_message = ( ["Contact the website administrator."] if is_prod_mode() - else [f"{type(exception).__name__}: {exception}.", "See logs for details."] + else [f"{type(exception).__name__}: {exception}", "See logs for details."] ) return toast( "An error occurred.", level="error", fallback_to_alert=True, - description="
".join(error_message), + description="\n".join(error_message), position="top-center", id="backend_error", - style={"width": "500px"}, + style={"width": "500px", "white-space": "pre-wrap"}, ) @@ -440,8 +440,8 @@ class App(MiddlewareMixin, LifespanMixin): # The async server name space. _event_namespace: EventNamespace | None = None - # Background tasks that are currently running. - _background_tasks: set[asyncio.Task] = dataclasses.field(default_factory=set) + # The processor queue for handling events. + _event_processor: EventProcessor | None = None # Frontend Error Handler Function frontend_exception_handler: Callable[[Exception], None] = ( @@ -473,6 +473,18 @@ def event_namespace(self) -> EventNamespace | None: """ return self._event_namespace + @property + def event_processor(self) -> EventProcessor: + """Get the event processor. + + Raises: + RuntimeError: If the event processor is not initialized. + """ + if self._event_processor is None: + msg = "Event processor is not initialized." + raise RuntimeError(msg) + return self._event_processor + def __post_init__(self): """Initialize the app. @@ -606,6 +618,21 @@ async def modified_send(message: Message): # Check the exception handlers self._validate_exception_handlers() + # Ensure the event processor starts and stops with the server. + self.register_lifespan_task(self._setup_event_processor) + + @contextlib.asynccontextmanager + async def _setup_event_processor(self) -> AsyncIterator[None]: + # Create the event processor. + self._event_processor = EventProcessor( + middleware=self, backend_exception_handler=self.backend_exception_handler + ) + async with self._event_processor.configure( + state_manager=self.state_manager, + event_namespace=self.event_namespace, + ): + yield + def __repr__(self) -> str: """Get the string representation of the app. @@ -1630,62 +1657,11 @@ async def modify_state( await self.event_namespace.emit_update( update=StateUpdate( delta=delta, - final=True if not background else None, + final=True, ), token=token.ident, ) - def _process_background( - self, state: BaseState, event: Event - ) -> asyncio.Task | None: - """Process an event in the background and emit updates as they arrive. - - Args: - state: The state to process the event for. - event: The event to process. - - Returns: - Task if the event was backgroundable, otherwise None - """ - substate, handler = state._get_event_handler(event) - - if not handler.is_background: - return None - - substate = StateProxy(substate) - - async def _coro(): - """Coroutine to process the event and emit updates inside an asyncio.Task. - - Raises: - RuntimeError: If the app has not been initialized yet. - """ - if self.event_namespace is None: - msg = "App has not been initialized yet." - raise RuntimeError(msg) - - # Process the event. - async for update in state._process_event( - handler=handler, state=substate, payload=event.payload - ): - # Postprocess the event. - update = await self._postprocess(state, event, update) - - # Send the update to the client. - await self.event_namespace.emit_update( - update=update, - token=event.token, - ) - - task = asyncio.create_task( - _coro(), - name=f"reflex_background_task|{event.name}|{time.time()}|{event.token}", - ) - self._background_tasks.add(task) - # Clean up task from background_tasks set when complete. - task.add_done_callback(self._background_tasks.discard) - return task - def _validate_exception_handlers(self): """Validate the custom event exception handlers for front- and backend. @@ -1779,95 +1755,6 @@ def _validate_exception_handlers(self): raise ValueError(msg) -async def process( - app: App, event: Event, sid: str, headers: dict, client_ip: str -) -> AsyncGenerator[StateUpdate]: - """Process an event. - - Args: - app: The app to process the event for. - event: The event to process. - sid: The Socket.IO session id. - headers: The client headers. - client_ip: The client_ip. - - Raises: - Exception: If a reflex specific error occurs during processing the event. - - Yields: - The state updates after processing the event. - """ - from reflex.utils import telemetry - - try: - # Add request data to the state. - router_data = event.router_data - router_data.update({ - constants.RouteVar.QUERY: format.format_query_params(event.router_data), - constants.RouteVar.CLIENT_TOKEN: event.token, - constants.RouteVar.SESSION_ID: sid, - constants.RouteVar.HEADERS: headers, - constants.RouteVar.CLIENT_IP: client_ip, - }) - # Get the state for the session exclusively. - async with app.state_manager.modify_state_with_links( - event.substate_token, event=event - ) as state: - # When this is a brand new instance of the state, signal the - # frontend to reload before processing it. - if ( - not state.router_data - and event.name != get_hydrate_event(state) - and app.event_namespace is not None - ): - await asyncio.create_task( - app.event_namespace.emit( - "reload", - data=event, - to=sid, - ), - name=f"reflex_emit_reload|{event.name}|{time.time()}|{event.token}", - ) - return - router_data[constants.RouteVar.PATH] = "/" + ( - app.router(path) or "404" - if (path := router_data.get(constants.RouteVar.PATH)) - else "404" - ).removeprefix("/") - # re-assign only when the value is different - if state.router_data != router_data: - # assignment will recurse into substates and force recalculation of - # dependent ComputedVar (dynamic route variables) - state.router_data = router_data - state.router = RouterData.from_router_data(router_data) - - # Preprocess the event. - update = await app._preprocess(state, event) - - # If there was an update, yield it. - if update is not None: - yield update - - # Only process the event if there is no update. - else: - if app._process_background(state, event) is not None: - # `final=True` allows the frontend send more events immediately. - yield StateUpdate(final=True) - else: - # Process the event synchronously. - async for update in state._process(event): - # Postprocess the event. - update = await app._postprocess(state, event, update) - - # Yield the update. - yield update - except Exception as ex: - telemetry.send_error(ex, context="backend") - - app.backend_exception_handler(ex) - raise - - def ping(_request: Request) -> Response: """Test API endpoint. @@ -2290,14 +2177,20 @@ async def on_event(self, sid: str, data: Any): .partition(",")[0] .strip() ) - - async with contextlib.aclosing( - process(self.app, event, sid, headers, client_ip) - ) as updates_gen: - # Process the events. - async for update in updates_gen: - # Emit the update from processing the event. - await self.emit_update(update=update, token=event.token) + router_data = event.router_data + router_data.update({ + constants.RouteVar.QUERY: format.format_query_params(event.router_data), + constants.RouteVar.CLIENT_TOKEN: event.token, + constants.RouteVar.SESSION_ID: sid, + constants.RouteVar.HEADERS: headers, + constants.RouteVar.CLIENT_IP: client_ip, + }) + router_data[constants.RouteVar.PATH] = "/" + ( + self.app.router(path) or "404" + if (path := router_data.get(constants.RouteVar.PATH)) + else "404" + ).removeprefix("/") + await self.app.event_processor.enqueue(token=event.token, event=event) async def on_ping(self, sid: str): """Event for testing the API endpoint. diff --git a/reflex/app_mixins/middleware.py b/reflex/app_mixins/middleware.py index b78b96ec2dd..77895ba1058 100644 --- a/reflex/app_mixins/middleware.py +++ b/reflex/app_mixins/middleware.py @@ -6,7 +6,7 @@ import inspect from reflex.event import Event -from reflex.middleware import HydrateMiddleware, Middleware +from reflex.middleware import Middleware from reflex.state import BaseState, StateUpdate from .mixin import AppMixin @@ -19,9 +19,6 @@ class MiddlewareMixin(AppMixin): # Middleware to add to the app. Users should use `add_middleware`. _middlewares: list[Middleware] = dataclasses.field(default_factory=list) - def _init_mixin(self): - self._middlewares.append(HydrateMiddleware()) - def add_middleware(self, middleware: Middleware, index: int | None = None): """Add middleware to the app. diff --git a/reflex/ievent/context.py b/reflex/ievent/context.py index bf8e0e28e6f..cc329abfcac 100644 --- a/reflex/ievent/context.py +++ b/reflex/ievent/context.py @@ -29,11 +29,12 @@ def get_name(cls: type | Callable) -> str: class EnqueueProtocol(Protocol): """Protocol for the enqueue function in the event context.""" - async def __call__(self, event: Event) -> None: + async def __call__(self, token: str, *events: Event) -> None: """Enqueue an event handler to be executed. Args: - event: The event to enqueue. + token: The client token associated with the event. + events: The events to enqueue. """ ... @@ -43,12 +44,14 @@ class EmitDeltaProtocol(Protocol): async def __call__( self, - deltas: Mapping[str, Mapping[str, Any]], + token: str, + delta: Mapping[str, Mapping[str, Any]], ) -> None: """Emit a delta to the frontend. Args: - deltas: The deltas to emit, mapping client tokens to variable updates. + token: The client token to emit the delta to. + delta: The deltas to emit, mapping client tokens to variable updates. """ ... @@ -64,15 +67,49 @@ class EventContext: state_manager: StateManager # Function responsible for enqueuing an event handler to be executed. - enqueue: EnqueueProtocol + enqueue_impl: EnqueueProtocol # Each event is associated with a top-level transaction id. txid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex[:12]) # The txid of another EventContext that enqueued this context's event. parent_txid: str | None = None - emit_delta: EmitDeltaProtocol | None = None + emit_delta_impl: EmitDeltaProtocol | None = None cached_states: dict[type, Any] = dataclasses.field(default_factory=dict, init=False) + def fork(self, token: str | None = None) -> "EventContext": + """Return a new EventContext with the specified fields replaced. + + Args: + token: The client token for the new context. + + Returns: + A new EventContext with the specified fields replaced. + """ + return type(self)( + token=token or self.token, + parent_txid=self.txid, + state_manager=self.state_manager, + enqueue_impl=self.enqueue_impl, + emit_delta_impl=self.emit_delta_impl, + ) + + async def emit_delta(self, delta: Mapping[str, Mapping[str, Any]]) -> None: + """Emit a delta to the frontend. + + Args: + delta: The deltas to emit, mapping client tokens to variable updates. + """ + if self.emit_delta_impl is not None: + await self.emit_delta_impl(self.token, delta) + + async def enqueue(self, *event: Event) -> None: + """Enqueue an event handler to be executed. + + Args: + event: The event to enqueue. + """ + await self.enqueue_impl(self.token, *event) + event_context: ContextVar[EventContext] = ContextVar("event_context") diff --git a/reflex/ievent/processor/__init__.py b/reflex/ievent/processor/__init__.py new file mode 100644 index 00000000000..35a4c2182ba --- /dev/null +++ b/reflex/ievent/processor/__init__.py @@ -0,0 +1,10 @@ +"""Procedures for handling events.""" + +from reflex.ievent.processor import base_state_processor +from reflex.ievent.processor.event_processor import EventProcessor, EventQueueEntry + +__all__ = [ + "EventProcessor", + "EventQueueEntry", + "base_state_processor", +] diff --git a/reflex/ievent/processor/base_state_processor.py b/reflex/ievent/processor/base_state_processor.py new file mode 100644 index 00000000000..d8596c025db --- /dev/null +++ b/reflex/ievent/processor/base_state_processor.py @@ -0,0 +1,261 @@ +"""Functions for processing BaseState-derived event handlers.""" + +import dataclasses +import functools +import inspect +import warnings +from collections.abc import Mapping, Sequence +from enum import Enum +from importlib.util import find_spec +from typing import TYPE_CHECKING, Any + +from reflex.ievent.context import event_context +from reflex.istate.proxy import StateProxy +from reflex.utils import console, types +from reflex.utils.monitoring import is_pyleak_enabled, monitor_loopblocks + +if TYPE_CHECKING: + from reflex.event import EventHandler, EventSpec + from reflex.state import BaseState + + +def _check_valid_yield(events: Any, handler_name: str = "unknown") -> Any: + """Check if the events yielded are valid. They must be EventHandlers or EventSpecs. + + Args: + events: The events to be checked. + handler_name: The name of the handler that yielded the events, used for error messages. + + Raises: + TypeError: If any of the events are not valid. + + Returns: + The events as they are if valid. + """ + from reflex.event import Event, EventHandler, EventSpec + + def _is_valid_type(events: Any) -> bool: + return isinstance(events, (Event, EventHandler, EventSpec)) + + if events is None or _is_valid_type(events): + return events + + if not (isinstance(events, Sequence) and not isinstance(events, (str, bytes))): + events = [events] + + try: + if all(_is_valid_type(e) for e in events): + return events + except TypeError: + pass + + coroutines = [e for e in events if inspect.iscoroutine(e)] + + for coroutine in coroutines: + coroutine_name = coroutine.__qualname__ + warnings.filterwarnings( + "ignore", message=f"coroutine '{coroutine_name}' was never awaited" + ) + + msg = ( + f"Your handler {handler_name} must only return/yield: None, Events or other EventHandlers referenced by their class (i.e. using `type(self)` or other class references)." + f" Returned events of types {', '.join(map(str, map(type, events)))!s}." + ) + raise TypeError(msg) + + +def _transform_event_arg(value: Any, hinted_args: Any) -> Any: + """Transform an event argument based on its type hint. + + Args: + value: The value to transform. + hinted_args: The type hint for the argument. + + Returns: + The transformed value. + + Raises: + ValueError: If a string value is received for an int or float type and cannot be converted. + """ + from reflex.model import Model + from reflex.utils.serializers import deserializers + + if hinted_args is Any: + return value + if types.is_union(hinted_args): + if value is None: + return value + hinted_args = types.value_inside_optional(hinted_args) + if ( + isinstance(value, dict) + and isinstance(hinted_args, type) + and not types.is_generic_alias(hinted_args) # py3.10 + ): + if issubclass(hinted_args, Model): + # Remove non-fields from the payload + return hinted_args(**{ + key: value + for key, value in value.items() + if key in hinted_args.__fields__ + }) + if dataclasses.is_dataclass(hinted_args): + return hinted_args(**value) + if find_spec("pydantic"): + from pydantic import BaseModel as BaseModelV2 + from pydantic.v1 import BaseModel as BaseModelV1 + + if issubclass(hinted_args, BaseModelV1): + return hinted_args.parse_obj(value) + if issubclass(hinted_args, BaseModelV2): + return hinted_args.model_validate(value) + if isinstance(value, list) and (hinted_args is set or hinted_args is set): + return set(value) + if isinstance(value, list) and (hinted_args is tuple or hinted_args is tuple): + return tuple(value) + if isinstance(hinted_args, type) and issubclass(hinted_args, Enum): + try: + return hinted_args(value) + except ValueError: + msg = f"Received an invalid enum value ({value}) for type {hinted_args}" + raise ValueError(msg) from None + if ( + isinstance(value, str) + and (deserializer := deserializers.get(hinted_args)) is not None + ): + try: + return deserializer(value) + except ValueError: + msg = f"Received a string value ({value}) but expected a {hinted_args}" + raise ValueError(msg) from None + return value + + +def _transform_event_payload( + payload: Mapping[str, Any], type_hints: Mapping[str, Any] +) -> dict[str, Any]: + """Transform an event payload based on the type hints of the handler. + + Args: + payload: The event payload to transform. + type_hints: The type hints for the handler's arguments. + + Returns: + The transformed event payload. + """ + transformed = {} + for arg, value in list(payload.items()): + hinted_args = type_hints.get(arg, Any) + try: + transformed[arg] = _transform_event_arg(value, hinted_args) + except Exception as ex: + msg = f"Error transforming event argument '{arg}' with value '{value}' and type hint '{hinted_args}'" + raise ValueError(msg) from ex + return transformed + + +async def chain_updates( + events: EventSpec | list[EventSpec] | None, + handler_name: str, + root_state: BaseState | None = None, +) -> None: + """Chain yielded events and emit a delta to the frontend. + + Check for validitity and convert the EventSpec into qualified Event objects + to be queued against the current EventContext. + + Args: + events: The events to queue with the update. + handler_name: The name of the handler that yielded the events, used for error messages. + root_state: The root state of the app, no delta emitted if omitted. + """ + from reflex.event import fix_events + + ctx = event_context.get() + token = ctx.token + + # Convert valid EventHandler and EventSpec into Event + if fixed_events := fix_events( + _check_valid_yield(events, handler_name=handler_name), token + ): + await ctx.enqueue(*fixed_events) + + if root_state is not None: + # Get the delta after processing the event. + try: + delta = await root_state._get_resolved_delta() + if delta: + await ctx.emit_delta(delta) + finally: + root_state._clean() + + +async def process_event( + handler: EventHandler, + payload: dict, + state: BaseState | StateProxy, + root_state: BaseState, +): + """Process event. + + Args: + handler: EventHandler to process. + payload: The event payload. + state: State to process the handler. + root_state: The root state of the app, used for emitting deltas. + + Raises: + ValueError: If a string value is received for an int or float type and cannot be converted. + """ + handler_name = handler.fn.__qualname__ + + # Get the function to process the event. + if is_pyleak_enabled(): + console.debug(f"Monitoring leaks for handler: {handler_name}") + fn = functools.partial(monitor_loopblocks(handler.fn), state) + else: + fn = functools.partial(handler.fn, state) + + try: + type_hints = types.get_type_hints(handler.fn) + payload = _transform_event_payload(payload, type_hints) + except Exception as ex: + # No transformation was possible, continue with the original payload + console.warn( + f"Error transforming event payload for handler {handler_name}: {ex}" + ) + + # Handle async functions. + if inspect.iscoroutinefunction(fn.func): + events = await fn(**payload) + + # Handle regular functions. + else: + events = fn(**payload) + # Handle async generators. + if inspect.isasyncgen(events): + async for event in events: + await chain_updates(event, root_state=root_state, handler_name=handler_name) + await chain_updates(None, root_state=root_state, handler_name=handler_name) + + # Handle regular generators. + elif inspect.isgenerator(events): + try: + while True: + await chain_updates( + next(events), root_state=root_state, handler_name=handler_name + ) + except StopIteration as si: + # the "return" value of the generator is not available + # in the loop, we must catch StopIteration to access it + if si.value is not None: + await chain_updates( + si.value, root_state=root_state, handler_name=handler_name + ) + await chain_updates(None, root_state=root_state, handler_name=handler_name) + + # Handle regular event chains. + else: + await chain_updates(events, root_state=root_state, handler_name=handler_name) + + +__all__ = ["chain_updates", "process_event"] diff --git a/reflex/ievent/processor/event_processor.py b/reflex/ievent/processor/event_processor.py new file mode 100644 index 00000000000..63b9a4f7c84 --- /dev/null +++ b/reflex/ievent/processor/event_processor.py @@ -0,0 +1,379 @@ +"""Base EventProcessor class for handling backend event queue.""" + +import asyncio +import dataclasses +import time +import traceback +from collections.abc import Callable, Mapping +from contextvars import Context, copy_context +from typing import TYPE_CHECKING, Any, Self + +import rich.markup + +from reflex.app_mixins.middleware import MiddlewareMixin +from reflex.ievent.context import EmitDeltaProtocol, EventContext, event_context +from reflex.ievent.processor import base_state_processor +from reflex.ievent.registry import REGISTERED_HANDLERS, RegisteredEventHandler +from reflex.istate.data import RouterData +from reflex.istate.manager import StateManager +from reflex.istate.proxy import StateProxy +from reflex.utils import console +from reflex.utils.tasks import ensure_task + +if TYPE_CHECKING: + from reflex.app import EventNamespace + from reflex.event import Event, EventSpec + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class EventQueueEntry: + """An entry in the event queue.""" + + event: Event + ctx: EventContext + + +@dataclasses.dataclass(kw_only=True, slots=True) +class EventProcessor: + """Responsible for queuing and processing events.""" + + middleware: MiddlewareMixin | None = None + backend_exception_handler: ( + Callable[[Exception], EventSpec | list[EventSpec] | None] | None + ) = None + + _queue: asyncio.Queue[EventQueueEntry] = dataclasses.field( + default_factory=asyncio.Queue, init=False + ) + _queue_task: asyncio.Task | None = dataclasses.field(default=None, init=False) + _root_context: Context | None = dataclasses.field(default=None, init=False) + _tasks: dict[str, asyncio.Task] = dataclasses.field( + default_factory=dict, init=False + ) + + def configure( + self, + state_manager: StateManager | None = None, + event_namespace: EventNamespace | None = None, + ) -> Self: + """Set up the event processor. + + Before an event processor can be used, it must be configured with a + state manager and optionally an event namespace to communicate with the + frontend. + + Args: + state_manager: The state manager to use for processing events. + event_namespace: The event namespace to use for processing events. + + Returns: + The event processor instance. + """ + from reflex.istate.manager.memory import StateManagerMemory + from reflex.state import StateUpdate + + if self._root_context is not None: + msg = "Event processor is already running" + raise RuntimeError(msg) + + emit_delta: EmitDeltaProtocol | None = None + if event_namespace is not None: + + async def _emit_delta( + token: str, delta: Mapping[str, Mapping[str, Any]] + ) -> None: + """Emit a delta to the frontend. + + Args: + token: The client token to emit the delta to. + delta: The delta to emit, mapping client tokens to variable updates. + """ + await event_namespace.emit_update( + update=StateUpdate(delta=delta, final=True), + token=token, + ) + + emit_delta = _emit_delta + + async def enqueue(token: str, *events: Event) -> None: + """Enqueue an event handler to be executed. + + Args: + token: The client token associated with the event. + events: The events to enqueue. + """ + for event in events: + if event.name.startswith("_"): + # Frontend events that start with "_" are emitted directly. + await event_namespace.emit_update( + update=StateUpdate(events=[event]), + token=token, + ) + else: + # Backend events will be processed by the internal queue. + await self.enqueue(token=token, event=event) + + else: + + async def enqueue(token: str, *events: Event) -> None: + """Enqueue an event handler to be executed. + + Args: + token: The client token associated with the event. + events: The events to enqueue. + """ + for event in events: + await self.enqueue(token=token, event=event) + + if state_manager is None: + # For testing use cases, default to a new in-memory state manager if one is not provided. + state_manager = StateManagerMemory() + + event_context.set( + EventContext( + token="", + parent_txid=None, + state_manager=state_manager, + enqueue_impl=enqueue, + emit_delta_impl=emit_delta, + ), + ) + self._root_context = copy_context() + return self + + async def __aenter__(self) -> "EventProcessor": + """Enter the event processor context manager. + + Returns: + The event processor instance. + """ + self._ensure_queue_task() + return self + + async def __aexit__(self, *exc_info) -> None: + """Exit the event processor context manager and stop the processor.""" + await self.stop() + + async def stop(self): + """Stop the event processor and cancel all running tasks.""" + from reflex.utils import telemetry + + if self._root_context is None: + msg = "Event processor is not running" + raise RuntimeError(msg) + # Cancel the queue processing task. + if self._queue_task is not None: + self._queue_task.cancel() + # Cancel all running event handler tasks. + for task in self._tasks.values(): + task.cancel() + # Warn for any non CancelledError exceptions that were raised in the tasks. + for task in self._tasks.copy().values(): + try: + await task + except asyncio.CancelledError: # noqa: PERF203 + pass + except Exception as ex: + telemetry.send_error(ex, context="backend") + if self.backend_exception_handler is not None: + try: + await self._handle_backend_exception( + ex, ctx=task.get_context().run(event_context.get) + ) + except Exception: + console.error( + rich.markup.escape( + f"Error in backend exception handler for {task.get_name()} during shutdown:\n{traceback.format_exc()}" + ) + ) + else: + return + console.error( + rich.markup.escape( + f"Error in event handler task {task.get_name()} during shutdown:\n{traceback.format_exc()}" + ) + ) + + def _ensure_queue_task(self) -> None: + """Ensure the queue processing task is running.""" + if self._root_context is None: + msg = "Event processor is not running, call .start(...) first." + raise RuntimeError(msg) + ensure_task( + self, + "_queue_task", + self._process_queue, + task_context=self._root_context, + ) + + async def enqueue( + self, token: str, event: Event, ev_ctx: EventContext | None = None + ) -> None: + """Enqueue an event to be processed. + + Args: + token: The client token associated with the event. + event: The event to enqueue. + ev_ctx: The event context to use for this event. + """ + self._ensure_queue_task() + if ev_ctx is None: + try: + ev_ctx = event_context.get().fork(token=token) + except LookupError as le: + if self._root_context is not None: + ev_ctx = self._root_context.run(event_context.get).fork(token=token) + else: + msg = "Event processor is not running, call .start(...) first." + raise RuntimeError(msg) from le + await self._queue.put(EventQueueEntry(event=event, ctx=ev_ctx)) + + async def _process_event_queue_entry( + self, entry: EventQueueEntry, registered_handler: RegisteredEventHandler + ) -> None: + """Process a single event queue entry. + + This function runs in a new task for each event. + + Args: + entry: The event queue entry to process. + registered_handler: The registered handler for the event. + """ + # Set up the event context for this task. + ctx = entry.ctx + event_context.set(ctx) + event = entry.event + router_data = event.router_data or {} + # Get the state for the session exclusively. + async with ctx.state_manager.modify_state_with_links( + entry.event.substate_token, event=entry.event + ) as state: + # TODO: handle "reload" trigger of brand new state instances + + # re-assign only when the value is set and different + if router_data and state.router_data != router_data: + # assignment will recurse into substates and force recalculation of + # dependent ComputedVar (dynamic route variables) + state.router_data = router_data + state.router = RouterData.from_router_data(router_data) + + # Preprocess the event. + if ( + self.middleware is not None + and (update := await self.middleware._preprocess(state, event)) + is not None + ): + # If there was an update, yield it. + if update.delta: + await ctx.emit_delta(update.delta) + if update.events: + await ctx.enqueue(*update.events) + return + + # Get the event's substate. + substate = await state.get_state(event.substate_token.cls) + root_state = state._get_root_state() + + # Process non-background events while holding the lock. + if not registered_handler.handler.is_background: + await base_state_processor.process_event( + handler=registered_handler.handler, + payload=event.payload, + state=substate, + root_state=root_state, + ) + return + # Otherwise drop the state lock and start processing the background task with a proxy state. + await base_state_processor.process_event( + handler=registered_handler.handler, + state=StateProxy(substate), + payload=event.payload, + root_state=root_state, + ) + + async def _process_queue(self): + """Process events from the queue in a task.""" + while True: + entry = await self._queue.get() + try: + try: + registered_handler = REGISTERED_HANDLERS[entry.event.name] + except KeyError as ke: + msg = f"No registered handler found for event: {entry.event.name}" + raise KeyError(msg) from ke + # Create a new task to handle this event. + task = asyncio.create_task( + self._process_event_queue_entry(entry, registered_handler), + name=( + f"reflex_event|{entry.event.name}|{entry.ctx.token}|{time.time()}" + ), + ) + self._tasks[entry.ctx.txid] = task + task.add_done_callback(self._finish_task) + except Exception: + # Log the error and continue processing the next events. + console.error( + rich.markup.escape( + f"Error processing event queue entry for {entry.event} [txid={entry.ctx.txid}]:\n{traceback.format_exc()}" + ) + ) + + async def _handle_backend_exception(self, ex: Exception, ctx: EventContext): + """Handle an exception raised during event processing by calling the backend exception handler if it exists. + + Args: + ex: The exception that was raised. + ctx: The event context for the event that caused the exception. + """ + if self.backend_exception_handler is not None and ( + events := self.backend_exception_handler(ex) + ): + await base_state_processor.chain_updates( + events=events, + handler_name=self.backend_exception_handler.__qualname__, + ) + + def _finish_task(self, task: asyncio.Task): + """Callback for finishing a _process_event_queue_entry task. + + This function is responsible for calling the backend exception handler + if the task raised an exception, and logging any errors that occur + during the process. + + Args: + task: The task that finished. + """ + from reflex.utils import telemetry + + task_ctx = task.get_context().run(event_context.get) + self._tasks.pop(task_ctx.txid, None) + if task.done(): + try: + task.result() + except asyncio.CancelledError: + pass + except Exception as ex: + telemetry.send_error(ex, context="backend") + if ( + not task.get_name().startswith("reflex_backend_exception_handler|") + and self.backend_exception_handler is not None + ): + # Create a new task in the same context to invoke the exception handler. + t = self._tasks[task_ctx.txid] = task.get_context().run( + asyncio.create_task, + self._handle_backend_exception(ex, task_ctx), + name=f"reflex_backend_exception_handler|task=[{task.get_name()}]|{time.time()}", + ) + t.add_done_callback(self._finish_task) + return + console.error( + rich.markup.escape( + f"Error in {task.get_name()} [txid={task_ctx.txid}]:\n{traceback.format_exc()}" + ) + ) + + +__all__ = [ + "EventProcessor", + "EventQueueEntry", +] diff --git a/reflex/ievent/registry.py b/reflex/ievent/registry.py new file mode 100644 index 00000000000..6f714b9e2b9 --- /dev/null +++ b/reflex/ievent/registry.py @@ -0,0 +1,28 @@ +"""A registry for all known event handlers.""" + +import dataclasses +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from reflex.event import EventHandler + from reflex.state import BaseState + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class RegisteredEventHandler: + """A registered event handler, which includes the handler and its full name.""" + + handler: EventHandler + states: tuple[type[BaseState], ...] + + +REGISTERED_HANDLERS: dict[str, RegisteredEventHandler] = {} + + +def register(handler: EventHandler, states: tuple[type[BaseState], ...] = ()) -> None: + """Register an event handler with its full name and associated states.""" + from reflex.utils.format import format_event_handler + + REGISTERED_HANDLERS[format_event_handler(handler)] = RegisteredEventHandler( + handler=handler, states=states + ) diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index e7dd95857d2..e59827be35c 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -388,7 +388,6 @@ async def set_state( base_state = cast(BaseState, state) client_token = token.ident - substate_name = token.cls.get_full_name() if lock_id is not None and client_token not in self._local_leases: time_taken = ( @@ -406,19 +405,11 @@ async def set_state( dedupe=True, ) - # If the substate name on the token doesn't match the instance name, it cannot have a parent. - if ( - base_state.parent_state is not None - and base_state.get_full_name() != substate_name - ): - msg = f"Cannot `set_state` with mismatching token {token} and substate {base_state.get_full_name()}." - raise RuntimeError(msg) - # Recursively set_state on all known substates. tasks = [ asyncio.create_task( self.set_state( - token.with_cls(type(substate)), + token, substate, lock_id=lock_id, **context, @@ -432,7 +423,7 @@ async def set_state( pickle_state = base_state._serialize() if pickle_state: await self.redis.set( - str(token), + str(token.with_cls(type(base_state))), pickle_state, ex=self.token_expiration, ) diff --git a/reflex/state.py b/reflex/state.py index 47edd872d8f..d1dc0297da3 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -7,19 +7,14 @@ import contextlib import copy import dataclasses -import datetime import functools import inspect import pickle import re import sys import time -import uuid -import warnings -from collections.abc import AsyncIterator, Callable, Iterator, Sequence -from enum import Enum +from collections.abc import Callable, Iterator, Mapping, Sequence from hashlib import md5 -from importlib.util import find_spec from types import FunctionType from typing import ( TYPE_CHECKING, @@ -51,9 +46,8 @@ from reflex.istate import HANDLED_PICKLE_ERRORS, debug_failed_pickles from reflex.istate.data import RouterData from reflex.istate.proxy import ImmutableMutableProxy as ImmutableMutableProxy -from reflex.istate.proxy import MutableProxy, StateProxy, is_mutable_type +from reflex.istate.proxy import MutableProxy, is_mutable_type from reflex.istate.storage import ClientStorageBase -from reflex.model import Model from reflex.utils import console, format, prerequisites, types from reflex.utils.exceptions import ( ComputedVarShadowsBaseVarsError, @@ -71,8 +65,7 @@ ) from reflex.utils.exceptions import ImmutableStateError as ImmutableStateError from reflex.utils.exec import is_testing_env -from reflex.utils.monitoring import is_pyleak_enabled, monitor_loopblocks -from reflex.utils.types import _isinstance, is_union, value_inside_optional +from reflex.utils.types import _isinstance from reflex.vars import Field, VarData, field from reflex.vars.base import ( ComputedVar, @@ -88,7 +81,8 @@ from reflex.components.component import Component -Delta = dict[str, Any] +Delta = dict[str, dict[str, Any]] +DeltaMapping = Mapping[str, Mapping[str, Any]] var = computed_var @@ -325,15 +319,6 @@ def _override_base_method(fn: Callable[PARAMS, RETURN]) -> Callable[PARAMS, RETU return fn -_deserializers = { - int: int, - float: float, - datetime.datetime: datetime.datetime.fromisoformat, - datetime.date: datetime.date.fromisoformat, - datetime.time: datetime.time.fromisoformat, - uuid.UUID: uuid.UUID, -} - all_base_state_classes: dict[str, None] = {} CLASS_VAR_NAMES = frozenset({ @@ -518,6 +503,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): Raises: StateValueError: If a substate class shadows another. """ + from reflex.ievent.registry import register from reflex.utils.exceptions import StateValueError super().__init_subclass__(**kwargs) @@ -642,6 +628,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): for name, fn in events.items(): handler = cls._create_event_handler(fn) cls.event_handlers[name] = handler + register(handler, states=(cls,)) setattr(cls, name, handler) # Initialize per-class var dependency tracking. @@ -662,9 +649,12 @@ def _add_event_handler( name: The name of the event handler. fn: The function to call when the event is triggered. """ + from reflex.ievent.registry import register + handler = cls._create_event_handler(fn) cls.event_handlers[name] = handler setattr(cls, name, handler) + register(handler, states=(cls,)) @staticmethod def _copy_fn(fn: Callable) -> Callable: @@ -1177,7 +1167,10 @@ def _create_event_handler( @classmethod def _create_setvar(cls): """Create the setvar method for the state.""" + from reflex.ievent.registry import register + cls.setvar = cls.event_handlers["setvar"] = EventHandlerSetVar(state_cls=cls) + register(cls.setvar, states=(cls,)) @classmethod def _create_setter(cls, name: str, prop: Var): @@ -1711,296 +1704,6 @@ async def get_var_value(self, var: Var[VAR_TYPE]) -> VAR_TYPE: ) return getattr(other_state, var_data.field_name) - def _get_event_handler(self, event: Event | str) -> tuple[BaseState, EventHandler]: - """Get the event handler for the given event. - - Args: - event: The event to get the handler for, or a dotted handler name string. - - - Returns: - The event handler. - - Raises: - ValueError: If the event handler or substate is not found. - """ - # Get the event handler. - name = event.name if isinstance(event, Event) else event - path = name.split(".") - path, name = path[:-1], path[-1] - substate = self.get_substate(path) - if not substate: - msg = "The value of state cannot be None when processing an event." - raise ValueError(msg) - handler = substate.event_handlers[name] - - return substate, handler - - async def _process(self, event: Event) -> AsyncIterator[StateUpdate]: - """Obtain event info and process event. - - Args: - event: The event to process. - - Yields: - The state update after processing the event. - """ - # Get the event handler. - substate, handler = self._get_event_handler(event) - - # For background tasks, proxy the state. - if handler.is_background: - substate = StateProxy(substate) - - # Run the event generator and yield state updates. - async for update in self._process_event( - handler=handler, - state=substate, - payload=event.payload, - ): - yield update - - def _check_valid(self, handler: EventHandler, events: Any) -> Any: - """Check if the events yielded are valid. They must be EventHandlers or EventSpecs. - - Args: - handler: EventHandler. - events: The events to be checked. - - Raises: - TypeError: If any of the events are not valid. - - Returns: - The events as they are if valid. - """ - - def _is_valid_type(events: Any) -> bool: - return isinstance(events, (Event, EventHandler, EventSpec)) - - if events is None or _is_valid_type(events): - return events - - if not (isinstance(events, Sequence) and not isinstance(events, (str, bytes))): - events = [events] - - try: - if all(_is_valid_type(e) for e in events): - return events - except TypeError: - pass - - coroutines = [e for e in events if inspect.iscoroutine(e)] - - for coroutine in coroutines: - coroutine_name = coroutine.__qualname__ - warnings.filterwarnings( - "ignore", message=f"coroutine '{coroutine_name}' was never awaited" - ) - - msg = ( - f"Your handler {handler.fn.__qualname__} must only return/yield: None, Events or other EventHandlers referenced by their class (i.e. using `type(self)` or other class references)." - f" Returned events of types {', '.join(map(str, map(type, events)))!s}." - ) - raise TypeError(msg) - - async def _as_state_update( - self, - handler: EventHandler, - events: EventSpec | list[EventSpec] | None, - final: bool, - ) -> StateUpdate: - """Convert the events to a StateUpdate. - - Fixes the events and checks for validity before converting. - - Args: - handler: The handler where the events originated from. - events: The events to queue with the update. - final: Whether the handler is done processing. - - Returns: - The valid StateUpdate containing the events and final flag. - """ - # get the delta from the root of the state tree - state = self._get_root_state() - - token = self.router.session.client_token - - # Convert valid EventHandler and EventSpec into Event - fixed_events = fix_events(self._check_valid(handler, events), token) - - try: - # Get the delta after processing the event. - delta = await state._get_resolved_delta() - state._clean() - - return StateUpdate( - delta=delta, - events=fixed_events, - final=final if not handler.is_background else None, - ) - except Exception as ex: - state._clean() - - event_specs = ( - prerequisites.get_and_validate_app().app.backend_exception_handler(ex) - ) - - if event_specs is None: - return StateUpdate() - - event_specs_correct_type = cast( - list[EventSpec | EventHandler] | None, - [event_specs] if isinstance(event_specs, EventSpec) else event_specs, - ) - fixed_events = fix_events( - event_specs_correct_type, - token, - router_data=state.router_data, - ) - return StateUpdate( - events=fixed_events, - final=True, - ) - - async def _process_event( - self, - handler: EventHandler, - state: BaseState | StateProxy, - payload: builtins.dict, - ) -> AsyncIterator[StateUpdate]: - """Process event. - - Args: - handler: EventHandler to process. - state: State to process the handler. - payload: The event payload. - - Yields: - StateUpdate object - - Raises: - ValueError: If a string value is received for an int or float type and cannot be converted. - """ - from reflex.utils import telemetry - - # Get the function to process the event. - if is_pyleak_enabled(): - console.debug(f"Monitoring leaks for handler: {handler.fn.__qualname__}") - fn = functools.partial(monitor_loopblocks(handler.fn), state) - else: - fn = functools.partial(handler.fn, state) - - try: - type_hints = types.get_type_hints(handler.fn) - except Exception: - type_hints = {} - - for arg, value in list(payload.items()): - hinted_args = type_hints.get(arg, Any) - if hinted_args is Any: - continue - if is_union(hinted_args): - if value is None: - continue - hinted_args = value_inside_optional(hinted_args) - if ( - isinstance(value, dict) - and isinstance(hinted_args, type) - and not types.is_generic_alias(hinted_args) # py3.10 - ): - if issubclass(hinted_args, Model): - # Remove non-fields from the payload - payload[arg] = hinted_args(**{ - key: value - for key, value in value.items() - if key in hinted_args.__fields__ - }) - elif dataclasses.is_dataclass(hinted_args): - payload[arg] = hinted_args(**value) - elif find_spec("pydantic"): - from pydantic import BaseModel as BaseModelV2 - from pydantic.v1 import BaseModel as BaseModelV1 - - if issubclass(hinted_args, BaseModelV1): - payload[arg] = hinted_args.parse_obj(value) - elif issubclass(hinted_args, BaseModelV2): - payload[arg] = hinted_args.model_validate(value) - elif isinstance(value, list) and (hinted_args is set or hinted_args is set): - payload[arg] = set(value) - elif isinstance(value, list) and ( - hinted_args is tuple or hinted_args is tuple - ): - payload[arg] = tuple(value) - elif isinstance(hinted_args, type) and issubclass(hinted_args, Enum): - try: - payload[arg] = hinted_args(value) - except ValueError: - msg = f"Received an invalid enum value ({value}) for {arg} of type {hinted_args}" - raise ValueError(msg) from None - elif ( - isinstance(value, str) - and (deserializer := _deserializers.get(hinted_args)) is not None - ): - try: - payload[arg] = deserializer(value) - except ValueError: - msg = f"Received a string value ({value}) for {arg} but expected a {hinted_args}" - raise ValueError(msg) from None - else: - console.warn( - f"Received a string value ({value}) for {arg} but expected a {hinted_args}. A simple conversion was successful." - ) - - # Wrap the function in a try/except block. - try: - # Handle async functions. - if inspect.iscoroutinefunction(fn.func): - events = await fn(**payload) - - # Handle regular functions. - else: - events = fn(**payload) - # Handle async generators. - if inspect.isasyncgen(events): - async for event in events: - yield await state._as_state_update(handler, event, final=False) - yield await state._as_state_update(handler, events=None, final=True) - - # Handle regular generators. - elif inspect.isgenerator(events): - try: - while True: - yield await state._as_state_update( - handler, next(events), final=False - ) - except StopIteration as si: - # the "return" value of the generator is not available - # in the loop, we must catch StopIteration to access it - if si.value is not None: - yield await state._as_state_update( - handler, si.value, final=False - ) - yield await state._as_state_update(handler, events=None, final=True) - - # Handle regular event chains. - else: - yield await state._as_state_update(handler, events, final=True) - - # If an error occurs, throw a window alert. - except Exception as ex: - telemetry.send_error(ex, context="backend") - - event_specs = ( - prerequisites.get_and_validate_app().app.backend_exception_handler(ex) - ) - - yield await state._as_state_update( - handler, - event_specs, - final=True, - ) - def _mark_dirty_computed_vars(self) -> None: """Mark ComputedVars that need to be recalculated based on dirty_vars.""" # Append expired computed vars to dirty_vars to trigger recalculation @@ -2506,6 +2209,25 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: return await internal_patch_linked_state(linked_token) return state_instance + @event + async def hydrate(self) -> None: + """Send the full state to the frontend to synchronize it with the backend.""" + from reflex.ievent.context import event_context + + # Clear client storage, to respect clearing cookies + self._reset_client_storage() + + # Mark state as not hydrated (until on_loads are complete) + self.is_hydrated = False + + # Get the initial state if needed. + ctx = event_context.get() + if ctx.emit_delta_impl is not None: + await ctx.emit_delta(delta=await _resolve_delta(self.dict())) + + # since a full dict was captured, clean any dirtiness + self._clean() + @event def set_is_hydrated(self, value: bool) -> None: """Set the hydrated state. @@ -2802,7 +2524,7 @@ class StateUpdate: """A state update sent to the frontend.""" # The state delta. - delta: Delta = dataclasses.field(default_factory=dict) + delta: DeltaMapping = dataclasses.field(default_factory=dict) # Events to be added to the event queue. events: list[Event] = dataclasses.field(default_factory=list) diff --git a/reflex/utils/format.py b/reflex/utils/format.py index 3093e455d55..b283ea31443 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -461,9 +461,9 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: # Get the function name name = parts[-1] - from reflex.state import State + from reflex.state import BaseState - if state_full_name == FRONTEND_EVENT_STATE and name not in State.__dict__: + if state_full_name == FRONTEND_EVENT_STATE and name not in BaseState.__dict__: return ("", to_snake_case(handler.fn.__qualname__)) return (state_full_name, name) diff --git a/reflex/utils/serializers.py b/reflex/utils/serializers.py index 9693601ec5d..f2b56a00d19 100644 --- a/reflex/utils/serializers.py +++ b/reflex/utils/serializers.py @@ -8,6 +8,7 @@ import functools import inspect import json +import uuid import warnings from collections.abc import Callable, Mapping, Sequence from datetime import date, datetime, time, timedelta @@ -35,6 +36,16 @@ SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer) +deserializers = { + int: int, + float: float, + datetime: datetime.fromisoformat, + date: date.fromisoformat, + time: time.fromisoformat, + uuid.UUID: uuid.UUID, +} + + @overload def serializer( fn: None = None, diff --git a/reflex/utils/tasks.py b/reflex/utils/tasks.py index ec51d94af22..a3b8a455a10 100644 --- a/reflex/utils/tasks.py +++ b/reflex/utils/tasks.py @@ -3,6 +3,7 @@ import asyncio import time from collections.abc import Callable, Coroutine +from contextvars import Context from typing import Any from reflex.utils import console @@ -64,6 +65,7 @@ def ensure_task( exception_delay: float = 1.0, exception_limit: int = 5, exception_limit_window: float = 60.0, + task_context: Context | None = None, **kwargs: Any, ) -> asyncio.Task: """Ensure that a task is running for the given coroutine function. @@ -78,6 +80,7 @@ def ensure_task( exception_delay: The delay between retries when an exception is suppressed. exception_limit: The maximum number of suppressed exceptions within the limit window before raising. exception_limit_window: The time window in seconds for counting suppressed exceptions. + task_context: The context to use for the task. *args: The arguments to pass to the coroutine function. **kwargs: The keyword arguments to pass to the coroutine function. @@ -93,17 +96,20 @@ def ensure_task( task = getattr(owner, task_attribute, None) if task is None or task.done(): asyncio.get_running_loop() # Ensure we're in an event loop. - task = asyncio.create_task( - _run_forever( - coro_function, - *args, - suppress_exceptions=suppress_exceptions, - exception_delay=exception_delay, - exception_limit=exception_limit, - exception_limit_window=exception_limit_window, - **kwargs, - ), - name=f"reflex_ensure_task|{type(owner).__name__}.{task_attribute}={coro_function.__name__}|{time.time()}", + rf_coro = _run_forever( + coro_function, + *args, + suppress_exceptions=suppress_exceptions, + exception_delay=exception_delay, + exception_limit=exception_limit, + exception_limit_window=exception_limit_window, + **kwargs, ) + task_name = f"reflex_ensure_task|{type(owner).__name__}.{task_attribute}={coro_function.__name__}|{time.time()}" + if task_context is not None: + # Run the task in the given context (not needed after Python 3.11+ which supports passing context to create_task directly). + task = task_context.run(asyncio.create_task, rf_coro, name=task_name) + else: + task = asyncio.create_task(rf_coro, name=task_name) setattr(owner, task_attribute, task) return task diff --git a/tests/integration/test_background_task.py b/tests/integration/test_background_task.py index 1e05475b7d0..283dbe4cd12 100644 --- a/tests/integration/test_background_task.py +++ b/tests/integration/test_background_task.py @@ -335,7 +335,7 @@ def test_background_task( AppHarness.expect(lambda: counter_async_cv.text == "620", timeout=40) # all tasks should have exited and cleaned up AppHarness.expect( - lambda: not background_task.app_instance._background_tasks # pyright: ignore [reportOptionalMemberAccess] + lambda: not background_task.app_instance.event_processor._tasks # pyright: ignore [reportOptionalMemberAccess] ) From b8036bd3f599d15914651b7a554c820d002802a1 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 23 Mar 2026 18:40:11 -0700 Subject: [PATCH 07/81] Overhaul EventProcessor lifecycle * No tasks start until `.start()` is called * add graceful shutdown timeout to allow tasks to finish before cancellation * use more keyword only parameters * move BaseState-specific processing to new BaseStateEventProcessor subclass * add test fixtures for `mock_event_processor` that can process simple registered events --- pyi_hashes.json | 2 +- reflex/app.py | 6 +- reflex/event.py | 62 ++- reflex/ievent/context.py | 40 +- reflex/ievent/processor/__init__.py | 6 +- .../ievent/processor/base_state_processor.py | 105 ++++- reflex/ievent/processor/event_processor.py | 422 +++++++++++------- reflex/ievent/registry.py | 15 +- reflex/state.py | 11 +- tests/units/conftest.py | 159 ++++++- tests/units/test_state.py | 31 +- 11 files changed, 642 insertions(+), 217 deletions(-) diff --git a/pyi_hashes.json b/pyi_hashes.json index ec9bbff8850..3513d09bf41 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -1,5 +1,5 @@ { - "reflex/__init__.pyi": "0a3ae880e256b9fd3b960e12a2cb51a7", + "reflex/__init__.pyi": "7ac0e0be65ce8822360da15124d463d5", "reflex/components/__init__.pyi": "ac05995852baa81062ba3d18fbc489fb", "reflex/components/base/__init__.pyi": "16e47bf19e0d62835a605baa3d039c5a", "reflex/components/base/app_wrap.pyi": "22e94feaa9fe675bcae51c412f5b67f1", diff --git a/reflex/app.py b/reflex/app.py index d9293512050..4d5afb40482 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -90,7 +90,7 @@ IndividualEventType, noop, ) -from reflex.ievent.processor import EventProcessor +from reflex.ievent.processor import BaseStateEventProcessor, EventProcessor from reflex.istate.manager.token import BaseStateToken from reflex.page import DECORATED_PAGES from reflex.route import ( @@ -624,7 +624,7 @@ async def modified_send(message: Message): @contextlib.asynccontextmanager async def _setup_event_processor(self) -> AsyncIterator[None]: # Create the event processor. - self._event_processor = EventProcessor( + self._event_processor = BaseStateEventProcessor( middleware=self, backend_exception_handler=self.backend_exception_handler ) async with self._event_processor.configure( @@ -2190,7 +2190,7 @@ async def on_event(self, sid: str, data: Any): if (path := router_data.get(constants.RouteVar.PATH)) else "404" ).removeprefix("/") - await self.app.event_processor.enqueue(token=event.token, event=event) + await self.app.event_processor.enqueue(event.token, event) async def on_ping(self, sid: str): """Event for testing the API endpoint. diff --git a/reflex/event.py b/reflex/event.py index 5dab0461e47..eb66e2652dd 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -66,13 +66,13 @@ class Event: """An event that describes any state change in the app.""" - # The token to specify the client that the event is for. + # The token to specify the client that the event is for (TODO: remove). token: str # The event name. name: str - # The routing data where event occurred + # The routing data where event occurred (TODO: remove). router_data: dict[str, Any] = dataclasses.field(default_factory=dict) # The event payload. @@ -99,6 +99,58 @@ def substate_token(self) -> BaseStateToken: cls=root_state.get_class_substate(tuple(substate.split("."))), ) + @classmethod + def from_event_type( + cls, events: "IndividualEventType | list[IndividualEventType] | None" + ) -> "list[Event]": + """Create a list of Events from event-like objects. + + Args: + events: The event-like objects to create Events from. + + Returns: + A list of Events created from the event-like objects. + """ + # If the event handler returns nothing, return an empty list. + if events is None: + return [] + + # If the handler returns a single event, wrap it in a list. + if not isinstance(events, list): + events = [events] + + # Fix the events created by the handler. + out = [] + for e in events: + if callable(e) and getattr(e, "__name__", "") == "": + # A lambda was returned, assume the user wants to call it with no args. + e = e() + if isinstance(e, Event): + # If the event is already an event, append it to the list. + out.append(e) + continue + # Otherwise, create an event from the event spec. + if isinstance(e, EventHandler): + e = e() + if not isinstance(e, EventSpec): + msg = f"Unexpected event type, {type(e)}." + raise ValueError(msg) + name = format.format_event_handler(e.handler) + # TODO: allow real python types to be passed through the backend queue. + payload = {k._js_expr: v._decode() for k, v in e.args} + + # Create an event and append it to the list. + out.append( + Event( + token="none", + name=name, + payload=payload, + router_data={}, + ) + ) + + return out + _EVENT_FIELDS: set[str] = {f.name for f in dataclasses.fields(Event)} @@ -266,7 +318,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> "EventSpec": from reflex.utils.exceptions import EventHandlerTypeError # Get the function args. - fn_args = list(self._parameters)[1:] + if self.state_full_name: + # Skip the `self` arg for state-bound event handlers. + fn_args = list(self._parameters)[1:] + else: + fn_args = list(self._parameters) if not isinstance( repeated_arg := next( diff --git a/reflex/ievent/context.py b/reflex/ievent/context.py index cc329abfcac..343335993c2 100644 --- a/reflex/ievent/context.py +++ b/reflex/ievent/context.py @@ -39,6 +39,19 @@ async def __call__(self, token: str, *events: Event) -> None: ... +class EmitEventProtocol(Protocol): + """Protocol for the emit_event function in the event context.""" + + async def __call__(self, token: str, *events: Event) -> None: + """Emit an event to be processed immediately. + + Args: + token: The client token associated with the event. + events: The events to emit. + """ + ... + + class EmitDeltaProtocol(Protocol): """Protocol for the emit_delta function in the event context.""" @@ -64,18 +77,25 @@ class EventContext: token: str # Manages persistence of state across events. - state_manager: StateManager + state_manager: StateManager = dataclasses.field(repr=False) # Function responsible for enqueuing an event handler to be executed. - enqueue_impl: EnqueueProtocol + enqueue_impl: EnqueueProtocol = dataclasses.field(repr=False) # Each event is associated with a top-level transaction id. txid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex[:12]) # The txid of another EventContext that enqueued this context's event. parent_txid: str | None = None - emit_delta_impl: EmitDeltaProtocol | None = None - cached_states: dict[type, Any] = dataclasses.field(default_factory=dict, init=False) + emit_delta_impl: EmitDeltaProtocol | None = dataclasses.field( + default=None, repr=False + ) + emit_event_impl: EmitEventProtocol | None = dataclasses.field( + default=None, repr=False + ) + cached_states: dict[type, Any] = dataclasses.field( + default_factory=dict, init=False, repr=False + ) def fork(self, token: str | None = None) -> "EventContext": """Return a new EventContext with the specified fields replaced. @@ -92,6 +112,7 @@ def fork(self, token: str | None = None) -> "EventContext": state_manager=self.state_manager, enqueue_impl=self.enqueue_impl, emit_delta_impl=self.emit_delta_impl, + emit_event_impl=self.emit_event_impl, ) async def emit_delta(self, delta: Mapping[str, Mapping[str, Any]]) -> None: @@ -103,6 +124,17 @@ async def emit_delta(self, delta: Mapping[str, Mapping[str, Any]]) -> None: if self.emit_delta_impl is not None: await self.emit_delta_impl(self.token, delta) + async def emit_event(self, *events: Event) -> None: + """Emit an event to be processed on the frontend. + + If no such handler exists, the event will not be processed. + + Args: + events: The events to emit. + """ + if self.emit_event_impl is not None: + await self.emit_event_impl(self.token, *events) + async def enqueue(self, *event: Event) -> None: """Enqueue an event handler to be executed. diff --git a/reflex/ievent/processor/__init__.py b/reflex/ievent/processor/__init__.py index 35a4c2182ba..fe905a50a81 100644 --- a/reflex/ievent/processor/__init__.py +++ b/reflex/ievent/processor/__init__.py @@ -1,10 +1,10 @@ """Procedures for handling events.""" -from reflex.ievent.processor import base_state_processor -from reflex.ievent.processor.event_processor import EventProcessor, EventQueueEntry +from reflex.ievent.processor.event_processor import EventProcessor, EventQueueEntry # noqa: I001 +from reflex.ievent.processor.base_state_processor import BaseStateEventProcessor __all__ = [ + "BaseStateEventProcessor", "EventProcessor", "EventQueueEntry", - "base_state_processor", ] diff --git a/reflex/ievent/processor/base_state_processor.py b/reflex/ievent/processor/base_state_processor.py index d8596c025db..8bc39857a13 100644 --- a/reflex/ievent/processor/base_state_processor.py +++ b/reflex/ievent/processor/base_state_processor.py @@ -10,6 +10,12 @@ from typing import TYPE_CHECKING, Any from reflex.ievent.context import event_context +from reflex.ievent.processor import EventProcessor +from reflex.ievent.processor.event_processor import ( + EventQueueEntry, + RegisteredEventHandler, +) +from reflex.istate.data import RouterData from reflex.istate.proxy import StateProxy from reflex.utils import console, types from reflex.utils.monitoring import is_pyleak_enabled, monitor_loopblocks @@ -26,11 +32,11 @@ def _check_valid_yield(events: Any, handler_name: str = "unknown") -> Any: events: The events to be checked. handler_name: The name of the handler that yielded the events, used for error messages. - Raises: - TypeError: If any of the events are not valid. - Returns: The events as they are if valid. + + Raises: + TypeError: If any of the events are not valid. """ from reflex.event import Event, EventHandler, EventSpec @@ -177,7 +183,10 @@ async def chain_updates( if fixed_events := fix_events( _check_valid_yield(events, handler_name=handler_name), token ): - await ctx.enqueue(*fixed_events) + # Frontend events. + await ctx.emit_event(*(e for e in fixed_events if e.name.startswith("_"))) + # Backend events. + await ctx.enqueue(*(e for e in fixed_events if not e.name.startswith("_"))) if root_state is not None: # Get the delta after processing the event. @@ -258,4 +267,90 @@ async def process_event( await chain_updates(events, root_state=root_state, handler_name=handler_name) -__all__ = ["chain_updates", "process_event"] +class BaseStateEventProcessor(EventProcessor): + """Event processor for BaseState-derived states. + + This processor is used to process events for BaseState-derived states, and + is responsible for maintaining the event queue and emitting deltas to the + frontend. + """ + + async def _process_event_queue_entry( + self, *, entry: EventQueueEntry, registered_handler: RegisteredEventHandler + ) -> None: + """Process a single event queue entry. + + This function runs in a new task for each event. + + Args: + entry: The event queue entry to process. + registered_handler: The registered handler for the event. + """ + # Set up the event context for this task. + ctx = entry.ctx + event_context.set(ctx) + event = entry.event + router_data = event.router_data or {} + # Get the state for the session exclusively. + async with ctx.state_manager.modify_state_with_links( + entry.event.substate_token, event=entry.event + ) as state: + # TODO: handle "reload" trigger of brand new state instances + + # re-assign only when the value is set and different + if router_data and state.router_data != router_data: + # assignment will recurse into substates and force recalculation of + # dependent ComputedVar (dynamic route variables) + state.router_data = router_data + state.router = RouterData.from_router_data(router_data) + + # Preprocess the event. + if ( + self.middleware is not None + and (update := await self.middleware._preprocess(state, event)) + is not None + ): + # If there was an update, yield it. + if update.delta: + await ctx.emit_delta(update.delta) + if update.events: + await ctx.enqueue(*update.events) + return + + # Get the event's substate. + substate = await state.get_state(event.substate_token.cls) + root_state = state._get_root_state() + + # Process non-background events while holding the lock. + if not registered_handler.handler.is_background: + await process_event( + handler=registered_handler.handler, + payload=event.payload, + state=substate, + root_state=root_state, + ) + return + # Otherwise drop the state lock and start processing the background task with a proxy state. + await process_event( + handler=registered_handler.handler, + state=StateProxy(substate), + payload=event.payload, + root_state=root_state, + ) + + async def _handle_backend_exception(self, ex: Exception): + """Handle an exception raised during event processing by calling the backend exception handler if it exists. + + Args: + ex: The exception that was raised. + """ + if self.backend_exception_handler is not None and ( + events := self.backend_exception_handler(ex) + ): + await chain_updates( + events=events, + handler_name=self.backend_exception_handler.__qualname__, + ) + + +__all__ = ["BaseStateEventProcessor", "chain_updates", "process_event"] diff --git a/reflex/ievent/processor/event_processor.py b/reflex/ievent/processor/event_processor.py index 63b9a4f7c84..ad243a2ae86 100644 --- a/reflex/ievent/processor/event_processor.py +++ b/reflex/ievent/processor/event_processor.py @@ -1,30 +1,69 @@ """Base EventProcessor class for handling backend event queue.""" import asyncio +import contextlib import dataclasses +import inspect import time import traceback from collections.abc import Callable, Mapping -from contextvars import Context, copy_context +from contextvars import Token, copy_context from typing import TYPE_CHECKING, Any, Self import rich.markup from reflex.app_mixins.middleware import MiddlewareMixin -from reflex.ievent.context import EmitDeltaProtocol, EventContext, event_context -from reflex.ievent.processor import base_state_processor +from reflex.ievent.context import EventContext, event_context from reflex.ievent.registry import REGISTERED_HANDLERS, RegisteredEventHandler -from reflex.istate.data import RouterData from reflex.istate.manager import StateManager -from reflex.istate.proxy import StateProxy from reflex.utils import console -from reflex.utils.tasks import ensure_task if TYPE_CHECKING: from reflex.app import EventNamespace from reflex.event import Event, EventSpec +@dataclasses.dataclass(kw_only=True, slots=True) +class DrainTimeoutManager: + """Manages an optional combined timeout over multiple calls. + + Each time the context is entered, yield the remaining time until the + overall timeout is reached, or 0 if the timeout has already been reached. + This allows multiple operations to share a single overall timeout, even if + they are not executed sequentially. + """ + + drain_deadline: float | None = None + + @classmethod + def with_timeout(cls, timeout: float | None) -> "DrainTimeoutManager": + """Create a DrainTimeoutManager with a specified timeout. + + Args: + timeout: The overall amount of time in seconds to wait. + + Returns: + A DrainTimeoutManager instance with the drain deadline set. + """ + if timeout is None: + return cls(drain_deadline=None) + return cls(drain_deadline=time.time() + timeout) + + def __enter__(self) -> float: + """Enter the context and yield the remaining time. + + Returns: + The remaining time in seconds until the overall timeout is reached, or 0 if the timeout + has already been reached. + """ + if self.drain_deadline is not None: + return max(0, self.drain_deadline - time.time()) + return 0 + + def __exit__(self, *exc_info) -> None: + """Exit the context. No cleanup necessary.""" + + @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) class EventQueueEntry: """An entry in the event queue.""" @@ -35,24 +74,41 @@ class EventQueueEntry: @dataclasses.dataclass(kw_only=True, slots=True) class EventProcessor: - """Responsible for queuing and processing events.""" + """Responsible for queuing and processing events. + + Attributes: + middleware: An optional middleware mixin to apply to all events processed by this processor. + backend_exception_handler: An optional function to handle exceptions raised during event processing. The function should take an Exception as input and return an EventSpec or list of EventSpecs to be emitted in response, or None to not emit any events. + graceful_shutdown_timeout: An optional amount of time in seconds to wait for the queue to drain before forcefully cancelling tasks when stopping the processor. If None, the processor will not wait and will cancel tasks immediately. + + _queue: The asyncio queue for events to be processed. + _queue_task: The task responsible for processing the event queue. + _root_context: The root event context to use for events enqueued without an explicit context. + _attached_root_context_token: The context variable token for the attached root context, used to reset the context variable on shutdown. + _tasks: A mapping of active transaction ids to their corresponding event handler tasks, used for tracking and cancellation on shutdown. + """ middleware: MiddlewareMixin | None = None backend_exception_handler: ( Callable[[Exception], EventSpec | list[EventSpec] | None] | None ) = None + graceful_shutdown_timeout: float | None = None - _queue: asyncio.Queue[EventQueueEntry] = dataclasses.field( - default_factory=asyncio.Queue, init=False + _queue: asyncio.Queue[EventQueueEntry] | None = dataclasses.field( + default=None, init=False ) _queue_task: asyncio.Task | None = dataclasses.field(default=None, init=False) - _root_context: Context | None = dataclasses.field(default=None, init=False) + _root_context: EventContext | None = dataclasses.field(default=None, init=False) + _attached_root_context_token: Token | None = dataclasses.field( + default=None, init=False + ) _tasks: dict[str, asyncio.Task] = dataclasses.field( default_factory=dict, init=False ) def configure( self, + *, state_manager: StateManager | None = None, event_namespace: EventNamespace | None = None, ) -> Self: @@ -73,13 +129,15 @@ def configure( from reflex.state import StateUpdate if self._root_context is not None: - msg = "Event processor is already running" + msg = ( + "Event processor is already configured, call .configure(...) only once." + ) raise RuntimeError(msg) - emit_delta: EmitDeltaProtocol | None = None + emit_delta_impl = emit_event_impl = None if event_namespace is not None: - async def _emit_delta( + async def emit_delta( token: str, delta: Mapping[str, Mapping[str, Any]] ) -> None: """Emit a delta to the frontend. @@ -93,52 +151,36 @@ async def _emit_delta( token=token, ) - emit_delta = _emit_delta + emit_delta_impl = emit_delta - async def enqueue(token: str, *events: Event) -> None: - """Enqueue an event handler to be executed. + async def emit_event(token: str, *events: Event) -> None: + """Emit an event to be processed on the frontend. - Args: - token: The client token associated with the event. - events: The events to enqueue. - """ - for event in events: - if event.name.startswith("_"): - # Frontend events that start with "_" are emitted directly. - await event_namespace.emit_update( - update=StateUpdate(events=[event]), - token=token, - ) - else: - # Backend events will be processed by the internal queue. - await self.enqueue(token=token, event=event) - - else: - - async def enqueue(token: str, *events: Event) -> None: - """Enqueue an event handler to be executed. + If no such handler exists, the event will not be processed. Args: - token: The client token associated with the event. - events: The events to enqueue. + token: The client token to emit the event to. + events: The events to emit. """ - for event in events: - await self.enqueue(token=token, event=event) + await event_namespace.emit_update( + update=StateUpdate(events=list(events), final=True), + token=token, + ) + + emit_event_impl = emit_event if state_manager is None: # For testing use cases, default to a new in-memory state manager if one is not provided. state_manager = StateManagerMemory() - event_context.set( - EventContext( - token="", - parent_txid=None, - state_manager=state_manager, - enqueue_impl=enqueue, - emit_delta_impl=emit_delta, - ), + self._root_context = EventContext( + token="", + parent_txid=None, + state_manager=state_manager, + enqueue_impl=self.enqueue, + emit_delta_impl=emit_delta_impl, + emit_event_impl=emit_event_impl, ) - self._root_context = copy_context() return self async def __aenter__(self) -> "EventProcessor": @@ -147,28 +189,49 @@ async def __aenter__(self) -> "EventProcessor": Returns: The event processor instance. """ - self._ensure_queue_task() + await self.start() return self async def __aexit__(self, *exc_info) -> None: """Exit the event processor context manager and stop the processor.""" await self.stop() - async def stop(self): - """Stop the event processor and cancel all running tasks.""" - from reflex.utils import telemetry - + async def start(self) -> None: + """Start the event processor.""" if self._root_context is None: - msg = "Event processor is not running" + msg = "Event processor is not configured, call .configure(...) first." raise RuntimeError(msg) - # Cancel the queue processing task. - if self._queue_task is not None: - self._queue_task.cancel() - # Cancel all running event handler tasks. - for task in self._tasks.values(): + if self._queue is not None: + msg = "Event processor is already started" + raise RuntimeError(msg) + if self._attached_root_context_token is not None: + msg = "EventProcessor context cannot be nested." + raise RuntimeError(msg) + self._attached_root_context_token = event_context.set(self._root_context) + self._queue = asyncio.Queue() + self._ensure_queue_task() + + async def _stop_tasks(self, timeout: float | None = None) -> None: + """Stop all running tasks with an optional drain time. + + Args: + timeout: An optional amount of time in seconds to wait for the + queue to drain before cancelling tasks. If None, the processor will + not wait and will cancel tasks immediately. + """ + from reflex.utils import telemetry + + if timeout is not None and self._tasks: + with contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for( + asyncio.gather(*self._tasks.values()), + timeout=timeout, + ) + # Cancel all outstanding event handler tasks. + for task in (outstanding_tasks := list(self._tasks.values())): task.cancel() - # Warn for any non CancelledError exceptions that were raised in the tasks. - for task in self._tasks.copy().values(): + # Wait for all tasks to finish and log any exceptions that were raised. + for task in outstanding_tasks: try: await task except asyncio.CancelledError: # noqa: PERF203 @@ -177,8 +240,9 @@ async def stop(self): telemetry.send_error(ex, context="backend") if self.backend_exception_handler is not None: try: - await self._handle_backend_exception( - ex, ctx=task.get_context().run(event_context.get) + await task.get_context().run( + self._handle_backend_exception, + ex, ) except Exception: console.error( @@ -187,54 +251,135 @@ async def stop(self): ) ) else: - return + continue console.error( rich.markup.escape( f"Error in event handler task {task.get_name()} during shutdown:\n{traceback.format_exc()}" ) ) - def _ensure_queue_task(self) -> None: - """Ensure the queue processing task is running.""" + async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: + """Stop the event processor and cancel all running tasks. + + Args: + graceful_shutdown_timeout: An optional amount of time in seconds to wait for the + queue to drain before cancelling tasks. If None, the processor will + not wait and will cancel tasks immediately. + """ + from reflex.utils import telemetry + + if self._attached_root_context_token is not None: + event_context.reset(self._attached_root_context_token) + self._attached_root_context_token = None + # Optional grace period for tasks to finish before cancellation. + if graceful_shutdown_timeout is None: + graceful_shutdown_timeout = self.graceful_shutdown_timeout + drain_timeout = DrainTimeoutManager.with_timeout(graceful_shutdown_timeout) + with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): + if remaining_time > 0: + # Drain the queue first of any pending events. + await self.join(timeout=remaining_time) + # Stopping tasks may raise exceptions and chain additional deltas so the queue remains open. + with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): + await self._stop_tasks(timeout=remaining_time) + # Cancel queue processing now that all tasks have been cancelled. + if self._queue is not None: + self._queue.shutdown() + with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): + if remaining_time > 0: + await self.join(timeout=remaining_time) + with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): + # Stop all tasks again now that the queue is shut down, no additional events can be queued. + await self._stop_tasks(timeout=remaining_time) + self._queue = None + if self._queue_task is not None: + self._queue_task.cancel() + try: + await self._queue_task + except (asyncio.CancelledError, RuntimeError): + pass + except Exception as ex: + telemetry.send_error(ex, context="backend") + console.error( + rich.markup.escape( + f"Error in event processor queue task during shutdown:\n{traceback.format_exc()}" + ) + ) + self._queue_task = None + + async def join(self, timeout: float | None = None) -> None: + """Wait for the event processor to finish processing all events in the queue. + + Args: + timeout: An optional amount of time in seconds to wait for the queue to + drain before returning. If None, this method will wait indefinitely + until the queue is fully drained. + """ + if self._queue is not None: + await asyncio.wait_for(self._queue.join(), timeout=timeout) + + def _ensure_queue_task(self) -> asyncio.Queue[EventQueueEntry]: + """Ensure the queue processing task is running. + + Returns: + The event queue. + + Raises: + RuntimeError: If the event processor is not running and no queue is provided. + """ if self._root_context is None: + msg = "Event processor is not configured, call .configure(...) first." + raise RuntimeError(msg) + if self._queue is None: msg = "Event processor is not running, call .start(...) first." raise RuntimeError(msg) - ensure_task( - self, - "_queue_task", - self._process_queue, - task_context=self._root_context, - ) + if self._queue_task is None: + task_context = copy_context() + task_context.run(event_context.set, self._root_context) + self._queue_task = task_context.run( + asyncio.create_task, + self._process_queue(), + name=f"reflex_event_queue_processor|{time.time()}", + ) + return self._queue async def enqueue( - self, token: str, event: Event, ev_ctx: EventContext | None = None + self, token: str, *events: Event, ev_ctx: EventContext | None = None ) -> None: """Enqueue an event to be processed. Args: token: The client token associated with the event. - event: The event to enqueue. - ev_ctx: The event context to use for this event. + events: Remaining positional args are events to be enqueued. + ev_ctx: The event context to use for these events. """ - self._ensure_queue_task() if ev_ctx is None: try: ev_ctx = event_context.get().fork(token=token) except LookupError as le: if self._root_context is not None: - ev_ctx = self._root_context.run(event_context.get).fork(token=token) + ev_ctx = self._root_context.fork(token=token) else: msg = "Event processor is not running, call .start(...) first." raise RuntimeError(msg) from le - await self._queue.put(EventQueueEntry(event=event, ctx=ev_ctx)) + queue = self._ensure_queue_task() + for event in events: + await queue.put(EventQueueEntry(event=event, ctx=ev_ctx)) async def _process_event_queue_entry( - self, entry: EventQueueEntry, registered_handler: RegisteredEventHandler + self, *, entry: EventQueueEntry, registered_handler: RegisteredEventHandler ) -> None: """Process a single event queue entry. This function runs in a new task for each event. + The default implementation just calls the registered handler function + with the event payload as keyword arguments. + + Subclasses, such as BaseStateEventProcessor, can override this function + to provide additional functionality such as state management, event + chaining, and delta calculation. + Args: entry: The event queue entry to process. registered_handler: The registered handler for the event. @@ -243,95 +388,56 @@ async def _process_event_queue_entry( ctx = entry.ctx event_context.set(ctx) event = entry.event - router_data = event.router_data or {} - # Get the state for the session exclusively. - async with ctx.state_manager.modify_state_with_links( - entry.event.substate_token, event=entry.event - ) as state: - # TODO: handle "reload" trigger of brand new state instances - - # re-assign only when the value is set and different - if router_data and state.router_data != router_data: - # assignment will recurse into substates and force recalculation of - # dependent ComputedVar (dynamic route variables) - state.router_data = router_data - state.router = RouterData.from_router_data(router_data) - - # Preprocess the event. - if ( - self.middleware is not None - and (update := await self.middleware._preprocess(state, event)) - is not None - ): - # If there was an update, yield it. - if update.delta: - await ctx.emit_delta(update.delta) - if update.events: - await ctx.enqueue(*update.events) - return - - # Get the event's substate. - substate = await state.get_state(event.substate_token.cls) - root_state = state._get_root_state() - - # Process non-background events while holding the lock. - if not registered_handler.handler.is_background: - await base_state_processor.process_event( - handler=registered_handler.handler, - payload=event.payload, - state=substate, - root_state=root_state, - ) - return - # Otherwise drop the state lock and start processing the background task with a proxy state. - await base_state_processor.process_event( - handler=registered_handler.handler, - state=StateProxy(substate), - payload=event.payload, - root_state=root_state, - ) + result = registered_handler.handler.fn(**event.payload) + if inspect.isawaitable(result): + await result async def _process_queue(self): """Process events from the queue in a task.""" - while True: - entry = await self._queue.get() - try: + if (queue := self._queue) is None: + msg = "Event processor is not running, call .start(...) first." + raise RuntimeError(msg) + with contextlib.suppress(asyncio.QueueShutDown): + while True: + entry = await queue.get() try: - registered_handler = REGISTERED_HANDLERS[entry.event.name] - except KeyError as ke: - msg = f"No registered handler found for event: {entry.event.name}" - raise KeyError(msg) from ke - # Create a new task to handle this event. - task = asyncio.create_task( - self._process_event_queue_entry(entry, registered_handler), - name=( - f"reflex_event|{entry.event.name}|{entry.ctx.token}|{time.time()}" - ), - ) - self._tasks[entry.ctx.txid] = task - task.add_done_callback(self._finish_task) - except Exception: - # Log the error and continue processing the next events. - console.error( - rich.markup.escape( - f"Error processing event queue entry for {entry.event} [txid={entry.ctx.txid}]:\n{traceback.format_exc()}" + try: + registered_handler = REGISTERED_HANDLERS[entry.event.name] + except KeyError as ke: + msg = ( + f"No registered handler found for event: {entry.event.name}" + ) + raise KeyError(msg) from ke + # Create a new task to handle this event. + task = asyncio.create_task( + self._process_event_queue_entry( + entry=entry, registered_handler=registered_handler + ), + name=( + f"reflex_event|{entry.event.name}|{entry.ctx.token}|{time.time()}" + ), ) - ) + self._tasks[entry.ctx.txid] = task + task.add_done_callback(self._finish_task) + except Exception: + # Log the error and continue processing the next events. + console.error( + rich.markup.escape( + f"Error processing event queue entry for {entry.event} [txid={entry.ctx.txid}]:\n{traceback.format_exc()}" + ) + ) + queue.task_done() + if self._queue_task is asyncio.current_task(): + self._queue_task = None - async def _handle_backend_exception(self, ex: Exception, ctx: EventContext): + async def _handle_backend_exception(self, ex: Exception): """Handle an exception raised during event processing by calling the backend exception handler if it exists. Args: ex: The exception that was raised. - ctx: The event context for the event that caused the exception. """ - if self.backend_exception_handler is not None and ( - events := self.backend_exception_handler(ex) - ): - await base_state_processor.chain_updates( - events=events, - handler_name=self.backend_exception_handler.__qualname__, - ) + if self.backend_exception_handler is not None: + self.backend_exception_handler(ex) def _finish_task(self, task: asyncio.Task): """Callback for finishing a _process_event_queue_entry task. @@ -361,7 +467,7 @@ def _finish_task(self, task: asyncio.Task): # Create a new task in the same context to invoke the exception handler. t = self._tasks[task_ctx.txid] = task.get_context().run( asyncio.create_task, - self._handle_backend_exception(ex, task_ctx), + self._handle_backend_exception(ex), name=f"reflex_backend_exception_handler|task=[{task.get_name()}]|{time.time()}", ) t.add_done_callback(self._finish_task) diff --git a/reflex/ievent/registry.py b/reflex/ievent/registry.py index 6f714b9e2b9..b7de5bfe71c 100644 --- a/reflex/ievent/registry.py +++ b/reflex/ievent/registry.py @@ -19,10 +19,21 @@ class RegisteredEventHandler: REGISTERED_HANDLERS: dict[str, RegisteredEventHandler] = {} -def register(handler: EventHandler, states: tuple[type[BaseState], ...] = ()) -> None: - """Register an event handler with its full name and associated states.""" +def register( + handler: EventHandler, states: tuple[type[BaseState], ...] = () +) -> EventHandler: + """Register an event handler with its full name and associated states. + + Args: + handler: The event handler to register. + states: The states associated with the event handler. + + Returns: + The registered event handler. + """ from reflex.utils.format import format_event_handler REGISTERED_HANDLERS[format_event_handler(handler)] = RegisteredEventHandler( handler=handler, states=states ) + return handler diff --git a/reflex/state.py b/reflex/state.py index d1dc0297da3..8129e580575 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -627,8 +627,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): for name, fn in events.items(): handler = cls._create_event_handler(fn) - cls.event_handlers[name] = handler - register(handler, states=(cls,)) + cls.event_handlers[name] = register(handler, states=(cls,)) setattr(cls, name, handler) # Initialize per-class var dependency tracking. @@ -652,9 +651,8 @@ def _add_event_handler( from reflex.ievent.registry import register handler = cls._create_event_handler(fn) - cls.event_handlers[name] = handler + cls.event_handlers[name] = register(handler, states=(cls,)) setattr(cls, name, handler) - register(handler, states=(cls,)) @staticmethod def _copy_fn(fn: Callable) -> Callable: @@ -1169,8 +1167,9 @@ def _create_setvar(cls): """Create the setvar method for the state.""" from reflex.ievent.registry import register - cls.setvar = cls.event_handlers["setvar"] = EventHandlerSetVar(state_cls=cls) - register(cls.setvar, states=(cls,)) + cls.setvar = cls.event_handlers["setvar"] = register( + EventHandlerSetVar(state_cls=cls), states=(cls,) + ) @classmethod def _create_setter(cls, name: str, prop: Var): diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 612d8beaf85..7e6b7185ef1 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -2,16 +2,25 @@ import platform import uuid -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator, Mapping +from typing import Any from unittest import mock import pytest +import pytest_asyncio from reflex.app import App -from reflex.event import EventSpec +from reflex.event import Event, EventSpec +from reflex.ievent.context import EventContext +from reflex.ievent.processor import EventProcessor +from reflex.istate.manager import StateManager +from reflex.istate.manager.disk import StateManagerDisk +from reflex.istate.manager.memory import StateManagerMemory +from reflex.istate.manager.redis import StateManagerRedis from reflex.model import ModelRegistry from reflex.testing import chdir from reflex.utils import prerequisites +from tests.units.mock_redis import mock_redis from .states.upload import SubUploadState, UploadState @@ -209,3 +218,149 @@ def model_registry() -> Generator[type[ModelRegistry], None, None]: """ yield ModelRegistry ModelRegistry._metadata = None + + +@pytest_asyncio.fixture( + loop_scope="function", scope="function", params=["in_process", "disk", "redis"] +) +async def state_manager( + request: pytest.FixtureRequest, mock_root_event_context: EventContext +) -> AsyncGenerator[StateManager, None]: + """Instance of state manager parametrized for redis and in-process. + + Args: + request: pytest request object. + mock_root_event_context: The mock root event context to use for the state manager. + + Yields: + A state manager instance + """ + state_manager = StateManager.create() + if request.param == "redis": + if not isinstance(state_manager, StateManagerRedis): + state_manager = StateManagerRedis(redis=mock_redis()) + elif request.param == "disk": + # explicitly NOT using redis + state_manager = StateManagerDisk() + assert not state_manager._states_locks + else: + state_manager = StateManagerMemory() + assert not state_manager._states_locks + + orig_state_manager = mock_root_event_context.state_manager + object.__setattr__(mock_root_event_context, "state_manager", state_manager) + + yield state_manager + + await state_manager.close() + object.__setattr__(mock_root_event_context, "state_manager", orig_state_manager) + + +@pytest.fixture +def mock_event_processor_obj() -> EventProcessor: + """Create an event processor. + + Returns: + A fresh event processor. + """ + + def handle_backend_exception(ex: Exception) -> None: + raise ex + + return EventProcessor( + backend_exception_handler=handle_backend_exception, graceful_shutdown_timeout=1 + ) + + +@pytest.fixture +def emitted_deltas() -> list[tuple[str, Mapping[str, Mapping[str, Any]]]]: + """Create a list to store emitted deltas. + + Returns: + A list to store emitted deltas. + """ + return [] + + +@pytest.fixture +def emitted_events() -> list[tuple[str, tuple[Event, ...]]]: + """Create a list to store emitted events. + + Returns: + A list to store emitted events. + """ + return [] + + +@pytest.fixture +def mock_root_event_context( + mock_event_processor_obj: EventProcessor, + emitted_deltas: list[tuple[str, Mapping[str, Mapping[str, Any]]]], + emitted_events: list[tuple[str, tuple[Event, ...]]], +) -> EventContext: + """Create a mock event context. + + Args: + mock_event_processor_obj: The mock event processor to use for the context's enqueue implementation. + emitted_deltas: The list to store emitted deltas. + emitted_events: The list to store emitted events. + + Returns: + A mock event context. + """ + + async def emit_delta_impl( # noqa: RUF029 + token: str, delta: Mapping[str, Mapping[str, Any]] + ) -> None: + """Mock emit delta implementation that records emitted deltas. + + Args: + token: The client token to emit the delta to. + delta: The delta to emit. + """ + emitted_deltas.append((token, delta)) + + async def emit_event_impl(token: str, *events: Event) -> None: # noqa: RUF029 + """Mock emit event implementation that records emitted events. + + Args: + token: The client token to emit the events to. + events: The events to emit. + """ + emitted_events.append((token, events)) + + return EventContext( + token="", + state_manager=StateManagerMemory(), + enqueue_impl=mock_event_processor_obj.enqueue, + emit_delta_impl=emit_delta_impl, + emit_event_impl=emit_event_impl, + ) + + +@pytest.fixture +def mock_event_processor( + mock_root_event_context: EventContext, mock_event_processor_obj: EventProcessor +) -> EventProcessor: + """Create an event processor with a mock root context. + + Set the mock context as the task's current context, and set the processor's + root context to the mock context. + + Events can be queued against the processor via `await + mock_event_processor.enqueue(token, *events)`. + + The `state_manager` fixture is used by the `mock_root_event_context` so any + updates will be reflected in the context's state manager, and any deltas or + frontend events can be checked via the context's `emitted_deltas` and + `emitted_events` attributes. + + Args: + mock_root_event_context: The mock event context to use as the root context for the processor. + mock_event_processor_obj: The mock event processor to use for the processor's enqueue implementation. + + Returns: + An un-started event processor with a mock root context. + """ + mock_event_processor_obj._root_context = mock_root_event_context + return mock_event_processor_obj diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 303dfa91fbf..c678ef9c662 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -35,6 +35,7 @@ from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis from reflex.istate.manager.token import BaseStateToken +from reflex.istate.proxy import StateProxy from reflex.state import ( BaseState, ImmutableMutableProxy, @@ -43,7 +44,6 @@ OnLoadInternalState, RouterData, State, - StateProxy, StateUpdate, ) from reflex.testing import chdir @@ -1686,35 +1686,6 @@ def invalid_handler(self): assert "must only return/yield: None, Events or other EventHandlers" in captured.err -@pytest_asyncio.fixture( - loop_scope="function", scope="function", params=["in_process", "disk", "redis"] -) -async def state_manager(request) -> AsyncGenerator[StateManager, None]: - """Instance of state manager parametrized for redis and in-process. - - Args: - request: pytest request object. - - Yields: - A state manager instance - """ - state_manager = StateManager.create() - if request.param == "redis": - if not isinstance(state_manager, StateManagerRedis): - state_manager = StateManagerRedis(redis=mock_redis()) - elif request.param == "disk": - # explicitly NOT using redis - state_manager = StateManagerDisk() - assert not state_manager._states_locks - else: - state_manager = StateManagerMemory() - assert not state_manager._states_locks - - yield state_manager - - await state_manager.close() - - @pytest.fixture def substate_token(state_manager, token) -> BaseStateToken: """A token + substate name for looking up in state manager. From 01befb514947cc8b8dfd3ee6fc6c3b92fef86697 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 25 Mar 2026 09:38:45 -0700 Subject: [PATCH 08/81] Fix test_app to use BaseStateEventProcessor make Event.substate_token no longer work, because we're deprecating `token` as an Event field, so we cannot rely on it under the covers. --- reflex/event.py | 24 +- .../ievent/processor/base_state_processor.py | 19 +- reflex/istate/proxy.py | 3 +- reflex/state.py | 11 +- tests/units/conftest.py | 53 ++- tests/units/test_app.py | 391 +++++++++--------- 6 files changed, 278 insertions(+), 223 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index eb66e2652dd..e8d5ef5b8fe 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -78,6 +78,15 @@ class Event: # The event payload. payload: dict[str, Any] = dataclasses.field(default_factory=dict) + @property + def state_cls(self) -> "type[BaseState]": + """The state class for the event.""" + from reflex.state import all_base_state_classes + + substate_name = self.name.rpartition(".")[0] + + return all_base_state_classes[substate_name] + @property def substate_token(self) -> BaseStateToken: """Get the substate token for the event. @@ -85,19 +94,8 @@ def substate_token(self) -> BaseStateToken: Returns: The substate token. """ - from reflex.istate.manager.token import BaseStateToken - from reflex.state import State - from reflex.utils.prerequisites import get_app - - app = get_app().app - - root_state = State if app._state is None else app._state - - substate = self.name.rpartition(".")[0] - return BaseStateToken( - ident=self.token, - cls=root_state.get_class_substate(tuple(substate.split("."))), - ) + msg = "Event.substate_token should no longer be used." + raise NotImplementedError(msg) @classmethod def from_event_type( diff --git a/reflex/ievent/processor/base_state_processor.py b/reflex/ievent/processor/base_state_processor.py index 8bc39857a13..76ca10e917a 100644 --- a/reflex/ievent/processor/base_state_processor.py +++ b/reflex/ievent/processor/base_state_processor.py @@ -16,6 +16,7 @@ RegisteredEventHandler, ) from reflex.istate.data import RouterData +from reflex.istate.manager.token import BaseStateToken from reflex.istate.proxy import StateProxy from reflex.utils import console, types from reflex.utils.monitoring import is_pyleak_enabled, monitor_loopblocks @@ -181,10 +182,13 @@ async def chain_updates( # Convert valid EventHandler and EventSpec into Event if fixed_events := fix_events( - _check_valid_yield(events, handler_name=handler_name), token + _check_valid_yield(events, handler_name=handler_name), + token, + router_data=root_state.router_data if root_state else None, ): # Frontend events. - await ctx.emit_event(*(e for e in fixed_events if e.name.startswith("_"))) + if frontend_events := [e for e in fixed_events if e.name.startswith("_")]: + await ctx.emit_event(*frontend_events) # Backend events. await ctx.enqueue(*(e for e in fixed_events if not e.name.startswith("_"))) @@ -293,7 +297,11 @@ async def _process_event_queue_entry( router_data = event.router_data or {} # Get the state for the session exclusively. async with ctx.state_manager.modify_state_with_links( - entry.event.substate_token, event=entry.event + BaseStateToken( + ident=ctx.token, + cls=registered_handler.states[0], + ), + event=entry.event, ) as state: # TODO: handle "reload" trigger of brand new state instances @@ -302,7 +310,8 @@ async def _process_event_queue_entry( # assignment will recurse into substates and force recalculation of # dependent ComputedVar (dynamic route variables) state.router_data = router_data - state.router = RouterData.from_router_data(router_data) + if state.router != (router := RouterData.from_router_data(router_data)): + state.router = router # Preprocess the event. if ( @@ -318,7 +327,7 @@ async def _process_event_queue_entry( return # Get the event's substate. - substate = await state.get_state(event.substate_token.cls) + substate = await state.get_state(event.state_cls) root_state = state._get_root_state() # Process non-background events while holding the lock. diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index 1ba14da5c4c..e3b3c981217 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -18,6 +18,7 @@ from typing_extensions import Self from reflex.base import Base +from reflex.ievent.context import event_context from reflex.istate.manager.token import BaseStateToken from reflex.utils import prerequisites from reflex.utils.exceptions import ImmutableStateError @@ -76,7 +77,7 @@ def __init__( self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_substate_token = BaseStateToken( - ident=state_instance.router.session.client_token, + ident=event_context.get().token, cls=state_instance.__class__, ) self._self_actx = None diff --git a/reflex/state.py b/reflex/state.py index 8129e580575..6816ec5ba2f 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -319,7 +319,7 @@ def _override_base_method(fn: Callable[PARAMS, RETURN]) -> Callable[PARAMS, RETU return fn -all_base_state_classes: dict[str, None] = {} +all_base_state_classes: dict[str, type[BaseState]] = {} CLASS_VAR_NAMES = frozenset({ "vars", @@ -634,7 +634,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): cls._var_dependencies = {} cls._init_var_dependency_dicts() - all_base_state_classes[cls.get_full_name()] = None + all_base_state_classes[cls.get_full_name()] = cls @classmethod def _add_event_handler( @@ -1180,6 +1180,7 @@ def _create_setter(cls, name: str, prop: Var): prop: The var to create a setter for. """ from reflex.config import get_config + from reflex.ievent.registry import register config = get_config() create_event_handler_kwargs = {} @@ -1207,7 +1208,7 @@ def __call__(self, *args, **kwargs): event_handler = cls._create_event_handler( prop._get_setter(name), **create_event_handler_kwargs ) - cls.event_handlers[setter_name] = event_handler + cls.event_handlers[setter_name] = register(event_handler) setattr(cls, setter_name, event_handler) @classmethod @@ -1887,7 +1888,9 @@ def get_value(self, key: str) -> Any: return key.__wrapped__ if isinstance(key, str): - return getattr(self, key) + if isinstance(val := getattr(self, key), MutableProxy): + return val.__wrapped__ + return val msg = f"Invalid key type: {type(key)}. Expected str." raise TypeError(msg) diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 7e6b7185ef1..5ca3c5d2c33 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -12,7 +12,7 @@ from reflex.app import App from reflex.event import Event, EventSpec from reflex.ievent.context import EventContext -from reflex.ievent.processor import EventProcessor +from reflex.ievent.processor import BaseStateEventProcessor, EventProcessor from reflex.istate.manager import StateManager from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory @@ -272,6 +272,22 @@ def handle_backend_exception(ex: Exception) -> None: ) +@pytest.fixture +def mock_base_state_event_processor_obj() -> BaseStateEventProcessor: + """Create a BaseState event processor. + + Returns: + A fresh BaseState event processor. + """ + + def handle_backend_exception(ex: Exception) -> None: + raise ex + + return BaseStateEventProcessor( + backend_exception_handler=handle_backend_exception, graceful_shutdown_timeout=1 + ) + + @pytest.fixture def emitted_deltas() -> list[tuple[str, Mapping[str, Mapping[str, Any]]]]: """Create a list to store emitted deltas. @@ -294,14 +310,14 @@ def emitted_events() -> list[tuple[str, tuple[Event, ...]]]: @pytest.fixture def mock_root_event_context( - mock_event_processor_obj: EventProcessor, + mock_base_state_event_processor_obj: BaseStateEventProcessor, emitted_deltas: list[tuple[str, Mapping[str, Mapping[str, Any]]]], emitted_events: list[tuple[str, tuple[Event, ...]]], ) -> EventContext: """Create a mock event context. Args: - mock_event_processor_obj: The mock event processor to use for the context's enqueue implementation. + mock_base_state_event_processor_obj: The mock event processor to use for the context's enqueue implementation. emitted_deltas: The list to store emitted deltas. emitted_events: The list to store emitted events. @@ -332,7 +348,7 @@ async def emit_event_impl(token: str, *events: Event) -> None: # noqa: RUF029 return EventContext( token="", state_manager=StateManagerMemory(), - enqueue_impl=mock_event_processor_obj.enqueue, + enqueue_impl=mock_base_state_event_processor_obj.enqueue, emit_delta_impl=emit_delta_impl, emit_event_impl=emit_event_impl, ) @@ -364,3 +380,32 @@ def mock_event_processor( """ mock_event_processor_obj._root_context = mock_root_event_context return mock_event_processor_obj + + +@pytest.fixture +def mock_base_state_event_processor( + mock_root_event_context: EventContext, + mock_base_state_event_processor_obj: BaseStateEventProcessor, +) -> BaseStateEventProcessor: + """Create a BaseState event processor with a mock root context. + + Set the mock context as the task's current context, and set the processor's + root context to the mock context. + + Events can be queued against the processor via `await + mock_base_state_event_processor.enqueue(token, *events)`. + + The `state_manager` fixture is used by the `mock_root_event_context` so any + updates will be reflected in the context's state manager, and any deltas or + frontend events can be checked via the context's `emitted_deltas` and + `emitted_events` attributes. + + Args: + mock_root_event_context: The mock event context to use as the root context for the processor. + mock_base_state_event_processor_obj: The mock BaseState event processor to use for the processor's enqueue implementation. + + Returns: + An un-started event processor with a mock root context. + """ + mock_base_state_event_processor_obj._root_context = mock_root_event_context + return mock_base_state_event_processor_obj diff --git a/tests/units/test_app.py b/tests/units/test_app.py index c61a4a1f06b..ca7e7556710 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -15,19 +15,15 @@ import pytest from pytest_mock import MockerFixture +from sqlalchemy.engine.base import Engine from starlette.applications import Starlette from starlette.datastructures import FormData, UploadFile from starlette.responses import StreamingResponse +from starlette_admin.auth import AuthProvider import reflex as rx from reflex import AdminDash, constants -from reflex.app import ( - App, - ComponentCallable, - default_overlay_component, - process, - upload, -) +from reflex.app import App, ComponentCallable, default_overlay_component, upload from reflex.components import Component from reflex.components.base.bare import Bare from reflex.components.base.fragment import Fragment @@ -36,13 +32,14 @@ from reflex.constants.state import FIELD_MARKER from reflex.environment import environment from reflex.event import Event +from reflex.ievent.context import EventContext +from reflex.ievent.processor import BaseStateEventProcessor from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis from reflex.istate.manager.token import BaseStateToken -from reflex.middleware import HydrateMiddleware from reflex.model import Model -from reflex.state import BaseState, OnLoadInternalState, RouterData, State, StateUpdate +from reflex.state import BaseState, OnLoadInternalState, RouterData, State from reflex.style import Style from reflex.utils import console, exceptions, format from reflex.vars.base import computed_var @@ -205,12 +202,16 @@ def test_default_app(app: App): Args: app: The app to test. """ - assert app._middlewares == [HydrateMiddleware()] + assert app._middlewares == [] assert app.style == Style() assert app.admin_dash is None -def test_multiple_states_error(monkeypatch, test_state, redundant_test_state): +def test_multiple_states_error( + monkeypatch: pytest.MonkeyPatch, + test_state: BaseState, + redundant_test_state: BaseState, +): """Test that an error is thrown when multiple classes subclass rx.BaseState. Args: @@ -223,7 +224,9 @@ def test_multiple_states_error(monkeypatch, test_state, redundant_test_state): App() -def test_add_page_default_route(app: App, index_page, about_page): +def test_add_page_default_route( + app: App, index_page: ComponentCallable, about_page: ComponentCallable +): """Test adding a page to an app. Args: @@ -241,7 +244,7 @@ def test_add_page_default_route(app: App, index_page, about_page): assert app._pages.keys() == {"index", "about"} -def test_add_page_set_route(app: App, index_page): +def test_add_page_set_route(app: App, index_page: ComponentCallable): """Test adding a page to an app. Args: @@ -255,7 +258,7 @@ def test_add_page_set_route(app: App, index_page): assert app._pages.keys() == {"test"} -def test_add_page_set_route_dynamic(index_page): +def test_add_page_set_route_dynamic(index_page: ComponentCallable): """Test adding a page with dynamic route variable to an app. Args: @@ -275,7 +278,7 @@ def test_add_page_set_route_dynamic(index_page): assert constants.ROUTER in app._state()._var_dependencies -def test_add_page_set_route_nested(app: App, index_page): +def test_add_page_set_route_nested(app: App, index_page: ComponentCallable): """Test adding a page to an app. Args: @@ -288,7 +291,7 @@ def test_add_page_set_route_nested(app: App, index_page): assert app._unevaluated_pages.keys() == {route} -def test_add_page_invalid_api_route(app: App, index_page): +def test_add_page_invalid_api_route(app: App, index_page: ComponentCallable): """Test adding a page with an invalid route to an app. Args: @@ -363,7 +366,7 @@ def test_add_duplicate_page_route_error(app: App, first_page, second_page, route or not find_spec("pydantic"), reason="starlette_admin not installed or sqlmodel not installed or pydantic not installed", ) -def test_initialize_with_admin_dashboard(test_model): +def test_initialize_with_admin_dashboard(test_model: Model): """Test setting the admin dashboard of an app. Args: @@ -382,9 +385,9 @@ def test_initialize_with_admin_dashboard(test_model): reason="starlette_admin not installed or sqlmodel not installed or pydantic not installed", ) def test_initialize_with_custom_admin_dashboard( - test_get_engine, - test_custom_auth_admin, - test_model_auth, + test_get_engine: Engine, + test_custom_auth_admin: type[AuthProvider], + test_model_auth: Model, ): """Test setting the custom admin dashboard of an app. @@ -492,25 +495,35 @@ async def test_set_and_get_state(test_state: type[ATestState]): @pytest.mark.asyncio -async def test_dynamic_var_event(test_state: type[ATestState], token: str): +async def test_dynamic_var_event( + test_state: type[ATestState], + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], + token: str, +): """Test that the default handler of a dynamic generated var works as expected. Args: test_state: State Fixture. + mock_base_state_event_processor: BaseStateEventProcessor Fixture. + emitted_deltas: List to store emitted deltas. token: a Token. """ state = test_state() # pyright: ignore [reportCallIssue] state.add_var("int_val", int, 0) - async for result in state._process( - Event( - token=token, - name=f"{test_state.get_name()}.set_int_val", - router_data={"pathname": "/", "query": {}}, - payload={"value": 50}, + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + token=token, + name=f"{test_state.get_name()}.set_int_val", + payload={"value": 50}, + ), ) - ): - assert result.delta == {test_state.get_name(): {"int_val" + FIELD_MARKER: 50}} + assert emitted_deltas == [ + (token, {test_state.get_name(): {"int_val" + FIELD_MARKER: 50}}) + ] @pytest.fixture @@ -685,6 +698,8 @@ async def test_list_mutation_detection__plain_list( event_tuples: list[tuple[str, list[str]]], list_mutation_state: State, token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], ): """Test list mutation detection when reassignment is not explicitly included in the logic. @@ -693,19 +708,23 @@ async def test_list_mutation_detection__plain_list( event_tuples: From parametrization. list_mutation_state: A state with list mutation features. token: a Token. + mock_base_state_event_processor: BaseStateEventProcessor Fixture. + emitted_deltas: List to store emitted deltas. """ for event_name, expected_delta in event_tuples: - async for result in list_mutation_state._process( - Event( - token=token, - name=f"{list_mutation_state.get_name()}.{event_name}", - router_data={"pathname": "/", "query": {}}, - payload={}, + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + token="", + name=f"{list_mutation_state.get_name()}.{event_name}", + payload={}, + ), ) - ): - # prefix keys in expected_delta with the state name - expected_delta = {list_mutation_state.get_name(): expected_delta} - assert result.delta == expected_delta + # prefix keys in expected_delta with the state name + expected_delta = {list_mutation_state.get_name(): expected_delta} + assert emitted_deltas == [(token, expected_delta)] + emitted_deltas.clear() # Clear emitted deltas for the next iteration @pytest.fixture @@ -877,6 +896,8 @@ async def test_dict_mutation_detection__plain_list( event_tuples: list[tuple[str, list[str]]], dict_mutation_state: State, token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], ): """Test dict mutation detection when reassignment is not explicitly included in the logic. @@ -885,20 +906,23 @@ async def test_dict_mutation_detection__plain_list( event_tuples: From parametrization. dict_mutation_state: A state with dict mutation features. token: a Token. + mock_base_state_event_processor: BaseStateEventProcessor Fixture. + emitted_deltas: List to store emitted deltas. """ for event_name, expected_delta in event_tuples: - async for result in dict_mutation_state._process( - Event( - token=token, - name=f"{dict_mutation_state.get_name()}.{event_name}", - router_data={"pathname": "/", "query": {}}, - payload={}, + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + token="", + name=f"{dict_mutation_state.get_name()}.{event_name}", + payload={}, + ), ) - ): - # prefix keys in expected_delta with the state name - expected_delta = {dict_mutation_state.get_name(): expected_delta} - - assert result.delta == expected_delta + # prefix keys in expected_delta with the state name + expected_delta = {dict_mutation_state.get_name(): expected_delta} + assert emitted_deltas == [(token, expected_delta)] + emitted_deltas.clear() # Clear emitted deltas for the next iteration @pytest.mark.asyncio @@ -932,7 +956,7 @@ async def test_dict_mutation_detection__plain_list( ], ) async def test_upload_file( - tmp_path, + tmp_path: Path, state, delta, token: str, @@ -1000,7 +1024,7 @@ async def form(): # noqa: RUF029 @pytest.mark.asyncio async def test_upload_file_keeps_form_open_until_stream_completes( - tmp_path, + tmp_path: Path, token: str, mocker: MockerFixture, app_module_mock: unittest.mock.Mock, @@ -1138,7 +1162,7 @@ async def cancelled_get_state(*_args, **_kwargs): @pytest.mark.asyncio async def test_upload_file_closes_form_if_response_cancelled_before_stream_starts( - tmp_path, + tmp_path: Path, token: str, mocker: MockerFixture, app_module_mock: unittest.mock.Mock, @@ -1209,7 +1233,11 @@ async def send(_message): "state", [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) -async def test_upload_file_without_annotation(state, tmp_path, token): +async def test_upload_file_without_annotation( + state: FileUploadState | ChildFileUploadState | GrandChildFileUploadState, + tmp_path: Path, + token: str, +): """Test that an error is thrown when there's no param annotated with rx.UploadFile or list[UploadFile]. Args: @@ -1248,7 +1276,11 @@ async def form(): # noqa: RUF029 "state", [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) -async def test_upload_file_background(state, tmp_path, token): +async def test_upload_file_background( + state: FileUploadState | ChildFileUploadState | GrandChildFileUploadState, + tmp_path: Path, + token: str, +): """Test that an error is thrown handler is a background task. Args: @@ -1282,7 +1314,7 @@ async def form(): # noqa: RUF029 await app.state_manager.close() -class DynamicState(BaseState): +class DynamicState(State): """State class for testing dynamic route var. This is defined at module level because event handlers cannot be addressed @@ -1320,8 +1352,6 @@ def comp_dynamic(self) -> str: """ return self.dynamic # pyright: ignore[reportAttributeAccessIssue] - on_load_internal = OnLoadInternalState.on_load_internal.fn # pyright: ignore [reportFunctionMemberAccess] - def test_dynamic_arg_shadow( index_page: ComponentCallable, @@ -1373,7 +1403,10 @@ async def test_dynamic_route_var_route_change_completed_on_load( index_page: ComponentCallable, token: str, app_module_mock: unittest.mock.Mock, - mocker: MockerFixture, + mock_root_event_context: EventContext, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], + emitted_events: list[tuple[str, tuple[Event, ...]]], ): """Create app with dynamic route var, and simulate navigation. @@ -1384,12 +1417,16 @@ async def test_dynamic_route_var_route_change_completed_on_load( index_page: The index page. token: a Token. app_module_mock: Mocked app module. - mocker: pytest mocker object. + mock_root_event_context: Mocked root event context. + mock_base_state_event_processor: Mocked BaseStateEventProcessor. + emitted_deltas: List to store emitted deltas. + emitted_events: List to store emitted events. """ DynamicState._app_ref = None arg_name = "dynamic" route = f"test/[{arg_name}]" - app = app_module_mock.app = App(_state=DynamicState) + app = app_module_mock.app = App() + app._state_manager = mock_root_event_context.state_manager assert app._state is not None assert arg_name not in app._state.vars app.add_page(index_page, route=route, on_load=DynamicState.on_load) @@ -1402,8 +1439,6 @@ async def test_dynamic_route_var_route_change_completed_on_load( assert constants.ROUTER in app._state()._var_dependencies substate_token = BaseStateToken(ident=token, cls=DynamicState) - sid = "mock_sid" - client_ip = "127.0.0.1" async with app.state_manager.modify_state(substate_token) as state: state.router_data = {"simulate": "hydrated"} assert state.dynamic == "" # pyright: ignore[reportAttributeAccessIssue] @@ -1418,7 +1453,7 @@ def _event(name, val, **kwargs): { "pathname": "/" + route, "query": {arg_name: val}, - "asPath": "/test/something", + "asPath": f"/test/{val}", }, ), payload=kwargs.pop("payload", {}), @@ -1435,57 +1470,53 @@ def _dynamic_state_event(name, val, **kwargs): prev_exp_val = "" for exp_index, exp_val in enumerate(exp_vals): on_load_internal = _event( - name=f"{state.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}", + name=f"{OnLoadInternalState.get_full_name()}.{constants.CompileVars.ON_LOAD_INTERNAL.rpartition('.')[2]}", val=exp_val, ) - exp_router_data = { - "headers": {}, - "ip": client_ip, - "sid": sid, - "token": token, - **on_load_internal.router_data, - } - exp_router = RouterData.from_router_data(exp_router_data) - process_coro = process( - app, - event=on_load_internal, - sid=sid, - headers={}, - client_ip=client_ip, - ) - update = await process_coro.__anext__() - # route change (on_load_internal) triggers: [call on_load events, call set_is_hydrated(True)] - assert update == StateUpdate( - delta={ - state.get_name(): { - arg_name + FIELD_MARKER: exp_val, - f"comp_{arg_name}" + FIELD_MARKER: exp_val, - constants.CompileVars.IS_HYDRATED + FIELD_MARKER: False, - "router" + FIELD_MARKER: exp_router, - } - }, - events=[ - _dynamic_state_event( - name="on_load", - val=exp_val, - ), - _event( - name=f"{State.get_name()}.set_is_hydrated", - payload={"value": True}, - val=exp_val, - router_data={}, - ), - ], - ) + exp_router = RouterData.from_router_data(on_load_internal.router_data) + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + on_load_internal, + ) + await processor.join() + assert emitted_deltas == [ + ( + token, + { + State.get_full_name(): { + arg_name + FIELD_MARKER: exp_val, + constants.CompileVars.IS_HYDRATED + FIELD_MARKER: False, + "router" + FIELD_MARKER: exp_router, + }, + DynamicState.get_full_name(): { + f"comp_{arg_name}" + FIELD_MARKER: exp_val, + }, + }, + ), + ( + token, + { + DynamicState.get_full_name(): { + "loaded" + FIELD_MARKER: exp_index + 1, + }, + }, + ), + ( + token, + { + State.get_full_name(): { + "is_hydrated" + FIELD_MARKER: True, + }, + }, + ), + ] + assert emitted_events == [] if isinstance(app.state_manager, StateManagerRedis): # When redis is used, the state is not updated until the processing is complete state = await app.state_manager.get_state(substate_token) assert state.dynamic == prev_exp_val # pyright: ignore[reportAttributeAccessIssue] - # complete the processing - with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() - if environment.REFLEX_OPLOCK_ENABLED.get(): await app.state_manager.close() @@ -1493,125 +1524,89 @@ def _dynamic_state_event(name, val, **kwargs): state = await app.state_manager.get_state(substate_token) assert state.dynamic == exp_val # pyright: ignore[reportAttributeAccessIssue] - process_coro = process( - app, - event=_dynamic_state_event(name="on_load", val=exp_val), - sid=sid, - headers={}, - client_ip=client_ip, - ) - on_load_update = await process_coro.__anext__() - assert on_load_update == StateUpdate( - delta={ - state.get_name(): { - "loaded" + FIELD_MARKER: exp_index + 1, - }, - }, - events=[], - ) - # complete the processing - with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() - process_coro = process( - app, - event=_dynamic_state_event( - name="set_is_hydrated", payload={"value": True}, val=exp_val - ), - sid=sid, - headers={}, - client_ip=client_ip, - ) - on_set_is_hydrated_update = await process_coro.__anext__() - assert on_set_is_hydrated_update == StateUpdate( - delta={ - state.get_name(): { - "is_hydrated" + FIELD_MARKER: True, - }, - }, - events=[], - ) - # complete the processing - with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() - # a simple state update event should NOT trigger on_load or route var side effects - process_coro = process( - app, - event=_dynamic_state_event(name="on_counter", val=exp_val), - sid=sid, - headers={}, - client_ip=client_ip, - ) - update = await process_coro.__anext__() - assert update == StateUpdate( - delta={ - state.get_name(): { - "counter" + FIELD_MARKER: exp_index + 1, - } - }, - events=[], - ) - # complete the processing - with pytest.raises(StopAsyncIteration): - await process_coro.__anext__() - + emitted_deltas.clear() + emitted_events.clear() + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + _dynamic_state_event(name="on_counter", val=exp_val), + ) + assert emitted_deltas == [ + ( + token, + { + DynamicState.get_full_name(): { + "counter" + FIELD_MARKER: exp_index + 1, + } + }, + ) + ] + assert emitted_events == [] + emitted_deltas.clear() + emitted_events.clear() prev_exp_val = exp_val if environment.REFLEX_OPLOCK_ENABLED.get(): await app.state_manager.close() state = await app.state_manager.get_state(substate_token) - assert isinstance(state, DynamicState) - assert state.loaded == len(exp_vals) - assert state.counter == len(exp_vals) + assert isinstance(state, State) + dynamic_state = await state.get_state(DynamicState) + assert isinstance(dynamic_state, DynamicState) + assert dynamic_state.loaded == len(exp_vals) + assert dynamic_state.counter == len(exp_vals) await app.state_manager.close() @pytest.mark.asyncio async def test_process_events( - mocker: MockerFixture, token: str, app_module_mock: unittest.mock.Mock + token: str, + app_module_mock: unittest.mock.Mock, + mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, + emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], ): """Test that an event is processed properly and that it is postprocessed n+1 times. Also check that the processing flag of the last stateupdate is set to False. Args: - mocker: mocker object. token: a Token. app_module_mock: The mock for the app module, used to patch the app instance. + mock_base_state_event_processor: BaseStateEventProcessor Fixture. + mock_root_event_context: The mock for the root event context, used to patch the app + state manager. + emitted_deltas: List to store emitted deltas. """ - router_data = { - "pathname": "/", - "query": {}, - "token": token, - "sid": "mock_sid", - "headers": {}, - "ip": "127.0.0.1", - } - app = app_module_mock.app = App(_state=GenState) - - mocker.patch.object(app, "_postprocess", AsyncMock()) event = Event( token=token, name=f"{GenState.get_name()}.go", payload={"c": 5}, - router_data=router_data, + router_data={}, ) - async with app.state_manager.modify_state(event.substate_token) as state: + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=GenState), + ) as state: state.router_data = {"simulate": "hydrated"} - async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): - pass + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + event, + ) if environment.REFLEX_OPLOCK_ENABLED.get(): - await app.state_manager.close() + await mock_root_event_context.state_manager.close() - gen_state = await app.state_manager.get_state(event.substate_token) + gen_state = await mock_root_event_context.state_manager.get_state( + event.substate_token + ) assert isinstance(gen_state, GenState) assert gen_state.value == 5 - assert app._postprocess.call_count == 6 # pyright: ignore [reportAttributeAccessIssue] + assert len(emitted_deltas) == 5 - await app.state_manager.close() + await mock_root_event_context.state_manager.close() @pytest.mark.parametrize( @@ -1672,7 +1667,7 @@ def test_overlay_component( @pytest.fixture -def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]: +def compilable_app(tmp_path: Path) -> Generator[tuple[App, Path], None, None]: """Fixture for an app that can be compiled. Args: @@ -1710,7 +1705,9 @@ def compilable_app(tmp_path) -> Generator[tuple[App, Path], None, None]: [True, False], ) def test_app_wrap_compile_theme( - react_strict_mode: bool, compilable_app: tuple[App, Path], mocker + react_strict_mode: bool, + compilable_app: tuple[App, Path], + mocker: MockerFixture, ): """Test that the radix theme component wraps the app. @@ -1761,7 +1758,9 @@ def test_app_wrap_compile_theme( [True, False], ) def test_app_wrap_priority( - react_strict_mode: bool, compilable_app: tuple[App, Path], mocker + react_strict_mode: bool, + compilable_app: tuple[App, Path], + mocker: MockerFixture, ): """Test that the app wrap components are wrapped in the correct order. From cfc706adc78132f245647690ec1ac373453330b6 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 25 Mar 2026 16:14:53 -0700 Subject: [PATCH 09/81] Avoid handling the same exception multiple times in EventProcessor --- reflex/ievent/processor/event_processor.py | 46 +++++++--------------- 1 file changed, 15 insertions(+), 31 deletions(-) diff --git a/reflex/ievent/processor/event_processor.py b/reflex/ievent/processor/event_processor.py index ad243a2ae86..6143399784d 100644 --- a/reflex/ievent/processor/event_processor.py +++ b/reflex/ievent/processor/event_processor.py @@ -219,44 +219,28 @@ async def _stop_tasks(self, timeout: float | None = None) -> None: queue to drain before cancelling tasks. If None, the processor will not wait and will cancel tasks immediately. """ - from reflex.utils import telemetry - + finished_tasks = set() + # Graceful drain time, wait for tasks to finish and handle any exceptions. if timeout is not None and self._tasks: with contextlib.suppress(asyncio.TimeoutError): - await asyncio.wait_for( - asyncio.gather(*self._tasks.values()), - timeout=timeout, - ) + async for task in asyncio.as_completed( + self._tasks.values(), timeout=timeout + ): + # Exceptions are handled in _finish_task and ignored here. + with contextlib.suppress(Exception): + await task + finished_tasks.add(task) # Cancel all outstanding event handler tasks. - for task in (outstanding_tasks := list(self._tasks.values())): + outstanding_tasks = [ + task for task in self._tasks.values() if task not in finished_tasks + ] + for task in outstanding_tasks: task.cancel() # Wait for all tasks to finish and log any exceptions that were raised. for task in outstanding_tasks: - try: + with contextlib.suppress(Exception, asyncio.CancelledError): + # Exceptions are handled in _finish_task. await task - except asyncio.CancelledError: # noqa: PERF203 - pass - except Exception as ex: - telemetry.send_error(ex, context="backend") - if self.backend_exception_handler is not None: - try: - await task.get_context().run( - self._handle_backend_exception, - ex, - ) - except Exception: - console.error( - rich.markup.escape( - f"Error in backend exception handler for {task.get_name()} during shutdown:\n{traceback.format_exc()}" - ) - ) - else: - continue - console.error( - rich.markup.escape( - f"Error in event handler task {task.get_name()} during shutdown:\n{traceback.format_exc()}" - ) - ) async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: """Stop the event processor and cancel all running tasks. From 8bbeab46d0d3feffd4f5ef7bcf630616f7f2e796 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 25 Mar 2026 16:15:28 -0700 Subject: [PATCH 10/81] Attach cls to setter event handlers --- reflex/state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reflex/state.py b/reflex/state.py index 6816ec5ba2f..daae277316a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -1208,7 +1208,7 @@ def __call__(self, *args, **kwargs): event_handler = cls._create_event_handler( prop._get_setter(name), **create_event_handler_kwargs ) - cls.event_handlers[setter_name] = register(event_handler) + cls.event_handlers[setter_name] = register(event_handler, states=(cls,)) setattr(cls, setter_name, event_handler) @classmethod From 25626068f29a34c756d557365fb947840fda95fc Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 25 Mar 2026 16:19:19 -0700 Subject: [PATCH 11/81] fix test_app, test_state and friends Use the new mock_base_state_event_processor fixture to process arbitrary events and assert on emitted events or deltas. --- tests/units/conftest.py | 28 +- tests/units/istate/test_proxy.py | 2 +- tests/units/test_app.py | 38 +- tests/units/test_event.py | 4 +- tests/units/test_model.py | 27 +- tests/units/test_state.py | 666 ++++++++++++++++--------------- 6 files changed, 432 insertions(+), 333 deletions(-) diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 5ca3c5d2c33..63ece7832ff 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -1,6 +1,7 @@ """Test fixtures.""" import platform +import traceback import uuid from collections.abc import AsyncGenerator, Generator, Mapping from typing import Any @@ -11,7 +12,7 @@ from reflex.app import App from reflex.event import Event, EventSpec -from reflex.ievent.context import EventContext +from reflex.ievent.context import EventContext, event_context from reflex.ievent.processor import BaseStateEventProcessor, EventProcessor from reflex.istate.manager import StateManager from reflex.istate.manager.disk import StateManagerDisk @@ -281,7 +282,8 @@ def mock_base_state_event_processor_obj() -> BaseStateEventProcessor: """ def handle_backend_exception(ex: Exception) -> None: - raise ex + formatted_exc = "\n".join(traceback.format_exception(ex)) + pytest.fail(f"Event processor raised an unexpected exception:\n{formatted_exc}") return BaseStateEventProcessor( backend_exception_handler=handle_backend_exception, graceful_shutdown_timeout=1 @@ -409,3 +411,25 @@ def mock_base_state_event_processor( """ mock_base_state_event_processor_obj._root_context = mock_root_event_context return mock_base_state_event_processor_obj + + +@pytest.fixture +def attached_mock_event_context( + mock_root_event_context: EventContext, token: str +) -> Generator[EventContext, None, None]: + """Fork the mock event context for the given token and attach it. + + Sets the forked context as the current event_context for the duration + of the test, then resets it afterwards. + + Args: + mock_root_event_context: The mock root event context. + token: The client token. + + Yields: + The forked EventContext. + """ + ctx = mock_root_event_context.fork(token=token) + reset_token = event_context.set(ctx) + yield ctx + event_context.reset(reset_token) diff --git a/tests/units/istate/test_proxy.py b/tests/units/istate/test_proxy.py index 5fd29725fa9..0f4d1171780 100644 --- a/tests/units/istate/test_proxy.py +++ b/tests/units/istate/test_proxy.py @@ -42,7 +42,7 @@ def test_mutable_proxy_pickle_preserves_object_identity(): assert unpickled["direct"][0] is unpickled["proxied"][0] -@pytest.mark.usefixtures("mock_app") +@pytest.mark.usefixtures("mock_app", "attached_mock_event_context") @pytest.mark.asyncio async def test_state_proxy_recovery(): """Ensure that `async with self` can be re-entered after a lock issue.""" diff --git a/tests/units/test_app.py b/tests/units/test_app.py index ca7e7556710..ec0ac7c9fe6 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import functools import io import json @@ -10,7 +11,7 @@ from contextlib import nullcontext as does_not_raise from importlib.util import find_spec from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any from unittest.mock import AsyncMock import pytest @@ -39,7 +40,13 @@ from reflex.istate.manager.redis import StateManagerRedis from reflex.istate.manager.token import BaseStateToken from reflex.model import Model -from reflex.state import BaseState, OnLoadInternalState, RouterData, State +from reflex.state import ( + BaseState, + OnLoadInternalState, + RouterData, + State, + reload_state_module, +) from reflex.style import Style from reflex.utils import console, exceptions, format from reflex.vars.base import computed_var @@ -955,6 +962,7 @@ async def test_dict_mutation_detection__plain_list( ), ], ) +@pytest.mark.skip("Waiting for upload PR") async def test_upload_file( tmp_path: Path, state, @@ -1023,6 +1031,7 @@ async def form(): # noqa: RUF029 @pytest.mark.asyncio +@pytest.mark.skip("Waiting for upload PR") async def test_upload_file_keeps_form_open_until_stream_completes( tmp_path: Path, token: str, @@ -1117,6 +1126,7 @@ async def send(message): # noqa: RUF029 @pytest.mark.asyncio +@pytest.mark.skip("Waiting for upload PR") async def test_upload_file_closes_form_on_event_creation_cancellation( token: str, mocker: MockerFixture, @@ -1161,6 +1171,7 @@ async def cancelled_get_state(*_args, **_kwargs): @pytest.mark.asyncio +@pytest.mark.skip("Waiting for upload PR") async def test_upload_file_closes_form_if_response_cancelled_before_stream_starts( tmp_path: Path, token: str, @@ -1233,6 +1244,7 @@ async def send(_message): "state", [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) +@pytest.mark.skip("Waiting for upload PR") async def test_upload_file_without_annotation( state: FileUploadState | ChildFileUploadState | GrandChildFileUploadState, tmp_path: Path, @@ -1276,6 +1288,7 @@ async def form(): # noqa: RUF029 "state", [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) +@pytest.mark.skip("Waiting for upload PR") async def test_upload_file_background( state: FileUploadState | ChildFileUploadState | GrandChildFileUploadState, tmp_path: Path, @@ -1331,7 +1344,6 @@ class DynamicState(State): is_hydrated: bool = False loaded: int = 0 counter: int = 0 - _app_ref: ClassVar[Any] = None @rx.event def on_load(self): @@ -1367,7 +1379,6 @@ def test_dynamic_arg_shadow( app_module_mock: Mocked app module. mocker: pytest mocker object. """ - DynamicState._app_ref = None arg_name = "counter" route = f"/test/[{arg_name}]" app = app_module_mock.app = App(_state=DynamicState) @@ -1398,6 +1409,22 @@ def test_multiple_dynamic_args( app.add_page(index_page, route=route2) +@pytest.fixture +def cleanup_dynamic_arg(): + """Fixture to reset DynamicState class vars after each test.""" + yield + with contextlib.suppress(AttributeError): + del State.dynamic # pyright: ignore[reportAttributeAccessIssue] + + State.computed_vars.pop("dynamic", None) + State.vars.pop("dynamic", None) + State._var_dependencies = {} + State._potentially_dirty_states = set() + State._always_dirty_computed_vars = set() + reload_state_module(__name__) + + +@pytest.mark.usefixtures("cleanup_dynamic_arg") @pytest.mark.asyncio async def test_dynamic_route_var_route_change_completed_on_load( index_page: ComponentCallable, @@ -1422,7 +1449,6 @@ async def test_dynamic_route_var_route_change_completed_on_load( emitted_deltas: List to store emitted deltas. emitted_events: List to store emitted events. """ - DynamicState._app_ref = None arg_name = "dynamic" route = f"test/[{arg_name}]" app = app_module_mock.app = App() @@ -1600,7 +1626,7 @@ async def test_process_events( await mock_root_event_context.state_manager.close() gen_state = await mock_root_event_context.state_manager.get_state( - event.substate_token + BaseStateToken(ident=token, cls=GenState), ) assert isinstance(gen_state, GenState) assert gen_state.value == 5 diff --git a/tests/units/test_event.py b/tests/units/test_event.py index c413a1f225e..ff45395684a 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -48,7 +48,7 @@ def test_fn(): test_fn.__qualname__ = "test_fn" - def fn_with_args(_, arg1, arg2): + def fn_with_args(arg1, arg2): pass fn_with_args.__qualname__ = "fn_with_args" @@ -153,7 +153,7 @@ def test_fix_events(arg1, arg2): arg2: The second arg passed to the handler. """ - def fn_with_args(_, arg1, arg2): + def fn_with_args(arg1, arg2): pass fn_with_args.__qualname__ = "fn_with_args" diff --git a/tests/units/test_model.py b/tests/units/test_model.py index e7f64f736e5..f9d6166ad53 100644 --- a/tests/units/test_model.py +++ b/tests/units/test_model.py @@ -7,6 +7,7 @@ import reflex.constants import reflex.model from reflex.constants.state import FIELD_MARKER +from reflex.event import Event from reflex.model import Model, ModelRegistry from reflex.state import BaseState, State from tests.units.test_state import ( @@ -221,25 +222,37 @@ def rx_model(self, m: ReflexModel): # noqa: D102 @pytest.mark.asyncio -@pytest.mark.usefixtures("mock_app_simple") @pytest.mark.parametrize( ("handler", "payload"), [ (UpcastStateWithSqlAlchemy.rx_model, {"m": {"foo": "bar"}}), ], ) -async def test_upcast_event_handler_arg(handler, payload): +async def test_upcast_event_handler_arg( + handler, payload, mock_base_state_event_processor, emitted_deltas +): """Test that upcast event handler args work correctly. Args: handler: The handler to test. payload: The payload to test. + mock_base_state_event_processor: Fixture for processing events with a BaseState. + emitted_deltas: List to store emitted deltas. """ - state = UpcastStateWithSqlAlchemy() - async for update in state._process_event(handler, state, payload): - assert update.delta == { - UpcastStateWithSqlAlchemy.get_full_name(): {"passed" + FIELD_MARKER: True} - } + async with mock_base_state_event_processor as processor: + await processor.enqueue( + "test_token", *Event.from_event_type(handler(**payload)) + ) + assert emitted_deltas == [ + ( + "test_token", + { + UpcastStateWithSqlAlchemy.get_full_name(): { + "passed" + FIELD_MARKER: True + } + }, + ), + ] def test_no_rebind_mutable_proxy_for_instrumented_functions(): diff --git a/tests/units/test_state.py b/tests/units/test_state.py index c678ef9c662..fbe785d8885 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -29,6 +29,8 @@ from reflex.constants.state import FIELD_MARKER from reflex.environment import environment from reflex.event import Event, EventHandler +from reflex.ievent.context import EventContext +from reflex.ievent.processor import BaseStateEventProcessor from reflex.istate.data import HeaderData, _FrozenDictStrStr from reflex.istate.manager import StateManager from reflex.istate.manager.disk import StateManagerDisk @@ -807,101 +809,121 @@ def test_reset(test_state: TestState, child_state: ChildState): @pytest.mark.asyncio -async def test_process_event_simple(test_state): +async def test_process_event_simple( + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, +): """Test processing an event. Args: - test_state: A state. + token: A token. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. """ - assert test_state.num1 == 0 - - event = Event(token="t", name="set_num1", payload={"value": 69}) - async for update in test_state._process(event): - # The event should update the value. - assert test_state.num1 == 69 - - # The delta should contain the changes, including computed vars. - assert update.delta == { - TestState.get_full_name(): { - "num1" + FIELD_MARKER: 69, - "sum" + FIELD_MARKER: 72.15, + event = Event( + token=token, + name=f"{TestState.get_full_name()}.set_num1", + payload={"value": 69}, + ) + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + # The delta should contain the changes, including computed vars. + assert emitted_deltas == [ + ( + token, + { + TestState.get_full_name(): { + "num1" + FIELD_MARKER: 69, + "sum" + FIELD_MARKER: 72.15, + }, + GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, }, - GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, - } - assert update.events == [] + ) + ] @pytest.mark.asyncio -async def test_process_event_substate(test_state, child_state, grandchild_state): +async def test_process_event_substate( + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, +): """Test processing an event on a substate. Args: - test_state: A state. - child_state: A child state. - grandchild_state: A grandchild state. + token: A token. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. """ # Events should bubble down to the substate. - assert child_state.value == "" - assert child_state.count == 23 event = Event( - token="t", - name=f"{ChildState.get_name()}.change_both", + token=token, + name=f"{ChildState.get_full_name()}.change_both", payload={"value": "hi", "count": 12}, ) - async for update in test_state._process(event): - assert child_state.value == "HI" - assert child_state.count == 24 - assert update.delta == { - # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, - ChildState.get_full_name(): { - "value" + FIELD_MARKER: "HI", - "count" + FIELD_MARKER: 24, + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + assert emitted_deltas == [ + ( + token, + { + ChildState.get_full_name(): { + "value" + FIELD_MARKER: "HI", + "count" + FIELD_MARKER: 24, + }, + GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, }, - GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, - } - test_state._clean() + ) + ] + emitted_deltas.clear() # Test with the grandchild state. - assert grandchild_state.value2 == "" event = Event( - token="t", + token=token, name=f"{GrandchildState.get_full_name()}.set_value2", payload={"value": "new"}, ) - async for update in test_state._process(event): - assert grandchild_state.value2 == "new" - assert update.delta == { - # TestState.get_full_name(): {"sum": 3.14, "upper": ""}, - GrandchildState.get_full_name(): {"value2" + FIELD_MARKER: "new"}, - GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, - } + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + assert emitted_deltas == [ + ( + token, + { + GrandchildState.get_full_name(): {"value2" + FIELD_MARKER: "new"}, + GrandchildState3.get_full_name(): {"computed" + FIELD_MARKER: ""}, + }, + ) + ] @pytest.mark.asyncio -async def test_process_event_generator(): - """Test event handlers that generate multiple updates.""" - gen_state = GenState() # pyright: ignore [reportCallIssue] +async def test_process_event_generator( + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, +): + """Test event handlers that generate multiple updates. + + Args: + token: A token. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. + """ event = Event( - token="t", - name="go", + token=token, + name=f"{GenState.get_full_name()}.go", payload={"c": 5}, ) - gen = gen_state._process(event) - - count = 0 - async for update in gen: - count += 1 - if count == 6: - assert update.delta == {} - assert update.final - else: - assert gen_state.value == count - assert update.delta == { - GenState.get_full_name(): {"value" + FIELD_MARKER: count}, - } - assert not update.final - - assert count == 6 + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + # Generator yields 5 deltas (one per increment). + assert len(emitted_deltas) == 5 + for count, (delta_token, delta) in enumerate(emitted_deltas, 1): + assert delta_token == token + assert delta == { + GenState.get_full_name(): {"value" + FIELD_MARKER: count}, + } def test_get_client_token(test_state, router_data): @@ -1644,12 +1666,15 @@ def reset(self): @pytest.mark.asyncio -async def test_state_with_invalid_yield(capsys: pytest.CaptureFixture[str], mock_app): +async def test_state_with_invalid_yield( + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, +): """Test that an error is thrown when a state yields an invalid value. Args: - capsys: Pytest fixture for capture standard streams. - mock_app: Mock app fixture. + token: A token. + mock_base_state_event_processor: The event processor. """ class StateWithInvalidYield(BaseState): @@ -1663,27 +1688,26 @@ def invalid_handler(self): """ yield 1 - invalid_state = StateWithInvalidYield() - async for update in invalid_state._process( - rx.event.Event(token="fake_token", name="invalid_handler") - ): - assert not update.delta - assert update.events == rx.event.fix_events( - [ - rx.toast( - "An error occurred.", - level="error", - fallback_to_alert=True, - description="TypeError: Your handler test_state_with_invalid_yield..StateWithInvalidYield.invalid_handler must only return/yield: None, Events or other EventHandlers referenced by their class (i.e. using `type(self)` or other class references). Returned events of types ..
See logs for details.", - id="backend_error", - position="top-center", - style={"width": "500px"}, - ) - ], - token="", - ) - captured = capsys.readouterr() - assert "must only return/yield: None, Events or other EventHandlers" in captured.err + captured_exceptions: list[Exception] = [] + + def capture_exception(ex: Exception) -> None: + captured_exceptions.append(ex) + + mock_base_state_event_processor.backend_exception_handler = capture_exception + + event = Event( + token=token, + name=f"{StateWithInvalidYield.get_full_name()}.invalid_handler", + payload={}, + ) + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + + assert len(captured_exceptions) == 1 + assert isinstance(captured_exceptions[0], TypeError) + assert "must only return/yield: None, Events or other EventHandlers" in str( + captured_exceptions[0] + ) @pytest.fixture @@ -2091,7 +2115,10 @@ def from_dict(cls, data: dict) -> ModelDC: @pytest.mark.asyncio async def test_state_proxy( - grandchild_state: GrandchildState, mock_app: rx.App, token: str + grandchild_state: GrandchildState, + mock_app: rx.App, + token: str, + attached_mock_event_context: EventContext, ): """Test that the state proxy works. @@ -2099,6 +2126,7 @@ async def test_state_proxy( grandchild_state: A grandchild state. mock_app: An app that will be returned by `get_app()` token: A token. + attached_mock_event_context: The event context attached for this test. """ child_state = grandchild_state.parent_state assert child_state is not None @@ -2215,7 +2243,7 @@ async def test_state_proxy( "computed" + FIELD_MARKER: "", }, }, - final=None, + final=True, ) assert mcall.kwargs["to"] == grandchild_state.router.session.session_id @@ -2327,79 +2355,46 @@ async def bad_chain2(self): @pytest.mark.asyncio -async def test_background_task_no_block(mock_app: rx.App, token: str): +async def test_background_task_no_block( + mock_app: rx.App, + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, + state_manager: StateManager, +): """Test that a background task does not block other events. Args: mock_app: An app that will be returned by `get_app()` token: A token. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. + state_manager: A state manager instance. """ - router_data = {"query": {}, "token": token} - sid = "test_sid" - namespace = mock_app.event_namespace - assert namespace is not None - namespace.sid_to_token[sid] = token - namespace._token_manager.instance_id = "mock" - namespace._token_manager.token_to_socket[token] = SocketRecord( - instance_id="mock", sid=sid - ) - mock_app._state = BackgroundTaskState - async for update in rx.app.process( - mock_app, - Event( - token=token, - name=f"{BackgroundTaskState.get_full_name()}.background_task", - router_data=router_data, - payload={}, - ), - sid=sid, - headers={}, - client_ip="", - ): - # background task returns empty update immediately - assert update == StateUpdate() - - # wait for the coroutine to start - await asyncio.sleep(0.5 if CI else 0.1) - assert len(mock_app._background_tasks) == 1 - - # Process another normal event - async for update in rx.app.process( - mock_app, - Event( - token=token, - name=f"{BackgroundTaskState.get_full_name()}.other", - router_data=router_data, - payload={}, - ), - sid=sid, - headers={}, - client_ip="", - ): - # other task returns delta - assert update == StateUpdate( - delta={ - BackgroundTaskState.get_full_name(): { - "order" + FIELD_MARKER: [ - "background_task:start", - "other", - ], - "computed_order" + FIELD_MARKER: [ - "background_task:start", - "other", - ], - } - }, + async with mock_base_state_event_processor as processor: + # Start background task + await processor.enqueue( + token, + Event( + token=token, + name=f"{BackgroundTaskState.get_full_name()}.background_task", + payload={}, + ), + ) + # Wait for the background task coroutine to start + await asyncio.sleep(0.5 if CI else 0.1) + + # Process another normal event while background task is polling + await processor.enqueue( + token, + Event( + token=token, + name=f"{BackgroundTaskState.get_full_name()}.other", + payload={}, + ), ) - # Explicit wait for background tasks - for task in tuple(mock_app._background_tasks): - await task - assert not mock_app._background_tasks - - if environment.REFLEX_OPLOCK_ENABLED.get(): - await mock_app.state_manager.close() - + # After processor context exits, all tasks including background are done. exp_order = [ "background_task:start", "other", @@ -2408,97 +2403,39 @@ async def test_background_task_no_block(mock_app: rx.App, token: str): "private", ] - background_task_state = await mock_app.state_manager.get_state( + background_task_state = await state_manager.get_state( BaseStateToken(ident=token, cls=BackgroundTaskState) ) assert isinstance(background_task_state, BackgroundTaskState) assert background_task_state.order == exp_order - assert mock_app.event_namespace is not None - emit_mock = mock_app.event_namespace.emit - - first_ws_message = emit_mock.mock_calls[0].args[1] # pyright: ignore [reportAttributeAccessIssue] - assert ( - first_ws_message.delta[BackgroundTaskState.get_full_name()].pop( - "router" + FIELD_MARKER - ) - is not None - ) - assert first_ws_message == StateUpdate( - delta={ - BackgroundTaskState.get_full_name(): { - "order" + FIELD_MARKER: ["background_task:start"], - "computed_order" + FIELD_MARKER: ["background_task:start"], - } - }, - events=[], - final=None, - ) - for call in emit_mock.mock_calls[1:5]: # pyright: ignore [reportAttributeAccessIssue] - assert call.args[1] == StateUpdate( - delta={ - BackgroundTaskState.get_full_name(): { - "computed_order" + FIELD_MARKER: ["background_task:start"], - } - }, - events=[], - final=None, - ) - assert emit_mock.mock_calls[-2].args[1] == StateUpdate( # pyright: ignore [reportAttributeAccessIssue] - delta={ - BackgroundTaskState.get_full_name(): { - "order" + FIELD_MARKER: exp_order, - "computed_order" + FIELD_MARKER: exp_order, - "dict_list" + FIELD_MARKER: {}, - } - }, - events=[], - final=None, - ) - assert emit_mock.mock_calls[-1].args[1] == StateUpdate( # pyright: ignore [reportAttributeAccessIssue] - delta={ - BackgroundTaskState.get_full_name(): { - "computed_order" + FIELD_MARKER: exp_order, - }, - }, - events=[], - final=None, - ) @pytest.mark.asyncio -async def test_background_task_reset(mock_app: rx.App, token: str): +async def test_background_task_reset( + mock_app: rx.App, + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + state_manager: StateManager, +): """Test that a background task calling reset is protected by the state proxy. Args: mock_app: An app that will be returned by `get_app()` token: A token. + mock_base_state_event_processor: The event processor. + state_manager: A state manager instance. """ - router_data = {"query": {}} - mock_app._state = BackgroundTaskState - async for update in rx.app.process( - mock_app, - Event( - token=token, - name=f"{BackgroundTaskState.get_name()}.background_task_reset", - router_data=router_data, - payload={}, - ), - sid="", - headers={}, - client_ip="", - ): - # background task returns empty update immediately - assert update == StateUpdate() - - # Explicit wait for background tasks - for task in tuple(mock_app._background_tasks): - await task - assert not mock_app._background_tasks - - if environment.REFLEX_OPLOCK_ENABLED.get(): - await mock_app.state_manager.close() + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + token=token, + name=f"{BackgroundTaskState.get_full_name()}.background_task_reset", + payload={}, + ), + ) - background_task_state = await mock_app.state_manager.get_state( + background_task_state = await state_manager.get_state( BaseStateToken(ident=token, cls=BackgroundTaskState) ) assert isinstance(background_task_state, BackgroundTaskState) @@ -2786,6 +2723,26 @@ def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable) assert not isinstance(var_copy, MutableProxy) +@pytest.fixture +def restore_all_base_state_classes(monkeypatch: pytest.MonkeyPatch): + """Fixture to restore BaseState subclasses after test. + + Args: + monkeypatch: Pytest monkeypatch object. + """ + from reflex.ievent import registry + from reflex.state import all_base_state_classes + + monkeypatch.setattr( + "reflex.state.all_base_state_classes", all_base_state_classes.copy() + ) + monkeypatch.setattr( + "reflex.ievent.registry.REGISTERED_HANDLERS", + registry.REGISTERED_HANDLERS.copy(), + ) + + +@pytest.mark.usefixtures("restore_all_base_state_classes") def test_duplicate_substate_class(mocker: MockerFixture): # Neuter pytest escape hatch, because we want to test duplicate detection. mocker.patch("reflex.state.is_testing_env", return_value=False) @@ -3047,7 +3004,14 @@ async def test_handler(self): ], ) async def test_preprocess( - app_module_mock, token, test_state, expected, mocker: MockerFixture + app_module_mock, + token, + test_state, + expected, + mocker: MockerFixture, + mock_root_event_context: EventContext, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, ): """Test that a state hydrate event is processed correctly. @@ -3057,12 +3021,16 @@ async def test_preprocess( test_state: State to process event. expected: Expected delta. mocker: pytest mock object. + mock_root_event_context: The mock root event context. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. """ OnLoadInternalState._app_ref = None mocker.patch( "reflex.state.State.class_subclasses", {test_state, OnLoadInternalState} ) app = app_module_mock.app = App(_state=State) + app._state_manager = mock_root_event_context.state_manager def index(): return "hello" @@ -3070,42 +3038,66 @@ def index(): app.add_page(index, on_load=test_state.test_handler) app._compile_page("index") - async with app.state_manager.modify_state( + async with mock_root_event_context.state_manager.modify_state( BaseStateToken(ident=token, cls=State) ) as state: state.router_data = {"simulate": "hydrate"} - updates = [] - async for update in rx.app.process( - app=app, - event=Event( - token=token, - name=f"{state.get_name()}.{CompileVars.ON_LOAD_INTERNAL}", - router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}}, - ), - sid="sid", - headers={}, - client_ip="", - ): - assert isinstance(update, StateUpdate) - updates.append(update) - assert len(updates) == 1 - assert updates[0].delta[State.get_name()].pop("router" + FIELD_MARKER) is not None - assert updates[0].delta == exp_is_hydrated(state, False) - - events = updates[0].events - assert len(events) == 2 - async for update in state._process(events[0]): - assert update.delta == {test_state.get_full_name(): {"num" + FIELD_MARKER: 1}} - async for update in state._process(events[1]): - assert update.delta == exp_is_hydrated(state) - - await app.state_manager.close() + on_load_internal_name = format.format_event_handler( + OnLoadInternalState.on_load_internal # pyright: ignore[reportArgumentType] + ) + + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + token=token, + name=on_load_internal_name, + router_data={ + RouteVar.PATH: "/", + RouteVar.ORIGIN: "/", + RouteVar.QUERY: {}, + }, + ), + ) + await processor.join() + + # The processor chains all events: on_load_internal sets is_hydrated=False, + # then the on_load handler runs, then set_is_hydrated(True) runs. + # First delta: router + is_hydrated=False + assert len(emitted_deltas) >= 2 + first_delta = emitted_deltas[0][1] + assert first_delta[State.get_full_name()].pop("router" + FIELD_MARKER) is not None + assert first_delta == exp_is_hydrated(state, False) + + # Find the delta containing the test handler's state change + handler_deltas = [ + d + for _, d in emitted_deltas + if test_state.get_full_name() in d + and "num" + FIELD_MARKER in d[test_state.get_full_name()] + ] + assert len(handler_deltas) >= 1 + assert handler_deltas[0][test_state.get_full_name()]["num" + FIELD_MARKER] == 1 + + # Find the delta that sets is_hydrated back to True + hydrated_deltas = [ + d + for _, d in emitted_deltas + if State.get_full_name() in d + and d[State.get_full_name()].get(CompileVars.IS_HYDRATED + FIELD_MARKER) is True + ] + assert len(hydrated_deltas) == 1 @pytest.mark.asyncio async def test_preprocess_multiple_load_events( - app_module_mock, token, mocker: MockerFixture + app_module_mock, + token, + mocker: MockerFixture, + mock_root_event_context: EventContext, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, ): """Test that a state hydrate event for multiple on-load events is processed correctly. @@ -3113,51 +3105,71 @@ async def test_preprocess_multiple_load_events( app_module_mock: The app module that will be returned by get_app(). token: A token. mocker: pytest mock object. + mock_root_event_context: The mock root event context. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. """ OnLoadInternalState._app_ref = None mocker.patch( "reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState} ) app = app_module_mock.app = App(_state=State) + app._state_manager = mock_root_event_context.state_manager def index(): return "hello" app.add_page(index, on_load=[OnLoadState.test_handler, OnLoadState.test_handler]) app._compile_page("index") - async with app.state_manager.modify_state( + async with mock_root_event_context.state_manager.modify_state( BaseStateToken(ident=token, cls=State) ) as state: state.router_data = {"simulate": "hydrate"} - updates = [] - async for update in rx.app.process( - app=app, - event=Event( - token=token, - name=f"{state.get_full_name()}.{CompileVars.ON_LOAD_INTERNAL}", - router_data={RouteVar.PATH: "/", RouteVar.ORIGIN: "/", RouteVar.QUERY: {}}, - ), - sid="sid", - headers={}, - client_ip="", - ): - assert isinstance(update, StateUpdate) - updates.append(update) - assert len(updates) == 1 - assert updates[0].delta[State.get_name()].pop("router" + FIELD_MARKER) is not None - assert updates[0].delta == exp_is_hydrated(state, False) - - events = updates[0].events - assert len(events) == 3 - async for update in state._process(events[0]): - assert update.delta == {OnLoadState.get_full_name(): {"num" + FIELD_MARKER: 1}} - async for update in state._process(events[1]): - assert update.delta == {OnLoadState.get_full_name(): {"num" + FIELD_MARKER: 2}} - async for update in state._process(events[2]): - assert update.delta == exp_is_hydrated(state) - - await app.state_manager.close() + on_load_internal_name = format.format_event_handler( + OnLoadInternalState.on_load_internal # pyright: ignore[reportArgumentType] + ) + + async with mock_base_state_event_processor as processor: + await processor.enqueue( + token, + Event( + token=token, + name=on_load_internal_name, + router_data={ + RouteVar.PATH: "/", + RouteVar.ORIGIN: "/", + RouteVar.QUERY: {}, + }, + ), + ) + await processor.join() + + # First delta: router + is_hydrated=False + assert len(emitted_deltas) >= 2 + first_delta = emitted_deltas[0][1] + assert first_delta[State.get_full_name()].pop("router" + FIELD_MARKER) is not None + assert first_delta == exp_is_hydrated(state, False) + + # Find deltas containing the test handler's state change (num incremented twice) + handler_deltas = [ + d + for _, d in emitted_deltas + if OnLoadState.get_full_name() in d + and "num" + FIELD_MARKER in d[OnLoadState.get_full_name()] + ] + assert len(handler_deltas) == 2 + assert handler_deltas[0][OnLoadState.get_full_name()]["num" + FIELD_MARKER] == 1 + assert handler_deltas[1][OnLoadState.get_full_name()]["num" + FIELD_MARKER] == 2 + + # Find the delta that sets is_hydrated back to True + hydrated_deltas = [ + d + for _, d in emitted_deltas + if State.get_full_name() in d + and d[State.get_full_name()].get(CompileVars.IS_HYDRATED + FIELD_MARKER) is True + ] + assert len(hydrated_deltas) == 1 @pytest.mark.asyncio @@ -3433,31 +3445,38 @@ def foo(self) -> str: @pytest.mark.asyncio -async def test_setvar(mock_app: rx.App, token: str): +async def test_setvar( + state_manager: StateManager, + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, +): """Test that setvar works correctly. Args: - mock_app: An app that will be returned by `get_app()` + state_manager: A state manager instance. token: A token. + mock_base_state_event_processor: The event processor. """ - state = await mock_app.state_manager.get_state( - BaseStateToken(ident=token, cls=TestState) - ) - assert isinstance(state, TestState) - # Set Var in same state (with Var type casting) - for event in rx.event.fix_events( - [TestState.setvar("num1", 42), TestState.setvar("num2", "4.2")], token - ): - async for update in state._process(event): - print(update) + events = Event.from_event_type([ + TestState.setvar("num1", 42), + TestState.setvar("num2", "4.2"), + ]) + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, *events) + + state = await state_manager.get_state(BaseStateToken(ident=token, cls=TestState)) + assert isinstance(state, TestState) assert state.num1 == 42 assert math.isclose(state.num2, 4.2) # Set Var in parent state - for event in rx.event.fix_events([GrandchildState.setvar("array", [43])], token): - async for update in state._process(event): - print(update) + events = Event.from_event_type([GrandchildState.setvar("array", [43])]) + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, *events) + + state = await state_manager.get_state(BaseStateToken(ident=token, cls=TestState)) + assert isinstance(state, TestState) assert state.array == [43] # Cannot setvar for non-existent var @@ -4152,7 +4171,6 @@ def py_unresolvable(self, u: Unresolvable): # noqa: D102, F821 # pyright: ignor @pytest.mark.asyncio -@pytest.mark.usefixtures("mock_app_simple") @pytest.mark.parametrize( ("handler", "payload"), [ @@ -4172,18 +4190,33 @@ def py_unresolvable(self, u: Unresolvable): # noqa: D102, F821 # pyright: ignor (UpcastState.py_unresolvable, {"u": ["foo"]}), ], ) -async def test_upcast_event_handler_arg(handler, payload): +async def test_upcast_event_handler_arg( + handler, + payload, + token: str, + mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list, +): """Test that upcast event handler args work correctly. Args: handler: The handler to test. payload: The payload to test. + token: A token. + mock_base_state_event_processor: The event processor. + emitted_deltas: List to capture emitted deltas. """ - state = UpcastState() - async for update in state._process_event(handler, state, payload): - assert update.delta == { - UpcastState.get_full_name(): {"passed" + FIELD_MARKER: True} - } + event = Event( + token=token, + name=format.format_event_handler(handler), + payload=payload, + ) + async with mock_base_state_event_processor as processor: + await processor.enqueue(token, event) + assert len(emitted_deltas) == 1 + assert emitted_deltas[0][1] == { + UpcastState.get_full_name(): {"passed" + FIELD_MARKER: True} + } @pytest.mark.asyncio @@ -4423,9 +4456,12 @@ class MutableProxyState(BaseState): @pytest.mark.asyncio -async def test_rebind_mutable_proxy(mock_app: rx.App, token: str) -> None: +async def test_rebind_mutable_proxy( + mock_app: rx.App, token: str, attached_mock_event_context: EventContext +) -> None: """Test that previously bound MutableProxy instances can be rebound correctly.""" mock_app._state = MutableProxyState + async with mock_app.state_manager.modify_state( BaseStateToken(ident=token, cls=MutableProxyState) ) as state: From ea93b5bf72424ce7b473d0b7550072df6fd4dcda Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 25 Mar 2026 22:26:57 -0700 Subject: [PATCH 12/81] ENG-9198: implement ContextVar-based registry for BaseState and EventHandler This allows better control over which states and events are part of a given app and avoiding true global variables makes cleanup and testing much simpler. --- reflex/_internal/__init__.py | 1 + reflex/_internal/context/__init__.py | 1 + reflex/_internal/context/base.py | 54 ++++++++++++ reflex/_internal/registry.py | 81 ++++++++++++++++++ reflex/components/component.py | 7 +- reflex/event.py | 95 +++++++++++----------- reflex/ievent/processor/event_processor.py | 6 +- reflex/ievent/registry.py | 39 --------- reflex/state.py | 41 ++++------ reflex/utils/format.py | 24 ++---- tests/units/conftest.py | 16 ++++ tests/units/test_event.py | 9 +- tests/units/test_state.py | 21 +---- 13 files changed, 236 insertions(+), 159 deletions(-) create mode 100644 reflex/_internal/__init__.py create mode 100644 reflex/_internal/context/__init__.py create mode 100644 reflex/_internal/context/base.py create mode 100644 reflex/_internal/registry.py delete mode 100644 reflex/ievent/registry.py diff --git a/reflex/_internal/__init__.py b/reflex/_internal/__init__.py new file mode 100644 index 00000000000..af1fcdf80fb --- /dev/null +++ b/reflex/_internal/__init__.py @@ -0,0 +1 @@ +"""Reflex internals: subject to change 🐉.""" diff --git a/reflex/_internal/context/__init__.py b/reflex/_internal/context/__init__.py new file mode 100644 index 00000000000..8c279f0bb08 --- /dev/null +++ b/reflex/_internal/context/__init__.py @@ -0,0 +1 @@ +"""Internal ContextVar and registration helpers for reflex.""" diff --git a/reflex/_internal/context/base.py b/reflex/_internal/context/base.py new file mode 100644 index 00000000000..bce7ebba0f3 --- /dev/null +++ b/reflex/_internal/context/base.py @@ -0,0 +1,54 @@ +import dataclasses +from contextvars import ContextVar, Token +from typing import ClassVar, Self + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class BaseContext: + """Base context class that acts as an async context manager to set the context var.""" + + _context_var: ClassVar[ContextVar[Self]] + _attached_context_token: ClassVar[dict[Self, Token[Self]]] + + @classmethod + def __init_subclass__(cls, **kwargs): + """Initialize the context variable for the subclass.""" + super().__init_subclass__(**kwargs) + cls._context_var = ContextVar(cls.__name__) + cls._attached_context_token = {} + + @classmethod + def get(cls) -> Self: + """Get the context from the context variable. + + Returns: + The context instance. + """ + return cls._context_var.get() + + def __enter__(self) -> Self: + """Enter the context. + + Returns: + This context instance. + """ + if self._attached_context_token.get(self) is not None: + msg = "Context is already attached, cannot enter context manager." + raise RuntimeError(msg) + self._attached_context_token[self] = self._context_var.set(self) + return self + + def __exit__(self, *exc_info): + """Exit the context.""" + if (token := self._attached_context_token.pop(self)) is not None: + self._context_var.reset(token) + + def ensure_context_attached(self): + """Ensure that the context is attached to the current context variable. + + Raises: + RuntimeError: If the context is not attached. + """ + if self._attached_context_token.get(self) is None: + msg = f"{type(self).__name__} must be entered before calling this method." + raise RuntimeError(msg) diff --git a/reflex/_internal/registry.py b/reflex/_internal/registry.py new file mode 100644 index 00000000000..f3162fbdf20 --- /dev/null +++ b/reflex/_internal/registry.py @@ -0,0 +1,81 @@ +"""A contextual registry for state and event handlers.""" + +import dataclasses +from typing import TYPE_CHECKING, Self + +from reflex._internal.context.base import BaseContext + +if TYPE_CHECKING: + from reflex.event import EventHandler + from reflex.state import BaseState + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class RegisteredEventHandler: + """A registered event handler, which includes the handler and its full name.""" + + handler: EventHandler + states: tuple[type[BaseState], ...] + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True, eq=False) +class RegistrationContext(BaseContext): + """Context for registering event handlers and states.""" + + event_handlers: dict[str, RegisteredEventHandler] = dataclasses.field( + default_factory=dict, + repr=False, + ) + base_states: dict[str, type[BaseState]] = dataclasses.field( + default_factory=dict, + repr=False, + ) + + @classmethod + def ensure_context(cls) -> Self: + """Ensure the context is attached, or create a new instance and attach it. + + Returns: + The registration context instance. + """ + try: + return cls.get() + except LookupError: + # If the context is not attached, create a new instance and attach it. + ctx = cls() + cls._context_var.set(ctx) + return ctx + + @classmethod + def register_base_state(cls, state_cls: type[BaseState]) -> type[BaseState]: + """Register a base state class with its full name. + + Args: + state_cls: The base state class to register. + + Returns: + The registered base state class. + """ + cls.ensure_context().base_states[state_cls.get_full_name()] = state_cls + return state_cls + + @classmethod + def register_event_handler( + cls, handler: EventHandler, states: tuple[type[BaseState], ...] = () + ) -> EventHandler: + """Register an event handler with its full name and associated states. + + Args: + handler: The event handler to register. + states: The states associated with the event handler. + + Returns: + The registered event handler. + """ + from reflex.utils.format import format_event_handler + + full_name = format_event_handler(handler) + cls.ensure_context().event_handlers[full_name] = RegisteredEventHandler( + handler=handler, states=states + ) + return handler diff --git a/reflex/components/component.py b/reflex/components/component.py index 02773061b47..fc8e6b2c921 100644 --- a/reflex/components/component.py +++ b/reflex/components/component.py @@ -37,7 +37,7 @@ PageNames, ) from reflex.constants.compiler import SpecialAttributes -from reflex.constants.state import CAMEL_CASE_MEMO_MARKER, FRONTEND_EVENT_STATE +from reflex.constants.state import CAMEL_CASE_MEMO_MARKER from reflex.event import ( EventCallback, EventChain, @@ -1421,10 +1421,7 @@ def _event_trigger_values_use_state(self) -> bool: if isinstance(event, EventCallback): continue if isinstance(event, EventSpec): - if ( - event.handler.state_full_name - and event.handler.state_full_name != FRONTEND_EVENT_STATE - ): + if event.handler.state is not None: return True else: if event._var_state: diff --git a/reflex/event.py b/reflex/event.py index e8d5ef5b8fe..4e003ed0cfe 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -27,7 +27,6 @@ from reflex import constants from reflex.components.field import BaseField from reflex.constants.compiler import CompileVars, Hooks, Imports -from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.utils import format from reflex.utils.decorator import once from reflex.utils.exceptions import ( @@ -57,6 +56,7 @@ if TYPE_CHECKING: from reflex.istate.manager.token import BaseStateToken + from reflex.state import BaseState @dataclasses.dataclass( @@ -79,13 +79,12 @@ class Event: payload: dict[str, Any] = dataclasses.field(default_factory=dict) @property - def state_cls(self) -> "type[BaseState]": + def state_cls(self) -> type[BaseState]: """The state class for the event.""" - from reflex.state import all_base_state_classes + from reflex._internal.registry import RegistrationContext substate_name = self.name.rpartition(".")[0] - - return all_base_state_classes[substate_name] + return RegistrationContext.get().base_states[substate_name] @property def substate_token(self) -> BaseStateToken: @@ -243,9 +242,25 @@ class EventHandler(EventActionsMixin): # The function to call in response to the event. fn: Any = dataclasses.field(default=None) - # The full name of the state class this event handler is attached to. - # Empty string means this event handler is a server side event. - state_full_name: str = dataclasses.field(default="") + # The state this EventHandler is directly attached to, if any. + state: type[BaseState] | None = dataclasses.field(default=None, repr=False) + + def __post_init__(self): + """Register the event handler.""" + from reflex._internal.registry import RegistrationContext + + RegistrationContext.register_event_handler( + self, states=(self.state,) if self.state else () + ) + + @property + def state_full_name(self) -> str: + """Get the full name of the state class this event handler is attached to. + + Returns: + The full name of the state class this event handler is attached to. + """ + return self.state.get_full_name() if self.state else "" def __hash__(self): """Get the hash of the event handler. @@ -253,7 +268,7 @@ def __hash__(self): Returns: The hash of the event handler. """ - return hash((tuple(self.event_actions.items()), self.fn, self.state_full_name)) + return hash((tuple(self.event_actions.items()), self.fn, self.state)) def get_parameters(self) -> Mapping[str, inspect.Parameter]: """Get the parameters of the function. @@ -316,7 +331,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> "EventSpec": from reflex.utils.exceptions import EventHandlerTypeError # Get the function args. - if self.state_full_name: + if self.state is not None: # Skip the `self` arg for state-bound event handlers. fn_args = list(self._parameters)[1:] else: @@ -440,8 +455,10 @@ def add_args(self, *args: Var) -> "EventSpec": """ from reflex.utils.exceptions import EventHandlerTypeError + n_self_args = 1 if self.handler.state is not None else 0 + # Get the remaining unfilled function args. - fn_args = list(self.handler._parameters)[1 + len(self.args) :] + fn_args = list(self.handler._parameters)[n_self_args + len(self.args) :] fn_args = (Var(_js_expr=arg) for arg in fn_args) # Construct the payload. @@ -1028,7 +1045,7 @@ def fn(): fn.__qualname__ = name fn.__signature__ = sig # pyright: ignore [reportFunctionMemberAccess] return EventSpec( - handler=EventHandler(fn=fn, state_full_name=FRONTEND_EVENT_STATE), + handler=EventHandler(fn=fn), args=tuple( ( Var(_js_expr=k), @@ -1077,7 +1094,7 @@ def redirect( """ return server_side( "_redirect", - get_fn_signature(redirect), + inspect.signature(redirect), path=path, external=is_external, popup=popup, @@ -1141,7 +1158,7 @@ def set_focus(ref: str) -> EventSpec: """ return server_side( "_set_focus", - get_fn_signature(set_focus), + inspect.signature(set_focus), ref=LiteralVar.create(format.format_ref(ref)), ) @@ -1157,7 +1174,7 @@ def blur_focus(ref: str) -> EventSpec: """ return server_side( "_blur_focus", - get_fn_signature(blur_focus), + inspect.signature(blur_focus), ref=LiteralVar.create(format.format_ref(ref)), ) @@ -1195,7 +1212,7 @@ def set_value(ref: str, value: Any) -> EventSpec: """ return server_side( "_set_value", - get_fn_signature(set_value), + inspect.signature(set_value), ref=LiteralVar.create(format.format_ref(ref)), value=value, ) @@ -1215,7 +1232,7 @@ def remove_cookie(key: str, options: dict[str, Any] | None = None) -> EventSpec: options["path"] = options.get("path", "/") return server_side( "_remove_cookie", - get_fn_signature(remove_cookie), + inspect.signature(remove_cookie), key=key, options=options, ) @@ -1229,7 +1246,7 @@ def clear_local_storage() -> EventSpec: """ return server_side( "_clear_local_storage", - get_fn_signature(clear_local_storage), + inspect.signature(clear_local_storage), ) @@ -1244,7 +1261,7 @@ def remove_local_storage(key: str) -> EventSpec: """ return server_side( "_remove_local_storage", - get_fn_signature(remove_local_storage), + inspect.signature(remove_local_storage), key=key, ) @@ -1257,7 +1274,7 @@ def clear_session_storage() -> EventSpec: """ return server_side( "_clear_session_storage", - get_fn_signature(clear_session_storage), + inspect.signature(clear_session_storage), ) @@ -1272,7 +1289,7 @@ def remove_session_storage(key: str) -> EventSpec: """ return server_side( "_remove_session_storage", - get_fn_signature(remove_session_storage), + inspect.signature(remove_session_storage), key=key, ) @@ -1369,7 +1386,7 @@ def download( return server_side( "_download", - get_fn_signature(download), + inspect.signature(download), url=url, filename=filename, ) @@ -1410,7 +1427,7 @@ def call_script( return server_side( "_call_script", - get_fn_signature(call_script), + inspect.signature(call_script), javascript_code=javascript_code, **callback_kwargs, ) @@ -1446,7 +1463,7 @@ def call_function( return server_side( "_call_function", - get_fn_signature(call_function), + inspect.signature(call_function), function=javascript_code, **callback_kwargs, ) @@ -1629,13 +1646,14 @@ def call_event_handler( if isinstance(event_callback, EventSpec): parameters = event_callback.handler._parameters + n_self_args = 1 if event_callback.handler.state is not None else 0 check_fn_match_arg_spec( event_callback.handler.fn, parameters, event_spec_args, key, - bool(event_callback.handler.state_full_name) + len(event_callback.args), + n_self_args + len(event_callback.args), event_callback.handler.fn.__qualname__, ) @@ -1651,9 +1669,7 @@ def call_event_handler( _check_event_args_subclass_of_callback( [ arg - for arg in event_callback_spec_args[ - bool(event_callback.handler.state_full_name) : - ] + for arg in event_callback_spec_args[n_self_args:] if arg not in argument_names ], event_spec_return_types, @@ -1665,6 +1681,8 @@ def call_event_handler( # Handle partial application of EventSpec args return event_callback.add_args(*event_spec_args) + n_self_args = 1 if event_callback.state is not None else 0 + parameters = event_callback._parameters check_fn_match_arg_spec( @@ -1672,7 +1690,7 @@ def call_event_handler( parameters, event_spec_args, key, - bool(event_callback.state_full_name), + n_self_args, event_callback.fn.__qualname__, ) @@ -1685,7 +1703,7 @@ def call_event_handler( type_hints_of_provided_callback = {} _check_event_args_subclass_of_callback( - event_callback_spec_args[1:], + event_callback_spec_args[n_self_args:], event_spec_return_types, type_hints_of_provided_callback, event_callback.fn.__qualname__, @@ -1980,22 +1998,6 @@ def fix_events( return out -def get_fn_signature(fn: Callable) -> inspect.Signature: - """Get the signature of a function. - - Args: - fn: The function. - - Returns: - The signature of the function. - """ - signature = inspect.signature(fn) - new_param = inspect.Parameter( - FRONTEND_EVENT_STATE, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Any - ) - return signature.replace(parameters=(new_param, *signature.parameters.values())) - - # These chains can be used for their side effects when no other events are desired. stop_propagation = noop().stop_propagation prevent_default = noop().prevent_default @@ -2549,7 +2551,6 @@ def wrapper( parse_args_spec = staticmethod(parse_args_spec) args_specs_from_fields = staticmethod(args_specs_from_fields) unwrap_var_annotation = staticmethod(unwrap_var_annotation) - get_fn_signature = staticmethod(get_fn_signature) # Event Spec Functions passthrough_event_spec = staticmethod(passthrough_event_spec) diff --git a/reflex/ievent/processor/event_processor.py b/reflex/ievent/processor/event_processor.py index 6143399784d..838ffa1ca74 100644 --- a/reflex/ievent/processor/event_processor.py +++ b/reflex/ievent/processor/event_processor.py @@ -12,9 +12,9 @@ import rich.markup +from reflex._internal.registry import RegisteredEventHandler, RegistrationContext from reflex.app_mixins.middleware import MiddlewareMixin from reflex.ievent.context import EventContext, event_context -from reflex.ievent.registry import REGISTERED_HANDLERS, RegisteredEventHandler from reflex.istate.manager import StateManager from reflex.utils import console @@ -386,7 +386,9 @@ async def _process_queue(self): entry = await queue.get() try: try: - registered_handler = REGISTERED_HANDLERS[entry.event.name] + registered_handler = RegistrationContext.get().event_handlers[ + entry.event.name + ] except KeyError as ke: msg = ( f"No registered handler found for event: {entry.event.name}" diff --git a/reflex/ievent/registry.py b/reflex/ievent/registry.py deleted file mode 100644 index b7de5bfe71c..00000000000 --- a/reflex/ievent/registry.py +++ /dev/null @@ -1,39 +0,0 @@ -"""A registry for all known event handlers.""" - -import dataclasses -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from reflex.event import EventHandler - from reflex.state import BaseState - - -@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) -class RegisteredEventHandler: - """A registered event handler, which includes the handler and its full name.""" - - handler: EventHandler - states: tuple[type[BaseState], ...] - - -REGISTERED_HANDLERS: dict[str, RegisteredEventHandler] = {} - - -def register( - handler: EventHandler, states: tuple[type[BaseState], ...] = () -) -> EventHandler: - """Register an event handler with its full name and associated states. - - Args: - handler: The event handler to register. - states: The states associated with the event handler. - - Returns: - The registered event handler. - """ - from reflex.utils.format import format_event_handler - - REGISTERED_HANDLERS[format_event_handler(handler)] = RegisteredEventHandler( - handler=handler, states=states - ) - return handler diff --git a/reflex/state.py b/reflex/state.py index daae277316a..956e9bb252e 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -173,8 +173,6 @@ def _split_substate_key(substate_key: str) -> tuple[str, str]: class EventHandlerSetVar(EventHandler): """A special event handler to wrap setvar functionality.""" - state_cls: type[BaseState] = dataclasses.field(init=False) - def __init__(self, state_cls: type[BaseState]): """Initialize the EventHandlerSetVar. @@ -183,9 +181,8 @@ def __init__(self, state_cls: type[BaseState]): """ super().__init__( fn=type(self).setvar, - state_full_name=state_cls.get_full_name(), + state=state_cls, ) - object.__setattr__(self, "state_cls", state_cls) def __hash__(self): """Get the hash of the event handler. @@ -197,7 +194,7 @@ def __hash__(self): tuple(self.event_actions.items()), self.fn, self.state_full_name, - self.state_cls, + self.state, )) def setvar(self, var_name: str, value: Any): @@ -229,11 +226,11 @@ def __call__(self, *args: Any) -> EventSpec: from reflex.utils.exceptions import EventHandlerValueError config = get_config() - if config.state_auto_setters is None: + if config.state_auto_setters is None and self.state is not None: console.deprecate( feature_name="state_auto_setters defaulting to True", reason="The default value will be changed to False in a future release. Set state_auto_setters explicitly or define setters explicitly. " - f"Used {self.state_cls.__name__}.setvar without defining it.", + f"Used {self.state.__name__}.setvar without defining it.", deprecation_version="0.8.9", removal_version="0.9.0", dedupe=True, @@ -244,11 +241,11 @@ def __call__(self, *args: Any) -> EventSpec: msg = f"Var name must be passed as a string, got {args[0]!r}" raise EventHandlerValueError(msg) - handler = getattr(self.state_cls, constants.SETTER_PREFIX + args[0], None) + handler = getattr(self.state, constants.SETTER_PREFIX + args[0], None) # Check that the requested Var setter exists on the State at compile time. if handler is None: - msg = f"Variable `{args[0]}` cannot be set on `{self.state_cls.get_full_name()}`" + msg = f"Variable `{args[0]}` cannot be set on `{self.state_full_name}`" raise AttributeError(msg) if inspect.iscoroutinefunction(handler.fn): @@ -319,7 +316,7 @@ def _override_base_method(fn: Callable[PARAMS, RETURN]) -> Callable[PARAMS, RETU return fn -all_base_state_classes: dict[str, type[BaseState]] = {} +all_base_state_classes: dict[str, None] = {} CLASS_VAR_NAMES = frozenset({ "vars", @@ -503,7 +500,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): Raises: StateValueError: If a substate class shadows another. """ - from reflex.ievent.registry import register + from reflex._internal.registry import RegistrationContext from reflex.utils.exceptions import StateValueError super().__init_subclass__(**kwargs) @@ -627,14 +624,15 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): for name, fn in events.items(): handler = cls._create_event_handler(fn) - cls.event_handlers[name] = register(handler, states=(cls,)) + cls.event_handlers[name] = handler setattr(cls, name, handler) # Initialize per-class var dependency tracking. cls._var_dependencies = {} cls._init_var_dependency_dicts() - all_base_state_classes[cls.get_full_name()] = cls + all_base_state_classes[cls.get_full_name()] = None + RegistrationContext.register_base_state(cls) @classmethod def _add_event_handler( @@ -648,10 +646,8 @@ def _add_event_handler( name: The name of the event handler. fn: The function to call when the event is triggered. """ - from reflex.ievent.registry import register - handler = cls._create_event_handler(fn) - cls.event_handlers[name] = register(handler, states=(cls,)) + cls.event_handlers[name] = handler setattr(cls, name, handler) @staticmethod @@ -1158,18 +1154,12 @@ def _create_event_handler( # Check if function has stored event_actions from decorator event_actions = getattr(fn, EVENT_ACTIONS_MARKER, {}) - return event_handler_cls( - fn=fn, state_full_name=cls.get_full_name(), event_actions=event_actions - ) + return event_handler_cls(fn=fn, state=cls, event_actions=event_actions) @classmethod def _create_setvar(cls): """Create the setvar method for the state.""" - from reflex.ievent.registry import register - - cls.setvar = cls.event_handlers["setvar"] = register( - EventHandlerSetVar(state_cls=cls), states=(cls,) - ) + cls.setvar = cls.event_handlers["setvar"] = EventHandlerSetVar(state_cls=cls) @classmethod def _create_setter(cls, name: str, prop: Var): @@ -1180,7 +1170,6 @@ def _create_setter(cls, name: str, prop: Var): prop: The var to create a setter for. """ from reflex.config import get_config - from reflex.ievent.registry import register config = get_config() create_event_handler_kwargs = {} @@ -1208,7 +1197,7 @@ def __call__(self, *args, **kwargs): event_handler = cls._create_event_handler( prop._get_setter(name), **create_event_handler_kwargs ) - cls.event_handlers[setter_name] = register(event_handler, states=(cls,)) + cls.event_handlers[setter_name] = event_handler setattr(cls, setter_name, event_handler) @classmethod diff --git a/reflex/utils/format.py b/reflex/utils/format.py index b283ea31443..ff7b710c57e 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any from reflex import constants -from reflex.constants.state import FRONTEND_EVENT_STATE from reflex.utils import exceptions if TYPE_CHECKING: @@ -448,25 +447,20 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: Returns: The state and function name. """ - # Get the class that defines the event handler. - parts = handler.fn.__qualname__.split(".") + # Get the name of the event function. + name = handler.fn.__qualname__ # Get the state full name - state_full_name = handler.state_full_name + state_full_name = handler.state.get_full_name() if handler.state else "" - # If there's no enclosing class, just return the function name. - if not state_full_name: - return ("", parts[-1]) + # If there's no enclosing state, just return the full name. + if handler.state is None: + return ("", name) - # Get the function name - name = parts[-1] + # Get the event name inside the state. + func_name = name.rpartition(".")[2] - from reflex.state import BaseState - - if state_full_name == FRONTEND_EVENT_STATE and name not in BaseState.__dict__: - return ("", to_snake_case(handler.fn.__qualname__)) - - return (state_full_name, name) + return (state_full_name, func_name) def format_event_handler(handler: EventHandler) -> str: diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 63ece7832ff..6c5a0c98da9 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -4,12 +4,14 @@ import traceback import uuid from collections.abc import AsyncGenerator, Generator, Mapping +from copy import deepcopy from typing import Any from unittest import mock import pytest import pytest_asyncio +from reflex._internal.registry import RegistrationContext from reflex.app import App from reflex.event import Event, EventSpec from reflex.ievent.context import EventContext, event_context @@ -433,3 +435,17 @@ def attached_mock_event_context( reset_token = event_context.set(ctx) yield ctx event_context.reset(reset_token) + + +@pytest.fixture +def forked_registration_context() -> Generator[RegistrationContext, None, None]: + """Fork the registration context and attach it. + + Sets the forked context as the current registration context for the duration + of the test, then resets it afterwards. + + Yields: + The forked RegistrationContext. + """ + with deepcopy(RegistrationContext.get()) as ctx: + yield ctx diff --git a/tests/units/test_event.py b/tests/units/test_event.py index ff45395684a..9a8ad1dde49 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -103,7 +103,7 @@ def fn_with_args(arg1, arg2): def test_call_event_handler_partial(): """Calling an EventHandler with incomplete args returns an EventSpec that can be extended.""" - def fn_with_args(_, arg1, arg2): + def fn_with_args(arg1, arg2): pass fn_with_args.__qualname__ = "fn_with_args" @@ -111,7 +111,7 @@ def fn_with_args(_, arg1, arg2): def spec(a2: Var[str]) -> list[Var[str]]: return [a2] - handler = EventHandler(fn=fn_with_args, state_full_name="BigState") + handler = EventHandler(fn=fn_with_args) event_spec = handler(make_var("first")) event_spec2 = call_event_handler(event_spec, spec) @@ -120,8 +120,7 @@ def spec(a2: Var[str]) -> list[Var[str]]: assert event_spec.args[0][0].equals(Var(_js_expr="arg1")) assert event_spec.args[0][1].equals(Var(_js_expr="first")) assert ( - format.format_event(event_spec) - == 'ReflexEvent("BigState.fn_with_args", {arg1:first})' + format.format_event(event_spec) == 'ReflexEvent("fn_with_args", {arg1:first})' ) assert event_spec2 is not event_spec @@ -133,7 +132,7 @@ def spec(a2: Var[str]) -> list[Var[str]]: assert event_spec2.args[1][1].equals(Var(_js_expr="_a2", _var_type=str)) assert ( format.format_event(event_spec2) - == 'ReflexEvent("BigState.fn_with_args", {arg1:first,arg2:_a2})' + == 'ReflexEvent("fn_with_args", {arg1:first,arg2:_a2})' ) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index fbe785d8885..d967ba09916 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2723,26 +2723,7 @@ def test_mutable_copy_vars(mutable_state: MutableTestState, copy_func: Callable) assert not isinstance(var_copy, MutableProxy) -@pytest.fixture -def restore_all_base_state_classes(monkeypatch: pytest.MonkeyPatch): - """Fixture to restore BaseState subclasses after test. - - Args: - monkeypatch: Pytest monkeypatch object. - """ - from reflex.ievent import registry - from reflex.state import all_base_state_classes - - monkeypatch.setattr( - "reflex.state.all_base_state_classes", all_base_state_classes.copy() - ) - monkeypatch.setattr( - "reflex.ievent.registry.REGISTERED_HANDLERS", - registry.REGISTERED_HANDLERS.copy(), - ) - - -@pytest.mark.usefixtures("restore_all_base_state_classes") +@pytest.mark.usefixtures("forked_registration_context") def test_duplicate_substate_class(mocker: MockerFixture): # Neuter pytest escape hatch, because we want to test duplicate detection. mocker.patch("reflex.state.is_testing_env", return_value=False) From 1cfd974f1f3d8c5635cc6e68c0a12c5b8bf505f0 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 25 Mar 2026 22:47:34 -0700 Subject: [PATCH 13/81] Remove `token` field from Event --- reflex/.templates/web/utils/state.js | 1 - reflex/app.py | 19 ++++----- reflex/event.py | 41 ++++++++----------- .../ievent/processor/base_state_processor.py | 6 +-- reflex/istate/shared.py | 1 - reflex/state.py | 1 - tests/units/middleware/conftest.py | 1 - tests/units/test_app.py | 5 --- tests/units/test_event.py | 6 +-- tests/units/test_state.py | 11 ----- 10 files changed, 29 insertions(+), 63 deletions(-) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 9e937ed62cd..c5ef0a68c70 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -379,7 +379,6 @@ export const applyEvent = async (event, socket, navigate, params) => { } // Update token and router data (if missing). - event.token = getToken(); if ( event.router_data === undefined || Object.keys(event.router_data).length === 0 diff --git a/reflex/app.py b/reflex/app.py index 4d5afb40482..d082939c2c6 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -2113,6 +2113,13 @@ async def on_event(self, sid: str, data: Any): Raises: EventDeserializationError: If the event data is not a dictionary. """ + # Determine the token for this SID + if (token := self.sid_to_token.get(sid)) is None: + console.warn( + f"Received event from session {sid} with no associated token. This may indicate a bug. Event data: {data}" + ) + return + fields = data if isinstance(fields, str): @@ -2137,14 +2144,6 @@ async def on_event(self, sid: str, data: Any): msg = f"Failed to deserialize event data: {fields}." raise exceptions.EventDeserializationError(msg) from ex - # Correct the token if it doesn't match what we expect for this SID - expected_token = self.sid_to_token.get(sid) - if expected_token and event.token != expected_token: - # Create new event with corrected token since Event is frozen - from dataclasses import replace - - event = replace(event, token=expected_token) - # Get the event environment. if self.app.sio is None: msg = "Socket.IO is not initialized." @@ -2180,7 +2179,7 @@ async def on_event(self, sid: str, data: Any): router_data = event.router_data router_data.update({ constants.RouteVar.QUERY: format.format_query_params(event.router_data), - constants.RouteVar.CLIENT_TOKEN: event.token, + constants.RouteVar.CLIENT_TOKEN: token, constants.RouteVar.SESSION_ID: sid, constants.RouteVar.HEADERS: headers, constants.RouteVar.CLIENT_IP: client_ip, @@ -2190,7 +2189,7 @@ async def on_event(self, sid: str, data: Any): if (path := router_data.get(constants.RouteVar.PATH)) else "404" ).removeprefix("/") - await self.app.event_processor.enqueue(event.token, event) + await self.app.event_processor.enqueue(token, event) async def on_ping(self, sid: str): """Event for testing the API endpoint. diff --git a/reflex/event.py b/reflex/event.py index 4e003ed0cfe..557fbeb29d0 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -55,7 +55,6 @@ from reflex.vars.object import ObjectVar if TYPE_CHECKING: - from reflex.istate.manager.token import BaseStateToken from reflex.state import BaseState @@ -66,13 +65,10 @@ class Event: """An event that describes any state change in the app.""" - # The token to specify the client that the event is for (TODO: remove). - token: str - # The event name. name: str - # The routing data where event occurred (TODO: remove). + # The routing data where event occurred. router_data: dict[str, Any] = dataclasses.field(default_factory=dict) # The event payload. @@ -86,24 +82,18 @@ def state_cls(self) -> type[BaseState]: substate_name = self.name.rpartition(".")[0] return RegistrationContext.get().base_states[substate_name] - @property - def substate_token(self) -> BaseStateToken: - """Get the substate token for the event. - - Returns: - The substate token. - """ - msg = "Event.substate_token should no longer be used." - raise NotImplementedError(msg) - @classmethod def from_event_type( - cls, events: "IndividualEventType | list[IndividualEventType] | None" + cls, + events: "IndividualEventType | list[IndividualEventType] | None", + *, + router_data: dict[str, Any] | None = None, ) -> "list[Event]": """Create a list of Events from event-like objects. Args: events: The event-like objects to create Events from. + router_data: The routing data for the events. Returns: A list of Events created from the event-like objects. @@ -124,7 +114,12 @@ def from_event_type( e = e() if isinstance(e, Event): # If the event is already an event, append it to the list. - out.append(e) + if router_data is not None and e.router_data != router_data: + out.append( + dataclasses.replace(e, router_data=e.router_data | router_data) + ) + else: + out.append(e) continue # Otherwise, create an event from the event spec. if isinstance(e, EventHandler): @@ -139,10 +134,9 @@ def from_event_type( # Create an event and append it to the list. out.append( Event( - token="none", name=name, payload=payload, - router_data={}, + router_data=router_data or {}, ) ) @@ -1934,21 +1928,19 @@ def get_handler_args( def fix_events( events: list[EventSpec | EventHandler] | None, - token: str, router_data: dict[str, Any] | None = None, ) -> list[Event]: """Fix a list of events returned by an event handler. Args: events: The events to fix. - token: The user token. router_data: The optional router data to set in the event. - Raises: - ValueError: If the event type is not what was expected. - Returns: The fixed events. + + Raises: + ValueError: If the event type is not what was expected. """ # If the event handler returns nothing, return an empty list. if events is None: @@ -1988,7 +1980,6 @@ def fix_events( # Create an event and append it to the list. out.append( Event( - token=token, name=name, payload=payload, router_data=event_router_data, diff --git a/reflex/ievent/processor/base_state_processor.py b/reflex/ievent/processor/base_state_processor.py index 76ca10e917a..1ecda7cb610 100644 --- a/reflex/ievent/processor/base_state_processor.py +++ b/reflex/ievent/processor/base_state_processor.py @@ -175,15 +175,13 @@ async def chain_updates( handler_name: The name of the handler that yielded the events, used for error messages. root_state: The root state of the app, no delta emitted if omitted. """ - from reflex.event import fix_events + from reflex.event import Event ctx = event_context.get() - token = ctx.token # Convert valid EventHandler and EventSpec into Event - if fixed_events := fix_events( + if fixed_events := Event.from_event_type( _check_valid_yield(events, handler_name=handler_name), - token, router_data=root_state.router_data if root_state else None, ): # Frontend events. diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py index 3fa2f8ffa37..d4ca0f57921 100644 --- a/reflex/istate/shared.py +++ b/reflex/istate/shared.py @@ -166,7 +166,6 @@ def _rehydrate(self): """ return [ Event( - token=self.router.session.client_token, name=get_hydrate_event(self._get_root_state()), ), State.set_is_hydrated(True), diff --git a/reflex/state.py b/reflex/state.py index 956e9bb252e..78902ef99b7 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2389,7 +2389,6 @@ def on_load_internal(self) -> list[Event | EventSpec | event.EventCallback] | No return [ *fix_events( cast(list[EventSpec | EventHandler], load_events), - self.router.session.client_token, router_data=self.router_data, ), State.set_is_hydrated(True), diff --git a/tests/units/middleware/conftest.py b/tests/units/middleware/conftest.py index d786db6521d..6260de29606 100644 --- a/tests/units/middleware/conftest.py +++ b/tests/units/middleware/conftest.py @@ -6,7 +6,6 @@ def create_event(name): return Event( - token="", name=name, router_data={ "pathname": "/", diff --git a/tests/units/test_app.py b/tests/units/test_app.py index ec0ac7c9fe6..68f030451c0 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -523,7 +523,6 @@ async def test_dynamic_var_event( await processor.enqueue( token, Event( - token=token, name=f"{test_state.get_name()}.set_int_val", payload={"value": 50}, ), @@ -723,7 +722,6 @@ async def test_list_mutation_detection__plain_list( await processor.enqueue( token, Event( - token="", name=f"{list_mutation_state.get_name()}.{event_name}", payload={}, ), @@ -921,7 +919,6 @@ async def test_dict_mutation_detection__plain_list( await processor.enqueue( token, Event( - token="", name=f"{dict_mutation_state.get_name()}.{event_name}", payload={}, ), @@ -1472,7 +1469,6 @@ async def test_dynamic_route_var_route_change_completed_on_load( def _event(name, val, **kwargs): return Event( - token=kwargs.pop("token", token), name=name, router_data=kwargs.pop( "router_data", @@ -1606,7 +1602,6 @@ async def test_process_events( emitted_deltas: List to store emitted deltas. """ event = Event( - token=token, name=f"{GenState.get_name()}.go", payload={"c": 5}, router_data={}, diff --git a/tests/units/test_event.py b/tests/units/test_event.py index 9a8ad1dde49..85b5bb8a571 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -34,8 +34,7 @@ def make_var(value) -> Var: def test_create_event(): """Test creating an event.""" - event = Event(token="token", name="state.do_thing", payload={"arg": "value"}) - assert event.token == "token" + event = Event(name="state.do_thing", payload={"arg": "value"}) assert event.name == "state.do_thing" assert event.payload == {"arg": "value"} @@ -159,9 +158,8 @@ def fn_with_args(arg1, arg2): handler = EventHandler(fn=fn_with_args) event_spec = handler(arg1, arg2) - event = fix_events([event_spec], token="foo")[0] + event = fix_events([event_spec])[0] assert event.name == fn_with_args.__qualname__ - assert event.token == "foo" assert event.payload == {"arg1": arg1, "arg2": arg2} diff --git a/tests/units/test_state.py b/tests/units/test_state.py index d967ba09916..b5e87277e16 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -822,7 +822,6 @@ async def test_process_event_simple( emitted_deltas: List to capture emitted deltas. """ event = Event( - token=token, name=f"{TestState.get_full_name()}.set_num1", payload={"value": 69}, ) @@ -858,7 +857,6 @@ async def test_process_event_substate( """ # Events should bubble down to the substate. event = Event( - token=token, name=f"{ChildState.get_full_name()}.change_both", payload={"value": "hi", "count": 12}, ) @@ -880,7 +878,6 @@ async def test_process_event_substate( # Test with the grandchild state. event = Event( - token=token, name=f"{GrandchildState.get_full_name()}.set_value2", payload={"value": "new"}, ) @@ -911,7 +908,6 @@ async def test_process_event_generator( emitted_deltas: List to capture emitted deltas. """ event = Event( - token=token, name=f"{GenState.get_full_name()}.go", payload={"c": 5}, ) @@ -1696,7 +1692,6 @@ def capture_exception(ex: Exception) -> None: mock_base_state_event_processor.backend_exception_handler = capture_exception event = Event( - token=token, name=f"{StateWithInvalidYield.get_full_name()}.invalid_handler", payload={}, ) @@ -2376,7 +2371,6 @@ async def test_background_task_no_block( await processor.enqueue( token, Event( - token=token, name=f"{BackgroundTaskState.get_full_name()}.background_task", payload={}, ), @@ -2388,7 +2382,6 @@ async def test_background_task_no_block( await processor.enqueue( token, Event( - token=token, name=f"{BackgroundTaskState.get_full_name()}.other", payload={}, ), @@ -2429,7 +2422,6 @@ async def test_background_task_reset( await processor.enqueue( token, Event( - token=token, name=f"{BackgroundTaskState.get_full_name()}.background_task_reset", payload={}, ), @@ -3032,7 +3024,6 @@ def index(): await processor.enqueue( token, Event( - token=token, name=on_load_internal_name, router_data={ RouteVar.PATH: "/", @@ -3115,7 +3106,6 @@ def index(): await processor.enqueue( token, Event( - token=token, name=on_load_internal_name, router_data={ RouteVar.PATH: "/", @@ -4188,7 +4178,6 @@ async def test_upcast_event_handler_arg( emitted_deltas: List to capture emitted deltas. """ event = Event( - token=token, name=format.format_event_handler(handler), payload=payload, ) From 6be4f7253117ff359d786c3fb7bfd21d6d02d3a2 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 25 Mar 2026 23:19:15 -0700 Subject: [PATCH 14/81] Clean up frontend Event and StateUpdate remove null/default fields when serializing Event from the frontend and StateUpdate from the backend. --- reflex/.templates/web/utils/state.js | 66 ++++++++++--------- reflex/app.py | 5 +- .../ievent/processor/base_state_processor.py | 1 - reflex/ievent/processor/event_processor.py | 4 +- reflex/state.py | 21 +++--- reflex/utils/format.py | 2 +- tests/units/test_state.py | 1 - 7 files changed, 51 insertions(+), 49 deletions(-) diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index c5ef0a68c70..52edd969c3c 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -40,8 +40,6 @@ const cookies = new Cookies(); // Dictionary holding component references. export const refs = {}; -// Flag ensures that only one event is processing on the backend concurrently. -let event_processing = false; // Array holding pending events to be processed. const event_queue = []; @@ -386,15 +384,18 @@ export const applyEvent = async (event, socket, navigate, params) => { // Since we don't have router directly, we need to get info from our hooks event.router_data = { pathname: window.location.pathname, - query: { - ...Object.fromEntries(new URLSearchParams(window.location.search)), - ...params(), - }, asPath: window.location.pathname + window.location.search + window.location.hash, }; + const query = { + ...Object.fromEntries(new URLSearchParams(window.location.search)), + ...params(), + }; + if (query && Object.keys(query).length > 0) { + event.router_data.query = query; + } } // Send the event to the server. @@ -496,13 +497,10 @@ export const processEvent = async (socket, navigate, params) => { } // Only proceed if we're not already processing an event. - if (event_queue.length === 0 || event_processing) { + if (event_queue.length === 0) { return; } - // Set processing to true to block other events from being processed. - event_processing = true; - // Apply the next event in the queue. const event = event_queue.shift(); @@ -513,9 +511,7 @@ export const processEvent = async (socket, navigate, params) => { } else { eventSent = await applyEvent(event, socket, navigate, params); } - // If no event was sent, set processing to false. if (!eventSent) { - event_processing = false; // recursively call processEvent to drain the queue, since there is // no state update to trigger the useEffect event loop. await processEvent(socket, navigate, params); @@ -639,7 +635,7 @@ export const connect = async ( ); } // Drain any initial events from the queue. - while (event_queue.length > 0 && !event_processing) { + while (event_queue.length > 0) { await processEvent(socket.current, navigate, () => params.current); } }); @@ -659,12 +655,10 @@ export const connect = async ( }, 200 * n_connect_errors); // Incremental backoff }); - // When the socket disconnects reset the event_processing flag socket.current.on("disconnect", (reason, details) => { socket.current.wait_connect = false; const try_reconnect = reason !== "io server disconnect" && reason !== "io client disconnect"; - event_processing = false; window.removeEventListener("unload", disconnectTrigger); window.removeEventListener("beforeunload", disconnectTrigger); window.removeEventListener("pagehide", pagehideHandler); @@ -676,27 +670,25 @@ export const connect = async ( // On each received message, queue the updates and events. socket.current.on("event", async (update) => { - for (const substate in update.delta) { - dispatch[substate](update.delta[substate]); - // handle events waiting for `is_hydrated` - if ( - substate === state_name && - update.delta[substate]?.is_hydrated_rx_state_ - ) { - queueEvents(on_hydrated_queue, socket, false, navigate, params); - on_hydrated_queue.length = 0; + if (update.delta && Object.keys(update.delta).length > 0) { + for (const substate in update.delta) { + dispatch[substate](update.delta[substate]); + // handle events waiting for `is_hydrated` + if ( + substate === state_name && + update.delta[substate]?.is_hydrated_rx_state_ + ) { + queueEvents(on_hydrated_queue, socket, false, navigate, params); + on_hydrated_queue.length = 0; + } } + applyClientStorageDelta(client_storage, update.delta); } - applyClientStorageDelta(client_storage, update.delta); - if (update.final !== null) { - event_processing = !update.final; - } - if (update.events) { + if (update.events && update.events.length > 0) { queueEvents(update.events, socket, false, navigate, params); } }); socket.current.on("reload", async (event) => { - event_processing = false; on_hydrated_queue.push(event); queueEvents(initialEvents(), socket, true, navigate, params); }); @@ -722,7 +714,17 @@ export const ReflexEvent = ( event_actions = {}, handler = null, ) => { - return { name, payload, handler, event_actions }; + const e = { name }; + if (payload && Object.keys(payload).length > 0) { + e.payload = payload; + } + if (event_actions && Object.keys(event_actions).length > 0) { + e.event_actions = event_actions; + } + if (handler !== null) { + e.handler = handler; + } + return e; }; /** @@ -1017,7 +1019,7 @@ export const useEventLoop = ( } (async () => { // Process all outstanding events. - while (event_queue.length > 0 && !event_processing) { + while (event_queue.length > 0) { await ensureSocketConnected(); await processEvent(socket.current, navigate, () => params.current); } diff --git a/reflex/app.py b/reflex/app.py index d082939c2c6..280c79d30eb 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1655,10 +1655,7 @@ async def modify_state( if delta: # When the frontend vars are modified emit the delta to the frontend. await self.event_namespace.emit_update( - update=StateUpdate( - delta=delta, - final=True, - ), + update=StateUpdate(delta=delta), token=token.ident, ) diff --git a/reflex/ievent/processor/base_state_processor.py b/reflex/ievent/processor/base_state_processor.py index 1ecda7cb610..fc09eb54cf6 100644 --- a/reflex/ievent/processor/base_state_processor.py +++ b/reflex/ievent/processor/base_state_processor.py @@ -182,7 +182,6 @@ async def chain_updates( # Convert valid EventHandler and EventSpec into Event if fixed_events := Event.from_event_type( _check_valid_yield(events, handler_name=handler_name), - router_data=root_state.router_data if root_state else None, ): # Frontend events. if frontend_events := [e for e in fixed_events if e.name.startswith("_")]: diff --git a/reflex/ievent/processor/event_processor.py b/reflex/ievent/processor/event_processor.py index 838ffa1ca74..3cd6107d7ff 100644 --- a/reflex/ievent/processor/event_processor.py +++ b/reflex/ievent/processor/event_processor.py @@ -147,7 +147,7 @@ async def emit_delta( delta: The delta to emit, mapping client tokens to variable updates. """ await event_namespace.emit_update( - update=StateUpdate(delta=delta, final=True), + update=StateUpdate(delta=delta), token=token, ) @@ -163,7 +163,7 @@ async def emit_event(token: str, *events: Event) -> None: events: The events to emit. """ await event_namespace.emit_update( - update=StateUpdate(events=list(events), final=True), + update=StateUpdate(events=list(events)), token=token, ) diff --git a/reflex/state.py b/reflex/state.py index 78902ef99b7..60473d69bcb 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -65,6 +65,7 @@ ) from reflex.utils.exceptions import ImmutableStateError as ImmutableStateError from reflex.utils.exec import is_testing_env +from reflex.utils.serializers import serializer from reflex.utils.types import _isinstance from reflex.vars import Field, VarData, field from reflex.vars.base import ( @@ -2519,16 +2520,20 @@ class StateUpdate: # Events to be added to the event queue. events: list[Event] = dataclasses.field(default_factory=list) - # Whether this is the final state update for the event. - final: bool | None = True - def json(self) -> str: - """Convert the state update to a JSON string. +@serializer(to=dict) +def serialize_state_update(update: StateUpdate) -> dict: + """Serialize a StateUpdate to a dictionary. - Returns: - The state update as a JSON string. - """ - return format.json_dumps(self) + Args: + update: The StateUpdate to serialize. + + Returns: + The serialized StateUpdate. + """ + return { + k.name: v for k in dataclasses.fields(update) if (v := getattr(update, k.name)) + } def code_uses_state_contexts(javascript_code: str) -> bool: diff --git a/reflex/utils/format.py b/reflex/utils/format.py index ff7b710c57e..c326cdff352 100644 --- a/reflex/utils/format.py +++ b/reflex/utils/format.py @@ -594,7 +594,7 @@ def format_query_params(router_data: dict[str, Any]) -> dict[str, str]: Returns: The reformatted query params """ - params = router_data[constants.RouteVar.QUERY] + params = router_data.get(constants.RouteVar.QUERY, {}) return {k.replace("-", "_"): v for k, v in params.items()} diff --git a/tests/units/test_state.py b/tests/units/test_state.py index b5e87277e16..bd182418d47 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2238,7 +2238,6 @@ async def test_state_proxy( "computed" + FIELD_MARKER: "", }, }, - final=True, ) assert mcall.kwargs["to"] == grandchild_state.router.session.session_id From 1e86d61d9b008ad648775767d4d0a56d608ef587 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 25 Mar 2026 23:30:25 -0700 Subject: [PATCH 15/81] remove get_app dependency from get_state and background tasks --- reflex/istate/manager/__init__.py | 4 +++- reflex/istate/proxy.py | 14 ++++++++++---- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index 0f8f42be3a7..bf4dada61d9 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -169,4 +169,6 @@ def get_state_manager() -> StateManager: Returns: The state manager. """ - return prerequisites.get_and_validate_app().app.state_manager + from reflex.ievent.context import event_context + + return event_context.get().state_manager diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index e3b3c981217..cc1bc1ded31 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -20,7 +20,6 @@ from reflex.base import Base from reflex.ievent.context import event_context from reflex.istate.manager.token import BaseStateToken -from reflex.utils import prerequisites from reflex.utils.exceptions import ImmutableStateError from reflex.utils.serializers import can_serialize, serialize, serializer from reflex.vars.base import Var @@ -74,7 +73,6 @@ def __init__( parent_state_proxy: The parent state proxy, for linked mutability and context tracking. """ super().__init__(state_instance) - self._self_app = prerequisites.get_and_validate_app().app self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_substate_token = BaseStateToken( ident=event_context.get().token, @@ -132,11 +130,13 @@ async def __aenter__(self) -> Self: msg = "The state is already mutable. Do not nest `async with self` blocks." raise ImmutableStateError(msg) + ctx = event_context.get() + await self._self_actx_lock.acquire() try: self._self_actx_lock_holder = current_task - self._self_actx = self._self_app.modify_state( - token=self._self_substate_token, background=True + self._self_actx = ctx.state_manager.modify_state_with_links( + token=self._self_substate_token, ) mutable_state = await self._self_actx.__aenter__() self._self_mutable = True @@ -163,6 +163,12 @@ async def __aexit__(self, *exc_info: Any) -> None: try: if self._self_mutable and self._self_actx is not None: await self._self_actx.__aexit__(*exc_info) + delta = await self.__wrapped__._get_resolved_delta() + self.__wrapped__._clean() + # When the frontend vars are modified emit the delta to the frontend. + if delta: + ctx = event_context.get() + await ctx.emit_delta(delta) finally: self._self_actx = None self._self_mutable = False From a1bc4b598a89fafbecd51815bac2d611907faac8 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 26 Mar 2026 00:09:36 -0700 Subject: [PATCH 16/81] Remove remaining get_app / mock_app dependency from tests Fix issue with background task delta calculation not starting from the root state --- reflex/event.py | 10 +- reflex/ievent/processor/event_processor.py | 2 +- reflex/istate/proxy.py | 18 +-- reflex/state.py | 6 +- tests/units/conftest.py | 16 +++ tests/units/istate/test_proxy.py | 15 ++- tests/units/test_state.py | 129 ++++++++++----------- 7 files changed, 109 insertions(+), 87 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index 557fbeb29d0..bcb64e0999d 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -27,7 +27,7 @@ from reflex import constants from reflex.components.field import BaseField from reflex.constants.compiler import CompileVars, Hooks, Imports -from reflex.utils import format +from reflex.utils import console, format from reflex.utils.decorator import once from reflex.utils.exceptions import ( EventFnArgMismatchError, @@ -1932,6 +1932,8 @@ def fix_events( ) -> list[Event]: """Fix a list of events returned by an event handler. + Deprecated: use Event.from_event_type instead. + Args: events: The events to fix. router_data: The optional router data to set in the event. @@ -1942,6 +1944,12 @@ def fix_events( Raises: ValueError: If the event type is not what was expected. """ + console.deprecate( + feature_name="rx.event.fix_events()", + reason="Use Event.from_event_type() instead", + deprecation_version="0.9.0", + removal_version="1.0", + ) # If the event handler returns nothing, return an empty list. if events is None: return [] diff --git a/reflex/ievent/processor/event_processor.py b/reflex/ievent/processor/event_processor.py index 3cd6107d7ff..8008b7df64d 100644 --- a/reflex/ievent/processor/event_processor.py +++ b/reflex/ievent/processor/event_processor.py @@ -183,7 +183,7 @@ async def emit_event(token: str, *events: Event) -> None: ) return self - async def __aenter__(self) -> "EventProcessor": + async def __aenter__(self) -> Self: """Enter the event processor context manager. Returns: diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index cc1bc1ded31..72f0086bf84 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -162,18 +162,22 @@ async def __aexit__(self, *exc_info: Any) -> None: return try: if self._self_mutable and self._self_actx is not None: - await self._self_actx.__aexit__(*exc_info) - delta = await self.__wrapped__._get_resolved_delta() - self.__wrapped__._clean() + root_state = self.__wrapped__._get_root_state() + delta = await root_state._get_resolved_delta() + root_state._clean() # When the frontend vars are modified emit the delta to the frontend. if delta: ctx = event_context.get() await ctx.emit_delta(delta) finally: - self._self_actx = None - self._self_mutable = False - self._self_actx_lock_holder = None - self._self_actx_lock.release() + try: + if self._self_mutable and self._self_actx is not None: + await self._self_actx.__aexit__(*exc_info) + finally: + self._self_actx = None + self._self_mutable = False + self._self_actx_lock_holder = None + self._self_actx_lock.release() def __enter__(self): """Enter the regular context manager protocol. diff --git a/reflex/state.py b/reflex/state.py index 60473d69bcb..1066de6027a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -23,7 +23,6 @@ ClassVar, ParamSpec, TypeVar, - cast, get_type_hints, ) @@ -41,7 +40,6 @@ EventHandler, EventSpec, call_script, - fix_events, ) from reflex.istate import HANDLED_PICKLE_ERRORS, debug_failed_pickles from reflex.istate.data import RouterData @@ -2388,8 +2386,8 @@ def on_load_internal(self) -> list[Event | EventSpec | event.EventCallback] | No return None # Fast path for navigation with no on_load events defined. self.is_hydrated = False return [ - *fix_events( - cast(list[EventSpec | EventHandler], load_events), + *Event.from_event_type( + load_events, router_data=self.router_data, ), State.set_is_hydrated(True), diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 6c5a0c98da9..7a7797b6ff1 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -437,6 +437,22 @@ def attached_mock_event_context( event_context.reset(reset_token) +@pytest_asyncio.fixture +async def attached_mock_base_state_event_processor( + mock_base_state_event_processor: BaseStateEventProcessor, +) -> AsyncGenerator[BaseStateEventProcessor]: + """Fork the mock event context for the given token, attach it, and set the processor's root context to it. + + Args: + mock_base_state_event_processor: The mock BaseState event processor to use for the processor's enqueue implementation. + + Yields: + The mock BaseState event processor with the attached context as its root context. + """ + async with mock_base_state_event_processor as processor: + yield processor + + @pytest.fixture def forked_registration_context() -> Generator[RegistrationContext, None, None]: """Fork the registration context and attach it. diff --git a/tests/units/istate/test_proxy.py b/tests/units/istate/test_proxy.py index 0f4d1171780..1e74d99bdb5 100644 --- a/tests/units/istate/test_proxy.py +++ b/tests/units/istate/test_proxy.py @@ -4,11 +4,11 @@ import pickle from asyncio import CancelledError from contextlib import asynccontextmanager -from unittest.mock import patch import pytest import reflex as rx +from reflex.ievent.context import EventContext from reflex.istate.proxy import MutableProxy, StateProxy @@ -42,14 +42,15 @@ def test_mutable_proxy_pickle_preserves_object_identity(): assert unpickled["direct"][0] is unpickled["proxied"][0] -@pytest.mark.usefixtures("mock_app", "attached_mock_event_context") @pytest.mark.asyncio -async def test_state_proxy_recovery(): +async def test_state_proxy_recovery( + attached_mock_event_context: EventContext, monkeypatch: pytest.MonkeyPatch +): """Ensure that `async with self` can be re-entered after a lock issue.""" state = ProxyTestState() state_proxy = StateProxy(state) - with patch("reflex.app.App.modify_state") as mock_modify_state: + with monkeypatch.context() as m: @asynccontextmanager async def mock_modify_state_context(*args, **kwargs): # noqa: RUF029 @@ -57,7 +58,11 @@ async def mock_modify_state_context(*args, **kwargs): # noqa: RUF029 raise CancelledError(msg) yield - mock_modify_state.side_effect = mock_modify_state_context + m.setattr( + attached_mock_event_context.state_manager, + "modify_state", + mock_modify_state_context, + ) with pytest.raises(CancelledError, match="Simulated lock issue"): async with state_proxy: diff --git a/tests/units/test_state.py b/tests/units/test_state.py index bd182418d47..a772180495d 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -10,7 +10,7 @@ import os import sys import threading -from collections.abc import AsyncGenerator, Callable +from collections.abc import AsyncGenerator, Callable, Mapping from textwrap import dedent from typing import Any, ClassVar from unittest.mock import AsyncMock, Mock @@ -25,7 +25,7 @@ from reflex import constants from reflex.app import App from reflex.base import Base -from reflex.constants import CompileVars, RouteVar, SocketEvent +from reflex.constants import CompileVars, RouteVar from reflex.constants.state import FIELD_MARKER from reflex.environment import environment from reflex.event import Event, EventHandler @@ -46,7 +46,6 @@ OnLoadInternalState, RouterData, State, - StateUpdate, ) from reflex.testing import chdir from reflex.utils import format, prerequisites, types @@ -59,7 +58,6 @@ UnretrievableVarValueError, ) from reflex.utils.format import json_dumps -from reflex.utils.token_manager import SocketRecord from reflex.vars.base import Field, Var, computed_var, field from tests.units.mock_redis import mock_redis @@ -2111,16 +2109,18 @@ def from_dict(cls, data: dict) -> ModelDC: @pytest.mark.asyncio async def test_state_proxy( grandchild_state: GrandchildState, - mock_app: rx.App, token: str, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, Mapping[str, Mapping[str, Any]]]], attached_mock_event_context: EventContext, ): """Test that the state proxy works. Args: grandchild_state: A grandchild state. - mock_app: An app that will be returned by `get_app()` token: A token. + attached_mock_base_state_event_processor: The event processor attached for this test. + emitted_deltas: A list to capture emitted deltas. attached_mock_event_context: The event context attached for this test. """ child_state = grandchild_state.parent_state @@ -2133,21 +2133,13 @@ async def test_state_proxy( "sid": "test_sid", }) grandchild_state.router = router_data - namespace = mock_app.event_namespace - assert namespace is not None - namespace.sid_to_token[router_data.session.session_id] = token - namespace._token_manager.instance_id = "mock" - namespace._token_manager.token_to_socket[token] = SocketRecord( - instance_id="mock", sid=router_data.session.session_id - ) - if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): - mock_app.state_manager.states[parent_state.router.session.client_token] = ( - parent_state - ) - elif isinstance(mock_app.state_manager, StateManagerRedis): + state_manager = attached_mock_event_context.state_manager + if isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): + state_manager.states[parent_state.router.session.client_token] = parent_state + elif isinstance(state_manager, StateManagerRedis): pickle_state = parent_state._serialize() if pickle_state: - await mock_app.state_manager.redis.set( + await state_manager.redis.set( str( BaseStateToken( ident=parent_state.router.session.client_token, @@ -2155,13 +2147,12 @@ async def test_state_proxy( ) ), pickle_state, - ex=mock_app.state_manager.token_expiration, + ex=state_manager.token_expiration, ) sp = StateProxy(grandchild_state) assert sp.__wrapped__ == grandchild_state assert sp._self_substate_path == tuple(grandchild_state.get_full_name().split(".")) - assert sp._self_app is mock_app assert not sp._self_mutable assert sp._self_actx is None @@ -2192,7 +2183,7 @@ async def test_state_proxy( async with sp: assert sp._self_actx is not None assert sp._self_mutable # proxy is mutable inside context - if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): + if isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): # For in-process store, only one instance of the state exists assert sp.__wrapped__ is grandchild_state else: @@ -2204,16 +2195,16 @@ async def test_state_proxy( assert sp.value2 == "42" if environment.REFLEX_OPLOCK_ENABLED.get(): - await mock_app.state_manager.close() + await state_manager.close() # Get the state from the state manager directly and check that the value is updated - gotten_state = await mock_app.state_manager.get_state( + gotten_state = await state_manager.get_state( BaseStateToken( ident=grandchild_state.router.session.client_token, cls=type(grandchild_state), ) ) - if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): + if isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): # For in-process store, only one instance of the state exists assert gotten_state is parent_state else: @@ -2224,22 +2215,21 @@ async def test_state_proxy( assert gotten_grandchild_state.value2 == "42" # ensure state update was emitted - assert mock_app.event_namespace is not None - mock_app.event_namespace.emit.assert_called_once() # pyright: ignore [reportAttributeAccessIssue] - mcall = mock_app.event_namespace.emit.mock_calls[0] # pyright: ignore [reportAttributeAccessIssue] - assert mcall.args[0] == str(SocketEvent.EVENT) - assert mcall.args[1] == StateUpdate( - delta={ - TestState.get_full_name(): {"router" + FIELD_MARKER: router_data}, - grandchild_state.get_full_name(): { - "value2" + FIELD_MARKER: "42", - }, - GrandchildState3.get_full_name(): { - "computed" + FIELD_MARKER: "", + await attached_mock_base_state_event_processor.join(timeout=1) + assert emitted_deltas == [ + ( + token, + { + TestState.get_full_name(): {"router" + FIELD_MARKER: router_data}, + grandchild_state.get_full_name(): { + "value2" + FIELD_MARKER: "42", + }, + GrandchildState3.get_full_name(): { + "computed" + FIELD_MARKER: "", + }, }, - }, - ) - assert mcall.kwargs["to"] == grandchild_state.router.session.session_id + ) + ] class BackgroundTaskState(BaseState): @@ -3143,21 +3133,21 @@ def index(): @pytest.mark.asyncio -async def test_get_state(mock_app: rx.App, token: str): +async def test_get_state(token: str, attached_mock_event_context: EventContext): """Test that a get_state populates the top level state and delta calculation is correct. Args: - mock_app: An app that will be returned by `get_app()` token: A token. + attached_mock_event_context: An event context with a state manager that has a TestState instance corresponding to the token. """ - mock_app._state = TestState + state_manager = attached_mock_event_context.state_manager # Get instance of ChildState2. - test_state = await mock_app.state_manager.get_state( + test_state = await state_manager.get_state( BaseStateToken(ident=token, cls=ChildState2) ) assert isinstance(test_state, TestState) - if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): + if isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): # All substates are available assert tuple(sorted(test_state.substates)) == ( ChildState.get_name(), @@ -3217,11 +3207,11 @@ async def test_get_state(mock_app: rx.App, token: str): } # Get a fresh instance - new_test_state = await mock_app.state_manager.get_state( + new_test_state = await state_manager.get_state( BaseStateToken(ident=token, cls=ChildState2) ) assert isinstance(new_test_state, TestState) - if isinstance(mock_app.state_manager, (StateManagerMemory, StateManagerDisk)): + if isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): # In memory, it's the same instance assert new_test_state is test_state test_state._clean() @@ -3258,7 +3248,9 @@ async def test_get_state(mock_app: rx.App, token: str): @pytest.mark.asyncio -async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str): +async def test_get_state_from_sibling_not_cached( + token: str, attached_mock_event_context: EventContext +): """A test simulating update_vars_internal when setting cookies with computed vars. In that case, a sibling state, UpdateVarsInternalState handles the fetching @@ -3271,8 +3263,8 @@ async def test_get_state_from_sibling_not_cached(mock_app: rx.App, token: str): Explicit regression test for https://github.com/reflex-dev/reflex/issues/2851. Args: - mock_app: An app that will be returned by `get_app()` token: A token. + attached_mock_event_context: An event context with a state manager that has a TestState instance corresponding to the token. """ class Parent(BaseState): @@ -3311,16 +3303,14 @@ class GreatGrandchild3(Grandchild3): has a computed var. """ - mock_app._state = Parent + state_manager = attached_mock_event_context.state_manager # Get the top level state via unconnected sibling. - root = await mock_app.state_manager.get_state( - BaseStateToken(ident=token, cls=Child) - ) + root = await state_manager.get_state(BaseStateToken(ident=token, cls=Child)) # Set value in parent_var to assert it does not get refetched later. root.parent_var = 1 - if isinstance(mock_app.state_manager, StateManagerRedis): + if isinstance(state_manager, StateManagerRedis): # When redis is used, only states with computed vars are pre-fetched. assert Child2.get_name() not in root.substates assert Child3.get_name() in root.substates # (due to @rx.var) @@ -4226,12 +4216,14 @@ async def test_get_var_value( @pytest.mark.asyncio -async def test_async_computed_var_get_state(mock_app: rx.App, token: str): +async def test_async_computed_var_get_state( + token: str, attached_mock_event_context: EventContext +): """A test where an async computed var depends on a var in another state. Args: - mock_app: An app that will be returned by `get_app()` token: A token. + attached_mock_event_context: An event context that will be attached to the app's state manager. """ class Parent(BaseState): @@ -4264,16 +4256,14 @@ async def v(self) -> int: child3 = await self.get_state(Child3) return child3.child3_var + p.parent_var - mock_app._state = Parent + state_manager = attached_mock_event_context.state_manager # Get the top level state via unconnected sibling. - root = await mock_app.state_manager.get_state( - BaseStateToken(ident=token, cls=Child) - ) + root = await state_manager.get_state(BaseStateToken(ident=token, cls=Child)) # Set value in parent_var to assert it does not get refetched later. root.parent_var = 1 - if isinstance(mock_app.state_manager, StateManagerRedis): + if isinstance(state_manager, StateManagerRedis): # When redis is used, only states with uncached computed vars are pre-fetched. assert Child2.get_name() not in root.substates assert Child3.get_name() not in root.substates @@ -4385,7 +4375,9 @@ class SecondCvState(CvMixin, rx.State): @pytest.mark.asyncio -async def test_add_dependency_get_state_regression(mock_app: rx.App, token: str): +async def test_add_dependency_get_state_regression( + token: str, attached_mock_event_context: EventContext, mock_app: rx.App +): """Ensure that a state class can be fetched separately when it's is explicit dep.""" class DataState(rx.State): @@ -4410,8 +4402,7 @@ class OtherState(rx.State): async def fetch_data_state(self) -> None: print(await self.get_state(DataState)) - mock_app._state = rx.State - state = await mock_app.state_manager.get_state( + state = await attached_mock_event_context.state_manager.get_state( BaseStateToken(ident=token, cls=OtherState) ) other_state = await state.get_state(OtherState) @@ -4426,12 +4417,12 @@ class MutableProxyState(BaseState): @pytest.mark.asyncio async def test_rebind_mutable_proxy( - mock_app: rx.App, token: str, attached_mock_event_context: EventContext + token: str, attached_mock_event_context: EventContext ) -> None: """Test that previously bound MutableProxy instances can be rebound correctly.""" - mock_app._state = MutableProxyState + state_manager = attached_mock_event_context.state_manager - async with mock_app.state_manager.modify_state( + async with state_manager.modify_state( BaseStateToken(ident=token, cls=MutableProxyState) ) as state: state.router = RouterData.from_router_data({ @@ -4456,7 +4447,7 @@ async def test_rebind_mutable_proxy( assert not isinstance(state_proxy.__wrapped__.data["a"], ImmutableMutableProxy) # Flush any oplock. - await mock_app.state_manager.close() + await state_manager.close() new_state_proxy = StateProxy(state) assert state_proxy is not new_state_proxy @@ -4467,7 +4458,7 @@ async def test_rebind_mutable_proxy( async with state_proxy: state_proxy.data["a"].append(3) - async with mock_app.state_manager.modify_state( + async with state_manager.modify_state( BaseStateToken(ident=token, cls=MutableProxyState) ) as state: assert isinstance(state, MutableProxyState) From 92562127557619a804188b9b2c1acd0ced56cd91 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 26 Mar 2026 16:06:39 -0700 Subject: [PATCH 17/81] EventContext inherits from BaseContext remove extra `event_context` ContextVar being passed around --- reflex/_internal/context/base.py | 21 +++++++++++++++++++ reflex/ievent/context.py | 17 +++++++-------- .../ievent/processor/base_state_processor.py | 6 +++--- reflex/ievent/processor/event_processor.py | 14 ++++++------- reflex/istate/manager/__init__.py | 5 ++--- reflex/istate/proxy.py | 8 +++---- reflex/state.py | 4 ++-- tests/units/conftest.py | 8 +++---- 8 files changed, 50 insertions(+), 33 deletions(-) diff --git a/reflex/_internal/context/base.py b/reflex/_internal/context/base.py index bce7ebba0f3..790d0b8db26 100644 --- a/reflex/_internal/context/base.py +++ b/reflex/_internal/context/base.py @@ -26,6 +26,27 @@ def get(cls) -> Self: """ return cls._context_var.get() + @classmethod + def set(cls, context: Self) -> Token[Self]: + """Set the context in the context variable. + + Args: + context: The context instance to set. + + Returns: + The token for resetting the context variable. + """ + return cls._context_var.set(context) + + @classmethod + def reset(cls, token: Token[Self]) -> None: + """Reset the context variable to a previous state. + + Args: + token: The token to reset the context variable to. + """ + cls._context_var.reset(token) + def __enter__(self) -> Self: """Enter the context. diff --git a/reflex/ievent/context.py b/reflex/ievent/context.py index 343335993c2..73801ba1af1 100644 --- a/reflex/ievent/context.py +++ b/reflex/ievent/context.py @@ -4,14 +4,16 @@ import functools import uuid from collections.abc import Callable, Mapping -from contextvars import ContextVar from typing import TYPE_CHECKING, Any, Protocol -from reflex.istate.manager import StateManager -from reflex.utils.format import to_snake_case +from reflex_core.utils.format import to_snake_case + +from reflex._internal.context.base import BaseContext if TYPE_CHECKING: - from reflex.event import Event + from reflex_core.event import Event + + from reflex.istate.manager import StateManager @functools.lru_cache @@ -69,8 +71,8 @@ async def __call__( ... -@dataclasses.dataclass(frozen=True, kw_only=True) -class EventContext: +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True, eq=False) +class EventContext(BaseContext): """The context for an event.""" # Identifies the client session. @@ -142,6 +144,3 @@ async def enqueue(self, *event: Event) -> None: event: The event to enqueue. """ await self.enqueue_impl(self.token, *event) - - -event_context: ContextVar[EventContext] = ContextVar("event_context") diff --git a/reflex/ievent/processor/base_state_processor.py b/reflex/ievent/processor/base_state_processor.py index fc09eb54cf6..ebccf93a6d6 100644 --- a/reflex/ievent/processor/base_state_processor.py +++ b/reflex/ievent/processor/base_state_processor.py @@ -9,7 +9,7 @@ from importlib.util import find_spec from typing import TYPE_CHECKING, Any -from reflex.ievent.context import event_context +from reflex.ievent.context import EventContext from reflex.ievent.processor import EventProcessor from reflex.ievent.processor.event_processor import ( EventQueueEntry, @@ -177,7 +177,7 @@ async def chain_updates( """ from reflex.event import Event - ctx = event_context.get() + ctx = EventContext.get() # Convert valid EventHandler and EventSpec into Event if fixed_events := Event.from_event_type( @@ -289,7 +289,7 @@ async def _process_event_queue_entry( """ # Set up the event context for this task. ctx = entry.ctx - event_context.set(ctx) + EventContext.set(ctx) event = entry.event router_data = event.router_data or {} # Get the state for the session exclusively. diff --git a/reflex/ievent/processor/event_processor.py b/reflex/ievent/processor/event_processor.py index 8008b7df64d..bd5523e3100 100644 --- a/reflex/ievent/processor/event_processor.py +++ b/reflex/ievent/processor/event_processor.py @@ -14,7 +14,7 @@ from reflex._internal.registry import RegisteredEventHandler, RegistrationContext from reflex.app_mixins.middleware import MiddlewareMixin -from reflex.ievent.context import EventContext, event_context +from reflex.ievent.context import EventContext from reflex.istate.manager import StateManager from reflex.utils import console @@ -207,7 +207,7 @@ async def start(self) -> None: if self._attached_root_context_token is not None: msg = "EventProcessor context cannot be nested." raise RuntimeError(msg) - self._attached_root_context_token = event_context.set(self._root_context) + self._attached_root_context_token = EventContext.set(self._root_context) self._queue = asyncio.Queue() self._ensure_queue_task() @@ -253,7 +253,7 @@ async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: from reflex.utils import telemetry if self._attached_root_context_token is not None: - event_context.reset(self._attached_root_context_token) + EventContext.reset(self._attached_root_context_token) self._attached_root_context_token = None # Optional grace period for tasks to finish before cancellation. if graceful_shutdown_timeout is None: @@ -319,7 +319,7 @@ def _ensure_queue_task(self) -> asyncio.Queue[EventQueueEntry]: raise RuntimeError(msg) if self._queue_task is None: task_context = copy_context() - task_context.run(event_context.set, self._root_context) + task_context.run(EventContext.set, self._root_context) self._queue_task = task_context.run( asyncio.create_task, self._process_queue(), @@ -339,7 +339,7 @@ async def enqueue( """ if ev_ctx is None: try: - ev_ctx = event_context.get().fork(token=token) + ev_ctx = EventContext.get().fork(token=token) except LookupError as le: if self._root_context is not None: ev_ctx = self._root_context.fork(token=token) @@ -370,7 +370,7 @@ async def _process_event_queue_entry( """ # Set up the event context for this task. ctx = entry.ctx - event_context.set(ctx) + EventContext.set(ctx) event = entry.event result = registered_handler.handler.fn(**event.payload) if inspect.isawaitable(result): @@ -437,7 +437,7 @@ def _finish_task(self, task: asyncio.Task): """ from reflex.utils import telemetry - task_ctx = task.get_context().run(event_context.get) + task_ctx = task.get_context().run(EventContext.get) self._tasks.pop(task_ctx.txid, None) if task.done(): try: diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index e2007b4aeff..58aa688ffee 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -13,7 +13,6 @@ from typing_extensions import ReadOnly, Unpack from reflex.istate.manager.token import TOKEN_TYPE, StateToken -from reflex.state import BaseState from reflex.utils import console, prerequisites @@ -177,6 +176,6 @@ def get_state_manager() -> StateManager: Returns: The state manager. """ - from reflex.ievent.context import event_context + from reflex.ievent.context import EventContext - return event_context.get().state_manager + return EventContext.get().state_manager diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index 4d7b4eee2d3..430669f8b38 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -21,7 +21,7 @@ from reflex_core.vars.base import Var from typing_extensions import Self -from reflex.ievent.context import event_context +from reflex.ievent.context import EventContext from reflex.istate.manager.token import BaseStateToken if TYPE_CHECKING: @@ -78,7 +78,7 @@ def __init__( self._self_event = event self._self_substate_path = tuple(state_instance.get_full_name().split(".")) self._self_substate_token = BaseStateToken( - ident=event_context.get().token, + ident=EventContext.get().token, cls=state_instance.__class__, ) self._self_actx = None @@ -133,7 +133,7 @@ async def __aenter__(self) -> Self: msg = "The state is already mutable. Do not nest `async with self` blocks." raise ImmutableStateError(msg) - ctx = event_context.get() + ctx = EventContext.get() await self._self_actx_lock.acquire() try: @@ -170,7 +170,7 @@ async def __aexit__(self, *exc_info: Any) -> None: root_state._clean() # When the frontend vars are modified emit the delta to the frontend. if delta: - ctx = event_context.get() + ctx = EventContext.get() await ctx.emit_delta(delta) finally: try: diff --git a/reflex/state.py b/reflex/state.py index e844d64296d..9f4035e6caf 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2213,7 +2213,7 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: @event async def hydrate(self) -> None: """Send the full state to the frontend to synchronize it with the backend.""" - from reflex.ievent.context import event_context + from reflex.ievent.context import EventContext # Clear client storage, to respect clearing cookies self._reset_client_storage() @@ -2222,7 +2222,7 @@ async def hydrate(self) -> None: self.is_hydrated = False # Get the initial state if needed. - ctx = event_context.get() + ctx = EventContext.get() if ctx.emit_delta_impl is not None: await ctx.emit_delta(delta=await _resolve_delta(self.dict())) diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 1cf4e99a47f..55b753d5b76 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -16,7 +16,7 @@ from reflex._internal.registry import RegistrationContext from reflex.app import App from reflex.experimental.memo import EXPERIMENTAL_MEMOS -from reflex.ievent.context import EventContext, event_context +from reflex.ievent.context import EventContext from reflex.ievent.processor import BaseStateEventProcessor, EventProcessor from reflex.istate.manager import StateManager from reflex.istate.manager.disk import StateManagerDisk @@ -433,10 +433,8 @@ def attached_mock_event_context( Yields: The forked EventContext. """ - ctx = mock_root_event_context.fork(token=token) - reset_token = event_context.set(ctx) - yield ctx - event_context.reset(reset_token) + with mock_root_event_context.fork(token=token) as ctx: + yield ctx @pytest_asyncio.fixture From 0c04bc34f9fc36af174bfea1ef5f9011fb871b67 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 30 Mar 2026 08:58:27 -0700 Subject: [PATCH 18/81] additional fixups --- reflex/app.py | 26 ++++---------------------- reflex/istate/manager/__init__.py | 2 -- tests/units/test_app.py | 2 +- tests/units/test_state.py | 2 +- 4 files changed, 6 insertions(+), 26 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index c96e15acfc5..d083faced67 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -15,19 +15,14 @@ import time import traceback import urllib.parse -from collections.abc import ( - AsyncIterator, - Callable, - Coroutine, - Mapping, - Sequence, -) +from collections.abc import AsyncIterator, Callable, Coroutine, Mapping, Sequence from datetime import datetime from itertools import chain from pathlib import Path from timeit import default_timer as timer from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, ParamSpec +from typing import TYPE_CHECKING, Any, ParamSpec, overload +from warnings import deprecated from reflex_components_core.base.app_wrap import AppWrap from reflex_components_core.base.error_boundary import ErrorBoundary @@ -57,23 +52,11 @@ EventSpec, EventType, IndividualEventType, - get_hydrate_event, noop, ) from reflex_core.utils import console from reflex_core.utils.imports import ImportVar from reflex_core.utils.types import ASGIApp, Message, Receive, Scope, Send -from typing import ( - TYPE_CHECKING, - Any, - BinaryIO, - ParamSpec, - get_args, - get_type_hints, - overload, -) -from warnings import deprecated - from rich.progress import MofNCompleteColumn, Progress, TimeElapsedColumn from socketio import ASGIApp as EngineIOApp from socketio import AsyncNamespace, AsyncServer @@ -96,9 +79,8 @@ readable_name_from_component, ) from reflex.experimental.memo import EXPERIMENTAL_MEMOS -from reflex.istate.manager import StateModificationContext -from reflex.istate.proxy import StateProxy from reflex.ievent.processor import BaseStateEventProcessor, EventProcessor +from reflex.istate.manager import StateModificationContext from reflex.istate.manager.token import BaseStateToken from reflex.page import DECORATED_PAGES from reflex.route import ( diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index 58aa688ffee..b2191e5c57f 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -131,8 +131,6 @@ async def modify_state_with_links( Yields: The state for the token with linked states patched in. """ - from reflex.state import BaseState - async with self.modify_state(token, **context) as root_state: if ( isinstance(root_state, BaseState) diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 7f16a3fd5e0..afdc3bad1f2 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -12,7 +12,7 @@ from importlib.util import find_spec from pathlib import Path from types import SimpleNamespace -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any from unittest.mock import AsyncMock import pytest diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 2fcbc8c0804..8d54b2e7be8 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -22,7 +22,7 @@ from pydantic import BaseModel as Base from pytest_mock import MockerFixture from reflex_core import constants -from reflex_core.constants import CompileVars, RouteVar, SocketEvent +from reflex_core.constants import CompileVars, RouteVar from reflex_core.constants.state import FIELD_MARKER from reflex_core.event import Event, EventHandler from reflex_core.utils import format, types From 5940c8df555de0284aa462951b9727dd99e77520 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 30 Mar 2026 09:00:37 -0700 Subject: [PATCH 19/81] apply changes to migrated files separately --- .../src/reflex_core/components/component.py | 7 ++--- .../src/reflex_core/constants/state.py | 3 --- .../src/reflex_core/plugins/_screenshot.py | 10 ++++--- .../src/reflex_core/utils/format.py | 26 +++++++------------ .../src/reflex_core/utils/serializers.py | 11 ++++++++ 5 files changed, 30 insertions(+), 27 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/components/component.py b/packages/reflex-core/src/reflex_core/components/component.py index 82d60203b22..0a0b8794264 100644 --- a/packages/reflex-core/src/reflex_core/components/component.py +++ b/packages/reflex-core/src/reflex_core/components/component.py @@ -36,7 +36,7 @@ PageNames, ) from reflex_core.constants.compiler import SpecialAttributes -from reflex_core.constants.state import CAMEL_CASE_MEMO_MARKER, FRONTEND_EVENT_STATE +from reflex_core.constants.state import CAMEL_CASE_MEMO_MARKER from reflex_core.event import ( EventCallback, EventChain, @@ -1523,10 +1523,7 @@ def _event_trigger_values_use_state(self) -> bool: if isinstance(event, EventCallback): continue if isinstance(event, EventSpec): - if ( - event.handler.state_full_name - and event.handler.state_full_name != FRONTEND_EVENT_STATE - ): + if event.handler.state is not None: return True else: if event._var_state: diff --git a/packages/reflex-core/src/reflex_core/constants/state.py b/packages/reflex-core/src/reflex_core/constants/state.py index 3f6ebec2f17..8742f76e185 100644 --- a/packages/reflex-core/src/reflex_core/constants/state.py +++ b/packages/reflex-core/src/reflex_core/constants/state.py @@ -11,9 +11,6 @@ class StateManagerMode(str, Enum): REDIS = "redis" -# Used for things like console_log, etc. -FRONTEND_EVENT_STATE = "__reflex_internal_frontend_event_state" - FIELD_MARKER = "_rx_state_" MEMO_MARKER = "_rx_memo_" CAMEL_CASE_MEMO_MARKER = "RxMemo" diff --git a/packages/reflex-core/src/reflex_core/plugins/_screenshot.py b/packages/reflex-core/src/reflex_core/plugins/_screenshot.py index 7b3b1d5e8b7..3834aa72c98 100644 --- a/packages/reflex-core/src/reflex_core/plugins/_screenshot.py +++ b/packages/reflex-core/src/reflex_core/plugins/_screenshot.py @@ -97,7 +97,8 @@ async def clone_state(request: "Request") -> "Response": from starlette.responses import JSONResponse - from reflex.state import _substate_key + from reflex.istate.manager.token import BaseStateToken + from reflex.state import State if not app.event_namespace: return JSONResponse({}) @@ -109,7 +110,9 @@ async def clone_state(request: "Request") -> "Response": {"error": "Token to clone must be a string."}, status_code=400 ) - old_state = await app.state_manager.get_state(token_to_clone) + old_state = await app.state_manager.get_state( + BaseStateToken(ident=token_to_clone, cls=State), + ) new_state = _deep_copy(old_state) @@ -132,7 +135,8 @@ async def clone_state(request: "Request") -> "Response": found_new = True await app.state_manager.set_state( - _substate_key(new_token, new_state), new_state + BaseStateToken(ident=new_token, cls=type(new_state)), + new_state, ) return JSONResponse(new_token) diff --git a/packages/reflex-core/src/reflex_core/utils/format.py b/packages/reflex-core/src/reflex_core/utils/format.py index 4e673e79af9..6b5876bb38c 100644 --- a/packages/reflex-core/src/reflex_core/utils/format.py +++ b/packages/reflex-core/src/reflex_core/utils/format.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING, Any from reflex_core import constants -from reflex_core.constants.state import FRONTEND_EVENT_STATE from reflex_core.utils import exceptions if TYPE_CHECKING: @@ -454,25 +453,20 @@ def get_event_handler_parts(handler: EventHandler) -> tuple[str, str]: Returns: The state and function name. """ - # Get the class that defines the event handler. - parts = handler.fn.__qualname__.split(".") + # Get the name of the event function. + name = handler.fn.__qualname__ # Get the state full name - state_full_name = handler.state_full_name + state_full_name = handler.state.get_full_name() if handler.state else "" - # If there's no enclosing class, just return the function name. - if not state_full_name: - return ("", parts[-1]) + # If there's no enclosing state, just return the full name. + if handler.state is None: + return ("", name) - # Get the function name - name = parts[-1] + # Get the event name inside the state. + func_name = name.rpartition(".")[2] - from reflex.state import State - - if state_full_name == FRONTEND_EVENT_STATE and name not in State.__dict__: - return ("", to_snake_case(handler.fn.__qualname__)) - - return (state_full_name, name) + return (state_full_name, func_name) def format_event_handler(handler: EventHandler) -> str: @@ -606,7 +600,7 @@ def format_query_params(router_data: dict[str, Any]) -> dict[str, str]: Returns: The reformatted query params """ - params = router_data[constants.RouteVar.QUERY] + params = router_data.get(constants.RouteVar.QUERY, {}) return {k.replace("-", "_"): v for k, v in params.items()} diff --git a/packages/reflex-core/src/reflex_core/utils/serializers.py b/packages/reflex-core/src/reflex_core/utils/serializers.py index 49b3233c838..2ae3e5fbb68 100644 --- a/packages/reflex-core/src/reflex_core/utils/serializers.py +++ b/packages/reflex-core/src/reflex_core/utils/serializers.py @@ -8,6 +8,7 @@ import functools import inspect import json +import uuid import warnings from collections.abc import Callable, Mapping, Sequence from datetime import date, datetime, time, timedelta @@ -34,6 +35,16 @@ SERIALIZED_FUNCTION = TypeVar("SERIALIZED_FUNCTION", bound=Serializer) +deserializers = { + int: int, + float: float, + datetime: datetime.fromisoformat, + date: date.fromisoformat, + time: time.fromisoformat, + uuid.UUID: uuid.UUID, +} + + @overload def serializer( fn: None = None, From b13ad3791a1e7864fcb57f314516935079ba67d8 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 30 Mar 2026 09:39:06 -0700 Subject: [PATCH 20/81] add missing import --- reflex/istate/manager/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index b2191e5c57f..58aa688ffee 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -131,6 +131,8 @@ async def modify_state_with_links( Yields: The state for the token with linked states patched in. """ + from reflex.state import BaseState + async with self.modify_state(token, **context) as root_state: if ( isinstance(root_state, BaseState) From ada96c84ce38ed9d83417cf3d04f812b7ea43524 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 30 Mar 2026 09:39:51 -0700 Subject: [PATCH 21/81] remove pyleak integration from base_state_processor --- reflex/ievent/processor/base_state_processor.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/reflex/ievent/processor/base_state_processor.py b/reflex/ievent/processor/base_state_processor.py index ebccf93a6d6..d641522933b 100644 --- a/reflex/ievent/processor/base_state_processor.py +++ b/reflex/ievent/processor/base_state_processor.py @@ -19,7 +19,6 @@ from reflex.istate.manager.token import BaseStateToken from reflex.istate.proxy import StateProxy from reflex.utils import console, types -from reflex.utils.monitoring import is_pyleak_enabled, monitor_loopblocks if TYPE_CHECKING: from reflex.event import EventHandler, EventSpec @@ -219,11 +218,7 @@ async def process_event( handler_name = handler.fn.__qualname__ # Get the function to process the event. - if is_pyleak_enabled(): - console.debug(f"Monitoring leaks for handler: {handler_name}") - fn = functools.partial(monitor_loopblocks(handler.fn), state) - else: - fn = functools.partial(handler.fn, state) + fn = functools.partial(handler.fn, state) try: type_hints = types.get_type_hints(handler.fn) From a1eb1803b55d051cca5fa4a0c3d4b966d2821d17 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 30 Mar 2026 15:08:33 -0700 Subject: [PATCH 22/81] EventProcessor.enqueue_stream_delta and task Future EventProcessor.enqueue now returns a Future that tracks the completion of the event (and can be used to cancel the event) EventProcessor.enqueue_stream_delta overrides the default emit_delta implementation and instead yields deltas directly to the caller as the event is processing. --- reflex/ievent/context.py | 6 +- reflex/ievent/processor/event_processor.py | 112 ++++++++++++++++++++- 2 files changed, 111 insertions(+), 7 deletions(-) diff --git a/reflex/ievent/context.py b/reflex/ievent/context.py index 73801ba1af1..24d43363c44 100644 --- a/reflex/ievent/context.py +++ b/reflex/ievent/context.py @@ -31,7 +31,7 @@ def get_name(cls: type | Callable) -> str: class EnqueueProtocol(Protocol): """Protocol for the enqueue function in the event context.""" - async def __call__(self, token: str, *events: Event) -> None: + async def __call__(self, token: str, *events: Event) -> Any: """Enqueue an event handler to be executed. Args: @@ -44,7 +44,7 @@ async def __call__(self, token: str, *events: Event) -> None: class EmitEventProtocol(Protocol): """Protocol for the emit_event function in the event context.""" - async def __call__(self, token: str, *events: Event) -> None: + async def __call__(self, token: str, *events: Event) -> Any: """Emit an event to be processed immediately. Args: @@ -61,7 +61,7 @@ async def __call__( self, token: str, delta: Mapping[str, Mapping[str, Any]], - ) -> None: + ) -> Any: """Emit a delta to the frontend. Args: diff --git a/reflex/ievent/processor/event_processor.py b/reflex/ievent/processor/event_processor.py index bd5523e3100..ab3f9f1afed 100644 --- a/reflex/ievent/processor/event_processor.py +++ b/reflex/ievent/processor/event_processor.py @@ -6,7 +6,7 @@ import inspect import time import traceback -from collections.abc import Callable, Mapping +from collections.abc import AsyncGenerator, Callable, Mapping from contextvars import Token, copy_context from typing import TYPE_CHECKING, Any, Self @@ -105,6 +105,9 @@ class EventProcessor: _tasks: dict[str, asyncio.Task] = dataclasses.field( default_factory=dict, init=False ) + _futures: dict[str, asyncio.Future[Any]] = dataclasses.field( + default_factory=dict, init=False + ) def configure( self, @@ -290,6 +293,11 @@ async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: ) ) self._queue_task = None + # Cancel any remaining unresolved futures. + for future in self._futures.values(): + if not future.done(): + future.cancel() + self._futures.clear() async def join(self, timeout: float | None = None) -> None: """Wait for the event processor to finish processing all events in the queue. @@ -329,13 +337,16 @@ def _ensure_queue_task(self) -> asyncio.Queue[EventQueueEntry]: async def enqueue( self, token: str, *events: Event, ev_ctx: EventContext | None = None - ) -> None: + ) -> asyncio.Future[Any]: """Enqueue an event to be processed. Args: token: The client token associated with the event. events: Remaining positional args are events to be enqueued. ev_ctx: The event context to use for these events. + + Returns: + A Future that resolves to the result of the associated task. """ if ev_ctx is None: try: @@ -347,8 +358,88 @@ async def enqueue( msg = "Event processor is not running, call .start(...) first." raise RuntimeError(msg) from le queue = self._ensure_queue_task() + future: asyncio.Future[Any] = asyncio.get_running_loop().create_future() + txid = ev_ctx.txid + self._futures[txid] = future + future.add_done_callback(lambda f: self._on_future_done(txid, f)) for event in events: await queue.put(EventQueueEntry(event=event, ctx=ev_ctx)) + return future + + async def enqueue_stream_delta( + self, + token: str, + event: Event, + ) -> AsyncGenerator[Mapping[str, Any]]: + """Enqueue an event to be processed and yield deltas emitted by the event handler. + + Events queued by this method will not emit deltas to their target token in the typical way, instead + they will be yielded from this generator until the event handler finishes processing. + Deltas emitted for other tokens will be handled normally. + + Any frontend events or chained events are handled normally and deltas from chained events + will not be yielded by this method. + + Args: + token: The client token associated with the event. + event: The event to be enqueued. + + Yields: + Deltas emitted by the event handler for the specified token. + """ + if self._root_context is None: + msg = "Event processor is not configured, call .configure(...) first." + raise RuntimeError(msg) + + deltas = asyncio.Queue() + + async def _emit_delta_impl( + delta_token: str, delta: Mapping[str, Mapping[str, Any]] + ) -> None: + if ( + delta_token != token + and self._root_context is not None + and self._root_context.emit_delta_impl is not None + ): + # Emit deltas for other tokens normally. + await self._root_context.emit_delta_impl(token, delta) + await deltas.put(delta) + + task_future = await self.enqueue( + token, + event, + ev_ctx=dataclasses.replace( + self._root_context, + token=token, + emit_delta_impl=_emit_delta_impl, + ), + ) + while not task_future.done() or not deltas.empty(): + with contextlib.suppress(asyncio.TimeoutError): + async for result in asyncio.as_completed( + [deltas.get(), *([task_future] if not task_future.done() else [])], + timeout=1, + ): + if result is task_future: + continue + yield await result + + def _on_future_done(self, txid: str, future: asyncio.Future) -> None: + """Callback invoked when an enqueued future completes. + + If the future was cancelled externally, cancel the running task + if one exists. If the task has not started yet, ``_process_queue`` + will check the future and skip it when the entry is dequeued. + + Args: + txid: The transaction id associated with the future. + future: The future that completed. + """ + if not future.cancelled(): + return + task = self._tasks.get(txid) + if task is not None: + task.cancel() async def _process_event_queue_entry( self, *, entry: EventQueueEntry, registered_handler: RegisteredEventHandler @@ -384,6 +475,12 @@ async def _process_queue(self): with contextlib.suppress(asyncio.QueueShutDown): while True: entry = await queue.get() + if ( + future := self._futures.get(entry.ctx.txid) + ) is not None and future.cancelled(): + self._futures.pop(entry.ctx.txid, None) + queue.task_done() + continue try: try: registered_handler = RegistrationContext.get().event_handlers[ @@ -439,12 +536,16 @@ def _finish_task(self, task: asyncio.Task): task_ctx = task.get_context().run(EventContext.get) self._tasks.pop(task_ctx.txid, None) + future = self._futures.pop(task_ctx.txid, None) if task.done(): try: - task.result() + result = task.result() except asyncio.CancelledError: - pass + if future is not None and not future.done(): + future.cancel() except Exception as ex: + if future is not None and not future.done(): + future.set_exception(ex) telemetry.send_error(ex, context="backend") if ( not task.get_name().startswith("reflex_backend_exception_handler|") @@ -463,6 +564,9 @@ def _finish_task(self, task: asyncio.Task): f"Error in {task.get_name()} [txid={task_ctx.txid}]:\n{traceback.format_exc()}" ) ) + else: + if future is not None and not future.done(): + future.set_result(result) __all__ = [ From a073b0de0acbb903f160f1e3c74ee89db7bf676b Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 30 Mar 2026 15:11:23 -0700 Subject: [PATCH 23/81] Adapt upload endpoint to new EventProcessor --- .../reflex_components_core/core/_upload.py | 97 ++++--------------- 1 file changed, 19 insertions(+), 78 deletions(-) diff --git a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py index b096a03cd8a..20b1ef158b1 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py @@ -11,8 +11,8 @@ from typing import TYPE_CHECKING, Any, BinaryIO, cast from python_multipart.multipart import MultipartParser, parse_options_header -from reflex_core import constants from reflex_core.utils import exceptions +from reflex_core.utils.format import json_dumps from starlette.datastructures import Headers from starlette.datastructures import UploadFile as StarletteUploadFile from starlette.exceptions import HTTPException @@ -21,12 +21,13 @@ from starlette.responses import JSONResponse, Response, StreamingResponse from typing_extensions import Self +from reflex._internal.registry import RegistrationContext +from reflex.state import StateUpdate + if TYPE_CHECKING: - from reflex_core.event import EventHandler from reflex_core.utils.types import Receive, Scope, Send from reflex.app import App - from reflex.state import BaseState @dataclasses.dataclass(frozen=True) @@ -102,7 +103,7 @@ def __init__(self, *, maxsize: int = 8): self._condition = asyncio.Condition() self._closed = False self._error: Exception | None = None - self._consumer_task: asyncio.Task[Any] | None = None + self._consumer_task: asyncio.Future[Any] | None = None def __aiter__(self) -> Self: """Return the iterator itself. @@ -135,7 +136,7 @@ async def __anext__(self) -> UploadChunk: raise self._error raise StopAsyncIteration - def set_consumer_task(self, task: asyncio.Task[Any]) -> None: + def set_consumer_task(self, task: asyncio.Future[Any]) -> None: """Track the task consuming this iterator. Args: @@ -206,7 +207,7 @@ def _raise_if_consumer_finished(self) -> None: raise RuntimeError(msg) from task_exc raise RuntimeError(msg) - def _wake_waiters(self, task: asyncio.Task[Any]) -> None: + def _wake_waiters(self, task: asyncio.Future[Any]) -> None: """Wake any producers or consumers blocked on the iterator condition. Args: @@ -446,51 +447,6 @@ def _require_upload_headers(request: Request) -> tuple[str, str]: return token, handler -async def _get_upload_runtime_handler( - app: App, - token: str, - handler_name: str, -) -> tuple[BaseState, EventHandler]: - """Resolve the runtime state and event handler for an upload request. - - Args: - app: The Reflex app. - token: The client token. - handler_name: The fully qualified event handler name. - - Returns: - The root state instance and resolved event handler. - """ - from reflex.state import _substate_key - - substate_token = _substate_key(token, handler_name.rpartition(".")[0]) - state = await app.state_manager.get_state(substate_token) - _current_state, event_handler = state._get_event_handler(handler_name) - return state, event_handler - - -def _seed_upload_router_data(state: BaseState, token: str) -> None: - """Ensure upload-launched handlers have the client token in router state. - - Background upload handlers use ``StateProxy`` which derives its mutable-state - token from ``self.router.session.client_token``. Upload requests do not flow - through the normal websocket event pipeline, so we seed the token here. - - Args: - state: The root state instance. - token: The client token from the upload request. - """ - from reflex.state import RouterData - - router_data = dict(state.router_data) - if router_data.get(constants.RouteVar.CLIENT_TOKEN) == token: - return - - router_data[constants.RouteVar.CLIENT_TOKEN] = token - state.router_data = router_data - state.router = RouterData.from_router_data(router_data) - - async def _upload_buffered_file( request: Request, app: App, @@ -545,7 +501,6 @@ def _create_upload_event() -> Event: ) return Event( - token=token, name=handler_name, payload={handler_upload_param[0]: file_uploads}, ) @@ -567,12 +522,9 @@ async def _ndjson_updates(): Yields: Each state update as newline-delimited JSON. """ - async with app.state_manager.modify_state_with_links( - event.substate_token, event=event - ) as state: - async for update in state._process(event): - update = await app._postprocess(state, event, update) - yield update.json() + "\n" + # Enqueue the task on the main event loop, but emit deltas to the local queue. + async for delta in app.event_processor.enqueue_stream_delta(token, event): + yield json_dumps(StateUpdate(delta=delta)) + "\n" return _UploadStreamingResponse( _ndjson_updates(), @@ -583,10 +535,9 @@ async def _ndjson_updates(): def _background_upload_accepted_response() -> StreamingResponse: """Return a minimal ndjson response for background upload dispatch.""" - from reflex.state import StateUpdate def _accepted_updates(): - yield StateUpdate(final=True).json() + "\n" + yield "{}\n" return StreamingResponse( _accepted_updates(), @@ -613,23 +564,12 @@ async def _upload_chunk_file( chunk_iter = UploadChunkIterator(maxsize=8) event = Event( - token=token, name=handler_name, payload={handler_upload_param[0]: chunk_iter}, ) + task_future = await app.event_processor.enqueue(token, event) - async with app.state_manager.modify_state_with_links( - event.substate_token, - event=event, - ) as state: - _seed_upload_router_data(state, token) - task = app._process_background(state, event) - - if task is None: - msg = f"@rx.event(background=True) is required for upload_files_chunk handler `{handler_name}`." - return JSONResponse({"detail": msg}, status_code=400) - - chunk_iter.set_consumer_task(task) + chunk_iter.set_consumer_task(task_future) parser = _UploadChunkMultipartParser( headers=request.headers, @@ -640,9 +580,9 @@ async def _upload_chunk_file( try: await parser.parse() except ClientDisconnect: - task.cancel() + task_future.cancel() with contextlib.suppress(asyncio.CancelledError): - await task + await task_future return Response() except (MultiPartException, RuntimeError, ValueError) as err: await chunk_iter.fail(err) @@ -688,9 +628,10 @@ async def upload_file(request: Request): ) token, handler_name = _require_upload_headers(request) - _state, event_handler = await _get_upload_runtime_handler( - app, token, handler_name - ) + registered_event_handler = RegistrationContext.get().event_handlers[ + handler_name + ] + event_handler = registered_event_handler.handler if event_handler.is_background: try: From 700dc74a295fd155f622cf1ff06d0ffe6dd0b432 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 30 Mar 2026 15:26:03 -0700 Subject: [PATCH 24/81] Fix test_expiration.py and other new state tests --- reflex/istate/manager/memory.py | 4 +++- tests/units/istate/manager/test_expiration.py | 19 ++++++++++--------- tests/units/test_state.py | 4 +--- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index 514d875e0fb..de7ef02f90e 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -92,7 +92,9 @@ def _purge_expired_tokens(self) -> float | None: ) is not None and state_lock.locked(): continue if expires_at <= now: - self._purge_token(token) + self._purge_token( + BaseStateToken(ident=token, cls=type(self.states[token])) + ) continue if next_expires_at is None or expires_at < next_expires_at: next_expires_at = expires_at diff --git a/tests/units/istate/manager/test_expiration.py b/tests/units/istate/manager/test_expiration.py index ff5c76de458..f20ea71b052 100644 --- a/tests/units/istate/manager/test_expiration.py +++ b/tests/units/istate/manager/test_expiration.py @@ -8,7 +8,8 @@ import pytest_asyncio from reflex.istate.manager.memory import StateManagerMemory -from reflex.state import BaseState, _substate_key +from reflex.istate.manager.token import BaseStateToken +from reflex.state import BaseState class ExpiringState(BaseState): @@ -45,7 +46,7 @@ async def state_manager_memory() -> AsyncGenerator[StateManagerMemory]: Yields: The memory state manager under test. """ - state_manager = StateManagerMemory(state=ExpiringState, token_expiration=1) + state_manager = StateManagerMemory(token_expiration=1) yield state_manager await state_manager.close() @@ -56,7 +57,7 @@ async def test_memory_state_manager_evicts_expired_state( token: str, ): """Expired states should be removed from the in-memory cache and locks.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) async with state_manager_memory.modify_state(state_token) as state: state.value = 42 @@ -80,7 +81,7 @@ async def test_memory_state_manager_get_state_refreshes_expiration( token: str, ): """Accessing a state should extend its expiration window.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) state.value = 7 @@ -105,7 +106,7 @@ async def test_memory_state_manager_set_state_refreshes_expiration( token: str, ): """Persisting a state should extend its expiration window.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) state.value = 17 @@ -130,7 +131,7 @@ async def test_memory_state_manager_multiple_accesses_extend_expiration( token: str, ): """Repeated accesses should keep the state alive until it goes idle.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) expires_at = state_manager_memory._token_expires_at[token] @@ -154,7 +155,7 @@ async def test_memory_state_manager_returns_fresh_state_after_eviction( token: str, ): """A token should get a fresh state after the previous one expires.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) state = await state_manager_memory.get_state(state_token) assert isinstance(state, ExpiringState) state.value = 99 @@ -173,7 +174,7 @@ async def test_memory_state_manager_close_cancels_expiration_task( token: str, ): """Closing the manager should cancel the expiration task cleanly.""" - await state_manager_memory.get_state(_substate_key(token, ExpiringState)) + await state_manager_memory.get_state(BaseStateToken(ident=token, cls=ExpiringState)) expiration_task = state_manager_memory._expiration_task assert expiration_task is not None @@ -193,7 +194,7 @@ async def test_memory_state_manager_refreshes_expiration_after_locked_access( token: str, ): """Releasing a long-held state should start a fresh expiration window.""" - state_token = _substate_key(token, ExpiringState) + state_token = BaseStateToken(ident=token, cls=ExpiringState) async with state_manager_memory.modify_state(state_token) as state: state.value = 5 diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 8d54b2e7be8..06c4fad66fd 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3583,9 +3583,7 @@ def test_state_manager_create_respects_explicit_memory_mode_with_redis_url( with chdir(proj_root): reflex_core.config.get_config(reload=True) monkeypatch.setattr(prerequisites, "get_redis", mock_redis) - from reflex.state import State - - state_manager = StateManager.create(state=State) + state_manager = StateManager.create() assert isinstance(state_manager, StateManagerMemory) del sys.modules[constants.Config.MODULE] From aa15baabdb3c706682f29528765853ee8846d1cf Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 30 Mar 2026 16:38:43 -0700 Subject: [PATCH 25/81] Fix upload tests for new EventProcessor fixtures --- pyi_hashes.json | 2 +- reflex/ievent/processor/event_processor.py | 3 + tests/units/conftest.py | 14 +- tests/units/test_app.py | 266 ++++++++++----------- 4 files changed, 142 insertions(+), 143 deletions(-) diff --git a/pyi_hashes.json b/pyi_hashes.json index 62611f50a6d..c96d2327a65 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -118,7 +118,7 @@ "packages/reflex-components-recharts/src/reflex_components_recharts/polar.pyi": "1ce679c002336c7bdbdd6c8ff6f2413c", "packages/reflex-components-recharts/src/reflex_components_recharts/recharts.pyi": "1b92135de4ea79cb7d94eaaec55b9ab7", "packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "f09c503c4ab880c13c13d6fa67d708b8", - "reflex/__init__.pyi": "7696c38fd9c04a598518b49c5185c414", + "reflex/__init__.pyi": "7c55e3eb5bb9246d5cbb2622d27c51ef", "reflex/components/__init__.pyi": "55bb242d5e5428db329b88b4923c2ba5", "reflex/experimental/memo.pyi": "d16eccf33993c781e2f8bc2dd8bbd4d4" } diff --git a/reflex/ievent/processor/event_processor.py b/reflex/ievent/processor/event_processor.py index ab3f9f1afed..8764714760b 100644 --- a/reflex/ievent/processor/event_processor.py +++ b/reflex/ievent/processor/event_processor.py @@ -546,6 +546,9 @@ def _finish_task(self, task: asyncio.Task): except Exception as ex: if future is not None and not future.done(): future.set_exception(ex) + with contextlib.suppress(BaseException): + # Trigger the future to avoid warnings if the caller didn't wait. + future.result() telemetry.send_error(ex, context="backend") if ( not task.get_name().startswith("reflex_backend_exception_handler|") diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 55b753d5b76..ae26868a3ac 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -314,12 +314,12 @@ def emitted_events() -> list[tuple[str, tuple[Event, ...]]]: return [] -@pytest.fixture -def mock_root_event_context( +@pytest_asyncio.fixture +async def mock_root_event_context( mock_base_state_event_processor_obj: BaseStateEventProcessor, emitted_deltas: list[tuple[str, Mapping[str, Mapping[str, Any]]]], emitted_events: list[tuple[str, tuple[Event, ...]]], -) -> EventContext: +) -> AsyncGenerator[EventContext]: """Create a mock event context. Args: @@ -327,7 +327,7 @@ def mock_root_event_context( emitted_deltas: The list to store emitted deltas. emitted_events: The list to store emitted events. - Returns: + Yields: A mock event context. """ @@ -351,13 +351,15 @@ async def emit_event_impl(token: str, *events: Event) -> None: # noqa: RUF029 """ emitted_events.append((token, events)) - return EventContext( + state_manager = StateManagerMemory() + yield EventContext( token="", - state_manager=StateManagerMemory(), + state_manager=state_manager, enqueue_impl=mock_base_state_event_processor_obj.enqueue, emit_delta_impl=emit_delta_impl, emit_event_impl=emit_event_impl, ) + await state_manager.close() @pytest.fixture diff --git a/tests/units/test_app.py b/tests/units/test_app.py index afdc3bad1f2..cef0f6ea6d6 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -11,9 +11,8 @@ from contextlib import nullcontext as does_not_raise from importlib.util import find_spec from pathlib import Path -from types import SimpleNamespace from typing import TYPE_CHECKING, Any -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, Mock import pytest from pytest_mock import MockerFixture @@ -961,14 +960,14 @@ async def test_dict_mutation_detection__plain_list( ), ], ) -@pytest.mark.skip("Waiting for upload PR") async def test_upload_file( tmp_path: Path, state, delta, token: str, mocker: MockerFixture, - app_module_mock: unittest.mock.Mock, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that file upload works correctly. @@ -978,17 +977,18 @@ async def test_upload_file( delta: Expected delta after processing all files. token: a Token. mocker: pytest mocker object. - app_module_mock: The mock for the app module, used to patch the app instance. + attached_mock_base_state_event_processor: BaseStateEventProcessor Fixture attached to the app instance to capture emitted events. + mock_root_event_context: The mocked root event context, for accessing state_manager. """ mocker.patch( "reflex.state.State.class_subclasses", {state if state is FileUploadState else FileStateBase1}, ) - # The App state must be the "root" of the state tree - app = app_module_mock.app = App() - app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] - async with app.modify_state(BaseStateToken(ident=token, cls=state)) as root_state: - root_state.get_substate(state.get_full_name().split("."))._tmp_path = tmp_path + app = Mock(event_processor=attached_mock_base_state_event_processor) + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=state) + ) as root_state: + (await root_state.get_state(state))._tmp_path = tmp_path data = b"This is binary data" request_mock = unittest.mock.Mock() @@ -1018,24 +1018,20 @@ async def form(): # noqa: RUF029 updates = [] async for state_update in streaming_response.body_iterator: updates.append(json.loads(str(state_update))) - # 2 intermediate yields + 1 final - assert len(updates) == 3 - assert all(not u["final"] for u in updates[:-1]) - assert updates[-1]["final"] + # 2 intermediate yields + assert len(updates) == 2 # The last intermediate update should contain the full cumulative delta. assert updates[1]["delta"] == delta - await app.state_manager.close() - @pytest.mark.asyncio -@pytest.mark.skip("Waiting for upload PR") async def test_upload_file_keeps_form_open_until_stream_completes( tmp_path: Path, token: str, mocker: MockerFixture, - app_module_mock: unittest.mock.Mock, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that upload files are not eagerly copied into memory. @@ -1047,22 +1043,20 @@ async def test_upload_file_keeps_form_open_until_stream_completes( tmp_path: Temporary path. token: A token. mocker: pytest mocker object. - app_module_mock: The mock for the app module, used to patch the app instance. + attached_mock_base_state_event_processor: BaseStateEventProcessor Fixture attached to the app instance to capture emitted events. + mock_root_event_context: The mocked root event context, for accessing state_manager. """ mocker.patch( "reflex.state.State.class_subclasses", {FileUploadState}, ) - app = app_module_mock.app = App() - app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) # Set _tmp_path via modify_state instead of setting class attribute directly. - async with app.modify_state( + async with mock_root_event_context.state_manager.modify_state( BaseStateToken(ident=token, cls=FileUploadState) ) as root_state: - root_state.get_substate( - FileUploadState.get_full_name().split(".") - )._tmp_path = tmp_path + (await root_state.get_state(FileUploadState))._tmp_path = tmp_path request_mock = unittest.mock.Mock() request_mock.headers = { @@ -1121,26 +1115,25 @@ async def send(message): # noqa: RUF029 assert (tmp_path / "image1.jpg").read_bytes() == data1 assert (tmp_path / "image2.jpg").read_bytes() == data2 - await app.state_manager.close() - @pytest.mark.asyncio -@pytest.mark.skip("Waiting for upload PR") async def test_upload_empty_buffered_request_dispatches_alias_handler( token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that empty uploads still dispatch buffered alias handlers.""" mocker.patch( "reflex.state.State.class_subclasses", {FileUploadState}, ) - app = App() - app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) - async with app.modify_state(_substate_key(token, FileUploadState)) as root_state: - substate = root_state.get_substate(FileUploadState.get_full_name().split(".")) - substate.img_list = [] + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=FileUploadState) + ) as root_state: + (await root_state.get_state(FileUploadState)).img_list = [] request_mock = unittest.mock.Mock() request_mock.headers = { @@ -1161,35 +1154,33 @@ async def form(): # noqa: RUF029 async for state_update in streaming_response.body_iterator: updates.append(json.loads(str(state_update))) - assert updates[-1]["final"] - + assert len(updates) == 1 + assert updates[0]["delta"] == { + FileUploadState.get_full_name(): {"img_list" + FIELD_MARKER: ["count:0"]} + } if environment.REFLEX_OPLOCK_ENABLED.get(): - await app.state_manager.close() + await mock_root_event_context.state_manager.close() - state = await app.state_manager.get_state(_substate_key(token, FileUploadState)) - substate = ( - state - if isinstance(state, FileUploadState) - else state.get_substate(FileUploadState.get_full_name().split(".")) + state = await mock_root_event_context.state_manager.get_state( + BaseStateToken(ident=token, cls=FileUploadState) ) + substate = await state.get_state(FileUploadState) assert isinstance(substate, FileUploadState) assert substate.img_list == ["count:0"] - await app.state_manager.close() - @pytest.mark.asyncio -@pytest.mark.skip("Waiting for upload PR") -async def test_upload_file_closes_form_on_event_creation_cancellation( +async def test_upload_file_closes_form_on_form_error( token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, ): """Test that cancellation before form parsing leaves form data untouched.""" mocker.patch( "reflex.state.State.class_subclasses", {FileUploadState}, ) - app = App() + app = Mock(event_processor=attached_mock_base_state_event_processor) request_mock = unittest.mock.Mock() request_mock.headers = { @@ -1203,15 +1194,11 @@ async def test_upload_file_closes_form_on_event_creation_cancellation( form_close = AsyncMock(side_effect=original_close) form_data.close = form_close - async def form(): # noqa: RUF029 - return form_data - - async def cancelled_get_state(*_args, **_kwargs): + async def cancelled_form(): await asyncio.sleep(0) raise asyncio.CancelledError - request_mock.form = form - mocker.patch.object(app.state_manager, "get_state", side_effect=cancelled_get_state) + request_mock.form = cancelled_form upload_fn = upload(app) with pytest.raises(asyncio.CancelledError): @@ -1220,31 +1207,70 @@ async def cancelled_get_state(*_args, **_kwargs): assert form_close.await_count == 0 assert not file1.file.closed - await app.state_manager.close() + +@pytest.mark.asyncio +async def test_upload_file_closes_form_on_event_creation_cancellation( + token: str, + mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, +): + """Test that cancellation during event creation closes form data.""" + mocker.patch( + "reflex.state.State.class_subclasses", + {FileUploadState}, + ) + app = Mock(event_processor=attached_mock_base_state_event_processor) + + request_mock = unittest.mock.Mock() + request_mock.headers = { + "reflex-client-token": token, + "reflex-event-handler": f"{FileUploadState.get_full_name()}.multi_handle_upload", + } + + bio = io.BytesIO(b"data") + file1 = UploadFile(filename="image1.jpg", file=bio) + form_data = FormData([("files", file1)]) + original_close = form_data.close + form_close = AsyncMock(side_effect=original_close) + form_data.close = form_close + + async def form(): # noqa: RUF029 + return form_data + + request_mock.form = form + + # Patch getlist on the form_data to raise CancelledError during event + # creation (after form is parsed, before streaming begins). + form_data.getlist = Mock(side_effect=asyncio.CancelledError) + + upload_fn = upload(app) + with pytest.raises(asyncio.CancelledError): + await upload_fn(request_mock) + + # Form was parsed, so it should be closed on failure. + assert form_close.await_count == 1 + assert bio.closed @pytest.mark.asyncio -@pytest.mark.skip("Waiting for upload PR") async def test_upload_file_closes_form_if_response_cancelled_before_stream_starts( tmp_path: Path, token: str, mocker: MockerFixture, - app_module_mock: unittest.mock.Mock, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that response cancellation before iteration still closes form data.""" mocker.patch( "reflex.state.State.class_subclasses", {FileUploadState}, ) - app = app_module_mock.app = App() - app.event_namespace.emit = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) - async with app.modify_state( + async with mock_root_event_context.state_manager.modify_state( BaseStateToken(ident=token, cls=FileUploadState) ) as root_state: - root_state.get_substate( - FileUploadState.get_full_name().split(".") - )._tmp_path = tmp_path + (await root_state.get_state(FileUploadState))._tmp_path = tmp_path request_mock = unittest.mock.Mock() request_mock.headers = { @@ -1289,15 +1315,12 @@ async def send(_message): assert form_close.await_count == 1 assert bio.closed - await app.state_manager.close() - @pytest.mark.asyncio @pytest.mark.parametrize( "state", [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) -@pytest.mark.skip("Waiting for upload PR") async def test_upload_file_without_annotation( state: FileUploadState | ChildFileUploadState | GrandChildFileUploadState, tmp_path: Path, @@ -1341,7 +1364,6 @@ async def form(): # noqa: RUF029 "state", [FileUploadState, ChildFileUploadState, GrandChildFileUploadState], ) -@pytest.mark.skip("Waiting for upload PR") async def test_upload_file_background( state: FileUploadState | ChildFileUploadState | GrandChildFileUploadState, tmp_path: Path, @@ -1440,41 +1462,25 @@ async def stream(): return request_mock -async def _drain_background_tasks(app: App): - """Wait for all background tasks associated with an app. - - Returns: - The gathered background task results. - """ - tasks = tuple(app._background_tasks) - results = await asyncio.gather(*tasks, return_exceptions=True) if tasks else [] - if environment.REFLEX_OPLOCK_ENABLED.get(): - # Redis oplocks can keep completed background-task writes in the local - # lease cache until the manager is closed. - await app.state_manager.close() - return results - - @pytest.mark.asyncio async def test_upload_dispatches_chunk_handlers_on_upload_endpoint( tmp_path, token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that the standard upload endpoint dispatches chunk handlers.""" mocker.patch( "reflex.state.State.class_subclasses", {ChunkUploadState}, ) - app = App() - mocker.patch( - "reflex.utils.prerequisites.get_and_validate_app", - return_value=SimpleNamespace(app=app), - ) - app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) - async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: - substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=ChunkUploadState) + ) as root_state: + substate = await root_state.get_state(ChunkUploadState) substate._tmp_path = tmp_path substate.chunk_records = [] substate.completed_files = [] @@ -1503,17 +1509,16 @@ async def test_upload_dispatches_chunk_handlers_on_upload_endpoint( updates = [] async for state_update in response.body_iterator: updates.append(json.loads(str(state_update))) - assert updates == [{"delta": {}, "events": [], "final": True}] + assert updates == [{}] - task_results = await _drain_background_tasks(app) - assert all(result is None for result in task_results) + await attached_mock_base_state_event_processor.join() + if environment.REFLEX_OPLOCK_ENABLED.get(): + await mock_root_event_context.state_manager.close() - state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) - substate = ( - state - if isinstance(state, ChunkUploadState) - else state.get_substate(ChunkUploadState.get_full_name().split(".")) + state = await mock_root_event_context.state_manager.get_state( + BaseStateToken(ident=token, cls=ChunkUploadState) ) + substate = await state.get_state(ChunkUploadState) assert isinstance(substate, ChunkUploadState) parsed_chunk_records = [ (filename, int(offset), int(size), content_type) @@ -1550,31 +1555,26 @@ async def test_upload_dispatches_chunk_handlers_on_upload_endpoint( assert substate.completed_files == ["alpha.txt", "beta.txt"] assert (tmp_path / "alpha.txt").read_bytes() == b"abcde" assert (tmp_path / "beta.txt").read_bytes() == b"12345" - assert app.event_namespace.emit_update.await_count >= 1 # pyright: ignore [reportOptionalMemberAccess] - assert not app._background_tasks - - await app.state_manager.close() @pytest.mark.asyncio async def test_upload_empty_chunk_request_dispatches_alias_handler( token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that empty uploads still dispatch chunk alias handlers.""" mocker.patch( "reflex.state.State.class_subclasses", {ChunkUploadState}, ) - app = App() - mocker.patch( - "reflex.utils.prerequisites.get_and_validate_app", - return_value=SimpleNamespace(app=app), - ) - app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) - async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: - substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=ChunkUploadState) + ) as root_state: + substate = await root_state.get_state(ChunkUploadState) substate.chunk_records = [] substate.completed_files = [] @@ -1595,44 +1595,41 @@ async def test_upload_empty_chunk_request_dispatches_alias_handler( updates = [] async for state_update in response.body_iterator: updates.append(json.loads(str(state_update))) - assert updates == [{"delta": {}, "events": [], "final": True}] + assert updates == [{}] - task_results = await _drain_background_tasks(app) - assert all(result is None for result in task_results) + await attached_mock_base_state_event_processor.join() + if environment.REFLEX_OPLOCK_ENABLED.get(): + await mock_root_event_context.state_manager.close() - state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) - substate = ( - state - if isinstance(state, ChunkUploadState) - else state.get_substate(ChunkUploadState.get_full_name().split(".")) + state = await mock_root_event_context.state_manager.get_state( + BaseStateToken(ident=token, cls=ChunkUploadState) ) + substate = await state.get_state(ChunkUploadState) assert isinstance(substate, ChunkUploadState) assert substate.chunk_records == [] assert substate.completed_files == ["chunks:0"] - assert not app._background_tasks - - await app.state_manager.close() @pytest.mark.asyncio async def test_upload_chunk_invalid_offset_returns_400( token: str, mocker: MockerFixture, + attached_mock_base_state_event_processor: BaseStateEventProcessor, + mock_root_event_context: EventContext, ): """Test that malformed chunk metadata fails the standard upload request.""" mocker.patch( "reflex.state.State.class_subclasses", {ChunkUploadState}, ) - app = App() - mocker.patch( - "reflex.utils.prerequisites.get_and_validate_app", - return_value=SimpleNamespace(app=app), - ) - app.event_namespace.emit_update = AsyncMock() # pyright: ignore [reportOptionalMemberAccess] + app = Mock(event_processor=attached_mock_base_state_event_processor) + # The background task is expected to fail with a parse error for malformed input. + attached_mock_base_state_event_processor.backend_exception_handler = None - async with app.modify_state(_substate_key(token, ChunkUploadState)) as root_state: - substate = root_state.get_substate(ChunkUploadState.get_full_name().split(".")) + async with mock_root_event_context.state_manager.modify_state( + BaseStateToken(ident=token, cls=ChunkUploadState) + ) as root_state: + substate = await root_state.get_state(ChunkUploadState) substate.chunk_records = [] substate.completed_files = [] @@ -1651,20 +1648,17 @@ async def test_upload_chunk_invalid_offset_returns_400( "detail": "Missing boundary in multipart." } - await _drain_background_tasks(app) + await attached_mock_base_state_event_processor.join() + if environment.REFLEX_OPLOCK_ENABLED.get(): + await mock_root_event_context.state_manager.close() - state = await app.state_manager.get_state(_substate_key(token, ChunkUploadState)) - substate = ( - state - if isinstance(state, ChunkUploadState) - else state.get_substate(ChunkUploadState.get_full_name().split(".")) + state = await mock_root_event_context.state_manager.get_state( + BaseStateToken(ident=token, cls=ChunkUploadState) ) + substate = await state.get_state(ChunkUploadState) assert isinstance(substate, ChunkUploadState) assert substate.chunk_records == [] assert substate.completed_files == [] - assert not app._background_tasks - - await app.state_manager.close() class DynamicState(State): From b709bee7b231bd5c547bcff53adb12625dc33ddc Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 30 Mar 2026 17:52:19 -0700 Subject: [PATCH 26/81] add OPLOCK_ENABLED state_manager.close to tests --- tests/units/test_state.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 06c4fad66fd..239969cd8dc 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2385,6 +2385,9 @@ async def test_background_task_no_block( "private", ] + if environment.REFLEX_OPLOCK_ENABLED.get(): + await state_manager.close() + background_task_state = await state_manager.get_state( BaseStateToken(ident=token, cls=BackgroundTaskState) ) @@ -2416,6 +2419,9 @@ async def test_background_task_reset( ), ) + if environment.REFLEX_OPLOCK_ENABLED.get(): + await state_manager.close() + background_task_state = await state_manager.get_state( BaseStateToken(ident=token, cls=BackgroundTaskState) ) @@ -3425,6 +3431,9 @@ async def test_setvar( async with mock_base_state_event_processor as processor: await processor.enqueue(token, *events) + if environment.REFLEX_OPLOCK_ENABLED.get(): + await state_manager.close() + state = await state_manager.get_state(BaseStateToken(ident=token, cls=TestState)) assert isinstance(state, TestState) assert state.num1 == 42 @@ -3435,6 +3444,9 @@ async def test_setvar( async with mock_base_state_event_processor as processor: await processor.enqueue(token, *events) + if environment.REFLEX_OPLOCK_ENABLED.get(): + await state_manager.close() + state = await state_manager.get_state(BaseStateToken(ident=token, cls=TestState)) assert isinstance(state, TestState) assert state.array == [43] From 9e706ccf5fcd5dbc697918a91e9aeeaba09dd6fa Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Tue, 31 Mar 2026 13:20:56 -0700 Subject: [PATCH 27/81] state.js: pass around params as a ref The function () => params.current baked inside the ensureSocketConnected function was getting a stale reference and the early events (hydrate, on load, client state) were missing the query parameters in their router_data and thus on_load was not working correctly. --- .../reflex_core/.templates/web/utils/state.js | 26 +++++-------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js b/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js index 07e1c5f0f61..cc8ce1c814b 100644 --- a/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js +++ b/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js @@ -391,7 +391,7 @@ export const applyEvent = async (event, socket, navigate, params) => { }; const query = { ...Object.fromEntries(new URLSearchParams(window.location.search)), - ...params(), + ...params.current, }; if (query && Object.keys(query).length > 0) { event.router_data.query = query; @@ -616,17 +616,11 @@ export const connect = async ( window.addEventListener("unload", disconnectTrigger); if (socket.current.rehydrate) { socket.current.rehydrate = false; - queueEvents( - initialEvents(), - socket, - true, - navigate, - () => params.current, - ); + queueEvents(initialEvents(), socket, true, navigate, params); } // Drain any initial events from the queue. while (event_queue.length > 0) { - await processEvent(socket.current, navigate, () => params.current); + await processEvent(socket.current, navigate, params); } }); @@ -920,7 +914,7 @@ export const useEventLoop = ( setConnectErrors, client_storage, navigate, - () => params.current, + params, ); } }, [ @@ -948,7 +942,7 @@ export const useEventLoop = ( } return applyEventActions( - () => queueEvents(_events, socket, false, navigate, () => params.current), + () => queueEvents(_events, socket, false, navigate, params), event_actions, args, _events.map((e) => e.name).join("+++"), @@ -959,13 +953,7 @@ export const useEventLoop = ( const sentHydrate = useRef(false); // Avoid double-hydrate due to React strict-mode useEffect(() => { if (!sentHydrate.current) { - queueEvents( - initial_events(), - socket, - true, - navigate, - () => params.current, - ); + queueEvents(initial_events(), socket, true, navigate, params); sentHydrate.current = true; } }, []); @@ -1031,7 +1019,7 @@ export const useEventLoop = ( // Process all outstanding events. while (event_queue.length > 0) { await ensureSocketConnected(); - await processEvent(socket.current, navigate, () => params.current); + await processEvent(socket.current, navigate, params); } })(); }); From fc311d3f66844d7a13f14f2b214d1793cd8bd6e3 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 00:30:21 -0700 Subject: [PATCH 28/81] registry: substate tracking and stateful component cache Store the state_full_name to substate mapping in RegistrationContext Make it easier to register / re-register select states and event handlers in a new RegistrationContext Store StatefulComponent cached components in RegistrationContext for easier resetting/dropping after compilation or for use in testing. --- .../reflex_components_core/core/_upload.py | 3 +- .../src/reflex_core/components/component.py | 10 +-- packages/reflex-core/src/reflex_core/event.py | 8 -- reflex/_internal/registry.py | 81 ++++++++++++++++++- reflex/app.py | 2 +- reflex/state.py | 39 +++++---- 6 files changed, 109 insertions(+), 34 deletions(-) diff --git a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py index 20b1ef158b1..7dc9face947 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py @@ -21,7 +21,6 @@ from starlette.responses import JSONResponse, Response, StreamingResponse from typing_extensions import Self -from reflex._internal.registry import RegistrationContext from reflex.state import StateUpdate if TYPE_CHECKING: @@ -627,6 +626,8 @@ async def upload_file(request: Request): resolve_upload_handler_param, ) + from reflex._internal.registry import RegistrationContext + token, handler_name = _require_upload_headers(request) registered_event_handler = RegistrationContext.get().event_handlers[ handler_name diff --git a/packages/reflex-core/src/reflex_core/components/component.py b/packages/reflex-core/src/reflex_core/components/component.py index 0a0b8794264..79e00eba504 100644 --- a/packages/reflex-core/src/reflex_core/components/component.py +++ b/packages/reflex-core/src/reflex_core/components/component.py @@ -2386,9 +2386,6 @@ class StatefulComponent(BaseComponent): was created with. """ - # A lookup table to caching memoized component instances. - tag_to_stateful_component: ClassVar[dict[str, StatefulComponent]] = {} - # Reference to the original component that was memoized into this component. component: Component = field( default_factory=Component, is_javascript_property=False @@ -2422,6 +2419,8 @@ def create(cls, component: Component) -> StatefulComponent | None: """ from reflex_components_core.core.foreach import Foreach + from reflex._internal.registry import RegistrationContext + if component._memoization_mode.disposition == MemoizationDisposition.NEVER: # Never memoize this component. return None @@ -2466,11 +2465,12 @@ def create(cls, component: Component) -> StatefulComponent | None: return None # Look up the tag in the cache - stateful_component = cls.tag_to_stateful_component.get(tag_name) + ctx = RegistrationContext.get() + stateful_component = ctx.tag_to_stateful_component.get(tag_name) if stateful_component is None: memo_trigger_hooks = cls._fix_event_triggers(component) # Set the stateful component in the cache for the given tag. - stateful_component = cls.tag_to_stateful_component.setdefault( + stateful_component = ctx.tag_to_stateful_component.setdefault( tag_name, cls( children=component.children, diff --git a/packages/reflex-core/src/reflex_core/event.py b/packages/reflex-core/src/reflex_core/event.py index 59411629541..246799fc0b8 100644 --- a/packages/reflex-core/src/reflex_core/event.py +++ b/packages/reflex-core/src/reflex_core/event.py @@ -342,14 +342,6 @@ class EventHandler(EventActionsMixin): state: type[BaseState] | None = dataclasses.field(default=None, repr=False) - def __post_init__(self): - """Register the event handler.""" - from reflex._internal.registry import RegistrationContext - - RegistrationContext.register_event_handler( - self, states=(self.state,) if self.state else () - ) - @property def state_full_name(self) -> str: """Get the full name of the state class this event handler is attached to. diff --git a/reflex/_internal/registry.py b/reflex/_internal/registry.py index f3162fbdf20..7d22d9c26ca 100644 --- a/reflex/_internal/registry.py +++ b/reflex/_internal/registry.py @@ -3,9 +3,13 @@ import dataclasses from typing import TYPE_CHECKING, Self +from reflex_core.utils.exceptions import StateValueError + from reflex._internal.context.base import BaseContext if TYPE_CHECKING: + from reflex_core.components.component import StatefulComponent + from reflex.event import EventHandler from reflex.state import BaseState @@ -30,6 +34,14 @@ class RegistrationContext(BaseContext): default_factory=dict, repr=False, ) + base_state_substates: dict[str, set[type[BaseState]]] = dataclasses.field( + default_factory=dict, + repr=False, + ) + tag_to_stateful_component: dict[str, StatefulComponent] = dataclasses.field( + default_factory=dict, + repr=False, + ) @classmethod def ensure_context(cls) -> Self: @@ -50,13 +62,45 @@ def ensure_context(cls) -> Self: def register_base_state(cls, state_cls: type[BaseState]) -> type[BaseState]: """Register a base state class with its full name. + Also registers parent_state until finding one that is already registered. + + Args: + state_cls: The base state class to register. + + Returns: + The registered base state class. + """ + return cls.ensure_context()._register_base_state(state_cls) + + def _register_base_state(self, state_cls: type[BaseState]) -> type[BaseState]: + """Register a base state class with its full name. + + Also registers parent_state until finding one that is already registered. + Args: state_cls: The base state class to register. Returns: The registered base state class. """ - cls.ensure_context().base_states[state_cls.get_full_name()] = state_cls + self.base_states[state_cls.get_full_name()] = state_cls + for event_handler in state_cls.event_handlers.values(): + self._register_event_handler(event_handler, states=(state_cls,)) + if (parent_state := state_cls.get_parent_state()) is not None: + if parent_state.get_full_name() not in self.base_states: + self._register_base_state(parent_state) + parent_state_substates = self.base_state_substates.setdefault( + parent_state.get_full_name(), set() + ) + if state_cls in parent_state_substates: + msg = ( + f"State class {state_cls.get_full_name()} is already registered as a substate of " + f"{parent_state.get_full_name()}. This likely means there are multiple classes with the same name " + "in the same module, which causes a conflict in the registry. Please rename one of the classes to avoid " + "shadowing. Shadowing substate classes is not allowed." + ) + raise StateValueError(msg) + parent_state_substates.add(state_cls) return state_cls @classmethod @@ -65,6 +109,22 @@ def register_event_handler( ) -> EventHandler: """Register an event handler with its full name and associated states. + Args: + handler: The event handler to register. + states: The states associated with the event handler. + + Returns: + The registered event handler. + """ + return cls.ensure_context()._register_event_handler(handler, states=states) + + def _register_event_handler( + self, + handler: EventHandler, + states: tuple[type[BaseState], ...] = (), + ) -> EventHandler: + """Register an event handler with its full name and associated states. + Args: handler: The event handler to register. states: The states associated with the event handler. @@ -75,7 +135,24 @@ def register_event_handler( from reflex.utils.format import format_event_handler full_name = format_event_handler(handler) - cls.ensure_context().event_handlers[full_name] = RegisteredEventHandler( + self.event_handlers[full_name] = RegisteredEventHandler( handler=handler, states=states ) return handler + + def get_substates( + self, base_state_cls: type[BaseState] | str + ) -> set[type[BaseState]]: + """Get the substates for a base state class. + + Args: + base_state_cls: The base state class to get substates for. + + Returns: + A set of substate classes. + """ + if isinstance(base_state_cls, str): + return self.base_state_substates.setdefault(base_state_cls, set()) + return self.base_state_substates.setdefault( + base_state_cls.get_full_name(), set() + ) diff --git a/reflex/app.py b/reflex/app.py index 6915603aebc..301fd16bfc3 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -1124,7 +1124,7 @@ def _validate_var_dependencies(self, state: type[BaseState] | None = None) -> No msg = f"ComputedVar {var._name} on state {state.__name__} has an invalid dependency {state_name}.{dep}" raise exceptions.VarDependencyError(msg) - for substate in state.class_subclasses: + for substate in state.get_substates(): self._validate_var_dependencies(substate) def _compile( diff --git a/reflex/state.py b/reflex/state.py index bf5e25365fb..756a107d1c1 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -326,7 +326,6 @@ def _override_base_method(fn: Callable[PARAMS, RETURN]) -> Callable[PARAMS, RETU "backend_vars", "inherited_backend_vars", "event_handlers", - "class_subclasses", "_var_dependencies", "_always_dirty_computed_vars", "_always_dirty_substates", @@ -358,9 +357,6 @@ class BaseState(EvenMoreBasicBaseState): # The event handlers. event_handlers: ClassVar[builtins.dict[str, EventHandler]] = {} - # A set of subclassses of this class. - class_subclasses: ClassVar[set[type[BaseState]]] = set() - # Mapping of var name to set of (state_full_name, var_name) that depend on it. _var_dependencies: ClassVar[builtins.dict[str, set[tuple[str, str]]]] = {} @@ -522,9 +518,6 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): # Computed vars should not shadow builtin state props. cls._check_overridden_basevars() - # Reset subclass tracking for this class. - cls.class_subclasses = set() - # Reset dirty substate tracking for this class. cls._always_dirty_substates = set() cls._potentially_dirty_states = set() @@ -536,15 +529,13 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): cls.inherited_backend_vars = parent_state.backend_vars # Check if another substate class with the same name has already been defined. - if cls.get_name() in {c.get_name() for c in parent_state.class_subclasses}: + if cls.get_name() in {c.get_name() for c in parent_state.get_substates()}: # This should not happen, since we have added module prefix to state names in #3214 msg = ( f"The substate class '{cls.get_name()}' has been defined multiple times. " "Shadowing substate classes is not allowed." ) raise StateValueError(msg) - # Track this new subclass in the parent state's subclasses set. - parent_state.class_subclasses.add(cls) # Get computed vars. computed_vars = cls._get_computed_vars() @@ -628,12 +619,13 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): cls.event_handlers[name] = handler setattr(cls, name, handler) + RegistrationContext.register_base_state(cls) + # Initialize per-class var dependency tracking. cls._var_dependencies = {} cls._init_var_dependency_dicts() all_base_state_classes[cls.get_full_name()] = None - RegistrationContext.register_base_state(cls) @classmethod def _add_event_handler( @@ -969,7 +961,9 @@ def get_substates(cls) -> set[type[BaseState]]: Returns: The substates of the state. """ - return cls.class_subclasses + from reflex._internal.registry import RegistrationContext + + return RegistrationContext.get().get_substates(cls) @classmethod @functools.lru_cache @@ -1123,7 +1117,7 @@ def add_var(cls, name: str, type_: Any, default_value: Any = None): cls.vars.update({name: var}) # let substates know about the new variable - for substate_class in cls.class_subclasses: + for substate_class in cls.get_substates(): substate_class.vars.setdefault(name, var) # Reinitialize dependency tracking dicts. @@ -1152,10 +1146,17 @@ def _create_event_handler( Returns: The event handler. """ + from reflex._internal.registry import RegistrationContext + # Check if function has stored event_actions from decorator event_actions = getattr(fn, EVENT_ACTIONS_MARKER, {}) - return event_handler_cls(fn=fn, state=cls, event_actions=event_actions) + handler = event_handler_cls(fn=fn, state=cls, event_actions=event_actions) + if cls.get_full_name() in all_base_state_classes: + # Register handlers created after the class was registered. + reg_ctx = RegistrationContext.get() + reg_ctx.register_event_handler(handler, states=(cls,)) + return handler @classmethod def _create_setvar(cls): @@ -1259,7 +1260,7 @@ def _update_substate_inherited_vars(cls, vars_to_add: dict[str, Var]): Args: vars_to_add: names to Var instances to add to substates """ - for substate_class in cls.class_subclasses: + for substate_class in cls.get_substates(): for name, var in vars_to_add.items(): if types.is_backend_base_variable(name, cls): substate_class.backend_vars.setdefault(name, var) @@ -2569,6 +2570,8 @@ def reload_state_module( state: Recursive argument for the state class to reload. """ + from reflex._internal.registry import RegistrationContext + # Reset the _app_ref of OnLoadInternalState to avoid stale references. if state is OnLoadInternalState: state._app_ref = None @@ -2580,11 +2583,13 @@ def reload_state_module( and module is not None ): state._potentially_dirty_states.remove(pd_state) - for subclass in tuple(state.class_subclasses): + reg_ctx = RegistrationContext.get() + substates = reg_ctx.get_substates(state) + for subclass in tuple(substates): reload_state_module(module=module, state=subclass) if subclass.__module__ == module and module is not None: all_base_state_classes.pop(subclass.get_full_name(), None) - state.class_subclasses.remove(subclass) + substates.remove(subclass) state._always_dirty_substates.discard(subclass.get_name()) state._var_dependencies = {} state._init_var_dependency_dicts() From d9a996dfaae75ddd2a12d7ae816c0611738ddb80 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 00:32:17 -0700 Subject: [PATCH 29/81] close old locks in disk/memory state manager --- reflex/istate/manager/disk.py | 4 ++++ reflex/istate/manager/memory.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/reflex/istate/manager/disk.py b/reflex/istate/manager/disk.py index 6fdbeb7ae43..504b0061343 100644 --- a/reflex/istate/manager/disk.py +++ b/reflex/istate/manager/disk.py @@ -381,3 +381,7 @@ async def close(self): with contextlib.suppress(asyncio.CancelledError): await self._write_queue_task self._write_queue_task = None + # Dump unlocked locks. + for token, lock in tuple(self._states_locks.items()): + if not lock.locked(): + self._states_locks.pop(token) diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index de7ef02f90e..456df5841b9 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -209,3 +209,7 @@ async def close(self): with contextlib.suppress(asyncio.CancelledError): await self._expiration_task self._expiration_task = None + # Dump unlocked locks. + for token, lock in tuple(self._states_locks.items()): + if not lock.locked(): + self._states_locks.pop(token) From 55e04498ccea65ec3f3ec35ad7a6542b068ec670 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 00:55:33 -0700 Subject: [PATCH 30/81] Remove state_manager from AppHarness Update all associated tests to make assertions using the browser/app and not attempting to fetch the backend state directly. This makes the tests more robust, reduces state_manager related hacks, and makes the tests easier to eventually migrate to an external process using granian where direct state_manager access will not be available. --- reflex/testing.py | 225 +++------------- tests/integration/test_client_storage.py | 93 ++----- tests/integration/test_component_state.py | 82 ++++-- tests/integration/test_computed_vars.py | 14 - tests/integration/test_connection_banner.py | 55 ++-- tests/integration/test_dynamic_routes.py | 88 ++----- tests/integration/test_event_actions.py | 65 ++--- tests/integration/test_event_chain.py | 139 +++++----- tests/integration/test_form_submit.py | 23 +- tests/integration/test_input.py | 28 +- .../test_memory_state_manager_expiration.py | 4 +- tests/integration/test_upload.py | 248 +++++++++--------- tests/integration/utils.py | 77 +++++- tests/units/conftest.py | 14 + .../middleware/test_hydrate_middleware.py | 10 +- tests/units/test_app.py | 46 +--- tests/units/test_state.py | 15 +- 17 files changed, 504 insertions(+), 722 deletions(-) diff --git a/reflex/testing.py b/reflex/testing.py index 30834925094..a7b7bccc6f2 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -4,6 +4,7 @@ import asyncio import contextlib +import contextvars import dataclasses import functools import inspect @@ -19,11 +20,12 @@ import threading import time import types -from collections.abc import AsyncIterator, Callable, Coroutine, Sequence +from collections.abc import Callable, Coroutine, Sequence +from copy import deepcopy from http.server import SimpleHTTPRequestHandler from importlib.util import find_spec from pathlib import Path -from typing import TYPE_CHECKING, Any, Literal, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar import uvicorn from reflex_core.components.component import CUSTOM_COMPONENTS, CustomComponent @@ -38,13 +40,9 @@ import reflex.utils.format import reflex.utils.prerequisites import reflex.utils.processes +from reflex._internal.registry import RegistrationContext from reflex.experimental.memo import EXPERIMENTAL_MEMOS -from reflex.istate.manager import StateManager -from reflex.istate.manager.disk import StateManagerDisk -from reflex.istate.manager.memory import StateManagerMemory -from reflex.istate.manager.redis import StateManagerRedis -from reflex.istate.manager.token import BaseStateToken -from reflex.state import BaseState, State, _split_substate_key, reload_state_module +from reflex.state import reload_state_module from reflex.utils import console, js_runtimes from reflex.utils.export import export from reflex.utils.token_manager import TokenManager @@ -119,8 +117,9 @@ class AppHarness: frontend_output_thread: threading.Thread | None = None backend_thread: threading.Thread | None = None backend: uvicorn.Server | None = None - state_manager: StateManager | None = None _frontends: list[WebDriver] = dataclasses.field(default_factory=list) + _registry_token: contextvars.Token[RegistrationContext] | None = None + _base_registration_context: ClassVar[RegistrationContext] | None = None @classmethod def create( @@ -239,7 +238,6 @@ def _get_source_from_app_source(self, app_source: Any) -> str: def _initialize_app(self): # disable telemetry reporting for tests - os.environ["REFLEX_TELEMETRY_ENABLED"] = "false" # Reset global memo registries so previous AppHarness apps do not # leak compiled component definitions into the next test app. @@ -270,6 +268,14 @@ def _initialize_app(self): with chdir(self.app_path): reflex.utils.prerequisites.initialize_frontend_dependencies() with chdir(self.app_path): + # Use a new registration context for a new app. + if AppHarness._base_registration_context is None: + # Save the initial RegistrationContext for the app if we haven't already + AppHarness._base_registration_context = ( + RegistrationContext.ensure_context() + ) + new_registration_context = deepcopy(AppHarness._base_registration_context) + self._registry_token = RegistrationContext.set(new_registration_context) # ensure config and app are reloaded when testing different app config = get_config(reload=True) # Ensure the AppHarness test does not skip State assignment due to running via pytest @@ -286,19 +292,6 @@ def _initialize_app(self): ) ) self.app_asgi = self.app_instance() - if self.app_instance and self.app_instance._state_manager is not None: - if self.app_instance._state is None: - msg = "State is not set." - raise RuntimeError(msg) - if isinstance(self.app_instance._state_manager, StateManagerRedis): - # Create our own redis connection for testing. - self.state_manager = StateManagerRedis.create() - elif isinstance(self.app_instance._state_manager, StateManagerDisk): - self.state_manager = StateManagerDisk.create() - if self.state_manager is None: - self.state_manager = ( - self.app_instance._state_manager if self.app_instance else None - ) def _reload_state_module(self): """Reload the rx.State module to avoid conflict when reloading.""" @@ -350,53 +343,21 @@ def _start_backend(self, port: int = 0): ) ) self.backend.shutdown = self._get_backend_shutdown_handler() + + def _run_backend(context: contextvars.Context) -> None: + if self.backend is not None: + context.run(self.backend.run) + with chdir(self.app_path): print( # noqa: T201 "Creating backend in a new thread..." ) # for pytest diagnosis - self.backend_thread = threading.Thread(target=self.backend.run) + self.backend_thread = threading.Thread( + target=_run_backend, args=(contextvars.copy_context(),) + ) self.backend_thread.start() print("Backend started.") # for pytest diagnosis #noqa: T201 - async def _reset_backend_state_manager(self): - """Reset the StateManagerRedis event loop affinity. - - This is necessary when the backend is restarted and the state manager is a - StateManagerRedis instance. - - Raises: - RuntimeError: when the state manager cannot be reset - """ - if ( - self.app_instance is not None - and self.app_instance._state_manager is not None - ): - with contextlib.suppress(RuntimeError): - await self.app_instance._state_manager.close() - if ( - self.app_instance is not None - and isinstance( - self.app_instance._state_manager, - StateManagerRedis, - ) - and self.app_instance._state is not None - ): - self.app_instance._state_manager = StateManagerRedis.create() - if not isinstance(self.app_instance.state_manager, StateManagerRedis): - msg = "Failed to reset state manager." - raise RuntimeError(msg) - - # Also reset the TokenManager to avoid loop affinity issues - if ( - hasattr(self.app_instance, "event_namespace") - and self.app_instance.event_namespace is not None - and hasattr(self.app_instance.event_namespace, "_token_manager") - ): - # Import here to avoid circular imports - from reflex.utils.token_manager import TokenManager - - self.app_instance.event_namespace._token_manager = TokenManager.create() - def _start_frontend(self): # Set up the frontend. with chdir(self.app_path): @@ -503,6 +464,8 @@ def stop(self) -> None: driver.quit() self._reload_state_module() + if self._registry_token is not None: + RegistrationContext.reset(self._registry_token) if self.backend is not None: self.backend.should_exit = True @@ -715,104 +678,6 @@ def frontend( self._frontends.append(driver) return driver - async def get_state(self, token: str) -> BaseState: - """Get the state associated with the given token. - - Args: - token: The state token to look up. - - Returns: - The state instance associated with the given token - - Raises: - RuntimeError: when the app hasn't started running - """ - if self.state_manager is None: - msg = "state_manager is not set." - raise RuntimeError(msg) - if self.app_instance is not None and isinstance( - self.app_instance.state_manager, StateManagerDisk - ): - # Song and dance to convince the instance's state manager to flush - # (we can't directly await the _other_ loop's Future) - await self.app_instance.state_manager._flush_write_queue() - if isinstance(self.state_manager, StateManagerDisk): - # Force reload the latest state from disk. - client_token, _ = _split_substate_key(token) - self.state_manager.states.pop(client_token, None) - try: - return await self.state_manager.get_state( - BaseStateToken(ident=token, cls=State) - ) - finally: - await self.state_manager.close() - - async def set_state(self, token: str, **kwargs) -> None: - """Set the state associated with the given token. - - Args: - token: The state token to set. - kwargs: Attributes to set on the state. - - Raises: - RuntimeError: when the app hasn't started running - """ - if self.state_manager is None: - msg = "state_manager is not set." - raise RuntimeError(msg) - state = await self.get_state(token) - for key, value in kwargs.items(): - setattr(state, key, value) - try: - await self.state_manager.set_state( - BaseStateToken(ident=token, cls=type(state)), state - ) - finally: - if self.app_instance is not None and isinstance( - self.app_instance.state_manager, StateManagerDisk - ): - # Clear the token from the backend's cache so it will be reloaded. - client_token, _ = _split_substate_key(token) - self.app_instance.state_manager.states.pop(client_token, None) - await self.state_manager.close() - - @contextlib.asynccontextmanager - async def modify_state(self, token: str) -> AsyncIterator[BaseState]: - """Modify the state associated with the given token and send update to frontend. - - Args: - token: The state token to modify - - Yields: - The state instance associated with the given token - - Raises: - RuntimeError: when the app hasn't started running - """ - if self.state_manager is None: - msg = "state_manager is not set." - raise RuntimeError(msg) - if self.app_instance is None or self.app_instance._state is None: - msg = "App is not running." - raise RuntimeError(msg) - app_state_manager = self.app_instance.state_manager - if isinstance(self.state_manager, (StateManagerRedis, StateManagerDisk)): - # Temporarily replace the app's state manager with our own, since - # the redis/disk connection is on the backend_thread event loop - self.app_instance._state_manager = self.state_manager - try: - async with self.app_instance.modify_state( - BaseStateToken(ident=token, cls=self.app_instance._state) - ) as state: - yield state - finally: - if isinstance(app_state_manager, StateManagerDisk): - # Clear the token from the cache so it will be reloaded. - client_token, _ = _split_substate_key(token) - app_state_manager.states.pop(client_token, None) - await self.state_manager.close() - self.app_instance._state_manager = app_state_manager - def token_manager(self) -> TokenManager: """Get the token manager for the app instance. @@ -883,35 +748,6 @@ def poll_for_value( raise TimeoutError(msg) return element.get_attribute("value") - def poll_for_clients(self, timeout: TimeoutType = None) -> dict[str, BaseState]: - """Poll app state_manager for any connected clients. - - Args: - timeout: how long to wait for client states - - Returns: - active state instances when the polling loop exited - - Raises: - RuntimeError: when the app hasn't started running - TimeoutError: when the timeout expires before any states are seen - ValueError: when the state_manager is not a memory state manager - """ - if self.app_instance is None: - msg = "App is not running." - raise RuntimeError(msg) - state_manager = self.app_instance.state_manager - if not isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): - msg = "Only works with memory or disk state manager" - raise ValueError(msg) - if not self._poll_for( - target=lambda: state_manager.states, - timeout=timeout, - ): - msg = "No states were observed while polling." - raise TimeoutError(msg) - return state_manager.states - @staticmethod def poll_for_or_raise_timeout( target: Callable[[], T], @@ -1123,10 +959,17 @@ def _start_backend(self): ), ) self.backend.shutdown = self._get_backend_shutdown_handler() + + def _run_backend(context: contextvars.Context) -> None: + if self.backend is not None: + context.run(self.backend.run) + print( # noqa: T201 "Creating backend in a new thread..." ) - self.backend_thread = threading.Thread(target=self.backend.run) + self.backend_thread = threading.Thread( + target=_run_backend, args=(contextvars.copy_context(),) + ) self.backend_thread.start() print("Backend started.") # for pytest diagnosis #noqa: T201 diff --git a/tests/integration/test_client_storage.py b/tests/integration/test_client_storage.py index dede34b8b17..807a52159d3 100644 --- a/tests/integration/test_client_storage.py +++ b/tests/integration/test_client_storage.py @@ -11,10 +11,6 @@ from selenium.webdriver.firefox.webdriver import WebDriver as Firefox from selenium.webdriver.remote.webdriver import WebDriver -from reflex.istate.manager.disk import StateManagerDisk -from reflex.istate.manager.memory import StateManagerMemory -from reflex.istate.manager.redis import StateManagerRedis -from reflex.state import State, _substate_key from reflex.testing import AppHarness from . import utils @@ -22,6 +18,8 @@ def ClientSide(): """App for testing client-side state.""" + import uuid + import reflex as rx class ClientSideState(rx.State): @@ -36,6 +34,12 @@ def set_state_var(self, value: str): def set_input_value(self, value: str): self.input_value = value + @rx.event + def reset_token_no_hydrate(self): + return rx.run_script( + f"{{token = '{uuid.uuid4()}'; window.sessionStorage.setItem('token', token);}}" + ) + class ClientSideSubState(ClientSideState): # cookies with default settings c1: str = rx.Cookie() @@ -90,6 +94,11 @@ def index(): read_only=True, id="token", ), + rx.button( + "New Token - No Hydrate", + id="new_token", + on_click=ClientSideState.reset_token_no_hydrate, + ), rx.input( placeholder="state var", value=ClientSideState.state_var, @@ -350,7 +359,6 @@ def set_sub_sub(var: str, value: str): set_sub_sub("l1s", "l1s value") set_sub_sub("s1s", "s1s value") - state_name = client_side.get_full_state_name(["_client_side_state"]) sub_state_name = client_side.get_full_state_name([ "_client_side_state", "_client_side_sub_state", @@ -534,9 +542,8 @@ def set_sub_sub(var: str, value: str): assert l1s.text == "l1s value" assert s1s.text == "s1s value" - # reset the backend state to force refresh from client storage - async with client_side.modify_state(f"{token}_{state_name}") as state: - state.reset() + # set a new token to force reloading the values from client + driver.execute_script("window.sessionStorage.setItem('token', '');") driver.refresh() # wait for the backend connection to send the token (again) @@ -640,39 +647,8 @@ def set_sub_sub(var: str, value: str): assert s3.text == "s3 value" # Simulate state expiration - if isinstance(client_side.state_manager, StateManagerRedis): - await client_side.state_manager.redis.delete( - _substate_key(token, State.get_full_name()) - ) - await client_side.state_manager.redis.delete(_substate_key(token, state_name)) - await client_side.state_manager.redis.delete( - _substate_key(token, sub_state_name) - ) - await client_side.state_manager.redis.delete( - _substate_key(token, sub_sub_state_name) - ) - elif isinstance(client_side.state_manager, (StateManagerMemory, StateManagerDisk)): - del client_side.state_manager.states[token] - if ( - client_side.app_instance is not None - and (app_state_manager := client_side.app_instance.state_manager) is not None - and isinstance(app_state_manager, StateManagerDisk) - ): - # Purge the backend's disk manager - app_state_manager.states.pop(token, None) - app_state_manager._write_queue.clear() - og_token_expiration = app_state_manager.token_expiration - app_state_manager.token_expiration = 0 - app_state_manager._purge_expired_states() - app_state_manager.token_expiration = og_token_expiration - - # Ensure the state is gone (not hydrated) - async def poll_for_not_hydrated(): - state = await client_side.get_state(_substate_key(token or "", state_name)) - assert isinstance(state, State) - return not state.is_hydrated - - assert await AppHarness._poll_for_async(poll_for_not_hydrated) + new_token_btn = driver.find_element(By.ID, "new_token") + new_token_btn.click() # Trigger event to get a new instance of the state since the old was expired. set_sub("c1", "c1 post expire") @@ -714,41 +690,6 @@ async def poll_for_not_hydrated(): assert l1s.text == "l1s value" assert s1s.text == "s1s value" - # Get the backend state and ensure the values are still set - async def get_sub_state(): - root_state = await client_side.get_state( - _substate_key(token or "", sub_state_name) - ) - state = root_state.substates[client_side.get_state_name("_client_side_state")] - return state.substates[client_side.get_state_name("_client_side_sub_state")] - - async def poll_for_c1_set(): - sub_state = await get_sub_state() - return sub_state.c1 == "c1 post expire" # pyright: ignore[reportAttributeAccessIssue] - - assert await AppHarness._poll_for_async(poll_for_c1_set) - sub_state = await get_sub_state() - assert sub_state.c1 == "c1 post expire" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c2 == "c2 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c3 == "" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c4 == "c4 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c5 == "c5 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c6 == "c6 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.c7 == "c7 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.l1 == "l1 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.l2 == "l2 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.l3 == "l3 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.l4 == "l4 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.s1 == "s1 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.s2 == "s2 value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_state.s3 == "s3 value" # pyright: ignore[reportAttributeAccessIssue] - sub_sub_state = sub_state.substates[ - client_side.get_state_name("_client_side_sub_sub_state") - ] - assert sub_sub_state.c1s == "c1s value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_sub_state.l1s == "l1s value" # pyright: ignore[reportAttributeAccessIssue] - assert sub_sub_state.s1s == "s1s value" # pyright: ignore[reportAttributeAccessIssue] - # clear the cookie jar and local storage, ensure state reset to default driver.delete_all_cookies() local_storage.clear() diff --git a/tests/integration/test_component_state.py b/tests/integration/test_component_state.py index 0230fa129e4..0d674b55b5f 100644 --- a/tests/integration/test_component_state.py +++ b/tests/integration/test_component_state.py @@ -5,7 +5,6 @@ import pytest from selenium.webdriver.common.by import By -from reflex.state import State, _substate_key from reflex.testing import AppHarness from . import utils @@ -32,15 +31,58 @@ def increment(self): self.count += 1 self._be = self.count # pyright: ignore [reportAttributeAccessIssue] + @rx.event + def assert_be(self, value: E): + assert self._backend_vars != self.backend_vars + assert self._be == int(value) # pyright: ignore [reportAttributeAccessIssue, reportArgumentType] + + @rx.event + def assert_be_none(self): + assert self._backend_vars == self.backend_vars + assert self._be is None # pyright: ignore [reportAttributeAccessIssue] + + @rx.event + def assert_be_int(self, value: int): + assert self._be_int == value # pyright: ignore [reportAttributeAccessIssue] + + @rx.event + def assert_be_str(self, value: str): + assert self._be_str == value # pyright: ignore [reportAttributeAccessIssue] + @classmethod def get_component(cls, *children, **props): + eid = props.get("id", "default") return rx.vstack( *children, - rx.heading(cls.count, id=f"count-{props.get('id', 'default')}"), + rx.heading(cls.count, id=f"count-{eid}"), rx.button( "Increment", on_click=cls.increment, - id=f"button-{props.get('id', 'default')}", + id=f"button-{eid}", + ), + rx.form( + rx.input(id=f"{eid}-assert-be-value", name="be_value"), + rx.button( + "Assert _be", + id=f"{eid}-assert-be", + ), + on_submit=lambda fd: cls.assert_be(fd.to(dict)["be_value"]), # pyright: ignore [reportAttributeAccessIssue] + reset_on_submit=True, + ), + rx.button( + "Assert _be_none", + id=f"{eid}-assert-be-none", + on_click=cls.assert_be_none, + ), + rx.button( + "Assert _be_int == 0", + id=f"{eid}-assert-be-int", + on_click=cls.assert_be_int(0), + ), + rx.button( + "Assert _be_str == '42'", + id=f"{eid}-assert-be-str", + on_click=cls.assert_be_str("42"), ), **props, ) @@ -120,8 +162,7 @@ def component_state_app(tmp_path) -> Generator[AppHarness, None, None]: yield harness -@pytest.mark.asyncio -async def test_component_state_app(component_state_app: AppHarness): +def test_component_state_app(component_state_app: AppHarness): """Increment counters independently. Args: @@ -132,7 +173,6 @@ async def test_component_state_app(component_state_app: AppHarness): ss = utils.SessionStorage(driver) assert AppHarness._poll_for(lambda: ss.get("token") is not None), "token not found" - root_state_token = _substate_key(ss.get("token"), State) count_a = driver.find_element(By.ID, "count-a") count_b = driver.find_element(By.ID, "count-b") @@ -141,16 +181,9 @@ async def test_component_state_app(component_state_app: AppHarness): button_inc_a = driver.find_element(By.ID, "inc-a") # Check that backend vars in mixins are okay - a_state_name = driver.find_element(By.ID, "a_state_name").text - b_state_name = driver.find_element(By.ID, "b_state_name").text - root_state = await component_state_app.get_state(root_state_token) - a_state = root_state.substates[a_state_name] - b_state = root_state.substates[b_state_name] - assert a_state._backend_vars == a_state.backend_vars - assert a_state._backend_vars == b_state._backend_vars - assert a_state._backend_vars["_be"] is None - assert a_state._backend_vars["_be_int"] == 0 - assert a_state._backend_vars["_be_str"] == "42" + driver.find_element(By.ID, "a-assert-be-none").click() + driver.find_element(By.ID, "a-assert-be-int").click() + driver.find_element(By.ID, "a-assert-be-str").click() assert count_a.text == "0" @@ -163,13 +196,9 @@ async def test_component_state_app(component_state_app: AppHarness): button_inc_a.click() assert component_state_app.poll_for_content(count_a, exp_not_equal="2") == "3" - root_state = await component_state_app.get_state(root_state_token) - a_state = root_state.substates[a_state_name] - b_state = root_state.substates[b_state_name] - assert a_state._backend_vars != a_state.backend_vars - assert a_state._be == a_state._backend_vars["_be"] == 3 # pyright: ignore[reportAttributeAccessIssue] - assert b_state._be is None # pyright: ignore[reportAttributeAccessIssue] - assert b_state._backend_vars["_be"] is None + driver.find_element(By.ID, "a-assert-be-value").send_keys("3") + driver.find_element(By.ID, "a-assert-be").click() + driver.find_element(By.ID, "b-assert-be-none").click() assert count_b.text == "0" @@ -179,11 +208,8 @@ async def test_component_state_app(component_state_app: AppHarness): button_b.click() assert component_state_app.poll_for_content(count_b, exp_not_equal="1") == "2" - root_state = await component_state_app.get_state(root_state_token) - a_state = root_state.substates[a_state_name] - b_state = root_state.substates[b_state_name] - assert b_state._backend_vars != b_state.backend_vars - assert b_state._be == b_state._backend_vars["_be"] == 2 # pyright: ignore[reportAttributeAccessIssue] + driver.find_element(By.ID, "b-assert-be-value").send_keys("2") + driver.find_element(By.ID, "b-assert-be").click() # Check locally-defined substate style count_c = driver.find_element(By.ID, "count-c") diff --git a/tests/integration/test_computed_vars.py b/tests/integration/test_computed_vars.py index f4fb7a8d5f2..905b1594cb5 100644 --- a/tests/integration/test_computed_vars.py +++ b/tests/integration/test_computed_vars.py @@ -198,14 +198,6 @@ async def test_computed_vars( """ assert computed_vars.app_instance is not None - state_name = computed_vars.get_state_name("_state") - full_state_name = computed_vars.get_full_state_name(["_state"]) - token = f"{token}_{full_state_name}" - state = (await computed_vars.get_state(token)).substates[state_name] - assert state is not None - assert state.count1_backend == 0 # pyright: ignore[reportAttributeAccessIssue] - assert state._count1_backend == 0 # pyright: ignore[reportAttributeAccessIssue] - # test that backend var is not rendered count1_backend = driver.find_element(By.ID, "count1_backend") assert count1_backend @@ -257,12 +249,6 @@ async def test_computed_vars( computed_vars.poll_for_content(depends_on_count, timeout=2, exp_not_equal="0") == "1" ) - state = (await computed_vars.get_state(token)).substates[state_name] - assert state is not None - assert state.count1_backend == 1 # pyright: ignore[reportAttributeAccessIssue] - assert count1_backend.text == "" - assert state._count1_backend == 1 # pyright: ignore[reportAttributeAccessIssue] - assert count1_backend_.text == "" mark_dirty.click() with pytest.raises(TimeoutError): diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index 8ff1516005d..f165ae655cd 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -1,9 +1,11 @@ """Test case for displaying the connection banner when the websocket drops.""" import pickle -from collections.abc import Generator +from collections.abc import AsyncGenerator, Generator import pytest +import pytest_asyncio +from redis.asyncio import Redis from reflex_core import constants from selenium.common.exceptions import NoSuchElementException from selenium.webdriver.common.by import By @@ -147,12 +149,37 @@ def _assert_token(connection_banner, driver) -> str: return ss.get("token") +@pytest_asyncio.fixture +async def redis( + connection_banner: AppHarness, +) -> AsyncGenerator[Redis | None]: + """Get the Redis instance from the StateManagerRedis used in the connection_banner test. + + Args: + connection_banner: AppHarness instance. + + Yields: + A Redis instance or None if the StateManager is not Redis. + """ + from reflex.utils.prerequisites import get_redis + + redis = None + if (app := connection_banner.app_instance) is not None and isinstance( + app.state_manager, StateManagerRedis + ): + redis = get_redis() + yield redis + if redis is not None: + await redis.close() + + @pytest.mark.asyncio -async def test_connection_banner(connection_banner: AppHarness): +async def test_connection_banner(connection_banner: AppHarness, redis: Redis | None): """Test that the connection banner is displayed when the websocket drops. Args: connection_banner: AppHarness instance. + redis: Redis instance used by the app, or None if not using Redis. """ assert connection_banner.app_instance is not None assert connection_banner.backend is not None @@ -165,11 +192,9 @@ async def test_connection_banner(connection_banner: AppHarness): app_token_manager = connection_banner.token_manager() assert token in app_token_manager.token_to_sid sid_before = app_token_manager.token_to_sid[token] - if isinstance(connection_banner.state_manager, StateManagerRedis): + if redis is not None: assert isinstance(app_token_manager, RedisTokenManager) - assert await connection_banner.state_manager.redis.get( - app_token_manager._get_redis_key(token) - ) == pickle.dumps( + assert await redis.get(app_token_manager._get_redis_key(token)) == pickle.dumps( SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_before) ) @@ -197,14 +222,9 @@ async def test_connection_banner(connection_banner: AppHarness): # The token association should have been removed when the server exited. assert token not in app_token_manager.token_to_sid - if isinstance(connection_banner.state_manager, StateManagerRedis): + if redis is not None: assert isinstance(app_token_manager, RedisTokenManager) - assert ( - await connection_banner.state_manager.redis.get( - app_token_manager._get_redis_key(token) - ) - is None - ) + assert await redis.get(app_token_manager._get_redis_key(token)) is None # Increment the counter with backend down increment_button.click() @@ -213,9 +233,6 @@ async def test_connection_banner(connection_banner: AppHarness): # Bring the backend back up connection_banner._start_backend(port=backend_port) - # Create a new StateManager to avoid async loop affinity issues w/ redis - await connection_banner._reset_backend_state_manager() - # Banner should be gone now AppHarness.expect(lambda: not has_error_modal(driver)) @@ -224,11 +241,9 @@ async def test_connection_banner(connection_banner: AppHarness): # Make sure the new connection has a different websocket sid. sid_after = app_token_manager.token_to_sid[token] assert sid_before != sid_after - if isinstance(connection_banner.state_manager, StateManagerRedis): + if redis is not None: assert isinstance(app_token_manager, RedisTokenManager) - assert await connection_banner.state_manager.redis.get( - app_token_manager._get_redis_key(token) - ) == pickle.dumps( + assert await redis.get(app_token_manager._get_redis_key(token)) == pickle.dumps( SocketRecord(instance_id=app_token_manager.instance_id, sid=sid_after) ) diff --git a/tests/integration/test_dynamic_routes.py b/tests/integration/test_dynamic_routes.py index 2f573f192c5..7c3e7641f86 100644 --- a/tests/integration/test_dynamic_routes.py +++ b/tests/integration/test_dynamic_routes.py @@ -3,7 +3,8 @@ from __future__ import annotations import asyncio -from collections.abc import Callable, Coroutine, Generator +import json +from collections.abc import Generator from urllib.parse import urlsplit import pytest @@ -11,7 +12,7 @@ from reflex.testing import AppHarness, WebDriver -from .utils import poll_for_navigation +from .utils import poll_assert_event_order, poll_for_navigation def DynamicRoute(): @@ -47,6 +48,10 @@ def next_page(self) -> str: except ValueError: return "0" + @rx.var + def params(self) -> dict[str, str | list[str]]: + return self.router.page.params + def index(): return rx.fragment( rx.input( @@ -68,12 +73,14 @@ def index(): id="link_page_next", ), rx.link("missing", href="/missing", id="link_missing"), - rx.list( # pyright: ignore [reportAttributeAccessIssue] + rx.vstack( rx.foreach( DynamicState.order, # pyright: ignore [reportAttributeAccessIssue] - lambda i: rx.list_item(rx.text(i)), + rx.text, ), + id="event_order", ), + rx.text(DynamicState.params.to_string(), id="params"), ) class ArgState(rx.State): @@ -215,46 +222,10 @@ def token(dynamic_route: AppHarness, driver: WebDriver) -> str: return token -@pytest.fixture -def poll_for_order( - dynamic_route: AppHarness, token: str -) -> Callable[[list[str]], Coroutine[None, None, None]]: - """Poll for the order list to match the expected order. - - Args: - dynamic_route: harness for DynamicRoute app. - token: The token visible in the driver browser. - - Returns: - An async function that polls for the order list to match the expected order. - """ - dynamic_state_name = dynamic_route.get_state_name("_dynamic_state") - dynamic_state_full_name = dynamic_route.get_full_state_name(["_dynamic_state"]) - - async def _poll_for_order(exp_order: list[str]): - async def _backend_state(): - return await dynamic_route.get_state(f"{token}_{dynamic_state_full_name}") - - async def _check(): - return (await _backend_state()).substates[ - dynamic_state_name - ].order == exp_order # pyright: ignore[reportAttributeAccessIssue] - - await AppHarness._poll_for_async(_check, timeout=10) - assert ( - list((await _backend_state()).substates[dynamic_state_name].order) # pyright: ignore[reportAttributeAccessIssue] - == exp_order - ) - - return _poll_for_order - - -@pytest.mark.asyncio -async def test_on_load_navigate( +def test_on_load_navigate( dynamic_route: AppHarness, driver: WebDriver, token: str, - poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], ): """Click links to navigate between dynamic pages with on_load event. @@ -262,9 +233,7 @@ async def test_on_load_navigate( dynamic_route: harness for DynamicRoute app. driver: WebDriver instance. token: The token visible in the driver browser. - poll_for_order: function that polls for the order list to match the expected order. """ - dynamic_state_full_name = dynamic_route.get_full_state_name(["_dynamic_state"]) assert dynamic_route.app_instance is not None link = driver.find_element(By.ID, "link_page_next") assert link @@ -290,7 +259,7 @@ async def test_on_load_navigate( page_id_input, exp_not_equal=str(ix - 1) ) == str(ix) assert dynamic_route.poll_for_value(raw_path_input) == f"/page/{ix}" - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) frontend_url = dynamic_route.frontend_url assert frontend_url @@ -300,48 +269,46 @@ async def test_on_load_navigate( exp_order += ["/page/[page_id]-10"] with poll_for_navigation(driver): driver.get(f"{frontend_url}/page/10") - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) # make sure internal nav still hydrates after redirect exp_order += ["/page/[page_id]-11"] link = driver.find_element(By.ID, "link_page_next") with poll_for_navigation(driver): link.click() - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) # load same page with a query param and make sure it passes through exp_order += ["/page/[page_id]-11"] with poll_for_navigation(driver): driver.get(f"{driver.current_url}?foo=bar") - await poll_for_order(exp_order) - assert ( - await dynamic_route.get_state(f"{token}_{dynamic_state_full_name}") - ).router.page.params["foo"] == "bar" + poll_assert_event_order(driver, exp_order) + params_json = driver.find_element(By.ID, "params") + params = json.loads(params_json.text) + assert params == {"foo": "bar", "page_id": "11"} # hit a 404 and ensure we still hydrate exp_order += ["/404-no page id"] with poll_for_navigation(driver): driver.get(f"{frontend_url}/missing") - await poll_for_order(exp_order) # browser nav should still trigger hydration exp_order += ["/page/[page_id]-11"] with poll_for_navigation(driver): driver.back() - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) # next/link to a 404 and ensure we still hydrate exp_order += ["/404-no page id"] link = driver.find_element(By.ID, "link_missing") with poll_for_navigation(driver): link.click() - await poll_for_order(exp_order) # hit a page that redirects back to dynamic page exp_order += ["on_load_redir-{'foo': 'bar', 'page_id': '0'}", "/page/[page_id]-0"] with poll_for_navigation(driver): driver.get(f"{frontend_url}/redirect-page/0/?foo=bar") - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) # should have redirected back to page 0 assert urlsplit(driver.current_url).path.removesuffix("/") == "/page/0" @@ -349,21 +316,18 @@ async def test_on_load_navigate( exp_order += ["on-load-static"] with poll_for_navigation(driver): driver.get(f"{frontend_url}/page/static") - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) -@pytest.mark.asyncio -async def test_on_load_navigate_non_dynamic( +def test_on_load_navigate_non_dynamic( dynamic_route: AppHarness, driver: WebDriver, - poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], ): """Click links to navigate between static pages with on_load event. Args: dynamic_route: harness for DynamicRoute app. driver: WebDriver instance. - poll_for_order: function that polls for the order list to match the expected order. """ assert dynamic_route.app_instance is not None link = driver.find_element(By.ID, "link_page_x") @@ -372,7 +336,7 @@ async def test_on_load_navigate_non_dynamic( with poll_for_navigation(driver): link.click() assert urlsplit(driver.current_url).path.removesuffix("/") == "/static/x" - await poll_for_order(["/static/x-no page id"]) + poll_assert_event_order(driver, ["/static/x-no page id"]) # go back to the index and navigate back to the static route link = driver.find_element(By.ID, "link_index") @@ -384,13 +348,13 @@ async def test_on_load_navigate_non_dynamic( with poll_for_navigation(driver): link.click() assert urlsplit(driver.current_url).path.removesuffix("/") == "/static/x" - await poll_for_order(["/static/x-no page id", "/static/x-no page id"]) + poll_assert_event_order(driver, ["/static/x-no page id", "/static/x-no page id"]) for _ in range(3): link = driver.find_element(By.ID, "link_page_x") link.click() assert urlsplit(driver.current_url).path.removesuffix("/") == "/static/x" - await poll_for_order(["/static/x-no page id"] * 5) + poll_assert_event_order(driver, ["/static/x-no page id"] * 5) @pytest.mark.asyncio diff --git a/tests/integration/test_event_actions.py b/tests/integration/test_event_actions.py index 801d1c24de9..f253c306fd1 100644 --- a/tests/integration/test_event_actions.py +++ b/tests/integration/test_event_actions.py @@ -4,7 +4,7 @@ import asyncio import time -from collections.abc import Callable, Coroutine, Generator +from collections.abc import Generator import pytest from selenium.webdriver.common.by import By @@ -12,8 +12,8 @@ from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.wait import WebDriverWait -from reflex.state import BaseState from reflex.testing import AppHarness, WebDriver +from tests.integration.utils import poll_assert_event_order def TestEventAction(): @@ -158,11 +158,12 @@ def index(): 200 ).stop_propagation, ), - rx.list( # pyright: ignore [reportAttributeAccessIssue] + rx.vstack( rx.foreach( EventActionState.order, - rx.list_item, + rx.text, ), + id="event_order", ), on_click=EventActionState.on_click("outer"), # pyright: ignore [reportCallIssue] ), rx.form( @@ -245,36 +246,6 @@ def token(event_action: AppHarness, driver: WebDriver) -> str: return token -async def _backend_state(app: AppHarness, token: str) -> BaseState: - state_name = app.get_state_name("_event_action_state") - state_full_name = app.get_full_state_name(["_event_action_state"]) - return (await app.get_state(f"{token}_{state_full_name}")).substates[state_name] - - -@pytest.fixture -def poll_for_order( - event_action: AppHarness, token: str -) -> Callable[[list[str]], Coroutine[None, None, None]]: - """Poll for the order list to match the expected order. - - Args: - event_action: harness for TestEventAction app. - token: The token visible in the driver browser. - - Returns: - An async function that polls for the order list to match the expected order. - """ - - async def _poll_for_order(exp_order: list[str]): - async def _check(): - return (await _backend_state(event_action, token)).order == exp_order # pyright: ignore[reportAttributeAccessIssue] - - await AppHarness._poll_for_async(_check) - assert (await _backend_state(event_action, token)).order == exp_order # pyright: ignore[reportAttributeAccessIssue] - - return _poll_for_order - - @pytest.mark.parametrize( ("element_id", "exp_order"), [ @@ -303,7 +274,6 @@ async def _check(): @pytest.mark.asyncio async def test_event_actions( driver: WebDriver, - poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], element_id: str, exp_order: list[str], ): @@ -311,7 +281,6 @@ async def test_event_actions( Args: driver: WebDriver instance. - poll_for_order: function that polls for the order list to match the expected order. element_id: The id of the element to click. exp_order: The expected order of events. """ @@ -324,7 +293,7 @@ async def test_event_actions( if "on_click:outer" not in exp_order: # really make sure the outer event is not fired await asyncio.sleep(0.5) - await poll_for_order(exp_order) + poll_assert_event_order(driver, exp_order) if element_id.startswith("link") and "prevent-default" not in element_id: assert driver.current_url != prev_url @@ -332,8 +301,7 @@ async def test_event_actions( assert driver.current_url == prev_url -@pytest.mark.asyncio -async def test_event_actions_throttle_debounce( +def test_event_actions_throttle_debounce( event_action: AppHarness, driver: WebDriver, token: str, @@ -358,14 +326,16 @@ async def test_event_actions_throttle_debounce( btn_debounce.click() # Wait until the debounce event shows up - async def _debounce_received(): - state = await _backend_state(event_action, token) - return state.order and state.order[-1] == "on_click_debounce" # pyright: ignore[reportAttributeAccessIssue] + def _debounce_received(): + order = driver.find_elements(By.XPATH, '//*[@id="event_order"]/p') + return len(order) and order[-1].text == "on_click_debounce" - await AppHarness._poll_for_async(_debounce_received) + AppHarness._poll_for(_debounce_received) # This test is inherently racy, so ensure the `on_click_throttle` event is fired approximately the expected number of times. - final_event_order = (await _backend_state(event_action, token)).order # pyright: ignore[reportAttributeAccessIssue] + final_event_order = [ + elem.text for elem in driver.find_elements(By.XPATH, '//*[@id="event_order"]/p') + ] n_on_click_throttle_received = final_event_order.count("on_click_throttle") print( f"Expected ~{exp_events} on_click_throttle events, received {n_on_click_throttle_received}" @@ -377,16 +347,13 @@ async def _debounce_received(): @pytest.mark.usefixtures("token") -@pytest.mark.asyncio -async def test_event_actions_dialog_form_in_form( +def test_event_actions_dialog_form_in_form( driver: WebDriver, - poll_for_order: Callable[[list[str]], Coroutine[None, None, None]], ): """Click links and buttons and assert on fired events. Args: driver: WebDriver instance. - poll_for_order: function that polls for the order list to match the expected order. """ open_dialog_id = "btn-dialog" submit_button_id = "btn-submit" @@ -400,4 +367,4 @@ async def test_event_actions_dialog_form_in_form( btn_no_events = wait.until(EC.element_to_be_clickable((By.ID, "btn-no-events"))) btn_no_events.location_once_scrolled_into_view btn_no_events.click() - await poll_for_order(["on_submit", "on_click:outer"]) + poll_assert_event_order(driver, ["on_submit", "on_click:outer"]) diff --git a/tests/integration/test_event_chain.py b/tests/integration/test_event_chain.py index 289f12cbe80..eb65c2fe1d2 100644 --- a/tests/integration/test_event_chain.py +++ b/tests/integration/test_event_chain.py @@ -9,6 +9,7 @@ from selenium.webdriver.common.by import By from reflex.testing import AppHarness, WebDriver +from tests.integration.utils import poll_assert_event_order MANY_EVENTS = 50 @@ -146,14 +147,20 @@ def click_yield_interim_value(self): app = rx.App() - token_input = rx.input( - value=State.router.session.client_token, is_read_only=True, id="token" + common_elements = rx.vstack( + rx.input( + value=State.router.session.client_token, is_read_only=True, id="token" + ), + rx.vstack( + rx.foreach(State.event_order, lambda x: rx.text(x)), id="event_order" + ), + rx.input(value=State.is_hydrated, is_read_only=True, id="is_hydrated"), ) @app.add_page def index(): return rx.fragment( - token_input, + common_elements, rx.input(value=State.interim_value, is_read_only=True, id="interim_value"), rx.button( "Return Event", @@ -225,13 +232,13 @@ def index(): def on_load_return_chain(): return rx.fragment( rx.text("return"), - token_input, + common_elements, ) def on_load_yield_chain(): return rx.fragment( rx.text("yield"), - token_input, + common_elements, ) def on_mount_return_chain(): @@ -241,7 +248,7 @@ def on_mount_return_chain(): on_mount=State.on_load_return_chain, on_unmount=lambda: State.event_arg("unmount"), ), - token_input, + common_elements, rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"), ) @@ -255,7 +262,7 @@ def on_mount_yield_chain(): ], on_unmount=State.event_no_args, ), - token_input, + common_elements, rx.button("Unmount", on_click=rx.redirect("/"), id="unmount"), ) @@ -315,6 +322,7 @@ def event_chain_strict(tmp_path_factory) -> Generator[AppHarness, None, None]: with AppHarness.create( root=tmp_path_factory.mktemp("event_chain_strict"), app_source=EventChain, + app_name="event_chain_strict", ) as harness: yield harness @@ -440,8 +448,7 @@ def assert_token(event_chain: AppHarness, driver: WebDriver) -> str: ), ], ) -@pytest.mark.asyncio -async def test_event_chain_click( +def test_event_chain_click( event_chain: AppHarness, driver: WebDriver, button_id: str, @@ -455,19 +462,11 @@ async def test_event_chain_click( button_id: the ID of the button to click exp_event_order: the expected events recorded in the State """ - token = assert_token(event_chain, driver) - state_name = event_chain.get_state_name("_state") + assert_token(event_chain, driver) btn = driver.find_element(By.ID, button_id) btn.click() - async def _has_all_events(): - return len( - (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] - ) == len(exp_event_order) - - await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] - assert event_order == exp_event_order + poll_assert_event_order(driver, exp_event_order) @pytest.mark.parametrize( @@ -493,8 +492,7 @@ async def _has_all_events(): ), ], ) -@pytest.mark.asyncio -async def test_event_chain_on_load( +def test_event_chain_on_load( event_chain: AppHarness, driver: WebDriver, uri: str, @@ -510,18 +508,15 @@ async def test_event_chain_on_load( """ assert event_chain.frontend_url is not None driver.get(event_chain.frontend_url.removesuffix("/") + uri) - token = assert_token(event_chain, driver) - state_name = event_chain.get_state_name("_state") - - async def _has_all_events(): - return len( - (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] - ) == len(exp_event_order) + assert_token(event_chain, driver) - await AppHarness._poll_for_async(_has_all_events) - backend_state = (await event_chain.get_state(token)).substates[state_name] - assert backend_state.event_order == exp_event_order # pyright: ignore[reportAttributeAccessIssue] - assert backend_state.is_hydrated is True # pyright: ignore[reportAttributeAccessIssue] + poll_assert_event_order(driver, exp_event_order) + assert ( + event_chain.poll_for_value( + driver.find_element(By.ID, "is_hydrated"), exp_not_equal="false" + ) + == "true" + ) @pytest.mark.parametrize( @@ -541,21 +536,22 @@ async def _has_all_events(): "/on-mount-yield-chain", [ "on_load_yield_chain", - "event_arg:mount", - "event_arg:4", - "event_arg:5", - "event_arg:6", - "event_no_args", + { + "event_arg:4", + "event_arg:5", + "event_arg:6", + "event_arg:mount", + "event_no_args", + }, ], ), ], ) -@pytest.mark.asyncio -async def test_event_chain_on_mount( +def test_event_chain_on_mount( event_chain: AppHarness, driver: WebDriver, uri: str, - exp_event_order: list[str], + exp_event_order: list[str | set[str]], ): """Load the URI, assert that the events are handled in the correct order. @@ -576,18 +572,10 @@ async def test_event_chain_on_mount( unmount_button = AppHarness.poll_for_or_raise_timeout( lambda: driver.find_element(By.ID, "unmount") ) - token = assert_token(event_chain, driver) - state_name = event_chain.get_state_name("_state") + assert_token(event_chain, driver) unmount_button.click() - async def _has_all_events(): - return len( - (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] - ) == len(exp_event_order) - - await AppHarness._poll_for_async(_has_all_events) - event_order = (await event_chain.get_state(token)).substates[state_name].event_order # pyright: ignore[reportAttributeAccessIssue] - assert list(event_order) == exp_event_order + poll_assert_event_order(driver, exp_event_order) @pytest.mark.parametrize( @@ -597,14 +585,18 @@ async def _has_all_events(): "/on-mount-return-chain", [ "on_load_return_chain", - "event_arg:unmount", - "on_load_return_chain", - "event_arg:1", - "event_arg:2", - "event_arg:3", - "event_arg:1", - "event_arg:2", - "event_arg:3", + { + "event_arg:unmount", + "on_load_return_chain", + "event_arg:1", + "event_arg:2", + "event_arg:3", + }, + { + "event_arg:1", + "event_arg:2", + "event_arg:3", + }, "event_arg:unmount", ], ), @@ -612,27 +604,30 @@ async def _has_all_events(): "/on-mount-yield-chain", [ "on_load_yield_chain", - "event_arg:mount", - "event_no_args", - "on_load_yield_chain", - "event_arg:mount", - "event_arg:4", - "event_arg:5", - "event_arg:6", - "event_arg:4", - "event_arg:5", - "event_arg:6", + { + "event_arg:4", + "event_arg:5", + "event_arg:6", + "event_arg:mount", + "event_no_args", + "on_load_yield_chain", + }, + { + "event_arg:mount", + "event_arg:4", + "event_arg:5", + "event_arg:6", + }, "event_no_args", ], ), ], ) -@pytest.mark.asyncio -async def test_event_chain_on_mount_strict( +def test_event_chain_on_mount_strict( event_chain_strict: AppHarness, driver_strict: WebDriver, uri: str, - exp_event_order: list[str], + exp_event_order: list[str | set[str]], ): """Run the test_event_chain_on_mount test with strict mode enabled. @@ -642,7 +637,7 @@ async def test_event_chain_on_mount_strict( uri: the page to load exp_event_order: the expected events recorded in the State """ - await test_event_chain_on_mount( + test_event_chain_on_mount( event_chain=event_chain_strict, driver=driver_strict, uri=uri, diff --git a/tests/integration/test_form_submit.py b/tests/integration/test_form_submit.py index efe5708758e..35aab1273a4 100644 --- a/tests/integration/test_form_submit.py +++ b/tests/integration/test_form_submit.py @@ -2,6 +2,7 @@ import asyncio import functools +import json from collections.abc import Generator import pytest @@ -21,7 +22,7 @@ def FormSubmit(form_component): import reflex as rx class FormState(rx.State): - form_data: dict = {} + form_data: rx.Field[dict] = rx.field(default_factory=dict) var_options: list[str] = ["option3", "option4"] @@ -65,6 +66,7 @@ def index(): on_submit=FormState.form_submit, custom_attrs={"action": "/invalid"}, ), + rx.text(FormState.form_data.to_string(), id="form-data"), rx.spacer(), height="100vh", ) @@ -79,7 +81,7 @@ def FormSubmitName(form_component): import reflex as rx class FormState(rx.State): - form_data: dict = {} + form_data: rx.Field[dict] = rx.field(default_factory=dict) val: str = "foo" options: list[str] = ["option1", "option2"] @@ -122,6 +124,7 @@ def index(): on_submit=FormState.form_submit, custom_attrs={"action": "/invalid"}, ), + rx.text(FormState.form_data.to_string(), id="form-data"), rx.spacer(), height="100vh", ) @@ -225,20 +228,12 @@ async def test_submit(driver, form_submit: AppHarness): submit_input = driver.find_element(By.CLASS_NAME, "rt-Button") submit_input.click() - state_name = form_submit.get_state_name("_form_state") - full_state_name = form_submit.get_full_state_name(["_form_state"]) - - async def get_form_data(): - return ( - (await form_submit.get_state(f"{token}_{full_state_name}")) - .substates[state_name] - .form_data # pyright: ignore[reportAttributeAccessIssue] - ) - # wait for the form data to arrive at the backend - form_data = await AppHarness._poll_for_async(get_form_data) + form_submit.poll_for_content( + driver.find_element(By.ID, "form-data"), exp_not_equal="{}" + ) + form_data = json.loads(driver.find_element(By.ID, "form-data").text) assert isinstance(form_data, dict) - form_data = format.collect_form_dict_names(form_data) print(form_data) diff --git a/tests/integration/test_input.py b/tests/integration/test_input.py index e4859a3a9de..fda2752c289 100644 --- a/tests/integration/test_input.py +++ b/tests/integration/test_input.py @@ -20,6 +20,10 @@ class State(rx.State): def set_text(self, text: str): self.text = text + @rx.event + def do_clear(self): + self.text = "" + app = rx.App() @app.add_page @@ -28,6 +32,11 @@ def index(): rx.input( value=State.router.session.client_token, is_read_only=True, id="token" ), + rx.button( + "Clear State", + on_click=State.do_clear, + id="clear-backend", + ), rx.input( id="debounce_input_input", on_change=State.set_text, @@ -72,8 +81,7 @@ def fully_controlled_input(tmp_path) -> Generator[AppHarness, None, None]: yield harness -@pytest.mark.asyncio -async def test_fully_controlled_input(fully_controlled_input: AppHarness): +def test_fully_controlled_input(fully_controlled_input: AppHarness): """Type text after moving cursor. Update text on backend. Args: @@ -91,13 +99,6 @@ async def test_fully_controlled_input(fully_controlled_input: AppHarness): token = fully_controlled_input.poll_for_value(token_input) assert token - state_name = fully_controlled_input.get_state_name("_state") - full_state_name = fully_controlled_input.get_full_state_name(["_state"]) - - async def get_state_text(): - state = await fully_controlled_input.get_state(f"{token}_{full_state_name}") - return state.substates[state_name].text # pyright: ignore[reportAttributeAccessIssue] - # ensure defaults are set correctly assert ( fully_controlled_input.poll_for_value( @@ -142,15 +143,10 @@ async def get_state_text(): lambda: fully_controlled_input.poll_for_value(value_input) == "ifoonitial" ) assert debounce_input.get_attribute("value") == "ifoonitial" - assert await get_state_text() == "ifoonitial" assert fully_controlled_input.poll_for_value(plain_value_input) == "ifoonitial" # clear the input on the backend - async with fully_controlled_input.modify_state( - f"{token}_{full_state_name}" - ) as state: - state.substates[state_name].text = "" - assert await get_state_text() == "" + driver.find_element(By.ID, "clear-backend").click() assert ( fully_controlled_input.poll_for_value( debounce_input, exp_not_equal="ifoonitial" @@ -166,7 +162,6 @@ async def get_state_text(): ) ) assert debounce_input.get_attribute("value") == "getting testing done" - assert await get_state_text() == "getting testing done" assert ( fully_controlled_input.poll_for_value(plain_value_input) == "getting testing done" @@ -181,7 +176,6 @@ async def get_state_text(): ) assert debounce_input.get_attribute("value") == "overwrite the state" assert on_change_input.get_attribute("value") == "overwrite the state" - assert await get_state_text() == "overwrite the state" assert ( fully_controlled_input.poll_for_value(plain_value_input) == "overwrite the state" diff --git a/tests/integration/test_memory_state_manager_expiration.py b/tests/integration/test_memory_state_manager_expiration.py index a6feda0f0ef..f4d3e88d7dc 100644 --- a/tests/integration/test_memory_state_manager_expiration.py +++ b/tests/integration/test_memory_state_manager_expiration.py @@ -57,7 +57,9 @@ def memory_expiration_app( app_name=f"memory_expiration_{app_harness_env.__name__.lower()}", app_source=MemoryExpirationApp, ) as harness: - assert isinstance(harness.state_manager, StateManagerMemory) + assert isinstance( + getattr(harness.app_instance, "state_manager", None), StateManagerMemory + ) yield harness diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index c698efb2b0c..c2b60d1c61f 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -3,10 +3,11 @@ from __future__ import annotations import asyncio +import json import time from collections.abc import Generator from pathlib import Path -from typing import Any, cast +from typing import Any from urllib.parse import urlsplit import pytest @@ -21,12 +22,14 @@ def UploadFile(): """App for testing dynamic routes.""" + import shutil + import reflex as rx LARGE_DATA = "DUMMY" * 1024 * 512 class UploadState(rx.State): - _file_data: dict[str, str] = {} + upload_done: rx.Field[bool] = rx.field(False) event_order: rx.Field[list[str]] = rx.field([]) progress_dicts: rx.Field[list[dict]] = rx.field([]) stream_progress_dicts: rx.Field[list[dict]] = rx.field([]) @@ -38,29 +41,42 @@ class UploadState(rx.State): @rx.event async def handle_upload(self, files: list[rx.UploadFile]): + self.upload_done = False for file in files: upload_data = await file.read() - self._file_data[file.name or ""] = upload_data.decode("utf-8") + if not file.name: + continue + local_file = rx.get_upload_dir() / file.name + local_file.parent.mkdir(parents=True, exist_ok=True) + local_file.write_bytes(upload_data) + self.upload_done = True @rx.event async def handle_upload_secondary(self, files: list[rx.UploadFile]): + self.upload_done = False for file in files: upload_data = await file.read() - self._file_data[file.name or ""] = upload_data.decode("utf-8") + if not file.name: + continue + local_file = rx.get_upload_dir() / file.name + local_file.parent.mkdir(parents=True, exist_ok=True) + local_file.write_bytes(upload_data) self.large_data = LARGE_DATA yield UploadState.chain_event @rx.event def upload_progress(self, progress): assert progress - self.event_order.append("upload_progress") + print(self.event_order) self.progress_dicts.append(progress) @rx.event def chain_event(self): assert self.large_data == LARGE_DATA self.large_data = "" + self.upload_done = True self.event_order.append("chain_event") + print(self.event_order) @rx.event def stream_upload_progress(self, progress): @@ -69,17 +85,23 @@ def stream_upload_progress(self, progress): @rx.event async def handle_upload_tertiary(self, files: list[rx.UploadFile]): + self.upload_done = False for file in files: (rx.get_upload_dir() / (file.name or "INVALID")).write_bytes( await file.read() ) + self.upload_done = True @rx.event async def handle_upload_quaternary(self, files: list[rx.UploadFile]): + self.upload_done = False self.quaternary_names = [file.name for file in files if file.name] + self.upload_done = True @rx.event(background=True) async def handle_upload_stream(self, chunk_iter: rx.UploadChunkIterator): + async with self: + self.upload_done = False upload_dir = rx.get_upload_dir() / "streaming" file_handles: dict[str, Any] = {} @@ -106,11 +128,17 @@ async def handle_upload_stream(self, chunk_iter: rx.UploadChunkIterator): async with self: self.stream_completed_files = sorted(file_handles) + self.upload_done = True @rx.event def do_download(self): return rx.download(rx.get_upload_url("test.txt")) + @rx.event + def clear_uploads(self): + shutil.rmtree(rx.get_upload_dir(), ignore_errors=True) + self.reset() + def index(): return rx.vstack( rx.input( @@ -118,6 +146,16 @@ def index(): read_only=True, id="token", ), + rx.input( + value=UploadState.upload_done.to_string(), + read_only=True, + id="upload_done", + ), + rx.button( + "Clear Uploaded Files", + id="clear_uploads", + on_click=UploadState.clear_uploads, + ), rx.heading("Default Upload"), rx.upload.root( rx.vstack( @@ -177,7 +215,8 @@ def index(): rx.foreach( UploadState.progress_dicts, lambda d: rx.text(d.to_string()), - ) + ), + id="progress_dicts", ), rx.button( "Cancel", @@ -265,6 +304,13 @@ def index(): UploadState.stream_completed_files.to_string(), id="stream_completed_files", ), + rx.vstack( + rx.foreach( + UploadState.stream_progress_dicts, + lambda d: rx.text(d.to_string()), + ), + id="stream_progress_dicts", + ), rx.text(UploadState.event_order.to_string(), id="event-order"), ) @@ -282,11 +328,18 @@ def upload_file(tmp_path_factory) -> Generator[AppHarness, None, None]: Yields: running AppHarness instance """ - with AppHarness.create( - root=tmp_path_factory.mktemp("upload_file"), - app_source=UploadFile, - ) as harness: - yield harness + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setenv( + "REFLEX_UPLOADED_FILES_DIR", str(tmp_path_factory.mktemp("uploaded_files")) + ) + try: + with AppHarness.create( + root=tmp_path_factory.mktemp("upload_file"), + app_source=UploadFile, + ) as harness: + yield harness + finally: + monkeypatch.undo() @pytest.fixture @@ -327,8 +380,7 @@ def poll_for_token(driver: WebDriver, upload_file: AppHarness) -> str: @pytest.mark.parametrize("secondary", [False, True]) -@pytest.mark.asyncio -async def test_upload_file( +def test_upload_file( tmp_path, upload_file: AppHarness, driver: WebDriver, secondary: bool ): """Submit a file upload and check that it arrived on the backend. @@ -340,10 +392,9 @@ async def test_upload_file( secondary: whether to use the secondary upload form """ assert upload_file.app_instance is not None - token = poll_for_token(driver, upload_file) - full_state_name = upload_file.get_full_state_name(["_upload_state"]) - state_name = upload_file.get_state_name("_upload_state") - substate_token = f"{token}_{full_state_name}" + poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() suffix = "_secondary" if secondary else "" @@ -366,27 +417,20 @@ async def test_upload_file( selected_files = driver.find_element(By.ID, f"selected_files{suffix}") assert Path(selected_files.text).name == Path(exp_name).name + # Wait for the upload to complete. + upload_done = driver.find_element(By.ID, "upload_done") + assert upload_file.poll_for_value(upload_done, exp_not_equal="false") == "true" + if secondary: event_order_displayed = driver.find_element(By.ID, "event-order") AppHarness.expect(lambda: "chain_event" in event_order_displayed.text) - - state = await upload_file.get_state(substate_token) - # only the secondary form tracks progress and chain events - assert state.substates[state_name].event_order.count("upload_progress") == 1 # pyright: ignore[reportAttributeAccessIssue] - assert state.substates[state_name].event_order.count("chain_event") == 1 # pyright: ignore[reportAttributeAccessIssue] + progress_dicts = driver.find_elements(By.XPATH, "//*[@id='progress_dicts']/p") + assert len(progress_dicts) > 0 + assert json.loads(progress_dicts[-1].text)["progress"] == 1 # look up the backend state and assert on uploaded contents - async def get_file_data(): - return ( - (await upload_file.get_state(substate_token)) - .substates[state_name] - ._file_data # pyright: ignore[reportAttributeAccessIssue] - ) - - file_data = await AppHarness._poll_for_async(get_file_data) - assert isinstance(file_data, dict) - normalized_file_data = {Path(k).name: v for k, v in file_data.items()} - assert normalized_file_data[Path(exp_name).name] == exp_contents + actual_contents = (rx.get_upload_dir() / exp_name).read_text() + assert actual_contents == exp_contents @pytest.mark.asyncio @@ -399,10 +443,9 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): driver: WebDriver instance. """ assert upload_file.app_instance is not None - token = poll_for_token(driver, upload_file) - full_state_name = upload_file.get_full_state_name(["_upload_state"]) - state_name = upload_file.get_state_name("_upload_state") - substate_token = f"{token}_{full_state_name}" + poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() upload_box = driver.find_element(By.XPATH, "//input[@type='file']") assert upload_box @@ -430,19 +473,13 @@ async def test_upload_file_multiple(tmp_path, upload_file: AppHarness, driver): # do the upload upload_button.click() - # look up the backend state and assert on uploaded contents - async def get_file_data(): - return ( - (await upload_file.get_state(substate_token)) - .substates[state_name] - ._file_data # pyright: ignore[reportAttributeAccessIssue] - ) + # Wait for the upload to complete. + upload_done = driver.find_element(By.ID, "upload_done") + assert upload_file.poll_for_value(upload_done, exp_not_equal="false") == "true" - file_data = await AppHarness._poll_for_async(get_file_data) - assert isinstance(file_data, dict) - normalized_file_data = {Path(k).name: v for k, v in file_data.items()} - for exp_name, exp_contents in exp_files.items(): - assert normalized_file_data[Path(exp_name).name] == exp_contents + for exp_name, exp_content in exp_files.items(): + actual_contents = (rx.get_upload_dir() / exp_name).read_text() + assert actual_contents == exp_content @pytest.mark.parametrize("secondary", [False, True]) @@ -459,6 +496,8 @@ def test_clear_files( """ assert upload_file.app_instance is not None poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() suffix = "_secondary" if secondary else "" @@ -520,10 +559,7 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive "latency": 200, # 200ms }, ) - token = poll_for_token(driver, upload_file) - state_name = upload_file.get_state_name("_upload_state") - state_full_name = upload_file.get_full_state_name(["_upload_state"]) - substate_token = f"{token}_{state_full_name}" + poll_for_token(driver, upload_file) upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[1] upload_button = driver.find_element(By.ID, "upload_button_secondary") @@ -543,23 +579,11 @@ async def test_cancel_upload(tmp_path, upload_file: AppHarness, driver: WebDrive # Wait a bit for the upload to get cancelled. await asyncio.sleep(12) - # Get interim progress dicts saved in the on_upload_progress handler. - async def _progress_dicts(): - state = await upload_file.get_state(substate_token) - return state.substates[state_name].progress_dicts # pyright: ignore[reportAttributeAccessIssue] - - # We should have _some_ progress - assert await AppHarness._poll_for_async(_progress_dicts) - # But there should never be a final progress record for a cancelled upload. - for p in await _progress_dicts(): - assert p["progress"] != 1 + for p in driver.find_elements(By.XPATH, "//*[@id='progress_dicts']/p"): + assert json.loads(p.text)["progress"] != 1 - state = await upload_file.get_state(substate_token) - file_data = state.substates[state_name]._file_data # pyright: ignore[reportAttributeAccessIssue] - assert isinstance(file_data, dict) - normalized_file_data = {Path(k).name: v for k, v in file_data.items()} - assert Path(exp_name).name not in normalized_file_data + assert not (rx.get_upload_dir() / exp_name).exists() target_file.unlink() @@ -568,10 +592,9 @@ async def _progress_dicts(): async def test_upload_chunk_file(tmp_path, upload_file: AppHarness, driver: WebDriver): """Submit a streaming upload and check that chunks are processed incrementally.""" assert upload_file.app_instance is not None - token = poll_for_token(driver, upload_file) - state_name = upload_file.get_state_name("_upload_state") - state_full_name = upload_file.get_full_state_name(["_upload_state"]) - substate_token = f"{token}_{state_full_name}" + poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[4] upload_button = driver.find_element(By.ID, "upload_button_streaming") @@ -598,28 +621,6 @@ async def test_upload_chunk_file(tmp_path, upload_file: AppHarness, driver: WebD AppHarness.expect(lambda: "stream1.txt" in chunk_records_display.text) - async def _stream_completed(): - state = await upload_file.get_state(substate_token) - return ( - len( - state.substates[state_name].stream_completed_files # pyright: ignore[reportAttributeAccessIssue] - ) - == 2 - ) - - await AppHarness._poll_for_async(_stream_completed) - - state = await upload_file.get_state(substate_token) - substate = cast(Any, state.substates[state_name]) - chunk_records = substate.stream_chunk_records - - assert len(chunk_records) > 2 - assert {Path(record.split(":")[0]).name for record in chunk_records} == { - "stream1.txt", - "stream2.txt", - } - assert substate.stream_completed_files == ["stream1.txt", "stream2.txt"] - AppHarness.expect( lambda: ( "stream1.txt" in completed_files_display.text @@ -627,6 +628,10 @@ async def _stream_completed(): ) ) + # Wait for the upload to complete. + upload_done = driver.find_element(By.ID, "upload_done") + assert upload_file.poll_for_value(upload_done, exp_not_equal="false") == "true" + for exp_name, exp_contents in exp_files.items(): assert ( rx.get_upload_dir() / "streaming" / exp_name @@ -651,10 +656,7 @@ async def test_cancel_upload_chunk( "latency": 200, # 200ms }, ) - token = poll_for_token(driver, upload_file) - state_name = upload_file.get_state_name("_upload_state") - state_full_name = upload_file.get_full_state_name(["_upload_state"]) - substate_token = f"{token}_{state_full_name}" + poll_for_token(driver, upload_file) upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[4] upload_button = driver.find_element(By.ID, "upload_button_streaming") @@ -673,21 +675,11 @@ async def test_cancel_upload_chunk( await asyncio.sleep(12) - async def _stream_progress_dicts(): - state = await upload_file.get_state(substate_token) - return ( - state.substates[state_name].stream_progress_dicts # pyright: ignore[reportAttributeAccessIssue] - ) - - assert await AppHarness._poll_for_async(_stream_progress_dicts) - - for progress in await _stream_progress_dicts(): - assert progress["progress"] != 1 + # But there should never be a final progress record for a cancelled upload. + for p in driver.find_elements(By.XPATH, "//*[@id='stream_progress_dicts']/p"): + assert json.loads(p.text)["progress"] != 1 - state = await upload_file.get_state(substate_token) - substate = cast(Any, state.substates[state_name]) - assert substate.stream_completed_files == [] - assert substate.stream_chunk_records + assert not (rx.get_upload_dir() / exp_name).exists() partial_path = rx.get_upload_dir() / "streaming" / exp_name assert partial_path.exists() @@ -715,6 +707,8 @@ def test_upload_download_file( """ assert upload_file.app_instance is not None poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[2] assert upload_box @@ -749,8 +743,7 @@ def test_upload_download_file( assert driver.find_element(by=By.TAG_NAME, value="body").text == exp_contents -@pytest.mark.asyncio -async def test_on_drop( +def test_on_drop( tmp_path, upload_file: AppHarness, driver: WebDriver, @@ -763,10 +756,9 @@ async def test_on_drop( driver: WebDriver instance. """ assert upload_file.app_instance is not None - token = poll_for_token(driver, upload_file) - full_state_name = upload_file.get_full_state_name(["_upload_state"]) - state_name = upload_file.get_state_name("_upload_state") - substate_token = f"{token}_{full_state_name}" + poll_for_token(driver, upload_file) + clear_btn = driver.find_element(By.ID, "clear_uploads") + clear_btn.click() upload_box = driver.find_elements(By.XPATH, "//input[@type='file']")[ 3 @@ -781,16 +773,18 @@ async def test_on_drop( # Simulate file drop by directly setting the file input upload_box.send_keys(str(target_file)) - # Wait for the on_drop event to be processed - await asyncio.sleep(0.5) + # Wait for the upload to complete. + upload_done = driver.find_element(By.ID, "upload_done") + assert upload_file.poll_for_value(upload_done, exp_not_equal="false") == "true" - async def exp_name_in_quaternary(): - state = await upload_file.get_state(substate_token) - return exp_name in state.substates[state_name].quaternary_names # pyright: ignore[reportAttributeAccessIssue] + def exp_name_in_quaternary(): + quaternary_files = driver.find_element(By.ID, "quaternary_files").text + if quaternary_files: + files = json.loads(quaternary_files) + return exp_name in files + return False # Poll until the file names appear in the display - await AppHarness._poll_for_async(exp_name_in_quaternary) + AppHarness._poll_for(exp_name_in_quaternary) - # Verify through state that the file names were captured correctly - state = await upload_file.get_state(substate_token) - assert exp_name in state.substates[state_name].quaternary_names # pyright: ignore[reportAttributeAccessIssue] + assert exp_name_in_quaternary() diff --git a/tests/integration/utils.py b/tests/integration/utils.py index d6b705551f9..5c851f94b65 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -2,9 +2,10 @@ from __future__ import annotations -from collections.abc import Generator, Iterator +from collections.abc import Generator, Iterator, Sequence from contextlib import contextmanager +from selenium.webdriver.common.by import By from selenium.webdriver.remote.webdriver import WebDriver from reflex.testing import AppHarness @@ -33,6 +34,80 @@ def poll_for_navigation( AppHarness.expect(lambda: prev_url != driver.current_url, timeout=timeout) +def n_expected_events(exp_event_order: Sequence[str | set[str]]) -> int: + """Calculate the number of expected events, accounting for sets in the expected order. + + Args: + exp_event_order: the expected events recorded in the State, where some entries may be sets of events that can occur in any order. + + Returns: + The total number of expected events. + """ + return sum( + len(events) if isinstance(events, set) else 1 for events in exp_event_order + ) + + +def assert_event_order( + actual_event_order: list[str], exp_event_order: Sequence[str | set[str]] +) -> None: + """Verify that the actual event order matches the expected event order, accounting for sets in the expected order. + + Args: + actual_event_order: the actual events recorded in the State. + exp_event_order: the expected events recorded in the State, where some entries may be sets of events that can occur in any order. + + Raises: + AssertionError: if the actual event order does not match the expected event order. + """ + actual_idx = 0 + for expected in exp_event_order: + if isinstance(expected, str): + assert actual_event_order[actual_idx] == expected, ( + f"Expected event '{expected}' at position {actual_idx}, but got '{actual_event_order[actual_idx]}'." + ) + actual_idx += 1 + else: # expected is a set of events that can occur in any order + expected_events = set(expected) + actual_events = set( + actual_event_order[actual_idx : actual_idx + len(expected_events)] + ) + assert actual_events == expected_events, ( + f"Expected events {expected_events} at positions {actual_idx} to {actual_idx + len(expected_events) - 1}, but got {actual_events}." + ) + actual_idx += len(expected_events) + assert actual_idx == len(actual_event_order), ( + f"Expected {actual_idx} events, but got {len(actual_event_order)}: {actual_event_order[actual_idx:]} remain." + ) + + +def poll_assert_event_order( + driver: WebDriver, + exp_event_order: Sequence[str | set[str]], + xpath: str = '//*[@id="event_order"]/p', +) -> None: + """Poll until the actual event order matches the expected event order, accounting for sets in the expected order. + + Args: + driver: WebDriver instance. + exp_event_order: the expected events recorded in the State, where some entries may be sets of events that can occur in any order. + xpath: The XPath to the event order elements. + + Raises: + AssertionError: if the actual event order does not match the expected event order after polling. + """ + n_exp_events = n_expected_events(exp_event_order) + + def _has_number_of_expected_events(): + event_elements = driver.find_elements(By.XPATH, xpath) + return len(event_elements) == n_exp_events + + AppHarness._poll_for(_has_number_of_expected_events) + + event_elements = driver.find_elements(By.XPATH, xpath) + assert_event_order([elem.text for elem in event_elements], exp_event_order) + + class LocalStorage: """Class to access local storage. diff --git a/tests/units/conftest.py b/tests/units/conftest.py index ae26868a3ac..3939a749eec 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -469,6 +469,20 @@ def forked_registration_context() -> Generator[RegistrationContext, None, None]: yield ctx +@pytest.fixture +def clean_registration_context() -> Generator[RegistrationContext, None, None]: + """Create and attach a clean registration context. + + Sets the new context as the current registration context for the duration + of the test, then resets it afterwards. + + Yields: + The clean RegistrationContext. + """ + with RegistrationContext() as ctx: + yield ctx + + @pytest.fixture def preserve_memo_registries(): """Save and restore global memo registries around a test. diff --git a/tests/units/middleware/test_hydrate_middleware.py b/tests/units/middleware/test_hydrate_middleware.py index 2ac1fadd139..9d54a16d31e 100644 --- a/tests/units/middleware/test_hydrate_middleware.py +++ b/tests/units/middleware/test_hydrate_middleware.py @@ -1,8 +1,8 @@ from __future__ import annotations import pytest -from pytest_mock import MockerFixture +from reflex._internal.registry import RegistrationContext from reflex.app import App from reflex.middleware.hydrate_middleware import HydrateMiddleware from reflex.state import State, StateUpdate @@ -31,15 +31,17 @@ def hydrate_middleware() -> HydrateMiddleware: @pytest.mark.asyncio -async def test_preprocess_no_events(hydrate_middleware, event1, mocker: MockerFixture): +async def test_preprocess_no_events( + hydrate_middleware, event1, clean_registration_context: RegistrationContext +): """Test that app without on_load is processed correctly. Args: hydrate_middleware: Instance of HydrateMiddleware event1: An Event. - mocker: pytest mock object. + clean_registration_context: The registration context fixture, which is cleared before each test. """ - mocker.patch("reflex.state.State.class_subclasses", {TestState}) + clean_registration_context.register_base_state(TestState) state = State() update = await hydrate_middleware.preprocess( app=App(_state=State), diff --git a/tests/units/test_app.py b/tests/units/test_app.py index cef0f6ea6d6..dc2d1cc986d 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -34,6 +34,7 @@ import reflex as rx from reflex import AdminDash, constants +from reflex._internal.registry import RegistrationContext from reflex.app import App, ComponentCallable, default_overlay_component, upload from reflex.environment import environment from reflex.ievent.context import EventContext @@ -56,7 +57,6 @@ from .states.upload import ( ChildFileUploadState, ChunkUploadState, - FileStateBase1, FileUploadState, GrandChildFileUploadState, ) @@ -508,6 +508,7 @@ async def test_dynamic_var_event( mock_base_state_event_processor: BaseStateEventProcessor, emitted_deltas: list[tuple[str, dict[str, dict[str, Any]]]], token: str, + clean_registration_context: RegistrationContext, ): """Test that the default handler of a dynamic generated var works as expected. @@ -517,7 +518,9 @@ async def test_dynamic_var_event( mock_base_state_event_processor: BaseStateEventProcessor Fixture. emitted_deltas: List to store emitted deltas. token: a Token. + clean_registration_context: The registration context fixture, which is cleared before each test. """ + clean_registration_context.register_base_state(test_state) state = test_state() # pyright: ignore [reportCallIssue] state.add_var("int_val", int, 0) async with mock_base_state_event_processor as processor: @@ -968,6 +971,7 @@ async def test_upload_file( mocker: MockerFixture, attached_mock_base_state_event_processor: BaseStateEventProcessor, mock_root_event_context: EventContext, + clean_registration_context: RegistrationContext, ): """Test that file upload works correctly. @@ -979,12 +983,12 @@ async def test_upload_file( mocker: pytest mocker object. attached_mock_base_state_event_processor: BaseStateEventProcessor Fixture attached to the app instance to capture emitted events. mock_root_event_context: The mocked root event context, for accessing state_manager. + clean_registration_context: Fixture to ensure clean registration context for each test, preventing cross-test contamination of state subclasses. """ - mocker.patch( - "reflex.state.State.class_subclasses", - {state if state is FileUploadState else FileStateBase1}, + clean_registration_context.register_base_state(state) + app = Mock( + event_processor=attached_mock_base_state_event_processor, ) - app = Mock(event_processor=attached_mock_base_state_event_processor) async with mock_root_event_context.state_manager.modify_state( BaseStateToken(ident=token, cls=state) ) as root_state: @@ -1046,10 +1050,6 @@ async def test_upload_file_keeps_form_open_until_stream_completes( attached_mock_base_state_event_processor: BaseStateEventProcessor Fixture attached to the app instance to capture emitted events. mock_root_event_context: The mocked root event context, for accessing state_manager. """ - mocker.patch( - "reflex.state.State.class_subclasses", - {FileUploadState}, - ) app = Mock(event_processor=attached_mock_base_state_event_processor) # Set _tmp_path via modify_state instead of setting class attribute directly. @@ -1124,10 +1124,6 @@ async def test_upload_empty_buffered_request_dispatches_alias_handler( mock_root_event_context: EventContext, ): """Test that empty uploads still dispatch buffered alias handlers.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {FileUploadState}, - ) app = Mock(event_processor=attached_mock_base_state_event_processor) async with mock_root_event_context.state_manager.modify_state( @@ -1176,10 +1172,6 @@ async def test_upload_file_closes_form_on_form_error( attached_mock_base_state_event_processor: BaseStateEventProcessor, ): """Test that cancellation before form parsing leaves form data untouched.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {FileUploadState}, - ) app = Mock(event_processor=attached_mock_base_state_event_processor) request_mock = unittest.mock.Mock() @@ -1215,10 +1207,6 @@ async def test_upload_file_closes_form_on_event_creation_cancellation( attached_mock_base_state_event_processor: BaseStateEventProcessor, ): """Test that cancellation during event creation closes form data.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {FileUploadState}, - ) app = Mock(event_processor=attached_mock_base_state_event_processor) request_mock = unittest.mock.Mock() @@ -1261,10 +1249,6 @@ async def test_upload_file_closes_form_if_response_cancelled_before_stream_start mock_root_event_context: EventContext, ): """Test that response cancellation before iteration still closes form data.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {FileUploadState}, - ) app = Mock(event_processor=attached_mock_base_state_event_processor) async with mock_root_event_context.state_manager.modify_state( @@ -1471,10 +1455,6 @@ async def test_upload_dispatches_chunk_handlers_on_upload_endpoint( mock_root_event_context: EventContext, ): """Test that the standard upload endpoint dispatches chunk handlers.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {ChunkUploadState}, - ) app = Mock(event_processor=attached_mock_base_state_event_processor) async with mock_root_event_context.state_manager.modify_state( @@ -1565,10 +1545,6 @@ async def test_upload_empty_chunk_request_dispatches_alias_handler( mock_root_event_context: EventContext, ): """Test that empty uploads still dispatch chunk alias handlers.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {ChunkUploadState}, - ) app = Mock(event_processor=attached_mock_base_state_event_processor) async with mock_root_event_context.state_manager.modify_state( @@ -1618,10 +1594,6 @@ async def test_upload_chunk_invalid_offset_returns_400( mock_root_event_context: EventContext, ): """Test that malformed chunk metadata fails the standard upload request.""" - mocker.patch( - "reflex.state.State.class_subclasses", - {ChunkUploadState}, - ) app = Mock(event_processor=attached_mock_base_state_event_processor) # The background task is expected to fail with a parse error for malformed input. attached_mock_base_state_event_processor.backend_exception_handler = None diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 239969cd8dc..28691a2fe54 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1747,7 +1747,8 @@ async def test_state_manager_modify_state( if isinstance(state_manager, StateManagerRedis): assert (await state_manager.redis.get(f"{token}_lock")) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): - assert not state_manager._states_locks[token].locked() + lock = state_manager._states_locks.get(token) + assert lock is None or not lock.locked() # separate instances should NOT share locks sm2 = type(state_manager)() @@ -1797,8 +1798,8 @@ async def _coro(): if isinstance(state_manager, StateManagerRedis): assert (await state_manager.redis.get(f"{token}_lock")) is None elif isinstance(state_manager, (StateManagerMemory, StateManagerDisk)): - assert token in state_manager._states_locks - assert not state_manager._states_locks[token].locked() + lock = state_manager._states_locks.get(token) + assert lock is None or not lock.locked() @pytest_asyncio.fixture(loop_scope="function", scope="function") @@ -1848,6 +1849,7 @@ async def test_state_manager_lock_expire( """ state_manager_redis.lock_expiration = LOCK_EXPIRATION state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD + state_manager_redis.oplock_hold_time_ms = LOCK_EXPIRATION // 2 loop_exception = None @@ -1898,6 +1900,7 @@ async def test_state_manager_lock_expire_contend( state_manager_redis.lock_expiration = LOCK_EXPIRATION state_manager_redis.lock_warning_threshold = LOCK_WARNING_THRESHOLD + state_manager_redis.oplock_hold_time_ms = LOCK_EXPIRATION // 2 loop_exception = None @@ -2994,9 +2997,6 @@ async def test_preprocess( emitted_deltas: List to capture emitted deltas. """ OnLoadInternalState._app_ref = None - mocker.patch( - "reflex.state.State.class_subclasses", {test_state, OnLoadInternalState} - ) app = app_module_mock.app = App(_state=State) app._state_manager = mock_root_event_context.state_manager @@ -3077,9 +3077,6 @@ async def test_preprocess_multiple_load_events( emitted_deltas: List to capture emitted deltas. """ OnLoadInternalState._app_ref = None - mocker.patch( - "reflex.state.State.class_subclasses", {OnLoadState, OnLoadInternalState} - ) app = app_module_mock.app = App(_state=State) app._state_manager = mock_root_event_context.state_manager From d914ea0a168ee7a445901a8afe76714131de0f3c Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 01:11:19 -0700 Subject: [PATCH 31/81] Move reflex._internal to reflex_core._internal --- .../src/reflex_components_core/core/_upload.py | 3 +-- .../reflex-core/src/reflex_core}/_internal/__init__.py | 0 .../src/reflex_core}/_internal/context/__init__.py | 0 .../src/reflex_core}/_internal/context/base.py | 0 .../reflex-core/src/reflex_core}/_internal/registry.py | 8 +++----- .../reflex-core/src/reflex_core/components/component.py | 2 +- packages/reflex-core/src/reflex_core/event.py | 2 +- reflex/ievent/context.py | 3 +-- reflex/ievent/processor/event_processor.py | 2 +- reflex/state.py | 9 ++++----- reflex/testing.py | 2 +- tests/units/conftest.py | 2 +- tests/units/middleware/test_hydrate_middleware.py | 2 +- tests/units/test_app.py | 2 +- 14 files changed, 16 insertions(+), 21 deletions(-) rename {reflex => packages/reflex-core/src/reflex_core}/_internal/__init__.py (100%) rename {reflex => packages/reflex-core/src/reflex_core}/_internal/context/__init__.py (100%) rename {reflex => packages/reflex-core/src/reflex_core}/_internal/context/base.py (100%) rename {reflex => packages/reflex-core/src/reflex_core}/_internal/registry.py (98%) diff --git a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py index 7dc9face947..426c63f1fc1 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py @@ -621,13 +621,12 @@ async def upload_file(request: Request): UploadTypeError: If a non-streaming upload is wired to a background task. HTTPException: when the request does not include token / handler headers. """ + from reflex_core._internal.registry import RegistrationContext from reflex_core.event import ( resolve_upload_chunk_handler_param, resolve_upload_handler_param, ) - from reflex._internal.registry import RegistrationContext - token, handler_name = _require_upload_headers(request) registered_event_handler = RegistrationContext.get().event_handlers[ handler_name diff --git a/reflex/_internal/__init__.py b/packages/reflex-core/src/reflex_core/_internal/__init__.py similarity index 100% rename from reflex/_internal/__init__.py rename to packages/reflex-core/src/reflex_core/_internal/__init__.py diff --git a/reflex/_internal/context/__init__.py b/packages/reflex-core/src/reflex_core/_internal/context/__init__.py similarity index 100% rename from reflex/_internal/context/__init__.py rename to packages/reflex-core/src/reflex_core/_internal/context/__init__.py diff --git a/reflex/_internal/context/base.py b/packages/reflex-core/src/reflex_core/_internal/context/base.py similarity index 100% rename from reflex/_internal/context/base.py rename to packages/reflex-core/src/reflex_core/_internal/context/base.py diff --git a/reflex/_internal/registry.py b/packages/reflex-core/src/reflex_core/_internal/registry.py similarity index 98% rename from reflex/_internal/registry.py rename to packages/reflex-core/src/reflex_core/_internal/registry.py index 7d22d9c26ca..ebe881718d7 100644 --- a/reflex/_internal/registry.py +++ b/packages/reflex-core/src/reflex_core/_internal/registry.py @@ -3,15 +3,13 @@ import dataclasses from typing import TYPE_CHECKING, Self +from reflex_core._internal.context.base import BaseContext from reflex_core.utils.exceptions import StateValueError -from reflex._internal.context.base import BaseContext - if TYPE_CHECKING: - from reflex_core.components.component import StatefulComponent - - from reflex.event import EventHandler from reflex.state import BaseState + from reflex_core.components.component import StatefulComponent + from reflex_core.event import EventHandler @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) diff --git a/packages/reflex-core/src/reflex_core/components/component.py b/packages/reflex-core/src/reflex_core/components/component.py index 79e00eba504..2eb0543ebdb 100644 --- a/packages/reflex-core/src/reflex_core/components/component.py +++ b/packages/reflex-core/src/reflex_core/components/component.py @@ -2419,7 +2419,7 @@ def create(cls, component: Component) -> StatefulComponent | None: """ from reflex_components_core.core.foreach import Foreach - from reflex._internal.registry import RegistrationContext + from reflex_core._internal.registry import RegistrationContext if component._memoization_mode.disposition == MemoizationDisposition.NEVER: # Never memoize this component. diff --git a/packages/reflex-core/src/reflex_core/event.py b/packages/reflex-core/src/reflex_core/event.py index 246799fc0b8..cbb2edd56b7 100644 --- a/packages/reflex-core/src/reflex_core/event.py +++ b/packages/reflex-core/src/reflex_core/event.py @@ -81,7 +81,7 @@ class Event: @property def state_cls(self) -> type[BaseState]: """The state class for the event.""" - from reflex._internal.registry import RegistrationContext + from reflex_core._internal.registry import RegistrationContext substate_name = self.name.rpartition(".")[0] return RegistrationContext.get().base_states[substate_name] diff --git a/reflex/ievent/context.py b/reflex/ievent/context.py index 24d43363c44..a8c120150af 100644 --- a/reflex/ievent/context.py +++ b/reflex/ievent/context.py @@ -6,10 +6,9 @@ from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Protocol +from reflex_core._internal.context.base import BaseContext from reflex_core.utils.format import to_snake_case -from reflex._internal.context.base import BaseContext - if TYPE_CHECKING: from reflex_core.event import Event diff --git a/reflex/ievent/processor/event_processor.py b/reflex/ievent/processor/event_processor.py index 8764714760b..d327091b911 100644 --- a/reflex/ievent/processor/event_processor.py +++ b/reflex/ievent/processor/event_processor.py @@ -11,8 +11,8 @@ from typing import TYPE_CHECKING, Any, Self import rich.markup +from reflex_core._internal.registry import RegisteredEventHandler, RegistrationContext -from reflex._internal.registry import RegisteredEventHandler, RegistrationContext from reflex.app_mixins.middleware import MiddlewareMixin from reflex.ievent.context import EventContext from reflex.istate.manager import StateManager diff --git a/reflex/state.py b/reflex/state.py index 756a107d1c1..a97e9f6e977 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -496,10 +496,9 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): Raises: StateValueError: If a substate class shadows another. """ + from reflex_core._internal.registry import RegistrationContext from reflex_core.utils.exceptions import StateValueError - from reflex._internal.registry import RegistrationContext - super().__init_subclass__(**kwargs) if cls._mixin: @@ -961,7 +960,7 @@ def get_substates(cls) -> set[type[BaseState]]: Returns: The substates of the state. """ - from reflex._internal.registry import RegistrationContext + from reflex_core._internal.registry import RegistrationContext return RegistrationContext.get().get_substates(cls) @@ -1146,7 +1145,7 @@ def _create_event_handler( Returns: The event handler. """ - from reflex._internal.registry import RegistrationContext + from reflex_core._internal.registry import RegistrationContext # Check if function has stored event_actions from decorator event_actions = getattr(fn, EVENT_ACTIONS_MARKER, {}) @@ -2570,7 +2569,7 @@ def reload_state_module( state: Recursive argument for the state class to reload. """ - from reflex._internal.registry import RegistrationContext + from reflex_core._internal.registry import RegistrationContext # Reset the _app_ref of OnLoadInternalState to avoid stale references. if state is OnLoadInternalState: diff --git a/reflex/testing.py b/reflex/testing.py index a7b7bccc6f2..4c693b8659e 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -28,6 +28,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar import uvicorn +from reflex_core._internal.registry import RegistrationContext from reflex_core.components.component import CUSTOM_COMPONENTS, CustomComponent from reflex_core.config import get_config from reflex_core.environment import environment @@ -40,7 +41,6 @@ import reflex.utils.format import reflex.utils.prerequisites import reflex.utils.processes -from reflex._internal.registry import RegistrationContext from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.state import reload_state_module from reflex.utils import console, js_runtimes diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 3939a749eec..4c0cec711bc 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -10,10 +10,10 @@ import pytest import pytest_asyncio +from reflex_core._internal.registry import RegistrationContext from reflex_core.components.component import CUSTOM_COMPONENTS from reflex_core.event import Event, EventSpec -from reflex._internal.registry import RegistrationContext from reflex.app import App from reflex.experimental.memo import EXPERIMENTAL_MEMOS from reflex.ievent.context import EventContext diff --git a/tests/units/middleware/test_hydrate_middleware.py b/tests/units/middleware/test_hydrate_middleware.py index 9d54a16d31e..489bf23793d 100644 --- a/tests/units/middleware/test_hydrate_middleware.py +++ b/tests/units/middleware/test_hydrate_middleware.py @@ -1,8 +1,8 @@ from __future__ import annotations import pytest +from reflex_core._internal.registry import RegistrationContext -from reflex._internal.registry import RegistrationContext from reflex.app import App from reflex.middleware.hydrate_middleware import HydrateMiddleware from reflex.state import State, StateUpdate diff --git a/tests/units/test_app.py b/tests/units/test_app.py index dc2d1cc986d..e3fbd7cc5b7 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -20,6 +20,7 @@ from reflex_components_core.base.fragment import Fragment from reflex_components_core.core.cond import Cond from reflex_components_radix.themes.typography.text import Text +from reflex_core._internal.registry import RegistrationContext from reflex_core.components.component import Component from reflex_core.constants.state import FIELD_MARKER from reflex_core.event import Event @@ -34,7 +35,6 @@ import reflex as rx from reflex import AdminDash, constants -from reflex._internal.registry import RegistrationContext from reflex.app import App, ComponentCallable, default_overlay_component, upload from reflex.environment import environment from reflex.ievent.context import EventContext From c50d2b2890cc0fcb2f3accc529cf2d1cac4c8e64 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 01:14:57 -0700 Subject: [PATCH 32/81] move reflex.ievent to reflex_core._internal.event --- .../src/reflex_core/_internal/event}/__init__.py | 0 .../src/reflex_core/_internal/event}/context.py | 3 +-- .../_internal/event/processor/__init__.py | 15 +++++++++++++++ .../event}/processor/base_state_processor.py | 12 ++++++------ .../_internal/event}/processor/event_processor.py | 4 ++-- reflex/app.py | 5 ++++- reflex/ievent/processor/__init__.py | 10 ---------- reflex/istate/manager/__init__.py | 2 +- reflex/istate/proxy.py | 2 +- reflex/state.py | 2 +- tests/units/conftest.py | 7 +++++-- tests/units/istate/test_proxy.py | 2 +- tests/units/test_app.py | 4 ++-- tests/units/test_state.py | 4 ++-- 14 files changed, 41 insertions(+), 31 deletions(-) rename {reflex/ievent => packages/reflex-core/src/reflex_core/_internal/event}/__init__.py (100%) rename {reflex/ievent => packages/reflex-core/src/reflex_core/_internal/event}/context.py (99%) create mode 100644 packages/reflex-core/src/reflex_core/_internal/event/processor/__init__.py rename {reflex/ievent => packages/reflex-core/src/reflex_core/_internal/event}/processor/base_state_processor.py (98%) rename {reflex/ievent => packages/reflex-core/src/reflex_core/_internal/event}/processor/event_processor.py (99%) delete mode 100644 reflex/ievent/processor/__init__.py diff --git a/reflex/ievent/__init__.py b/packages/reflex-core/src/reflex_core/_internal/event/__init__.py similarity index 100% rename from reflex/ievent/__init__.py rename to packages/reflex-core/src/reflex_core/_internal/event/__init__.py diff --git a/reflex/ievent/context.py b/packages/reflex-core/src/reflex_core/_internal/event/context.py similarity index 99% rename from reflex/ievent/context.py rename to packages/reflex-core/src/reflex_core/_internal/event/context.py index a8c120150af..1c964b44dfb 100644 --- a/reflex/ievent/context.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/context.py @@ -10,9 +10,8 @@ from reflex_core.utils.format import to_snake_case if TYPE_CHECKING: - from reflex_core.event import Event - from reflex.istate.manager import StateManager + from reflex_core.event import Event @functools.lru_cache diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/__init__.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/__init__.py new file mode 100644 index 00000000000..df0d957c186 --- /dev/null +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/__init__.py @@ -0,0 +1,15 @@ +"""Procedures for handling events.""" + +from reflex_core._internal.event.processor.base_state_processor import ( + BaseStateEventProcessor, +) +from reflex_core._internal.event.processor.event_processor import ( + EventProcessor, + EventQueueEntry, +) + +__all__ = [ + "BaseStateEventProcessor", + "EventProcessor", + "EventQueueEntry", +] diff --git a/reflex/ievent/processor/base_state_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py similarity index 98% rename from reflex/ievent/processor/base_state_processor.py rename to packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py index d641522933b..988330d42bf 100644 --- a/reflex/ievent/processor/base_state_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py @@ -9,16 +9,16 @@ from importlib.util import find_spec from typing import TYPE_CHECKING, Any -from reflex.ievent.context import EventContext -from reflex.ievent.processor import EventProcessor -from reflex.ievent.processor.event_processor import ( - EventQueueEntry, - RegisteredEventHandler, -) from reflex.istate.data import RouterData from reflex.istate.manager.token import BaseStateToken from reflex.istate.proxy import StateProxy from reflex.utils import console, types +from reflex_core._internal.event.context import EventContext +from reflex_core._internal.event.processor.event_processor import ( + EventProcessor, + EventQueueEntry, + RegisteredEventHandler, +) if TYPE_CHECKING: from reflex.event import EventHandler, EventSpec diff --git a/reflex/ievent/processor/event_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py similarity index 99% rename from reflex/ievent/processor/event_processor.py rename to packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py index d327091b911..a1f34c903a1 100644 --- a/reflex/ievent/processor/event_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py @@ -11,12 +11,12 @@ from typing import TYPE_CHECKING, Any, Self import rich.markup -from reflex_core._internal.registry import RegisteredEventHandler, RegistrationContext from reflex.app_mixins.middleware import MiddlewareMixin -from reflex.ievent.context import EventContext from reflex.istate.manager import StateManager from reflex.utils import console +from reflex_core._internal.event.context import EventContext +from reflex_core._internal.registry import RegisteredEventHandler, RegistrationContext if TYPE_CHECKING: from reflex.app import EventNamespace diff --git a/reflex/app.py b/reflex/app.py index 301fd16bfc3..92e030194a2 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -38,6 +38,10 @@ from reflex_components_radix import themes from reflex_components_sonner.toast import toast from reflex_core import constants +from reflex_core._internal.event.processor import ( + BaseStateEventProcessor, + EventProcessor, +) from reflex_core.components.component import ( CUSTOM_COMPONENTS, Component, @@ -79,7 +83,6 @@ readable_name_from_component, ) from reflex.experimental.memo import EXPERIMENTAL_MEMOS -from reflex.ievent.processor import BaseStateEventProcessor, EventProcessor from reflex.istate.manager import StateManager, StateModificationContext from reflex.istate.manager.token import BaseStateToken from reflex.page import DECORATED_PAGES diff --git a/reflex/ievent/processor/__init__.py b/reflex/ievent/processor/__init__.py deleted file mode 100644 index fe905a50a81..00000000000 --- a/reflex/ievent/processor/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -"""Procedures for handling events.""" - -from reflex.ievent.processor.event_processor import EventProcessor, EventQueueEntry # noqa: I001 -from reflex.ievent.processor.base_state_processor import BaseStateEventProcessor - -__all__ = [ - "BaseStateEventProcessor", - "EventProcessor", - "EventQueueEntry", -] diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index 58aa688ffee..95e34459513 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -176,6 +176,6 @@ def get_state_manager() -> StateManager: Returns: The state manager. """ - from reflex.ievent.context import EventContext + from reflex_core._internal.event.context import EventContext return EventContext.get().state_manager diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index 430669f8b38..4108a254819 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -15,13 +15,13 @@ from typing import TYPE_CHECKING, Any, SupportsIndex, TypeVar import wrapt +from reflex_core._internal.event.context import EventContext from reflex_core.event import Event from reflex_core.utils.exceptions import ImmutableStateError from reflex_core.utils.serializers import can_serialize, serialize, serializer from reflex_core.vars.base import Var from typing_extensions import Self -from reflex.ievent.context import EventContext from reflex.istate.manager.token import BaseStateToken if TYPE_CHECKING: diff --git a/reflex/state.py b/reflex/state.py index a97e9f6e977..e4a5a88c296 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2214,7 +2214,7 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: @event async def hydrate(self) -> None: """Send the full state to the frontend to synchronize it with the backend.""" - from reflex.ievent.context import EventContext + from reflex_core._internal.event.context import EventContext # Clear client storage, to respect clearing cookies self._reset_client_storage() diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 4c0cec711bc..521d298df3e 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -10,14 +10,17 @@ import pytest import pytest_asyncio +from reflex_core._internal.event.context import EventContext +from reflex_core._internal.event.processor import ( + BaseStateEventProcessor, + EventProcessor, +) from reflex_core._internal.registry import RegistrationContext from reflex_core.components.component import CUSTOM_COMPONENTS from reflex_core.event import Event, EventSpec from reflex.app import App from reflex.experimental.memo import EXPERIMENTAL_MEMOS -from reflex.ievent.context import EventContext -from reflex.ievent.processor import BaseStateEventProcessor, EventProcessor from reflex.istate.manager import StateManager from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory diff --git a/tests/units/istate/test_proxy.py b/tests/units/istate/test_proxy.py index 1e74d99bdb5..13d042d0f33 100644 --- a/tests/units/istate/test_proxy.py +++ b/tests/units/istate/test_proxy.py @@ -6,9 +6,9 @@ from contextlib import asynccontextmanager import pytest +from reflex_core._internal.event.context import EventContext import reflex as rx -from reflex.ievent.context import EventContext from reflex.istate.proxy import MutableProxy, StateProxy diff --git a/tests/units/test_app.py b/tests/units/test_app.py index e3fbd7cc5b7..b1c31b4641a 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -20,6 +20,8 @@ from reflex_components_core.base.fragment import Fragment from reflex_components_core.core.cond import Cond from reflex_components_radix.themes.typography.text import Text +from reflex_core._internal.event.context import EventContext +from reflex_core._internal.event.processor import BaseStateEventProcessor from reflex_core._internal.registry import RegistrationContext from reflex_core.components.component import Component from reflex_core.constants.state import FIELD_MARKER @@ -37,8 +39,6 @@ from reflex import AdminDash, constants from reflex.app import App, ComponentCallable, default_overlay_component, upload from reflex.environment import environment -from reflex.ievent.context import EventContext -from reflex.ievent.processor import BaseStateEventProcessor from reflex.istate.manager.disk import StateManagerDisk from reflex.istate.manager.memory import StateManagerMemory from reflex.istate.manager.redis import StateManagerRedis diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 28691a2fe54..53079075f93 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -22,6 +22,8 @@ from pydantic import BaseModel as Base from pytest_mock import MockerFixture from reflex_core import constants +from reflex_core._internal.event.context import EventContext +from reflex_core._internal.event.processor import BaseStateEventProcessor from reflex_core.constants import CompileVars, RouteVar from reflex_core.constants.state import FIELD_MARKER from reflex_core.event import Event, EventHandler @@ -40,8 +42,6 @@ import reflex as rx from reflex.app import App from reflex.environment import environment -from reflex.ievent.context import EventContext -from reflex.ievent.processor import BaseStateEventProcessor from reflex.istate.data import HeaderData, _FrozenDictStrStr from reflex.istate.manager import StateManager from reflex.istate.manager.disk import StateManagerDisk From 9650b0fd692d9d7ecf91ff8ce8553011828920e6 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 02:00:17 -0700 Subject: [PATCH 33/81] replace "reload" functionality with internal rehydration instead of telling the frontend to reload, just hydrate and run on_load internally before processing the user's requested event. --- .../reflex_core/.templates/web/utils/state.js | 4 -- .../event/processor/base_state_processor.py | 43 ++++++++++++++++++- tests/units/conftest.py | 8 +++- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js b/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js index cc8ce1c814b..426fe116f61 100644 --- a/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js +++ b/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js @@ -672,10 +672,6 @@ export const connect = async ( queueEvents(update.events, socket, false, navigate, params); } }); - socket.current.on("reload", async (event) => { - on_hydrated_queue.push(event); - queueEvents(initialEvents(), socket, true, navigate, params); - }); socket.current.on("new_token", async (new_token) => { token = new_token; window.sessionStorage.setItem(TOKEN_KEY, new_token); diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py index 988330d42bf..02aed3a1a99 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py @@ -19,12 +19,20 @@ EventQueueEntry, RegisteredEventHandler, ) +from reflex_core.utils.format import format_event_handler if TYPE_CHECKING: from reflex.event import EventHandler, EventSpec from reflex.state import BaseState +@functools.lru_cache(maxsize=1) +def _hydrate_event_name(): + from reflex.state import State + + return format_event_handler(State.event_handlers["hydrate"]) + + def _check_valid_yield(events: Any, handler_name: str = "unknown") -> Any: """Check if the events yielded are valid. They must be EventHandlers or EventSpecs. @@ -271,6 +279,33 @@ class BaseStateEventProcessor(EventProcessor): frontend. """ + async def _rehydrate(self, root_state: BaseState): + """Rehydrate the state by calling the hydrate event handler. + + Args: + root_state: The root state to rehydrate. + """ + from reflex.state import OnLoadInternalState, State + + if ( + type(root_state) is not State + or OnLoadInternalState.get_name() not in root_state.substates + ): + return + + await process_event( + handler=State.event_handlers["hydrate"], + payload={}, + state=root_state, + root_state=root_state, + ) + await process_event( + handler=OnLoadInternalState.event_handlers["on_load_internal"], + payload={}, + state=await root_state.get_state(OnLoadInternalState), + root_state=root_state, + ) + async def _process_event_queue_entry( self, *, entry: EventQueueEntry, registered_handler: RegisteredEventHandler ) -> None: @@ -295,7 +330,10 @@ async def _process_event_queue_entry( ), event=entry.event, ) as state: - # TODO: handle "reload" trigger of brand new state instances + # Compatibility hack rehydrate the state before processing this event. + needs_to_rehydrate = bool( + not state.router_data and event.name != _hydrate_event_name() + ) # re-assign only when the value is set and different if router_data and state.router_data != router_data: @@ -322,6 +360,9 @@ async def _process_event_queue_entry( substate = await state.get_state(event.state_cls) root_state = state._get_root_state() + if needs_to_rehydrate: + await self._rehydrate(root_state) + # Process non-background events while holding the lock. if not registered_handler.handler.is_background: await process_event( diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 521d298df3e..7a5f431e86b 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -281,12 +281,18 @@ def handle_backend_exception(ex: Exception) -> None: @pytest.fixture -def mock_base_state_event_processor_obj() -> BaseStateEventProcessor: +def mock_base_state_event_processor_obj( + monkeypatch: pytest.MonkeyPatch, +) -> BaseStateEventProcessor: """Create a BaseState event processor. + Args: + monkeypatch: pytest monkeypatch fixture. + Returns: A fresh BaseState event processor. """ + monkeypatch.setattr(BaseStateEventProcessor, "_rehydrate", mock.AsyncMock()) def handle_backend_exception(ex: Exception) -> None: formatted_exc = "\n".join(traceback.format_exception(ex)) From 465e6d0dcb750939b7c1fee0763eb0b93ec74bb4 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 02:05:33 -0700 Subject: [PATCH 34/81] incldue coverage from subpackages raise coverage bar back up to 72 at least --- pyproject.toml | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4bfb76a3608..57b12043124 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -234,7 +234,23 @@ ignore-words-list = "te, TreeE, selectin" [tool.coverage.run] -source = ["reflex"] +source = [ + "reflex", + "reflex_components_code", + "reflex_components_core", + "reflex_components_dataeditor", + "reflex_components_gridjs", + "reflex_components_lucide", + "reflex_components_markdown", + "reflex_components_moment", + "reflex_components_plotly", + "reflex_components_radix", + "reflex_components_react_player", + "reflex_components_recharts", + "reflex_components_sonner", + "reflex_core", + "reflex_docgen", +] branch = true omit = [ "*/pyi_generator.py", @@ -247,7 +263,7 @@ omit = [ [tool.coverage.report] show_missing = true # TODO bump back to 79 -fail_under = 50 +fail_under = 72 precision = 2 ignore_errors = true From 06270b36dcda1e026b863c7b76dc3491195b3e23 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 02:20:49 -0700 Subject: [PATCH 35/81] remove simulated pre-hydrated states --- tests/units/test_app.py | 7 ------- tests/units/test_state.py | 19 +++---------------- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/tests/units/test_app.py b/tests/units/test_app.py index b1c31b4641a..10fc957a438 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1771,9 +1771,6 @@ async def test_dynamic_route_var_route_change_completed_on_load( assert constants.ROUTER in app._state()._var_dependencies substate_token = BaseStateToken(ident=token, cls=DynamicState) - async with app.state_manager.modify_state(substate_token) as state: - state.router_data = {"simulate": "hydrated"} - assert state.dynamic == "" # pyright: ignore[reportAttributeAccessIssue] exp_vals = ["foo", "foobar", "baz"] def _event(name, val, **kwargs): @@ -1915,10 +1912,6 @@ async def test_process_events( payload={"c": 5}, router_data={}, ) - async with mock_root_event_context.state_manager.modify_state( - BaseStateToken(ident=token, cls=GenState), - ) as state: - state.router_data = {"simulate": "hydrated"} async with mock_base_state_event_processor as processor: await processor.enqueue( diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 53079075f93..011ab83f533 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2242,10 +2242,6 @@ class BackgroundTaskState(BaseState): dict_list: dict[str, list[int]] = {"foo": [1, 2, 3]} dc: ModelDC = ModelDC() - def __init__(self, **kwargs): # noqa: D107 - super().__init__(**kwargs) - self.router_data = {"simulate": "hydrate"} - @rx.var(cache=False) def computed_order(self) -> list[str]: """Get the order as a computed var. @@ -2909,7 +2905,7 @@ class BaseFieldSetterState(BaseState): assert "c2" in bfss.dirty_vars -def exp_is_hydrated(state: BaseState, is_hydrated: bool = True) -> dict[str, Any]: +def exp_is_hydrated(state: type[BaseState], is_hydrated: bool = True) -> dict[str, Any]: """Expected IS_HYDRATED delta that would be emitted by HydrateMiddleware. Args: @@ -3006,11 +3002,6 @@ def index(): app.add_page(index, on_load=test_state.test_handler) app._compile_page("index") - async with mock_root_event_context.state_manager.modify_state( - BaseStateToken(ident=token, cls=State) - ) as state: - state.router_data = {"simulate": "hydrate"} - on_load_internal_name = format.format_event_handler( OnLoadInternalState.on_load_internal # pyright: ignore[reportArgumentType] ) @@ -3035,7 +3026,7 @@ def index(): assert len(emitted_deltas) >= 2 first_delta = emitted_deltas[0][1] assert first_delta[State.get_full_name()].pop("router" + FIELD_MARKER) is not None - assert first_delta == exp_is_hydrated(state, False) + assert first_delta == exp_is_hydrated(State, False) # Find the delta containing the test handler's state change handler_deltas = [ @@ -3085,10 +3076,6 @@ def index(): app.add_page(index, on_load=[OnLoadState.test_handler, OnLoadState.test_handler]) app._compile_page("index") - async with mock_root_event_context.state_manager.modify_state( - BaseStateToken(ident=token, cls=State) - ) as state: - state.router_data = {"simulate": "hydrate"} on_load_internal_name = format.format_event_handler( OnLoadInternalState.on_load_internal # pyright: ignore[reportArgumentType] @@ -3112,7 +3099,7 @@ def index(): assert len(emitted_deltas) >= 2 first_delta = emitted_deltas[0][1] assert first_delta[State.get_full_name()].pop("router" + FIELD_MARKER) is not None - assert first_delta == exp_is_hydrated(state, False) + assert first_delta == exp_is_hydrated(State, False) # Find deltas containing the test handler's state change (num incremented twice) handler_deltas = [ From 5ed431fe074be5cbbe25956958d6c86874e4a5dc Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 03:09:18 -0700 Subject: [PATCH 36/81] Add unit test cases for new registry/context/processor modules Add unit test cases for reflex.istate.manager.token --- tests/units/istate/manager/test_token.py | 178 +++++++ tests/units/reflex_core/__init__.py | 0 tests/units/reflex_core/_internal/__init__.py | 0 .../reflex_core/_internal/context/__init__.py | 0 .../_internal/context/test_base.py | 94 ++++ .../event/test_base_state_processor.py | 152 ++++++ .../_internal/event/test_context.py | 84 +++ .../_internal/event/test_event_processor.py | 489 ++++++++++++++++++ .../reflex_core/_internal/test_registry.py | 133 +++++ tests/units/test_app.py | 1 + 10 files changed, 1131 insertions(+) create mode 100644 tests/units/istate/manager/test_token.py create mode 100644 tests/units/reflex_core/__init__.py create mode 100644 tests/units/reflex_core/_internal/__init__.py create mode 100644 tests/units/reflex_core/_internal/context/__init__.py create mode 100644 tests/units/reflex_core/_internal/context/test_base.py create mode 100644 tests/units/reflex_core/_internal/event/test_base_state_processor.py create mode 100644 tests/units/reflex_core/_internal/event/test_context.py create mode 100644 tests/units/reflex_core/_internal/event/test_event_processor.py create mode 100644 tests/units/reflex_core/_internal/test_registry.py diff --git a/tests/units/istate/manager/test_token.py b/tests/units/istate/manager/test_token.py new file mode 100644 index 00000000000..acb67fe60dd --- /dev/null +++ b/tests/units/istate/manager/test_token.py @@ -0,0 +1,178 @@ +"""Tests for StateToken, BaseStateToken, and from_legacy_token.""" + +import io +import pickle + +import pytest + +from reflex.istate.manager.token import BaseStateToken, StateToken + + +def test_state_token_str(): + """__str__ encodes ident and cls into 'ident/module.Class' format.""" + token = StateToken(ident="abc-123", cls=int) + assert str(token) == "abc-123/builtins.int" + + +def test_state_token_str_escapes_slashes(): + """Slashes in ident or cls name are percent-encoded.""" + token = StateToken(ident="a/b", cls=int) + result = str(token) + assert "%2F" in result + assert "/" in result + + +def test_state_token_with_cls(): + """with_cls returns a new token with updated cls, leaving the original unchanged.""" + token = StateToken(ident="tok", cls=int) + new = token.with_cls(bool) + assert new.cls is bool + assert new.ident == "tok" + assert token.cls is int + + +def test_state_token_serialize_deserialize_roundtrip(): + """serialize/deserialize with data= round-trips through pickle.""" + value = {"key": [1, 2, 3]} + data = StateToken.serialize(value) + assert isinstance(data, bytes) + assert StateToken.deserialize(data=data) == value + + +def test_state_token_deserialize_from_fp(): + """Deserialize with fp= reads from a file-like object.""" + value = "hello" + buf = io.BytesIO(pickle.dumps(value)) + assert StateToken.deserialize(fp=buf) == value + + +def test_state_token_deserialize_neither_raises(): + """Deserialize with neither data nor fp raises ValueError.""" + with pytest.raises(ValueError, match="Only one"): + StateToken.deserialize() + + +def test_state_token_get_and_reset_touched_state(): + """Default implementation always returns True.""" + assert StateToken.get_and_reset_touched_state("anything") is True + + +def test_base_state_token_str(clean_registration_context): + """__str__ uses 'ident_full_name' format. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class TokState(BaseState): + pass + + token = BaseStateToken(ident="client-abc", cls=TokState) + result = str(token) + assert result.startswith("client-abc_") + assert TokState.get_full_name() in result + + +def test_base_state_token_with_cls(clean_registration_context): + """with_cls returns a BaseStateToken (not a plain StateToken). + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class A(BaseState): + pass + + class B(BaseState): + pass + + token = BaseStateToken(ident="tok", cls=A) + new = token.with_cls(B) + assert isinstance(new, BaseStateToken) + assert new.cls is B + + +def test_base_state_token_serialize_deserialize(clean_registration_context): + """BaseStateToken serialization uses BaseState._serialize/_deserialize. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class SerState(BaseState): + x: int = 42 + + state = SerState() + data = BaseStateToken.serialize(state) + assert isinstance(data, bytes) + restored = BaseStateToken.deserialize(data=data) + assert isinstance(restored, SerState) + assert restored.x == 42 + + +def test_base_state_token_get_and_reset_touched(clean_registration_context): + """get_and_reset_touched_state returns the touched flag and resets it. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class TouchState(BaseState): + x: int = 0 + + state = TouchState() + state._was_touched = True + assert BaseStateToken.get_and_reset_touched_state(state) is True + assert state._was_touched is False + assert BaseStateToken.get_and_reset_touched_state(state) is False + + +def test_from_legacy_token(clean_registration_context): + """from_legacy_token parses 'ident_state.path' into a BaseStateToken. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class LegacyRoot(BaseState): + pass + + full_name = LegacyRoot.get_full_name() + legacy_str = f"my-client-token_{full_name}" + + token = BaseStateToken.from_legacy_token(legacy_str, root_state=LegacyRoot) + assert token.ident == "my-client-token" + assert token.cls is LegacyRoot + + +def test_from_legacy_token_substate(clean_registration_context): + """from_legacy_token resolves a substate path. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class LegRoot(BaseState): + pass + + class LegChild(LegRoot): + pass + + full_name = LegChild.get_full_name() + legacy_str = f"tok_{full_name}" + + token = BaseStateToken.from_legacy_token(legacy_str, root_state=LegRoot) + assert token.ident == "tok" + assert token.cls is LegChild + + +def test_from_legacy_token_none_root_raises(): + """from_legacy_token with root_state=None raises ValueError.""" + with pytest.raises(ValueError, match="Root state must be provided"): + BaseStateToken.from_legacy_token("tok_some.state", root_state=None) diff --git a/tests/units/reflex_core/__init__.py b/tests/units/reflex_core/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/reflex_core/_internal/__init__.py b/tests/units/reflex_core/_internal/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/reflex_core/_internal/context/__init__.py b/tests/units/reflex_core/_internal/context/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/units/reflex_core/_internal/context/test_base.py b/tests/units/reflex_core/_internal/context/test_base.py new file mode 100644 index 00000000000..9e95d552eb1 --- /dev/null +++ b/tests/units/reflex_core/_internal/context/test_base.py @@ -0,0 +1,94 @@ +"""Tests for BaseContext.""" + +import dataclasses + +import pytest +from reflex_core._internal.context.base import BaseContext + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class _TestContext(BaseContext): + """Minimal BaseContext subclass for unit testing.""" + + label: str = "test" + + +def test_get_without_set_raises(): + """get() raises LookupError when no context is set.""" + with pytest.raises(LookupError): + _TestContext.get() + + +def test_set_and_get(): + """set() makes the context retrievable via get().""" + ctx = _TestContext(label="a") + token = _TestContext.set(ctx) + try: + assert _TestContext.get() is ctx + finally: + _TestContext.reset(token) + + +def test_reset_restores_previous(): + """reset() restores the previously active context.""" + outer = _TestContext(label="outer") + outer_tok = _TestContext.set(outer) + try: + inner = _TestContext(label="inner") + inner_tok = _TestContext.set(inner) + assert _TestContext.get() is inner + _TestContext.reset(inner_tok) + assert _TestContext.get() is outer + finally: + _TestContext.reset(outer_tok) + + +def test_context_manager_enter_exit(): + """__enter__ sets the context and __exit__ removes it.""" + ctx = _TestContext(label="cm") + with ctx as entered: + assert entered is ctx + assert _TestContext.get() is ctx + with pytest.raises(LookupError): + _TestContext.get() + + +def test_context_manager_nesting(): + """Nested context managers restore the outer context on inner exit.""" + outer = _TestContext(label="outer") + inner = _TestContext(label="inner") + with outer: + assert _TestContext.get().label == "outer" + with inner: + assert _TestContext.get().label == "inner" + assert _TestContext.get().label == "outer" + + +def test_double_enter_raises(): + """Entering the same context instance twice raises RuntimeError.""" + ctx = _TestContext(label="double") + with ctx, pytest.raises(RuntimeError, match="already attached"): + ctx.__enter__() + + +def test_ensure_context_attached(): + """ensure_context_attached raises when not entered, succeeds when entered.""" + ctx = _TestContext(label="ensure") + with pytest.raises(RuntimeError, match="must be entered"): + ctx.ensure_context_attached() + with ctx: + ctx.ensure_context_attached() + + +def test_subclasses_have_independent_context_vars(): + """Two BaseContext subclasses do not share their ContextVar.""" + + @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) + class _OtherContext(BaseContext): + value: int = 0 + + ctx_a = _TestContext(label="a") + ctx_b = _OtherContext(value=42) + with ctx_a, ctx_b: + assert _TestContext.get().label == "a" + assert _OtherContext.get().value == 42 diff --git a/tests/units/reflex_core/_internal/event/test_base_state_processor.py b/tests/units/reflex_core/_internal/event/test_base_state_processor.py new file mode 100644 index 00000000000..2504598a1d6 --- /dev/null +++ b/tests/units/reflex_core/_internal/event/test_base_state_processor.py @@ -0,0 +1,152 @@ +"""Tests for BaseStateEventProcessor, specifically the _rehydrate path.""" + +import traceback +from collections.abc import Mapping +from typing import Any + +import pytest +import pytest_asyncio +from reflex_core._internal.event.context import EventContext +from reflex_core._internal.event.processor import BaseStateEventProcessor +from reflex_core._internal.registry import RegistrationContext +from reflex_core.constants import CompileVars +from reflex_core.constants.state import FIELD_MARKER + +from reflex import event +from reflex.app import App +from reflex.event import Event +from reflex.istate.manager.memory import StateManagerMemory +from reflex.state import OnLoadInternalState, State + + +@pytest.fixture +def _real_base_state_processor_obj() -> BaseStateEventProcessor: + """A BaseStateEventProcessor with real (unmocked) _rehydrate. + + Returns: + A fresh BaseStateEventProcessor instance. + """ + + def handle_backend_exception(ex: Exception) -> None: + formatted_exc = "\n".join(traceback.format_exception(ex)) + pytest.fail(f"Event processor raised an unexpected exception:\n{formatted_exc}") + + return BaseStateEventProcessor( + backend_exception_handler=handle_backend_exception, + graceful_shutdown_timeout=2, + ) + + +@pytest.fixture +def emitted_deltas() -> list[tuple[str, Mapping[str, Mapping[str, Any]]]]: + """List to capture emitted deltas. + + Returns: + An empty list for collecting deltas. + """ + return [] + + +@pytest.fixture +def emitted_events() -> list[tuple[str, tuple[Event, ...]]]: + """List to capture emitted events. + + Returns: + An empty list for collecting events. + """ + return [] + + +@pytest_asyncio.fixture +async def real_base_state_processor( + _real_base_state_processor_obj: BaseStateEventProcessor, + emitted_deltas: list, + emitted_events: list, + clean_registration_context: RegistrationContext, +): + """A fully wired BaseStateEventProcessor with real _rehydrate. + + Yields the processor (not yet started). The test must use ``async with processor`` to + control the lifecycle and assert on emitted deltas after stop. + + Args: + _real_base_state_processor_obj: The unmocked processor instance. + emitted_deltas: List to capture emitted deltas. + emitted_events: List to capture emitted events. + clean_registration_context: Isolated registration context for the test. + + Yields: + The configured but not-yet-started BaseStateEventProcessor. + """ + clean_registration_context.register_base_state(OnLoadInternalState) + state_manager = StateManagerMemory() + + async def emit_delta_impl( # noqa: RUF029 + token: str, delta: Mapping[str, Mapping[str, Any]] + ) -> None: + emitted_deltas.append((token, delta)) + + async def emit_event_impl(token: str, *events: Event) -> None: # noqa: RUF029 + emitted_events.append((token, events)) + + root_ctx = EventContext( + token="", + state_manager=state_manager, + enqueue_impl=_real_base_state_processor_obj.enqueue, + emit_delta_impl=emit_delta_impl, + emit_event_impl=emit_event_impl, + ) + _real_base_state_processor_obj._root_context = root_ctx + + yield _real_base_state_processor_obj + + await state_manager.close() + + +async def test_rehydrate_sets_is_hydrated_on_fresh_token( + app_module_mock, + real_base_state_processor: BaseStateEventProcessor, + emitted_deltas: list[tuple[str, Mapping[str, Mapping[str, Any]]]], + token: str, +): + """A non-hydrate event against a fresh token triggers _rehydrate, emitting is_hydrated=True. + + When a token has never been seen before (no router_data on the state), + and the event is not itself the hydrate event, the processor calls + _rehydrate which runs State.hydrate. With no on_load events defined, + hydrate sets is_hydrated=True directly. + + Args: + app_module_mock: The mock app module fixture. + real_base_state_processor: The unmocked BaseStateEventProcessor. + emitted_deltas: List to capture emitted deltas. + token: The client token. + """ + + class MyState(State): + @event + def noop(self): + pass + + OnLoadInternalState._app_ref = None + app = app_module_mock.app = App() + assert real_base_state_processor._root_context is not None + app._state_manager = real_base_state_processor._root_context.state_manager + + async with real_base_state_processor as processor: + await processor.enqueue( + token, + *Event.from_event_type(MyState.noop()), + ) + await processor.join(1) + + state_name = State.get_full_name() + is_hydrated_key = CompileVars.IS_HYDRATED + FIELD_MARKER + hydrated_deltas = [ + d + for _, d in emitted_deltas + if state_name in d and d[state_name].get(is_hydrated_key) is True + ] + assert len(hydrated_deltas) >= 1, ( + f"Expected at least one delta with is_hydrated=True, got deltas: {emitted_deltas}" + ) diff --git a/tests/units/reflex_core/_internal/event/test_context.py b/tests/units/reflex_core/_internal/event/test_context.py new file mode 100644 index 00000000000..3d3f0a48cfb --- /dev/null +++ b/tests/units/reflex_core/_internal/event/test_context.py @@ -0,0 +1,84 @@ +"""Tests for EventContext.""" + +from unittest import mock + +from reflex_core._internal.event.context import EventContext + + +def test_fork_creates_child(mock_root_event_context: EventContext): + """fork() creates a child context with a new txid and shared impls. + + Args: + mock_root_event_context: The root event context fixture. + """ + child = mock_root_event_context.fork(token="child-tok") + assert child.token == "child-tok" + assert child.parent_txid == mock_root_event_context.txid + assert child.txid != mock_root_event_context.txid + assert child.state_manager is mock_root_event_context.state_manager + assert child.enqueue_impl is mock_root_event_context.enqueue_impl + + +def test_fork_inherits_token(mock_root_event_context: EventContext): + """fork() without token= inherits the parent's token. + + Args: + mock_root_event_context: The root event context fixture. + """ + child = mock_root_event_context.fork() + assert child.token == mock_root_event_context.token + + +async def test_emit_delta(mock_root_event_context: EventContext, emitted_deltas: list): + """emit_delta records the delta via emit_delta_impl. + + Args: + mock_root_event_context: The root event context fixture. + emitted_deltas: List to capture emitted deltas. + """ + ctx = mock_root_event_context.fork(token="tok") + delta = {"state": {"x": 1}} + await ctx.emit_delta(delta) + assert emitted_deltas == [("tok", delta)] + + +async def test_emit_event(mock_root_event_context: EventContext, emitted_events: list): + """emit_event records the event via emit_event_impl. + + Args: + mock_root_event_context: The root event context fixture. + emitted_events: List to capture emitted events. + """ + from reflex.event import Event + + ctx = mock_root_event_context.fork(token="tok") + ev = Event(name="test", payload={}) + await ctx.emit_event(ev) + assert len(emitted_events) == 1 + assert emitted_events[0][0] == "tok" + + +async def test_emit_delta_noop_when_no_impl(): + """emit_delta is a no-op when emit_delta_impl is None.""" + from reflex.istate.manager.memory import StateManagerMemory + + ctx = EventContext( + token="t", + state_manager=StateManagerMemory(), + enqueue_impl=mock.AsyncMock(), + emit_delta_impl=None, + ) + await ctx.emit_delta({"s": {"k": "v"}}) + + +async def test_emit_event_noop_when_no_impl(): + """emit_event is a no-op when emit_event_impl is None.""" + from reflex.istate.manager.memory import StateManagerMemory + + ctx = EventContext( + token="t", + state_manager=StateManagerMemory(), + enqueue_impl=mock.AsyncMock(), + emit_event_impl=None, + ) + await ctx.emit_event() diff --git a/tests/units/reflex_core/_internal/event/test_event_processor.py b/tests/units/reflex_core/_internal/event/test_event_processor.py new file mode 100644 index 00000000000..0e21e4b8e1c --- /dev/null +++ b/tests/units/reflex_core/_internal/event/test_event_processor.py @@ -0,0 +1,489 @@ +"""Tests for EventProcessor lifecycle, task management, and error handling.""" + +import asyncio +import time +from typing import Any + +import pytest +from reflex_core._internal.event.context import EventContext +from reflex_core._internal.event.processor.event_processor import ( + DrainTimeoutManager, + EventProcessor, +) +from reflex_core._internal.registry import RegistrationContext + +from reflex.event import Event, EventHandler + +# Module-level log so event handlers can record what happened. +_CALL_LOG: list[dict[str, Any]] = [] + + +async def _noop_handler(): + """A handler that does nothing.""" + + +async def _slow_handler(delay: float = 0.5): + """A handler that sleeps for *delay* seconds. + + Args: + delay: How long to sleep in seconds. + """ + await asyncio.sleep(delay) + + +async def _error_handler(): # noqa: RUF029 + """A handler that always raises.""" + raise RuntimeError("boom") # noqa: EM101 + + +async def _logging_handler(value: str = "default"): # noqa: RUF029 + """A handler that records its invocation. + + Args: + value: The value to log. + """ + _CALL_LOG.append({"value": value}) + + +async def _chaining_handler(): + """A handler that enqueues a logging event via the current EventContext.""" + ctx = EventContext.get() + await ctx.enqueue( + *Event.from_event_type(logging_event("chained")), + ) + + +async def _delta_handler(): + """A handler that emits a single delta.""" + ctx = EventContext.get() + await ctx.emit_delta({"state": {"x": 1}}) + + +async def _multi_delta_handler(): + """A handler that emits multiple deltas with a small pause between them.""" + ctx = EventContext.get() + for i in range(3): + await ctx.emit_delta({"state": {"i": i}}) + await asyncio.sleep(0.01) + + +noop_event = EventHandler(fn=_noop_handler) +slow_event = EventHandler(fn=_slow_handler) +error_event = EventHandler(fn=_error_handler) +logging_event = EventHandler(fn=_logging_handler) +chaining_event = EventHandler(fn=_chaining_handler) +delta_event = EventHandler(fn=_delta_handler) +multi_delta_event = EventHandler(fn=_multi_delta_handler) + + +@pytest.fixture(autouse=True) +def _register_handlers(forked_registration_context: RegistrationContext): + """Register all test event handlers and clear the call log. + + Args: + forked_registration_context: Isolated registration context for the test. + """ + _CALL_LOG.clear() + for handler in ( + noop_event, + slow_event, + error_event, + logging_event, + chaining_event, + delta_event, + multi_delta_event, + ): + RegistrationContext.register_event_handler(handler) + + +@pytest.fixture +def processor() -> EventProcessor: + """A bare EventProcessor with no backend_exception_handler. + + Returns: + A fresh EventProcessor instance. + """ + return EventProcessor(graceful_shutdown_timeout=2) + + +def test_configure_once(processor: EventProcessor): + """Calling configure() twice raises RuntimeError. + + Args: + processor: The event processor fixture. + """ + processor.configure() + with pytest.raises(RuntimeError, match="already configured"): + processor.configure() + + +async def test_start_before_configure(processor: EventProcessor): + """Starting before configure raises RuntimeError. + + Args: + processor: The event processor fixture. + """ + with pytest.raises(RuntimeError, match="not configured"): + await processor.start() + + +async def test_start_twice(processor: EventProcessor): + """Starting a second time raises RuntimeError. + + Args: + processor: The event processor fixture. + """ + processor.configure() + await processor.start() + try: + with pytest.raises(RuntimeError, match="already started"): + await processor.start() + finally: + await processor.stop() + + +async def test_stop_idempotent(processor: EventProcessor): + """Stopping an already-stopped processor does not error. + + Args: + processor: The event processor fixture. + """ + processor.configure() + await processor.start() + await processor.stop() + await processor.stop() + + +async def test_async_context_manager(processor: EventProcessor): + """Entering/exiting via ``async with`` starts and stops the processor. + + Args: + processor: The event processor fixture. + """ + processor.configure() + async with processor as ep: + assert ep._queue is not None + assert ep._queue is None + assert ep._queue_task is None + + +async def test_enqueue_after_stop_raises(processor: EventProcessor): + """Enqueueing after stop raises because the queue is gone. + + Args: + processor: The event processor fixture. + """ + processor.configure() + async with processor: + pass + with pytest.raises(RuntimeError, match="not running"): + await processor.enqueue("tok", *Event.from_event_type(noop_event())) + + +async def test_enqueue_before_start_raises(processor: EventProcessor): + """Enqueueing before start raises because the queue doesn't exist. + + Args: + processor: The event processor fixture. + """ + processor.configure() + with pytest.raises(RuntimeError, match="not running"): + await processor.enqueue("tok", *Event.from_event_type(noop_event())) + + +async def test_events_are_processed( + mock_event_processor: EventProcessor, + emitted_deltas: list, + token: str, +): + """Events enqueued are actually processed. + + Args: + mock_event_processor: The event processor with mock root context. + emitted_deltas: List to capture emitted deltas. + token: The client token. + """ + async with mock_event_processor as ep: + await ep.enqueue(token, *Event.from_event_type(logging_event("hello"))) + assert _CALL_LOG == [{"value": "hello"}] + + +async def test_enqueue_returns_future( + mock_event_processor: EventProcessor, + token: str, +): + """enqueue() returns a Future that resolves when the task finishes. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + future = await ep.enqueue(token, *Event.from_event_type(noop_event())) + assert isinstance(future, asyncio.Future) + assert future.done() + + +async def test_tasks_cleared_after_stop( + mock_event_processor: EventProcessor, + token: str, +): + """After stop(), the internal _tasks dict is empty. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + await ep.enqueue(token, *Event.from_event_type(noop_event())) + assert ep._tasks == {} + + +async def test_futures_cleared_after_stop( + mock_event_processor: EventProcessor, + token: str, +): + """After stop(), the internal _futures dict is empty. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + await ep.enqueue(token, *Event.from_event_type(noop_event())) + assert ep._futures == {} + + +async def test_slow_tasks_cancelled_on_stop(processor: EventProcessor): + """Tasks that haven't finished by the graceful timeout are cancelled. + + Args: + processor: The event processor fixture. + """ + processor.graceful_shutdown_timeout = 0 + processor.configure() + async with processor as ep: + future = await ep.enqueue("tok", *Event.from_event_type(slow_event(10.0))) + assert future.cancelled() + assert ep._tasks == {} + + +async def test_multiple_futures_cancelled_on_stop(processor: EventProcessor): + """Unresolved futures are cancelled during stop. + + Args: + processor: The event processor fixture. + """ + processor.graceful_shutdown_timeout = 0 + processor.configure() + async with processor as ep: + f1 = await ep.enqueue("t1", *Event.from_event_type(slow_event(10.0))) + f2 = await ep.enqueue("t2", *Event.from_event_type(slow_event(10.0))) + for f in (f1, f2): + assert f.done() + assert ep._futures == {} + + +async def test_cancel_future_before_task_starts( + mock_event_processor: EventProcessor, + token: str, +): + """Cancelling the future before the task starts skips processing. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + future = await ep.enqueue(token, *Event.from_event_type(slow_event(10.0))) + future.cancel() + await asyncio.sleep(0.05) + assert ep._tasks == {} + + +async def test_cancel_future_cancels_running_task( + mock_event_processor: EventProcessor, + token: str, +): + """Cancelling the future cancels an already-running task. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + future = await ep.enqueue(token, *Event.from_event_type(slow_event(10.0))) + await asyncio.sleep(0.05) + future.cancel() + await asyncio.sleep(0.05) + assert ep._tasks == {} + + +async def test_exception_propagated_to_future( + processor: EventProcessor, + token: str, +): + """An exception in the handler is set on the future. + + Args: + processor: The event processor fixture. + token: The client token. + """ + processor.configure() + async with processor as ep: + future = await ep.enqueue(token, *Event.from_event_type(error_event())) + assert future.done() + with pytest.raises(RuntimeError, match="boom"): + future.result() + + +async def test_backend_exception_handler_called(token: str): + """The backend_exception_handler receives the exception. + + Args: + token: The client token. + """ + caught: list[Exception] = [] + + def _catch(ex: Exception) -> None: + caught.append(ex) + + ep = EventProcessor(backend_exception_handler=_catch, graceful_shutdown_timeout=2) + ep.configure() + async with ep: + await ep.enqueue(token, *Event.from_event_type(error_event())) + assert len(caught) == 1 + assert isinstance(caught[0], RuntimeError) + + +async def test_error_does_not_stop_queue( + processor: EventProcessor, + token: str, +): + """A failing event does not prevent subsequent events from processing. + + Args: + processor: The event processor fixture. + token: The client token. + """ + processor.configure() + async with processor as ep: + await ep.enqueue(token, *Event.from_event_type(error_event())) + await ep.enqueue(token, *Event.from_event_type(logging_event("after_error"))) + assert _CALL_LOG == [{"value": "after_error"}] + + +def test_drain_timeout_no_timeout(): + """DrainTimeoutManager with no timeout returns 0.""" + dtm = DrainTimeoutManager.with_timeout(None) + with dtm as remaining: + assert remaining == 0 + + +def test_drain_timeout_decreases(): + """DrainTimeoutManager remaining time decreases across re-entries.""" + dtm = DrainTimeoutManager.with_timeout(10.0) + with dtm as first: + assert 9.5 < first <= 10.0 + time.sleep(0.1) + with dtm as second: + assert second < first + + +def test_drain_timeout_expired_returns_zero(): + """DrainTimeoutManager with an already-expired timeout returns 0.""" + dtm = DrainTimeoutManager.with_timeout(0) + with dtm as remaining: + assert remaining == 0 + + +async def test_chained_event_processed(token: str): + """An event handler that enqueues another event via ctx.enqueue succeeds. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + await ep.enqueue(token, *Event.from_event_type(chaining_event())) + assert _CALL_LOG == [{"value": "chained"}] + + +async def test_join_when_not_started(processor: EventProcessor): + """join() when not started is a no-op (queue is None). + + Args: + processor: The event processor fixture. + """ + processor.configure() + await processor.join(timeout=1) + + +async def test_join_completes_after_processing( + mock_event_processor: EventProcessor, + token: str, +): + """join() returns once all queued entries have been dequeued. + + Args: + mock_event_processor: The event processor with mock root context. + token: The client token. + """ + async with mock_event_processor as ep: + await ep.enqueue(token, *Event.from_event_type(noop_event())) + await ep.join(timeout=5) + + +async def test_stream_delta_yields_single_delta(token: str): + """enqueue_stream_delta yields a delta emitted by the handler. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + event = Event.from_event_type(delta_event())[0] + collected = [d async for d in ep.enqueue_stream_delta(token, event)] + assert collected == [{"state": {"x": 1}}] + + +async def test_stream_delta_yields_multiple_deltas(token: str): + """enqueue_stream_delta yields all deltas in order. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + event = Event.from_event_type(multi_delta_event())[0] + collected = [d async for d in ep.enqueue_stream_delta(token, event)] + assert collected == [ + {"state": {"i": 0}}, + {"state": {"i": 1}}, + {"state": {"i": 2}}, + ] + + +async def test_stream_delta_noop_handler_yields_nothing(token: str): + """enqueue_stream_delta with a handler that emits no deltas yields nothing. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + event = Event.from_event_type(noop_event())[0] + collected = [d async for d in ep.enqueue_stream_delta(token, event)] + assert collected == [] + + +async def test_stream_delta_not_configured_raises(): + """enqueue_stream_delta raises RuntimeError if processor is not configured.""" + ep = EventProcessor() + with pytest.raises(RuntimeError, match="not configured"): + async for _ in ep.enqueue_stream_delta("tok", Event(name="x", payload={})): + pass diff --git a/tests/units/reflex_core/_internal/test_registry.py b/tests/units/reflex_core/_internal/test_registry.py new file mode 100644 index 00000000000..d10f79f7c65 --- /dev/null +++ b/tests/units/reflex_core/_internal/test_registry.py @@ -0,0 +1,133 @@ +"""Tests for RegistrationContext.""" + +import pytest +from reflex_core._internal.registry import RegisteredEventHandler, RegistrationContext +from reflex_core.utils.exceptions import StateValueError + + +def test_ensure_context_creates_if_missing(): + """ensure_context() returns existing context or creates a new one.""" + try: + existing = RegistrationContext._context_var.get() + assert RegistrationContext.ensure_context() is existing + except LookupError: + ctx = RegistrationContext.ensure_context() + assert isinstance(ctx, RegistrationContext) + assert RegistrationContext.get() is ctx + + +def test_clean_context_is_empty(clean_registration_context: RegistrationContext): + """A clean context starts with no handlers or states. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + assert clean_registration_context.event_handlers == {} + assert clean_registration_context.base_states == {} + assert clean_registration_context.base_state_substates == {} + + +def test_register_event_handler(clean_registration_context: RegistrationContext): + """register_event_handler stores the handler keyed by its full name. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.event import EventHandler + + async def my_fn(): + pass + + handler = EventHandler(fn=my_fn) + RegistrationContext.register_event_handler(handler) + assert len(clean_registration_context.event_handlers) == 1 + registered = next(iter(clean_registration_context.event_handlers.values())) + assert isinstance(registered, RegisteredEventHandler) + assert registered.handler is handler + + +def test_register_base_state(clean_registration_context: RegistrationContext): + """BaseState metaclass auto-registers during class definition into the active context. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class AutoRegistered(BaseState): + x: int = 0 + + assert AutoRegistered.get_full_name() in clean_registration_context.base_states + + +def test_duplicate_substate_raises(clean_registration_context: RegistrationContext): + """Registering the same substate twice raises StateValueError. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class DupParent(BaseState): + pass + + class DupChild(DupParent): + pass + + with pytest.raises(StateValueError, match="already registered"): + clean_registration_context._register_base_state(DupChild) + + +def test_get_substates(clean_registration_context: RegistrationContext): + """get_substates returns registered children of a parent. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class GetSubRoot(BaseState): + pass + + class GetSub1(GetSubRoot): + pass + + class GetSub2(GetSubRoot): + pass + + substates = clean_registration_context.get_substates(GetSubRoot) + assert GetSub1 in substates + assert GetSub2 in substates + + +def test_get_substates_by_name(clean_registration_context: RegistrationContext): + """get_substates also works when passed a string full name. + + Args: + clean_registration_context: A fresh, empty registration context. + """ + from reflex.state import BaseState + + class NamedState(BaseState): + pass + + result = clean_registration_context.get_substates(NamedState.get_full_name()) + assert isinstance(result, set) + + +def test_forked_context_is_independent( + forked_registration_context: RegistrationContext, +): + """Changes to a forked context do not affect the original. + + Args: + forked_registration_context: A deep copy of the current registration context. + """ + from reflex.event import EventHandler + + async def _tmp(): + pass + + handler = EventHandler(fn=_tmp) + RegistrationContext.register_event_handler(handler) + assert len(forked_registration_context.event_handlers) > 0 diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 10fc957a438..8c3661764c4 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -1755,6 +1755,7 @@ async def test_dynamic_route_var_route_change_completed_on_load( emitted_deltas: List to store emitted deltas. emitted_events: List to store emitted events. """ + OnLoadInternalState._app_ref = None arg_name = "dynamic" route = f"test/[{arg_name}]" app = app_module_mock.app = App() From 7bdc7dfc437d2e63b814b9210f2ac04500420bf1 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 09:48:10 -0700 Subject: [PATCH 37/81] Use correct token in enqueue_stream_delta Use the delta's token when emitting to the processor queue. Return after emitting to the processor queue so the caller does not get deltas from unrelated tokens. Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- .../reflex_core/_internal/event/processor/event_processor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py index a1f34c903a1..aa36afd6a3b 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py @@ -402,7 +402,8 @@ async def _emit_delta_impl( and self._root_context.emit_delta_impl is not None ): # Emit deltas for other tokens normally. - await self._root_context.emit_delta_impl(token, delta) + await self._root_context.emit_delta_impl(delta_token, delta) + return await deltas.put(delta) task_future = await self.enqueue( From c0466138791210533147a3117b56c584d0e505e1 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 09:49:13 -0700 Subject: [PATCH 38/81] Fix StateToken.deserialize implementation Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- reflex/istate/manager/token.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/reflex/istate/manager/token.py b/reflex/istate/manager/token.py index b9b4b75cfd9..9abe78018c5 100644 --- a/reflex/istate/manager/token.py +++ b/reflex/istate/manager/token.py @@ -69,22 +69,15 @@ def deserialize( Args: data: The serialized state data. fp: The file pointer to the serialized state data. - - Returns: - The raw deserialized state ("should match the token type"). - - Raises: - ValueError: If both data and fp are provided, or neither are provided. - """ - if data is not None and fp is None: + if data is not None and fp is not None: + msg = "Only one of `data` or `fp` may be provided, not both." + raise ValueError(msg) + if data is not None: return pickle.loads(data) if fp is not None: return pickle.load(fp) - msg = "Only one of `data` or `fp` must be provided" + msg = "At least one of `data` or `fp` must be provided." raise ValueError(msg) - - @classmethod - def get_and_reset_touched_state(cls, state: TOKEN_TYPE) -> bool: """Get the touched state and reset the touched flag. This is used to determine if a state has been modified since it was last serialized. From 945662b42118f978e700a96cd34a48af7133d581 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 09:49:51 -0700 Subject: [PATCH 39/81] Update packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- .../_internal/event/processor/base_state_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py index 02aed3a1a99..2d679031973 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py @@ -122,9 +122,9 @@ def _transform_event_arg(value: Any, hinted_args: Any) -> Any: return hinted_args.parse_obj(value) if issubclass(hinted_args, BaseModelV2): return hinted_args.model_validate(value) - if isinstance(value, list) and (hinted_args is set or hinted_args is set): + if isinstance(value, list) and (hinted_args is set or hinted_args is frozenset): return set(value) - if isinstance(value, list) and (hinted_args is tuple or hinted_args is tuple): + if isinstance(value, list) and hinted_args is tuple: return tuple(value) if isinstance(hinted_args, type) and issubclass(hinted_args, Enum): try: From 28b105fd2fb65cdfd412d08b73c39cee14a94a02 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 09:56:58 -0700 Subject: [PATCH 40/81] Fix StateToken mismerge (Thanks greptile) Remove TODO now that issue is created in repo. --- packages/reflex-core/src/reflex_core/event.py | 1 - reflex/istate/manager/token.py | 7 +++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/packages/reflex-core/src/reflex_core/event.py b/packages/reflex-core/src/reflex_core/event.py index cbb2edd56b7..2853e2ce0ff 100644 --- a/packages/reflex-core/src/reflex_core/event.py +++ b/packages/reflex-core/src/reflex_core/event.py @@ -132,7 +132,6 @@ def from_event_type( msg = f"Unexpected event type, {type(e)}." raise ValueError(msg) name = format.format_event_handler(e.handler) - # TODO: allow real python types to be passed through the backend queue. payload = {k._js_expr: v._decode() for k, v in e.args} # Create an event and append it to the list. diff --git a/reflex/istate/manager/token.py b/reflex/istate/manager/token.py index 9abe78018c5..928cb69b048 100644 --- a/reflex/istate/manager/token.py +++ b/reflex/istate/manager/token.py @@ -69,6 +69,10 @@ def deserialize( Args: data: The serialized state data. fp: The file pointer to the serialized state data. + + Returns: + The deserialized state instance. + """ if data is not None and fp is not None: msg = "Only one of `data` or `fp` may be provided, not both." raise ValueError(msg) @@ -78,6 +82,9 @@ def deserialize( return pickle.load(fp) msg = "At least one of `data` or `fp` must be provided." raise ValueError(msg) + + @classmethod + def get_and_reset_touched_state(cls, state: TOKEN_TYPE) -> bool: """Get the touched state and reset the touched flag. This is used to determine if a state has been modified since it was last serialized. From 50214536d07e6396bb58875745a206b6208f896c Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 10:02:39 -0700 Subject: [PATCH 41/81] move EventChain import to avoid circular dep --- packages/reflex-core/src/reflex_core/components/tags/tag.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/reflex-core/src/reflex_core/components/tags/tag.py b/packages/reflex-core/src/reflex_core/components/tags/tag.py index 3d9312c6ad4..711c8d56cb3 100644 --- a/packages/reflex-core/src/reflex_core/components/tags/tag.py +++ b/packages/reflex-core/src/reflex_core/components/tags/tag.py @@ -6,7 +6,6 @@ from collections.abc import Iterator, Mapping, Sequence from typing import Any -from reflex_core.event import EventChain from reflex_core.utils import format from reflex_core.vars.base import LiteralVar, Var @@ -89,6 +88,8 @@ def add_props(self, **kwargs: Any | None) -> Tag: Returns: The tag with the props added. """ + from reflex_core.event import EventChain + return dataclasses.replace( self, props={ From d54e178f763c873c6bc6885c0122d59b910110d6 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 10:03:08 -0700 Subject: [PATCH 42/81] fix StateToken deserialize tests --- tests/units/istate/manager/test_token.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/units/istate/manager/test_token.py b/tests/units/istate/manager/test_token.py index acb67fe60dd..d6f9411b608 100644 --- a/tests/units/istate/manager/test_token.py +++ b/tests/units/istate/manager/test_token.py @@ -48,10 +48,16 @@ def test_state_token_deserialize_from_fp(): def test_state_token_deserialize_neither_raises(): """Deserialize with neither data nor fp raises ValueError.""" - with pytest.raises(ValueError, match="Only one"): + with pytest.raises(ValueError, match="At least one"): StateToken.deserialize() +def test_state_token_deserialize_both_raises(): + """Deserialize with both data and fp raises ValueError.""" + with pytest.raises(ValueError, match="Only one"): + StateToken.deserialize(data=b"data", fp=io.BytesIO()) + + def test_state_token_get_and_reset_touched_state(): """Default implementation always returns True.""" assert StateToken.get_and_reset_touched_state("anything") is True From aece85b5133e5e6d1ac6df8cac25e8dac6941322 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 11:14:10 -0700 Subject: [PATCH 43/81] Python 3.11 and 3.12 compatibility * Vendor the new async iterator `asyncio.as_completed` for < 3.13 * Alternative Queue.shutdown mechanism for < 3.13 * Alternative to `Task.get_context` for < 3.12 * from __future__ import annotations for new modules (so TYPE_CHECKING imports work) * bug with cls super() call on dataclass with slots=True * typing_extensions.deprecated --- .../src/reflex_core/_internal/context/base.py | 4 +- .../reflex_core/_internal/event/context.py | 4 +- .../event/processor/base_state_processor.py | 23 +++-- .../_internal/event/processor/compat.py | 87 +++++++++++++++++++ .../event/processor/event_processor.py | 76 +++++++++++----- .../src/reflex_core/_internal/registry.py | 2 + packages/reflex-core/src/reflex_core/event.py | 4 +- reflex/app.py | 6 +- reflex/istate/manager/token.py | 6 +- .../_internal/event/test_event_processor.py | 5 +- 10 files changed, 178 insertions(+), 39 deletions(-) create mode 100644 packages/reflex-core/src/reflex_core/_internal/event/processor/compat.py diff --git a/packages/reflex-core/src/reflex_core/_internal/context/base.py b/packages/reflex-core/src/reflex_core/_internal/context/base.py index 790d0b8db26..23367de66e9 100644 --- a/packages/reflex-core/src/reflex_core/_internal/context/base.py +++ b/packages/reflex-core/src/reflex_core/_internal/context/base.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dataclasses from contextvars import ContextVar, Token from typing import ClassVar, Self @@ -13,7 +15,7 @@ class BaseContext: @classmethod def __init_subclass__(cls, **kwargs): """Initialize the context variable for the subclass.""" - super().__init_subclass__(**kwargs) + super(BaseContext, cls).__init_subclass__(**kwargs) cls._context_var = ContextVar(cls.__name__) cls._attached_context_token = {} diff --git a/packages/reflex-core/src/reflex_core/_internal/event/context.py b/packages/reflex-core/src/reflex_core/_internal/event/context.py index 1c964b44dfb..e9aef7d2f3b 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/context.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/context.py @@ -1,5 +1,7 @@ """The context and associated metadata for handling an event.""" +from __future__ import annotations + import dataclasses import functools import uuid @@ -97,7 +99,7 @@ class EventContext(BaseContext): default_factory=dict, init=False, repr=False ) - def fork(self, token: str | None = None) -> "EventContext": + def fork(self, token: str | None = None) -> EventContext: """Return a new EventContext with the specified fields replaced. Args: diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py index 2d679031973..4b8cf449ed4 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py @@ -1,5 +1,7 @@ """Functions for processing BaseState-derived event handlers.""" +from __future__ import annotations + import dataclasses import functools import inspect @@ -380,19 +382,24 @@ async def _process_event_queue_entry( root_state=root_state, ) - async def _handle_backend_exception(self, ex: Exception): + async def _handle_backend_exception( + self, ex: Exception, ev_ctx: EventContext | None = None + ) -> None: """Handle an exception raised during event processing by calling the backend exception handler if it exists. Args: ex: The exception that was raised. + ev_ctx: The event context for the exception. """ - if self.backend_exception_handler is not None and ( - events := self.backend_exception_handler(ex) - ): - await chain_updates( - events=events, - handler_name=self.backend_exception_handler.__qualname__, - ) + if self.backend_exception_handler is not None: + if ev_ctx is not None: + # Ensure the event context is set for the exception handler. + EventContext.set(ev_ctx) + if events := self.backend_exception_handler(ex): + await chain_updates( + events=events, + handler_name=self.backend_exception_handler.__qualname__, + ) __all__ = ["BaseStateEventProcessor", "chain_updates", "process_event"] diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/compat.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/compat.py new file mode 100644 index 00000000000..040e1aff475 --- /dev/null +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/compat.py @@ -0,0 +1,87 @@ +"""Compatibility shims since asyncio changes quite a bit from 3.11 to 3.14.""" + +import asyncio +import sys + +if sys.version_info >= (3, 13): + from asyncio import as_completed as as_completed +else: + # The following implementation of as_completed is adapted from Python 3.14 + # python/cpython@9e1f1644cd7b7661f0748bb37351836e8d6f37e2 + + class _AsCompletedIterator: + """Iterator of awaitables representing tasks of asyncio.as_completed. + + As an asynchronous iterator, iteration yields futures as they finish. As a + plain iterator, new coroutines are yielded that will return or raise the + result of the next underlying future to complete. + """ + + def __init__(self, aws, timeout): # noqa: ANN001 + self._done = asyncio.Queue() + self._timeout_handle = None + + loop = asyncio.get_event_loop() + todo = {asyncio.ensure_future(aw, loop=loop) for aw in set(aws)} + for f in todo: + f.add_done_callback(self._handle_completion) + if todo and timeout is not None: + self._timeout_handle = loop.call_later(timeout, self._handle_timeout) + self._todo = todo + self._todo_left = len(todo) + + def __aiter__(self): + return self + + def __iter__(self): + return self + + async def __anext__(self): + if not self._todo_left: + raise StopAsyncIteration + assert self._todo_left > 0 + self._todo_left -= 1 + return await self._wait_for_one() + + def __next__(self): + if not self._todo_left: + raise StopIteration + assert self._todo_left > 0 + self._todo_left -= 1 + return self._wait_for_one(resolve=True) + + def _handle_timeout(self): + for f in self._todo: + f.remove_done_callback(self._handle_completion) + self._done.put_nowait(None) # Sentinel for _wait_for_one(). + self._todo.clear() # Can't do todo.remove(f) in the loop. + + def _handle_completion(self, f): # noqa: ANN001 + if not self._todo: + return # _handle_timeout() was here first. + self._todo.remove(f) + self._done.put_nowait(f) + if not self._todo and self._timeout_handle is not None: + self._timeout_handle.cancel() + + async def _wait_for_one(self, resolve=False): # noqa: ANN001 + # Wait for the next future to be done and return it unless resolve is + # set, in which case return either the result of the future or raise + # an exception. + f = await self._done.get() + if f is None: + # Dummy value from _handle_timeout(). + raise asyncio.TimeoutError + return f.result() if resolve else f + + def as_completed(aws, *, timeout=None): # noqa: ANN001 + """Return an iterator of coroutines that yield the results of the given awaitables. + + The coroutines are ordered in the order in which the given awaitables complete. + If a given awaitable raises an exception, the corresponding coroutine raises the same exception. + + Args: + aws: An iterable of awaitables. + timeout: If provided, the maximum number of seconds to wait for the next awaitable to complete before raising a TimeoutError. + """ + return _AsCompletedIterator(aws, timeout) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py index aa36afd6a3b..aaa10d79df9 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py @@ -1,9 +1,12 @@ """Base EventProcessor class for handling backend event queue.""" +from __future__ import annotations + import asyncio import contextlib import dataclasses import inspect +import sys import time import traceback from collections.abc import AsyncGenerator, Callable, Mapping @@ -16,12 +19,23 @@ from reflex.istate.manager import StateManager from reflex.utils import console from reflex_core._internal.event.context import EventContext +from reflex_core._internal.event.processor.compat import as_completed from reflex_core._internal.registry import RegisteredEventHandler, RegistrationContext if TYPE_CHECKING: from reflex.app import EventNamespace from reflex.event import Event, EventSpec +if hasattr(asyncio, "QueueShutDown"): + + class QueueShutDown(asyncio.QueueShutDown): # pyright: ignore[reportRedeclaration] + """Exception raised when trying to put an item into a shut down queue.""" + +else: + + class QueueShutDown(Exception): # noqa: N818 + """Exception raised when trying to put an item into a shut down queue.""" + @dataclasses.dataclass(kw_only=True, slots=True) class DrainTimeoutManager: @@ -36,7 +50,7 @@ class DrainTimeoutManager: drain_deadline: float | None = None @classmethod - def with_timeout(cls, timeout: float | None) -> "DrainTimeoutManager": + def with_timeout(cls, timeout: float | None) -> DrainTimeoutManager: """Create a DrainTimeoutManager with a specified timeout. Args: @@ -226,9 +240,7 @@ async def _stop_tasks(self, timeout: float | None = None) -> None: # Graceful drain time, wait for tasks to finish and handle any exceptions. if timeout is not None and self._tasks: with contextlib.suppress(asyncio.TimeoutError): - async for task in asyncio.as_completed( - self._tasks.values(), timeout=timeout - ): + async for task in as_completed(self._tasks.values(), timeout=timeout): # Exceptions are handled in _finish_task and ignored here. with contextlib.suppress(Exception): await task @@ -271,14 +283,15 @@ async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: await self._stop_tasks(timeout=remaining_time) # Cancel queue processing now that all tasks have been cancelled. if self._queue is not None: - self._queue.shutdown() + if sys.version_info >= (3, 13): + self._queue.shutdown() + self._queue = None with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): if remaining_time > 0: await self.join(timeout=remaining_time) with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): # Stop all tasks again now that the queue is shut down, no additional events can be queued. await self._stop_tasks(timeout=remaining_time) - self._queue = None if self._queue_task is not None: self._queue_task.cancel() try: @@ -324,7 +337,7 @@ def _ensure_queue_task(self) -> asyncio.Queue[EventQueueEntry]: raise RuntimeError(msg) if self._queue is None: msg = "Event processor is not running, call .start(...) first." - raise RuntimeError(msg) + raise QueueShutDown(msg) if self._queue_task is None: task_context = copy_context() task_context.run(EventContext.set, self._root_context) @@ -415,15 +428,24 @@ async def _emit_delta_impl( emit_delta_impl=_emit_delta_impl, ), ) - while not task_future.done() or not deltas.empty(): - with contextlib.suppress(asyncio.TimeoutError): - async for result in asyncio.as_completed( - [deltas.get(), *([task_future] if not task_future.done() else [])], - timeout=1, - ): - if result is task_future: - continue - yield await result + waiting_for = {task_future, asyncio.create_task(deltas.get())} + try: + while not task_future.done() or not deltas.empty(): + with contextlib.suppress(asyncio.TimeoutError): + async for result in as_completed( + waiting_for, + timeout=1, + ): + waiting_for.remove(result) + if result is not task_future: + yield await result + waiting_for.add(asyncio.create_task(deltas.get())) + break + finally: + for future in waiting_for: + future.cancel() + # Raise any exceptions for the caller. + await task_future def _on_future_done(self, txid: str, future: asyncio.Future) -> None: """Callback invoked when an enqueued future completes. @@ -473,7 +495,7 @@ async def _process_queue(self): if (queue := self._queue) is None: msg = "Event processor is not running, call .start(...) first." raise RuntimeError(msg) - with contextlib.suppress(asyncio.QueueShutDown): + with contextlib.suppress(QueueShutDown): while True: entry = await queue.get() if ( @@ -501,6 +523,8 @@ async def _process_queue(self): f"reflex_event|{entry.event.name}|{entry.ctx.token}|{time.time()}" ), ) + if sys.version_info < (3, 12): + task._event_ctx = entry.ctx # pyright: ignore[reportAttributeAccessIssue] self._tasks[entry.ctx.txid] = task task.add_done_callback(self._finish_task) except Exception: @@ -514,13 +538,18 @@ async def _process_queue(self): if self._queue_task is asyncio.current_task(): self._queue_task = None - async def _handle_backend_exception(self, ex: Exception): + async def _handle_backend_exception( + self, ex: Exception, ev_ctx: EventContext | None = None + ) -> None: """Handle an exception raised during event processing by calling the backend exception handler if it exists. Args: ex: The exception that was raised. + ev_ctx: The event context for the exception, if available. This will be set in the context variable when calling the exception handler. """ if self.backend_exception_handler is not None: + if ev_ctx is not None: + EventContext.set(ev_ctx) self.backend_exception_handler(ex) def _finish_task(self, task: asyncio.Task): @@ -535,7 +564,11 @@ def _finish_task(self, task: asyncio.Task): """ from reflex.utils import telemetry - task_ctx = task.get_context().run(EventContext.get) + if sys.version_info < (3, 12): + # py3.11 compat + task_ctx = task._event_ctx # type: ignore[attr-defined] + else: + task_ctx = task.get_context().run(EventContext.get) self._tasks.pop(task_ctx.txid, None) future = self._futures.pop(task_ctx.txid, None) if task.done(): @@ -556,9 +589,8 @@ def _finish_task(self, task: asyncio.Task): and self.backend_exception_handler is not None ): # Create a new task in the same context to invoke the exception handler. - t = self._tasks[task_ctx.txid] = task.get_context().run( - asyncio.create_task, - self._handle_backend_exception(ex), + t = self._tasks[task_ctx.txid] = asyncio.create_task( + self._handle_backend_exception(ex, ev_ctx=task_ctx), name=f"reflex_backend_exception_handler|task=[{task.get_name()}]|{time.time()}", ) t.add_done_callback(self._finish_task) diff --git a/packages/reflex-core/src/reflex_core/_internal/registry.py b/packages/reflex-core/src/reflex_core/_internal/registry.py index ebe881718d7..c43177a9e89 100644 --- a/packages/reflex-core/src/reflex_core/_internal/registry.py +++ b/packages/reflex-core/src/reflex_core/_internal/registry.py @@ -1,5 +1,7 @@ """A contextual registry for state and event handlers.""" +from __future__ import annotations + import dataclasses from typing import TYPE_CHECKING, Self diff --git a/packages/reflex-core/src/reflex_core/event.py b/packages/reflex-core/src/reflex_core/event.py index 2853e2ce0ff..18b9868d61f 100644 --- a/packages/reflex-core/src/reflex_core/event.py +++ b/packages/reflex-core/src/reflex_core/event.py @@ -79,7 +79,7 @@ class Event: payload: dict[str, Any] = dataclasses.field(default_factory=dict) @property - def state_cls(self) -> type[BaseState]: + def state_cls(self) -> "type[BaseState]": """The state class for the event.""" from reflex_core._internal.registry import RegistrationContext @@ -339,7 +339,7 @@ class EventHandler(EventActionsMixin): fn: Any = dataclasses.field(default=None) - state: type[BaseState] | None = dataclasses.field(default=None, repr=False) + state: "type[BaseState] | None" = dataclasses.field(default=None, repr=False) @property def state_full_name(self) -> str: diff --git a/reflex/app.py b/reflex/app.py index 92e030194a2..25c80567924 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -22,7 +22,6 @@ from timeit import default_timer as timer from types import SimpleNamespace from typing import TYPE_CHECKING, Any, ParamSpec, overload -from warnings import deprecated from reflex_components_core.base.app_wrap import AppWrap from reflex_components_core.base.error_boundary import ErrorBoundary @@ -117,6 +116,11 @@ from reflex.utils.misc import run_in_thread from reflex.utils.token_manager import RedisTokenManager, TokenManager +if sys.version_info < (3, 13): + from typing_extensions import deprecated +else: + from warnings import deprecated + if TYPE_CHECKING: from reflex_core.vars import Var diff --git a/reflex/istate/manager/token.py b/reflex/istate/manager/token.py index 928cb69b048..fcf1c07de23 100644 --- a/reflex/istate/manager/token.py +++ b/reflex/istate/manager/token.py @@ -1,5 +1,7 @@ """Representation of a StateManager token.""" +from __future__ import annotations + import dataclasses import pickle from typing import TYPE_CHECKING, BinaryIO, Generic, Self, TypeVar @@ -105,7 +107,7 @@ class BaseStateToken(StateToken["BaseState"]): This token type implies subtree hierarchy population and other semantic checks. """ - def with_cls(self, cls: type["BaseState"]) -> Self: + def with_cls(self, cls: type[BaseState]) -> Self: """Return a new token with the cls field updated to the provided class. Args: @@ -174,7 +176,7 @@ def get_and_reset_touched_state(cls, state: BaseState) -> bool: @classmethod def from_legacy_token( - cls, legacy_token: str, root_state: "type[BaseState] | None" + cls, legacy_token: str, root_state: type[BaseState] | None ) -> Self: """Create a BaseStateToken from a legacy token string. diff --git a/tests/units/reflex_core/_internal/event/test_event_processor.py b/tests/units/reflex_core/_internal/event/test_event_processor.py index 0e21e4b8e1c..5767062f157 100644 --- a/tests/units/reflex_core/_internal/event/test_event_processor.py +++ b/tests/units/reflex_core/_internal/event/test_event_processor.py @@ -9,6 +9,7 @@ from reflex_core._internal.event.processor.event_processor import ( DrainTimeoutManager, EventProcessor, + QueueShutDown, ) from reflex_core._internal.registry import RegistrationContext @@ -176,7 +177,7 @@ async def test_enqueue_after_stop_raises(processor: EventProcessor): processor.configure() async with processor: pass - with pytest.raises(RuntimeError, match="not running"): + with pytest.raises(QueueShutDown, match="not running"): await processor.enqueue("tok", *Event.from_event_type(noop_event())) @@ -187,7 +188,7 @@ async def test_enqueue_before_start_raises(processor: EventProcessor): processor: The event processor fixture. """ processor.configure() - with pytest.raises(RuntimeError, match="not running"): + with pytest.raises(QueueShutDown, match="not running"): await processor.enqueue("tok", *Event.from_event_type(noop_event())) From 0fff344a4d2c305a6bf7ff1b4c4391fa3a5d935b Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 11:51:02 -0700 Subject: [PATCH 44/81] py3.10 Self compat --- .../reflex-core/src/reflex_core/_internal/context/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/context/base.py b/packages/reflex-core/src/reflex_core/_internal/context/base.py index 23367de66e9..9d75b05933b 100644 --- a/packages/reflex-core/src/reflex_core/_internal/context/base.py +++ b/packages/reflex-core/src/reflex_core/_internal/context/base.py @@ -2,7 +2,9 @@ import dataclasses from contextvars import ContextVar, Token -from typing import ClassVar, Self +from typing import ClassVar + +from typing_extensions import Self @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) From 8729b596a4429c1faf157be3bba74deaa5671ffe Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 12:25:59 -0700 Subject: [PATCH 45/81] py3.10: typing_extensions Self --- .../reflex_core/_internal/event/processor/event_processor.py | 3 ++- packages/reflex-core/src/reflex_core/_internal/registry.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py index aaa10d79df9..e180145901c 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py @@ -11,9 +11,10 @@ import traceback from collections.abc import AsyncGenerator, Callable, Mapping from contextvars import Token, copy_context -from typing import TYPE_CHECKING, Any, Self +from typing import TYPE_CHECKING, Any import rich.markup +from typing_extensions import Self from reflex.app_mixins.middleware import MiddlewareMixin from reflex.istate.manager import StateManager diff --git a/packages/reflex-core/src/reflex_core/_internal/registry.py b/packages/reflex-core/src/reflex_core/_internal/registry.py index c43177a9e89..686e64785d0 100644 --- a/packages/reflex-core/src/reflex_core/_internal/registry.py +++ b/packages/reflex-core/src/reflex_core/_internal/registry.py @@ -3,7 +3,9 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING, Self +from typing import TYPE_CHECKING + +from typing_extensions import Self from reflex_core._internal.context.base import BaseContext from reflex_core.utils.exceptions import StateValueError From b1ad918437371854f1d3df390dbf84b7dcfd7f43 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 12:30:25 -0700 Subject: [PATCH 46/81] ugh more py3.10 Self compat --- reflex/istate/manager/token.py | 4 +++- reflex/istate/shared.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/reflex/istate/manager/token.py b/reflex/istate/manager/token.py index fcf1c07de23..31b3a8a80db 100644 --- a/reflex/istate/manager/token.py +++ b/reflex/istate/manager/token.py @@ -4,7 +4,9 @@ import dataclasses import pickle -from typing import TYPE_CHECKING, BinaryIO, Generic, Self, TypeVar +from typing import TYPE_CHECKING, BinaryIO, Generic, TypeVar + +from typing_extensions import Self from reflex.utils import console diff --git a/reflex/istate/shared.py b/reflex/istate/shared.py index df238672007..e55fbba7f14 100644 --- a/reflex/istate/shared.py +++ b/reflex/istate/shared.py @@ -3,12 +3,13 @@ import asyncio import contextlib from collections.abc import AsyncIterator -from typing import Self, TypeVar +from typing import TypeVar from reflex_core.constants import ROUTER_DATA from reflex_core.event import Event, get_hydrate_event from reflex_core.utils import console from reflex_core.utils.exceptions import ReflexRuntimeError +from typing_extensions import Self from reflex.istate.manager.token import BaseStateToken from reflex.state import BaseState, State, _override_base_method From ed82a6c17efcee14c4f37a4a804c3aa885ce2d0b Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 12:34:45 -0700 Subject: [PATCH 47/81] fix reflex_core -> reflex import --- .../src/reflex_components_core/core/_upload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py index 426c63f1fc1..412ac2ecc41 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py @@ -21,8 +21,6 @@ from starlette.responses import JSONResponse, Response, StreamingResponse from typing_extensions import Self -from reflex.state import StateUpdate - if TYPE_CHECKING: from reflex_core.utils.types import Receive, Scope, Send @@ -462,6 +460,8 @@ async def _upload_buffered_file( from reflex_core.event import Event from reflex_core.utils.exceptions import UploadValueError + from reflex.state import StateUpdate + try: form_data = await request.form() except ClientDisconnect: From 0f05b66d3ba2bbcbb7f63604952b156456845480 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 12:58:42 -0700 Subject: [PATCH 48/81] Handle py3.11 compatible queue shutdown better Join the queue that we None'd out to make sure all the tasks have flushed. Fix some import issues being reported on my branch. --- .../_internal/event/processor/event_processor.py | 16 ++++++++++++---- .../src/reflex_core/components/tags/__init__.py | 2 ++ .../src/reflex_core/constants/__init__.py | 1 + tests/units/test_config.py | 2 +- tests/units/test_state.py | 1 + 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py index e180145901c..bb1414c0918 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py @@ -283,13 +283,14 @@ async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): await self._stop_tasks(timeout=remaining_time) # Cancel queue processing now that all tasks have been cancelled. + queue = self._queue if self._queue is not None: if sys.version_info >= (3, 13): self._queue.shutdown() self._queue = None with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): if remaining_time > 0: - await self.join(timeout=remaining_time) + await self.join(timeout=remaining_time, queue=queue) with drain_timeout as remaining_time, contextlib.suppress(asyncio.TimeoutError): # Stop all tasks again now that the queue is shut down, no additional events can be queued. await self._stop_tasks(timeout=remaining_time) @@ -313,16 +314,23 @@ async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: future.cancel() self._futures.clear() - async def join(self, timeout: float | None = None) -> None: + async def join( + self, timeout: float | None = None, queue: asyncio.Queue | None = None + ) -> None: """Wait for the event processor to finish processing all events in the queue. Args: timeout: An optional amount of time in seconds to wait for the queue to drain before returning. If None, this method will wait indefinitely until the queue is fully drained. + queue: An optional queue to wait for instead of the processor's main + queue. This can be used to wait for a specific queue to drain, such + as when using a separate queue for testing. """ - if self._queue is not None: - await asyncio.wait_for(self._queue.join(), timeout=timeout) + if queue is None: + queue = self._queue + if queue is not None: + await asyncio.wait_for(queue.join(), timeout=timeout) def _ensure_queue_task(self) -> asyncio.Queue[EventQueueEntry]: """Ensure the queue processing task is running. diff --git a/packages/reflex-core/src/reflex_core/components/tags/__init__.py b/packages/reflex-core/src/reflex_core/components/tags/__init__.py index 993da11fe69..c5003ff4aab 100644 --- a/packages/reflex-core/src/reflex_core/components/tags/__init__.py +++ b/packages/reflex-core/src/reflex_core/components/tags/__init__.py @@ -4,3 +4,5 @@ from .iter_tag import IterTag from .match_tag import MatchTag from .tag import Tag + +__all__ = ["CondTag", "IterTag", "MatchTag", "Tag"] diff --git a/packages/reflex-core/src/reflex_core/constants/__init__.py b/packages/reflex-core/src/reflex_core/constants/__init__.py index 69f79271d9e..41984e895e0 100644 --- a/packages/reflex-core/src/reflex_core/constants/__init__.py +++ b/packages/reflex-core/src/reflex_core/constants/__init__.py @@ -108,6 +108,7 @@ "PyprojectToml", "ReactRouter", "Reflex", + "ReflexHostingCLI", "RequirementsTxt", "RouteArgType", "RouteRegex", diff --git a/tests/units/test_config.py b/tests/units/test_config.py index ee64d6e293a..e72da73d48c 100644 --- a/tests/units/test_config.py +++ b/tests/units/test_config.py @@ -246,7 +246,7 @@ def test_replace_defaults( exp_config_values: The expected config values. """ mock_os_env = os.environ.copy() - monkeypatch.setattr(reflex_core.config.os, "environ", mock_os_env) + monkeypatch.setattr(reflex_core.config.os, "environ", mock_os_env) # pyright: ignore[reportPrivateImportUsage] mock_os_env.update({k: str(v) for k, v in env_vars.items()}) c = rx.Config(app_name="a", **config_kwargs) c._set_persistent(**set_persistent_vars) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 011ab83f533..83260b6dfb8 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -3414,6 +3414,7 @@ async def test_setvar( ]) async with mock_base_state_event_processor as processor: await processor.enqueue(token, *events) + await processor.join(1) if environment.REFLEX_OPLOCK_ENABLED.get(): await state_manager.close() From 75e160aee0cab8655f06220997bde5fc04cf79fc Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 13:31:42 -0700 Subject: [PATCH 49/81] state.js: pump the queue in processEvent If there is another event to process, chain to processEvent --- .../reflex_core/.templates/web/utils/state.js | 50 +++++++------------ 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js b/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js index 426fe116f61..fd2b5e5741b 100644 --- a/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js +++ b/packages/reflex-core/src/reflex_core/.templates/web/utils/state.js @@ -201,14 +201,12 @@ function urlFrom(string) { * @param socket The socket object to send the event on. * @param navigate The navigate function from useNavigate * @param params The params object from useParams - * - * @returns True if the event was sent, false if it was handled locally. */ export const applyEvent = async (event, socket, navigate, params) => { // Handle special events if (event.name == "_redirect") { if ((event.payload.path ?? undefined) === undefined) { - return false; + return; } if (event.payload.external) { window.open( @@ -216,7 +214,7 @@ export const applyEvent = async (event, socket, navigate, params) => { "_blank", "noopener" + (event.payload.popup ? ",popup" : ""), ); - return false; + return; } const url = urlFrom(event.payload.path); let pathname = event.payload.path; @@ -224,7 +222,7 @@ export const applyEvent = async (event, socket, navigate, params) => { if (url.host !== window.location.host) { // External URL window.location.assign(event.payload.path); - return false; + return; } else { pathname = url.pathname + url.search + url.hash; } @@ -234,37 +232,37 @@ export const applyEvent = async (event, socket, navigate, params) => { } else { navigate(pathname); } - return false; + return; } if (event.name == "_remove_cookie") { cookies.remove(event.payload.key, { ...event.payload.options }); queueEventIfSocketExists(initialEvents(), socket, navigate, params); - return false; + return; } if (event.name == "_clear_local_storage") { localStorage.clear(); queueEventIfSocketExists(initialEvents(), socket, navigate, params); - return false; + return; } if (event.name == "_remove_local_storage") { localStorage.removeItem(event.payload.key); queueEventIfSocketExists(initialEvents(), socket, navigate, params); - return false; + return; } if (event.name == "_clear_session_storage") { sessionStorage.clear(); queueEventIfSocketExists(initialEvents(), socket, navigate, params); - return false; + return; } if (event.name == "_remove_session_storage") { sessionStorage.removeItem(event.payload.key); queueEventIfSocketExists(initialEvents(), socket, navigate, params); - return false; + return; } if (event.name == "_download") { @@ -283,7 +281,7 @@ export const applyEvent = async (event, socket, navigate, params) => { a.download = event.payload.filename; a.click(); a.remove(); - return false; + return; } if (event.name == "_set_focus") { @@ -297,7 +295,7 @@ export const applyEvent = async (event, socket, navigate, params) => { } else { current.focus(); } - return false; + return; } if (event.name == "_blur_focus") { @@ -311,7 +309,7 @@ export const applyEvent = async (event, socket, navigate, params) => { } else { current.blur(); } - return false; + return; } if (event.name == "_set_value") { @@ -320,7 +318,7 @@ export const applyEvent = async (event, socket, navigate, params) => { if (ref.current) { ref.current.value = event.payload.value; } - return false; + return; } if ( @@ -346,7 +344,7 @@ export const applyEvent = async (event, socket, navigate, params) => { window.onerror(e.message, null, null, null, e); } } - return false; + return; } if (event.name == "_call_script" || event.name == "_call_function") { @@ -373,7 +371,7 @@ export const applyEvent = async (event, socket, navigate, params) => { window.onerror(e.message, null, null, null, e); } } - return false; + return; } // Update token and router data (if missing). @@ -401,10 +399,7 @@ export const applyEvent = async (event, socket, navigate, params) => { // Send the event to the server. if (socket) { socket.emit("event", event); - return true; } - - return false; }; /** @@ -413,11 +408,8 @@ export const applyEvent = async (event, socket, navigate, params) => { * @param socket The socket object to send the response event(s) on. * @param navigate The navigate function from React Router * @param params The params object from React Router - * - * @returns Whether the event was sent. */ export const applyRestEvent = async (event, socket, navigate, params) => { - let eventSent = false; if (event.handler === "uploadFiles") { // Start upload, but do not wait for it, which would block other events. uploadFiles( @@ -431,9 +423,7 @@ export const applyRestEvent = async (event, socket, navigate, params) => { getBackendURL, getToken, ); - return false; } - return eventSent; }; /** @@ -494,16 +484,14 @@ export const processEvent = async (socket, navigate, params) => { // Apply the next event in the queue. const event = event_queue.shift(); - let eventSent = false; // Process events with handlers via REST and all others via websockets. if (event.handler) { - eventSent = await applyRestEvent(event, socket, navigate, params); + await applyRestEvent(event, socket, navigate, params); } else { - eventSent = await applyEvent(event, socket, navigate, params); + await applyEvent(event, socket, navigate, params); } - if (!eventSent) { - // recursively call processEvent to drain the queue, since there is - // no state update to trigger the useEffect event loop. + // Process any remaining events. + if (event_queue.length > 0) { await processEvent(socket, navigate, params); } }; From f60c4bdfdb3f26181b97f1fc66ebae10cacb6059 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 13:33:23 -0700 Subject: [PATCH 50/81] AppHarness: pre-register SharedState so it's available in the base RegistrationContext --- reflex/testing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/reflex/testing.py b/reflex/testing.py index 4c693b8659e..6fae052a675 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -42,6 +42,7 @@ import reflex.utils.prerequisites import reflex.utils.processes from reflex.experimental.memo import EXPERIMENTAL_MEMOS +from reflex.istate.shared import SharedState as SharedState # To register it. from reflex.state import reload_state_module from reflex.utils import console, js_runtimes from reflex.utils.export import export From 43f97ab240f1559788937e5444f37bb0daa64770 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 13:35:04 -0700 Subject: [PATCH 51/81] BaseStateEventProcessor: emit deltas before enqueuing events --- .../event/processor/base_state_processor.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py index 4b8cf449ed4..1bab71a481c 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py @@ -19,8 +19,8 @@ from reflex_core._internal.event.processor.event_processor import ( EventProcessor, EventQueueEntry, - RegisteredEventHandler, ) +from reflex_core._internal.registry import RegisteredEventHandler from reflex_core.utils.format import format_event_handler if TYPE_CHECKING: @@ -188,6 +188,15 @@ async def chain_updates( ctx = EventContext.get() + if root_state is not None: + # Emit deltas first, so any frontend events are processed with the latest state. + try: + delta = await root_state._get_resolved_delta() + if delta: + await ctx.emit_delta(delta) + finally: + root_state._clean() + # Convert valid EventHandler and EventSpec into Event if fixed_events := Event.from_event_type( _check_valid_yield(events, handler_name=handler_name), @@ -198,15 +207,6 @@ async def chain_updates( # Backend events. await ctx.enqueue(*(e for e in fixed_events if not e.name.startswith("_"))) - if root_state is not None: - # Get the delta after processing the event. - try: - delta = await root_state._get_resolved_delta() - if delta: - await ctx.emit_delta(delta) - finally: - root_state._clean() - async def process_event( handler: EventHandler, From 8987a7e5a7e3db13e94ea8c2aaa3a613251fafbe Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 16:18:33 -0700 Subject: [PATCH 52/81] set RegistrationContext in ASGI middleware Ensure that the RegistrationContext associated with the App is set in the task that handles an ASGI request. --- reflex/app.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/reflex/app.py b/reflex/app.py index 25c80567924..e3a97c5b014 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -41,6 +41,7 @@ BaseStateEventProcessor, EventProcessor, ) +from reflex_core._internal.registry import RegistrationContext from reflex_core.components.component import ( CUSTOM_COMPONENTS, Component, @@ -401,6 +402,11 @@ class App(MiddlewareMixin, LifespanMixin): # The processor queue for handling events. _event_processor: EventProcessor | None = None + # Store the RegistrationContext to apply inside the ASGI callable task. + _registration_context: RegistrationContext = dataclasses.field( + default_factory=RegistrationContext.ensure_context + ) + frontend_exception_handler: Callable[[Exception], None] = ( default_frontend_exception_handler ) @@ -575,8 +581,32 @@ async def modified_send(message: Message): # Ensure the event processor starts and stops with the server. self.register_lifespan_task(self._setup_event_processor) + def _registration_context_middleware(self, app: ASGIApp) -> ASGIApp: + """Ensure the RegistrationContext is attached to the ASGI app. + + Args: + app: The ASGI app to attach the middleware to. + + Returns: + The ASGI app with the middleware attached. + """ + + async def registration_context_middleware( + scope: Scope, receive: Receive, send: Send + ): + if self._registration_context is not None: + RegistrationContext.set(self._registration_context) + await app(scope, receive, send) + + return registration_context_middleware + @contextlib.asynccontextmanager async def _setup_event_processor(self) -> AsyncIterator[None]: + # Make sure the RegistrationContext is attached. + if self._api is not None: + self._api.add_middleware( + self._registration_context_middleware, + ) # Create the event processor. self._event_processor = BaseStateEventProcessor( middleware=self, backend_exception_handler=self.backend_exception_handler From 5d609e461aaf237d63b24837df84ee5af3f56f2e Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 21:35:01 -0700 Subject: [PATCH 53/81] EventFuture: tracks execution of chained events roll up chained events for use with enqueue_stream_deltas for HTTP invocation of events. --- .../_internal/event/processor/__init__.py | 4 + .../event/processor/event_processor.py | 84 ++--- .../_internal/event/processor/future.py | 99 ++++++ .../_internal/event/processor/timeout.py | 52 +++ .../test_base_state_processor.py | 0 .../{ => processor}/test_event_processor.py | 26 -- .../_internal/event/processor/test_future.py | 301 ++++++++++++++++++ .../_internal/event/processor/test_timeout.py | 29 ++ 8 files changed, 514 insertions(+), 81 deletions(-) create mode 100644 packages/reflex-core/src/reflex_core/_internal/event/processor/future.py create mode 100644 packages/reflex-core/src/reflex_core/_internal/event/processor/timeout.py rename tests/units/reflex_core/_internal/event/{ => processor}/test_base_state_processor.py (100%) rename tests/units/reflex_core/_internal/event/{ => processor}/test_event_processor.py (94%) create mode 100644 tests/units/reflex_core/_internal/event/processor/test_future.py create mode 100644 tests/units/reflex_core/_internal/event/processor/test_timeout.py diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/__init__.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/__init__.py index df0d957c186..b3463d4f472 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/__init__.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/__init__.py @@ -7,9 +7,13 @@ EventProcessor, EventQueueEntry, ) +from reflex_core._internal.event.processor.future import EventFuture +from reflex_core._internal.event.processor.timeout import DrainTimeoutManager __all__ = [ "BaseStateEventProcessor", + "DrainTimeoutManager", + "EventFuture", "EventProcessor", "EventQueueEntry", ] diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py index bb1414c0918..f6a32d5cc3f 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py @@ -21,6 +21,8 @@ from reflex.utils import console from reflex_core._internal.event.context import EventContext from reflex_core._internal.event.processor.compat import as_completed +from reflex_core._internal.event.processor.future import EventFuture +from reflex_core._internal.event.processor.timeout import DrainTimeoutManager from reflex_core._internal.registry import RegisteredEventHandler, RegistrationContext if TYPE_CHECKING: @@ -38,47 +40,6 @@ class QueueShutDown(Exception): # noqa: N818 """Exception raised when trying to put an item into a shut down queue.""" -@dataclasses.dataclass(kw_only=True, slots=True) -class DrainTimeoutManager: - """Manages an optional combined timeout over multiple calls. - - Each time the context is entered, yield the remaining time until the - overall timeout is reached, or 0 if the timeout has already been reached. - This allows multiple operations to share a single overall timeout, even if - they are not executed sequentially. - """ - - drain_deadline: float | None = None - - @classmethod - def with_timeout(cls, timeout: float | None) -> DrainTimeoutManager: - """Create a DrainTimeoutManager with a specified timeout. - - Args: - timeout: The overall amount of time in seconds to wait. - - Returns: - A DrainTimeoutManager instance with the drain deadline set. - """ - if timeout is None: - return cls(drain_deadline=None) - return cls(drain_deadline=time.time() + timeout) - - def __enter__(self) -> float: - """Enter the context and yield the remaining time. - - Returns: - The remaining time in seconds until the overall timeout is reached, or 0 if the timeout - has already been reached. - """ - if self.drain_deadline is not None: - return max(0, self.drain_deadline - time.time()) - return 0 - - def __exit__(self, *exc_info) -> None: - """Exit the context. No cleanup necessary.""" - - @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) class EventQueueEntry: """An entry in the event queue.""" @@ -120,7 +81,7 @@ class EventProcessor: _tasks: dict[str, asyncio.Task] = dataclasses.field( default_factory=dict, init=False ) - _futures: dict[str, asyncio.Future[Any]] = dataclasses.field( + _futures: dict[str, EventFuture] = dataclasses.field( default_factory=dict, init=False ) @@ -359,7 +320,7 @@ def _ensure_queue_task(self) -> asyncio.Queue[EventQueueEntry]: async def enqueue( self, token: str, *events: Event, ev_ctx: EventContext | None = None - ) -> asyncio.Future[Any]: + ) -> EventFuture: """Enqueue an event to be processed. Args: @@ -368,7 +329,7 @@ async def enqueue( ev_ctx: The event context to use for these events. Returns: - A Future that resolves to the result of the associated task. + An EventFuture that resolves to the result of the associated task. """ if ev_ctx is None: try: @@ -380,13 +341,18 @@ async def enqueue( msg = "Event processor is not running, call .start(...) first." raise RuntimeError(msg) from le queue = self._ensure_queue_task() - future: asyncio.Future[Any] = asyncio.get_running_loop().create_future() + tracked = EventFuture.create() txid = ev_ctx.txid - self._futures[txid] = future - future.add_done_callback(lambda f: self._on_future_done(txid, f)) + self._futures[txid] = tracked + tracked.add_done_callback(lambda f: self._on_future_done(txid, f)) + # If this context has a parent, register as a child of the parent's future. + if ev_ctx.parent_txid is not None: + parent_tracked = self._futures.get(ev_ctx.parent_txid) + if parent_tracked is not None: + parent_tracked.add_child(tracked) for event in events: await queue.put(EventQueueEntry(event=event, ctx=ev_ctx)) - return future + return tracked async def enqueue_stream_delta( self, @@ -437,31 +403,33 @@ async def _emit_delta_impl( emit_delta_impl=_emit_delta_impl, ), ) - waiting_for = {task_future, asyncio.create_task(deltas.get())} + all_task_futures = asyncio.create_task(task_future.wait_all()) + waiting_for = {all_task_futures, asyncio.create_task(deltas.get())} try: - while not task_future.done() or not deltas.empty(): + while not all_task_futures.done() or not deltas.empty(): with contextlib.suppress(asyncio.TimeoutError): async for result in as_completed( waiting_for, timeout=1, ): waiting_for.remove(result) - if result is not task_future: + if result is not all_task_futures: yield await result waiting_for.add(asyncio.create_task(deltas.get())) break finally: for future in waiting_for: future.cancel() - # Raise any exceptions for the caller. - await task_future + # Raise any exceptions for the caller, waiting for all chained events. + await task_future.wait_all() def _on_future_done(self, txid: str, future: asyncio.Future) -> None: """Callback invoked when an enqueued future completes. If the future was cancelled externally, cancel the running task - if one exists. If the task has not started yet, ``_process_queue`` - will check the future and skip it when the entry is dequeued. + and all child futures. If the task has not started yet, + ``_process_queue`` will check the future and skip it when the + entry is dequeued. Args: txid: The transaction id associated with the future. @@ -469,6 +437,11 @@ def _on_future_done(self, txid: str, future: asyncio.Future) -> None: """ if not future.cancelled(): return + # Cascade cancellation to all child futures. + tracked = self._futures.get(txid) + if tracked is not None: + for child in tracked.children: + child.cancel() task = self._tasks.get(txid) if task is not None: task.cancel() @@ -615,6 +588,7 @@ def _finish_task(self, task: asyncio.Task): __all__ = [ + "EventFuture", "EventProcessor", "EventQueueEntry", ] diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py new file mode 100644 index 00000000000..fc4ee6e84e8 --- /dev/null +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py @@ -0,0 +1,99 @@ +"""EventFuture: a future that tracks child futures for hierarchical event processing.""" + +from __future__ import annotations + +import asyncio +import contextlib +from typing import Any + + +class EventFuture(asyncio.Future): + """A future that tracks child futures for hierarchical event processing. + + When events are chained (a handler enqueues additional events), the child + futures are tracked so callers can wait for the entire chain to complete. + """ + + children: list[EventFuture] + + def __init__(self, *, loop: asyncio.AbstractEventLoop | None = None) -> None: + super().__init__(loop=loop) + self.children = [] + + @classmethod + def create(cls, loop: asyncio.AbstractEventLoop | None = None) -> EventFuture: + """Create a new EventFuture on the given or running event loop. + + Args: + loop: The event loop to use. Defaults to the running loop. + + Returns: + A new EventFuture instance. + """ + if loop is None: + loop = asyncio.get_running_loop() + return cls(loop=loop) + + def add_child(self, child: EventFuture) -> None: + """Add a child future to this tracked future. + + Args: + child: The child EventFuture to add. + + Raises: + RuntimeError: If this future is already done. + """ + if self.done(): + msg = "Cannot add a child to an EventFuture that is already done." + raise RuntimeError(msg) + self.children.append(child) + + def all_done(self) -> bool: + """Check if this future and all descendant futures are done. + + Returns: + True if this future and all descendants have completed. + """ + if not self.done(): + return False + return all(child.all_done() for child in self.children) + + async def wait_all(self) -> Any: + """Wait for this future and all descendant futures to complete. + + Walks the children list by index so that children added after + iteration begins are still awaited. + + Child exceptions are suppressed since they are handled independently + by the event processor's _finish_task callback. + + Returns: + The result of this future. + """ + result = await self + i = 0 + while i < len(self.children): + child = self.children[i] + with contextlib.suppress(Exception, asyncio.CancelledError): + await child.wait_all() + i += 1 + return result + + def cancel(self, msg: object = None) -> bool: + """Cancel this future and all descendant futures. + + Args: + msg: Optional cancellation message. + + Returns: + True if the future was successfully cancelled. + """ + result = super().cancel(msg) + for child in self.children: + child.cancel(msg) + return result + + +__all__ = [ + "EventFuture", +] diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/timeout.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/timeout.py new file mode 100644 index 00000000000..a527221ad95 --- /dev/null +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/timeout.py @@ -0,0 +1,52 @@ +"""DrainTimeoutManager: manages an optional combined timeout over multiple calls.""" + +from __future__ import annotations + +import dataclasses +import time + + +@dataclasses.dataclass(kw_only=True, slots=True) +class DrainTimeoutManager: + """Manages an optional combined timeout over multiple calls. + + Each time the context is entered, yield the remaining time until the + overall timeout is reached, or 0 if the timeout has already been reached. + This allows multiple operations to share a single overall timeout, even if + they are not executed sequentially. + """ + + drain_deadline: float | None = None + + @classmethod + def with_timeout(cls, timeout: float | None) -> DrainTimeoutManager: + """Create a DrainTimeoutManager with a specified timeout. + + Args: + timeout: The overall amount of time in seconds to wait. + + Returns: + A DrainTimeoutManager instance with the drain deadline set. + """ + if timeout is None: + return cls(drain_deadline=None) + return cls(drain_deadline=time.time() + timeout) + + def __enter__(self) -> float: + """Enter the context and yield the remaining time. + + Returns: + The remaining time in seconds until the overall timeout is reached, or 0 if the timeout + has already been reached. + """ + if self.drain_deadline is not None: + return max(0, self.drain_deadline - time.time()) + return 0 + + def __exit__(self, *exc_info) -> None: + """Exit the context. No cleanup necessary.""" + + +__all__ = [ + "DrainTimeoutManager", +] diff --git a/tests/units/reflex_core/_internal/event/test_base_state_processor.py b/tests/units/reflex_core/_internal/event/processor/test_base_state_processor.py similarity index 100% rename from tests/units/reflex_core/_internal/event/test_base_state_processor.py rename to tests/units/reflex_core/_internal/event/processor/test_base_state_processor.py diff --git a/tests/units/reflex_core/_internal/event/test_event_processor.py b/tests/units/reflex_core/_internal/event/processor/test_event_processor.py similarity index 94% rename from tests/units/reflex_core/_internal/event/test_event_processor.py rename to tests/units/reflex_core/_internal/event/processor/test_event_processor.py index 5767062f157..f55ca5b1170 100644 --- a/tests/units/reflex_core/_internal/event/test_event_processor.py +++ b/tests/units/reflex_core/_internal/event/processor/test_event_processor.py @@ -1,13 +1,11 @@ """Tests for EventProcessor lifecycle, task management, and error handling.""" import asyncio -import time from typing import Any import pytest from reflex_core._internal.event.context import EventContext from reflex_core._internal.event.processor.event_processor import ( - DrainTimeoutManager, EventProcessor, QueueShutDown, ) @@ -374,30 +372,6 @@ async def test_error_does_not_stop_queue( assert _CALL_LOG == [{"value": "after_error"}] -def test_drain_timeout_no_timeout(): - """DrainTimeoutManager with no timeout returns 0.""" - dtm = DrainTimeoutManager.with_timeout(None) - with dtm as remaining: - assert remaining == 0 - - -def test_drain_timeout_decreases(): - """DrainTimeoutManager remaining time decreases across re-entries.""" - dtm = DrainTimeoutManager.with_timeout(10.0) - with dtm as first: - assert 9.5 < first <= 10.0 - time.sleep(0.1) - with dtm as second: - assert second < first - - -def test_drain_timeout_expired_returns_zero(): - """DrainTimeoutManager with an already-expired timeout returns 0.""" - dtm = DrainTimeoutManager.with_timeout(0) - with dtm as remaining: - assert remaining == 0 - - async def test_chained_event_processed(token: str): """An event handler that enqueues another event via ctx.enqueue succeeds. diff --git a/tests/units/reflex_core/_internal/event/processor/test_future.py b/tests/units/reflex_core/_internal/event/processor/test_future.py new file mode 100644 index 00000000000..6597b8c5673 --- /dev/null +++ b/tests/units/reflex_core/_internal/event/processor/test_future.py @@ -0,0 +1,301 @@ +"""Tests for EventFuture.""" + +import asyncio + +import pytest +from reflex_core._internal.event.processor.future import EventFuture + + +@pytest.mark.asyncio +async def test_create_uses_running_loop(): # noqa: RUF029 + """EventFuture.create() defaults to the running event loop.""" + running_loop = asyncio.get_running_loop() + f = EventFuture.create() + assert isinstance(f, EventFuture) + assert f.get_loop() is running_loop + assert f.children == [] + assert not f.done() + + +@pytest.mark.asyncio +async def test_create_with_explicit_loop(): # noqa: RUF029 + """EventFuture.create(loop=...) uses the given (non-default) loop.""" + other_loop = asyncio.new_event_loop() + try: + f = EventFuture.create(loop=other_loop) + assert isinstance(f, EventFuture) + assert f.get_loop() is other_loop + assert f.get_loop() is not asyncio.get_running_loop() + finally: + other_loop.close() + + +@pytest.mark.asyncio +async def test_add_child_multiple(): # noqa: RUF029 + """add_child can be called multiple times.""" + parent = EventFuture.create() + children = [EventFuture.create() for _ in range(3)] + for c in children: + parent.add_child(c) + assert parent.children == children + + +@pytest.mark.asyncio +async def test_add_child_to_done_future_raises(): # noqa: RUF029 + """add_child raises RuntimeError if the parent future is already done.""" + parent = EventFuture.create() + parent.set_result(None) + child = EventFuture.create() + with pytest.raises(RuntimeError, match="already done"): + parent.add_child(child) + + +@pytest.mark.asyncio +async def test_add_child_to_cancelled_future_raises(): # noqa: RUF029 + """add_child raises RuntimeError if the parent future is cancelled.""" + parent = EventFuture.create() + parent.cancel() + child = EventFuture.create() + with pytest.raises(RuntimeError, match="already done"): + parent.add_child(child) + + +@pytest.mark.asyncio +async def test_all_done_no_children(): # noqa: RUF029 + """all_done is True when the future is resolved and has no children.""" + f = EventFuture.create() + assert not f.all_done() + f.set_result(42) + assert f.all_done() + + +@pytest.mark.asyncio +async def test_all_done_with_pending_child(): # noqa: RUF029 + """all_done is False when a child is still pending.""" + parent = EventFuture.create() + child = EventFuture.create() + parent.add_child(child) + parent.set_result(None) + assert not parent.all_done() + child.set_result(None) + assert parent.all_done() + + +@pytest.mark.asyncio +async def test_all_done_nested(): # noqa: RUF029 + """all_done checks the full descendant tree.""" + root = EventFuture.create() + child = EventFuture.create() + grandchild = EventFuture.create() + root.add_child(child) + child.add_child(grandchild) + + root.set_result(None) + child.set_result(None) + # grandchild still pending + assert not root.all_done() + + grandchild.set_result(None) + assert root.all_done() + + +@pytest.mark.asyncio +async def test_all_done_with_cancelled_child(): # noqa: RUF029 + """all_done is True when all children are cancelled (done).""" + parent = EventFuture.create() + child = EventFuture.create() + parent.add_child(child) + parent.set_result(None) + child.cancel() + assert parent.all_done() + + +@pytest.mark.asyncio +async def test_all_done_with_exception_child(): # noqa: RUF029 + """all_done is True when a child has an exception (still done).""" + parent = EventFuture.create() + child = EventFuture.create() + parent.add_child(child) + parent.set_result(None) + child.set_exception(ValueError("boom")) + assert parent.all_done() + + +@pytest.mark.asyncio +async def test_wait_all_returns_result(): + """wait_all returns the result of the root future.""" + f = EventFuture.create() + f.set_result(42) + result = await f.wait_all() + assert result == 42 + + +@pytest.mark.asyncio +async def test_wait_all_waits_for_children(): + """wait_all waits for all children to complete.""" + parent = EventFuture.create() + child = EventFuture.create() + parent.add_child(child) + + async def resolve_later(): + await asyncio.sleep(0.01) + child.set_result("child_done") + + parent.set_result("parent_done") + task = asyncio.create_task(resolve_later()) + result = await parent.wait_all() + assert result == "parent_done" + assert child.done() + await task + + +@pytest.mark.asyncio +async def test_wait_all_waits_for_nested_children(): + """wait_all waits for grandchildren too.""" + root = EventFuture.create() + child = EventFuture.create() + grandchild = EventFuture.create() + root.add_child(child) + child.add_child(grandchild) + + async def resolve_chain(): + await asyncio.sleep(0.01) + child.set_result(None) + await asyncio.sleep(0.01) + grandchild.set_result(None) + + root.set_result("root") + task = asyncio.create_task(resolve_chain()) + result = await root.wait_all() + assert result == "root" + assert grandchild.done() + await task + + +@pytest.mark.asyncio +async def test_wait_all_suppresses_child_exceptions(): + """wait_all suppresses exceptions from children.""" + parent = EventFuture.create() + child = EventFuture.create() + parent.add_child(child) + + parent.set_result("ok") + child.set_exception(ValueError("child error")) + + # Should not raise + result = await parent.wait_all() + assert result == "ok" + + +@pytest.mark.asyncio +async def test_wait_all_suppresses_child_cancellation(): + """wait_all suppresses CancelledError from children.""" + parent = EventFuture.create() + child = EventFuture.create() + parent.add_child(child) + + parent.set_result("ok") + child.cancel() + + result = await parent.wait_all() + assert result == "ok" + + +@pytest.mark.asyncio +async def test_wait_all_children_added_during_iteration(): + """wait_all picks up children added while iterating (index-based walk).""" + parent = EventFuture.create() + child1 = EventFuture.create() + parent.add_child(child1) + parent.set_result("done") + + # child2 will be added to child1 after child1 resolves, + # simulating a chained event that enqueues more events. + child2 = EventFuture.create() + + async def resolve_and_chain(): + await asyncio.sleep(0.01) + child1.add_child(child2) + child1.set_result(None) + await asyncio.sleep(0.01) + child2.set_result(None) + + task = asyncio.create_task(resolve_and_chain()) + await parent.wait_all() + assert child2.done() + await task + + +@pytest.mark.asyncio +async def test_cancel_no_children(): # noqa: RUF029 + """Cancel cancels the future itself.""" + f = EventFuture.create() + assert f.cancel() + assert f.cancelled() + + +@pytest.mark.asyncio +async def test_cancel_cascades_to_children(): # noqa: RUF029 + """Cancel propagates to all children.""" + parent = EventFuture.create() + child1 = EventFuture.create() + child2 = EventFuture.create() + parent.add_child(child1) + parent.add_child(child2) + + parent.cancel() + assert parent.cancelled() + assert child1.cancelled() + assert child2.cancelled() + + +@pytest.mark.asyncio +async def test_cancel_cascades_to_grandchildren(): # noqa: RUF029 + """Cancel propagates through the full descendant tree.""" + root = EventFuture.create() + child = EventFuture.create() + grandchild = EventFuture.create() + root.add_child(child) + child.add_child(grandchild) + + root.cancel() + assert grandchild.cancelled() + + +@pytest.mark.asyncio +async def test_cancel_with_message(): # noqa: RUF029 + """Cancel passes the message to children.""" + parent = EventFuture.create() + child = EventFuture.create() + parent.add_child(child) + + parent.cancel("shutting down") + assert parent.cancelled() + assert child.cancelled() + with pytest.raises(asyncio.CancelledError, match="shutting down"): + parent.result() + with pytest.raises(asyncio.CancelledError, match="shutting down"): + child.result() + + +@pytest.mark.asyncio +async def test_cancel_already_done_child(): # noqa: RUF029 + """Cancel on a parent does not fail if a child is already resolved.""" + parent = EventFuture.create() + child = EventFuture.create() + parent.add_child(child) + child.set_result("already done") + + parent.cancel() + assert parent.cancelled() + # child was already done, cancel returns False but doesn't raise + assert not child.cancelled() + assert child.result() == "already done" + + +@pytest.mark.asyncio +async def test_cancel_already_done_parent_returns_false(): # noqa: RUF029 + """Cancel returns False if the parent is already resolved.""" + f = EventFuture.create() + f.set_result(None) + assert not f.cancel() diff --git a/tests/units/reflex_core/_internal/event/processor/test_timeout.py b/tests/units/reflex_core/_internal/event/processor/test_timeout.py new file mode 100644 index 00000000000..77472f4fd60 --- /dev/null +++ b/tests/units/reflex_core/_internal/event/processor/test_timeout.py @@ -0,0 +1,29 @@ +"""Tests for DrainTimeoutManager.""" + +import time + +from reflex_core._internal.event.processor.timeout import DrainTimeoutManager + + +def test_drain_timeout_no_timeout(): + """DrainTimeoutManager with no timeout returns 0.""" + dtm = DrainTimeoutManager.with_timeout(None) + with dtm as remaining: + assert remaining == 0 + + +def test_drain_timeout_decreases(): + """DrainTimeoutManager remaining time decreases across re-entries.""" + dtm = DrainTimeoutManager.with_timeout(10.0) + with dtm as first: + assert 9.5 < first <= 10.0 + time.sleep(0.1) + with dtm as second: + assert second < first + + +def test_drain_timeout_expired_returns_zero(): + """DrainTimeoutManager with an already-expired timeout returns 0.""" + dtm = DrainTimeoutManager.with_timeout(0) + with dtm as remaining: + assert remaining == 0 From be1a18d9d4c09580b668bbc275a5cafb25d05395 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Wed, 1 Apr 2026 23:05:00 -0700 Subject: [PATCH 54/81] Add BaseState to reflex_core.event namespace for docgen EventHandler has a BaseState field, so it needs this reference to generate the proper documentation (and basically not crash) --- packages/reflex-core/src/reflex_core/event.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/packages/reflex-core/src/reflex_core/event.py b/packages/reflex-core/src/reflex_core/event.py index 18b9868d61f..c89008c6901 100644 --- a/packages/reflex-core/src/reflex_core/event.py +++ b/packages/reflex-core/src/reflex_core/event.py @@ -2816,3 +2816,9 @@ def wrapper( event = EventNamespace event.event = event # pyright: ignore[reportAttributeAccessIssue] sys.modules[__name__] = event # pyright: ignore[reportArgumentType] + +# A reference to BaseState is needed for doc generation when resolving type +# hints, so add it to the namespace late to avoid circular import issues. +from reflex.state import BaseState # noqa: E402 + +event.BaseState = BaseState # pyright: ignore[reportAttributeAccessIssue] From 7f3271ee7d03530ce703fa15cfeac026e249760f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 2 Apr 2026 00:07:03 -0700 Subject: [PATCH 55/81] EventProcessor.enqueue only accepts a single Event Add new `enqueue_many` function for queuing multiple events. This allows each enqueued event to get its own EventContext/txid and Future/Task tracking. Update tests to wait for Future returned by `enqueue` to ensure processing has completed before making any assertions. --- .../event/processor/event_processor.py | 27 ++++-- tests/units/conftest.py | 2 +- .../processor/test_base_state_processor.py | 4 +- .../event/processor/test_event_processor.py | 36 ++++---- tests/units/test_model.py | 2 +- tests/units/test_state.py | 84 +++++++++++-------- 6 files changed, 89 insertions(+), 66 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py index f6a32d5cc3f..28175ce188b 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py @@ -9,7 +9,7 @@ import sys import time import traceback -from collections.abc import AsyncGenerator, Callable, Mapping +from collections.abc import AsyncGenerator, Callable, Mapping, Sequence from contextvars import Token, copy_context from typing import TYPE_CHECKING, Any @@ -156,7 +156,7 @@ async def emit_event(token: str, *events: Event) -> None: token="", parent_txid=None, state_manager=state_manager, - enqueue_impl=self.enqueue, + enqueue_impl=self.enqueue_many, emit_delta_impl=emit_delta_impl, emit_event_impl=emit_event_impl, ) @@ -319,14 +319,14 @@ def _ensure_queue_task(self) -> asyncio.Queue[EventQueueEntry]: return self._queue async def enqueue( - self, token: str, *events: Event, ev_ctx: EventContext | None = None + self, token: str, event: Event, ev_ctx: EventContext | None = None ) -> EventFuture: """Enqueue an event to be processed. Args: token: The client token associated with the event. - events: Remaining positional args are events to be enqueued. - ev_ctx: The event context to use for these events. + event: The event to be enqueued. + ev_ctx: The event context to use for this event. Returns: An EventFuture that resolves to the result of the associated task. @@ -350,10 +350,21 @@ async def enqueue( parent_tracked = self._futures.get(ev_ctx.parent_txid) if parent_tracked is not None: parent_tracked.add_child(tracked) - for event in events: - await queue.put(EventQueueEntry(event=event, ctx=ev_ctx)) + await queue.put(EventQueueEntry(event=event, ctx=ev_ctx)) return tracked + async def enqueue_many(self, token: str, *events: Event) -> Sequence[EventFuture]: + """Enqueue multiple events to be processed. + + Args: + token: The client token associated with the events. + events: Remaining positional args are events to be enqueued. + + Returns: + A list of EventFutures corresponding to each enqueued event. + """ + return [await self.enqueue(token, event) for event in events] + async def enqueue_stream_delta( self, token: str, @@ -575,6 +586,8 @@ def _finish_task(self, task: asyncio.Task): self._handle_backend_exception(ex, ev_ctx=task_ctx), name=f"reflex_backend_exception_handler|task=[{task.get_name()}]|{time.time()}", ) + if sys.version_info < (3, 12): + t._event_ctx = task_ctx # pyright: ignore[reportAttributeAccessIssue] t.add_done_callback(self._finish_task) return console.error( diff --git a/tests/units/conftest.py b/tests/units/conftest.py index 7a5f431e86b..cf750b8e7fd 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -364,7 +364,7 @@ async def emit_event_impl(token: str, *events: Event) -> None: # noqa: RUF029 yield EventContext( token="", state_manager=state_manager, - enqueue_impl=mock_base_state_event_processor_obj.enqueue, + enqueue_impl=mock_base_state_event_processor_obj.enqueue_many, emit_delta_impl=emit_delta_impl, emit_event_impl=emit_event_impl, ) diff --git a/tests/units/reflex_core/_internal/event/processor/test_base_state_processor.py b/tests/units/reflex_core/_internal/event/processor/test_base_state_processor.py index 2504598a1d6..d5f20f440c2 100644 --- a/tests/units/reflex_core/_internal/event/processor/test_base_state_processor.py +++ b/tests/units/reflex_core/_internal/event/processor/test_base_state_processor.py @@ -92,7 +92,7 @@ async def emit_event_impl(token: str, *events: Event) -> None: # noqa: RUF029 root_ctx = EventContext( token="", state_manager=state_manager, - enqueue_impl=_real_base_state_processor_obj.enqueue, + enqueue_impl=_real_base_state_processor_obj.enqueue_many, emit_delta_impl=emit_delta_impl, emit_event_impl=emit_event_impl, ) @@ -136,7 +136,7 @@ def noop(self): async with real_base_state_processor as processor: await processor.enqueue( token, - *Event.from_event_type(MyState.noop()), + Event.from_event_type(MyState.noop())[0], ) await processor.join(1) diff --git a/tests/units/reflex_core/_internal/event/processor/test_event_processor.py b/tests/units/reflex_core/_internal/event/processor/test_event_processor.py index f55ca5b1170..6b2b5dbbf47 100644 --- a/tests/units/reflex_core/_internal/event/processor/test_event_processor.py +++ b/tests/units/reflex_core/_internal/event/processor/test_event_processor.py @@ -48,7 +48,7 @@ async def _chaining_handler(): """A handler that enqueues a logging event via the current EventContext.""" ctx = EventContext.get() await ctx.enqueue( - *Event.from_event_type(logging_event("chained")), + Event.from_event_type(logging_event("chained"))[0], ) @@ -176,7 +176,7 @@ async def test_enqueue_after_stop_raises(processor: EventProcessor): async with processor: pass with pytest.raises(QueueShutDown, match="not running"): - await processor.enqueue("tok", *Event.from_event_type(noop_event())) + await processor.enqueue("tok", Event.from_event_type(noop_event())[0]) async def test_enqueue_before_start_raises(processor: EventProcessor): @@ -187,7 +187,7 @@ async def test_enqueue_before_start_raises(processor: EventProcessor): """ processor.configure() with pytest.raises(QueueShutDown, match="not running"): - await processor.enqueue("tok", *Event.from_event_type(noop_event())) + await processor.enqueue("tok", Event.from_event_type(noop_event())[0]) async def test_events_are_processed( @@ -203,7 +203,7 @@ async def test_events_are_processed( token: The client token. """ async with mock_event_processor as ep: - await ep.enqueue(token, *Event.from_event_type(logging_event("hello"))) + await ep.enqueue(token, Event.from_event_type(logging_event("hello"))[0]) assert _CALL_LOG == [{"value": "hello"}] @@ -218,7 +218,7 @@ async def test_enqueue_returns_future( token: The client token. """ async with mock_event_processor as ep: - future = await ep.enqueue(token, *Event.from_event_type(noop_event())) + future = await ep.enqueue(token, Event.from_event_type(noop_event())[0]) assert isinstance(future, asyncio.Future) assert future.done() @@ -234,7 +234,7 @@ async def test_tasks_cleared_after_stop( token: The client token. """ async with mock_event_processor as ep: - await ep.enqueue(token, *Event.from_event_type(noop_event())) + await ep.enqueue(token, Event.from_event_type(noop_event())[0]) assert ep._tasks == {} @@ -249,7 +249,7 @@ async def test_futures_cleared_after_stop( token: The client token. """ async with mock_event_processor as ep: - await ep.enqueue(token, *Event.from_event_type(noop_event())) + await ep.enqueue(token, Event.from_event_type(noop_event())[0]) assert ep._futures == {} @@ -262,7 +262,7 @@ async def test_slow_tasks_cancelled_on_stop(processor: EventProcessor): processor.graceful_shutdown_timeout = 0 processor.configure() async with processor as ep: - future = await ep.enqueue("tok", *Event.from_event_type(slow_event(10.0))) + future = await ep.enqueue("tok", Event.from_event_type(slow_event(10.0))[0]) assert future.cancelled() assert ep._tasks == {} @@ -276,8 +276,8 @@ async def test_multiple_futures_cancelled_on_stop(processor: EventProcessor): processor.graceful_shutdown_timeout = 0 processor.configure() async with processor as ep: - f1 = await ep.enqueue("t1", *Event.from_event_type(slow_event(10.0))) - f2 = await ep.enqueue("t2", *Event.from_event_type(slow_event(10.0))) + f1 = await ep.enqueue("t1", Event.from_event_type(slow_event(10.0))[0]) + f2 = await ep.enqueue("t2", Event.from_event_type(slow_event(10.0))[0]) for f in (f1, f2): assert f.done() assert ep._futures == {} @@ -294,7 +294,7 @@ async def test_cancel_future_before_task_starts( token: The client token. """ async with mock_event_processor as ep: - future = await ep.enqueue(token, *Event.from_event_type(slow_event(10.0))) + future = await ep.enqueue(token, Event.from_event_type(slow_event(10.0))[0]) future.cancel() await asyncio.sleep(0.05) assert ep._tasks == {} @@ -311,7 +311,7 @@ async def test_cancel_future_cancels_running_task( token: The client token. """ async with mock_event_processor as ep: - future = await ep.enqueue(token, *Event.from_event_type(slow_event(10.0))) + future = await ep.enqueue(token, Event.from_event_type(slow_event(10.0))[0]) await asyncio.sleep(0.05) future.cancel() await asyncio.sleep(0.05) @@ -330,7 +330,7 @@ async def test_exception_propagated_to_future( """ processor.configure() async with processor as ep: - future = await ep.enqueue(token, *Event.from_event_type(error_event())) + future = await ep.enqueue(token, Event.from_event_type(error_event())[0]) assert future.done() with pytest.raises(RuntimeError, match="boom"): future.result() @@ -350,7 +350,7 @@ def _catch(ex: Exception) -> None: ep = EventProcessor(backend_exception_handler=_catch, graceful_shutdown_timeout=2) ep.configure() async with ep: - await ep.enqueue(token, *Event.from_event_type(error_event())) + await ep.enqueue(token, Event.from_event_type(error_event())[0]) assert len(caught) == 1 assert isinstance(caught[0], RuntimeError) @@ -367,8 +367,8 @@ async def test_error_does_not_stop_queue( """ processor.configure() async with processor as ep: - await ep.enqueue(token, *Event.from_event_type(error_event())) - await ep.enqueue(token, *Event.from_event_type(logging_event("after_error"))) + await ep.enqueue(token, Event.from_event_type(error_event())[0]) + await ep.enqueue(token, Event.from_event_type(logging_event("after_error"))[0]) assert _CALL_LOG == [{"value": "after_error"}] @@ -381,7 +381,7 @@ async def test_chained_event_processed(token: str): ep = EventProcessor(graceful_shutdown_timeout=2) ep.configure() async with ep: - await ep.enqueue(token, *Event.from_event_type(chaining_event())) + await ep.enqueue(token, Event.from_event_type(chaining_event())[0]) assert _CALL_LOG == [{"value": "chained"}] @@ -406,7 +406,7 @@ async def test_join_completes_after_processing( token: The client token. """ async with mock_event_processor as ep: - await ep.enqueue(token, *Event.from_event_type(noop_event())) + await ep.enqueue(token, Event.from_event_type(noop_event())[0]) await ep.join(timeout=5) diff --git a/tests/units/test_model.py b/tests/units/test_model.py index 1c9886957cf..54508f4fc41 100644 --- a/tests/units/test_model.py +++ b/tests/units/test_model.py @@ -241,7 +241,7 @@ async def test_upcast_event_handler_arg( """ async with mock_base_state_event_processor as processor: await processor.enqueue( - "test_token", *Event.from_event_type(handler(**payload)) + "test_token", Event.from_event_type(handler(**payload))[0] ) assert emitted_deltas == [ ( diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 83260b6dfb8..14a2447687c 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2944,7 +2944,7 @@ def test_handler(self): Chain of EventHandlers """ self.num += 1 - return self.change_name + return type(self).change_name def change_name(self): """Test handler to change name.""" @@ -2965,9 +2965,28 @@ async def test_handler(self): @pytest.mark.parametrize( ("test_state", "expected"), [ - (OnLoadState, {"on_load_state": {"num": 1}}), - (OnLoadState2, {"on_load_state2": {"num": 1}}), - (OnLoadState3, {"on_load_state3": {"num": 1}}), + ( + OnLoadState, + [ + {OnLoadState.get_full_name(): {"num" + FIELD_MARKER: 1}}, + exp_is_hydrated(State, True), + ], + ), + ( + OnLoadState2, + [ + {OnLoadState2.get_full_name(): {"num" + FIELD_MARKER: 1}}, + exp_is_hydrated(State, True), + {OnLoadState2.get_full_name(): {"name" + FIELD_MARKER: "random"}}, + ], + ), + ( + OnLoadState3, + [ + {OnLoadState3.get_full_name(): {"num" + FIELD_MARKER: 1}}, + exp_is_hydrated(State, True), + ], + ), ], ) async def test_preprocess( @@ -3007,45 +3026,35 @@ def index(): ) async with mock_base_state_event_processor as processor: - await processor.enqueue( - token, - Event( - name=on_load_internal_name, - router_data={ - RouteVar.PATH: "/", - RouteVar.ORIGIN: "/", - RouteVar.QUERY: {}, - }, - ), + await ( + await processor.enqueue( + token, + Event( + name=on_load_internal_name, + router_data={ + RouteVar.PATH: "/", + RouteVar.ORIGIN: "/", + RouteVar.QUERY: {}, + }, + ), + ) ) - await processor.join() # The processor chains all events: on_load_internal sets is_hydrated=False, # then the on_load handler runs, then set_is_hydrated(True) runs. # First delta: router + is_hydrated=False - assert len(emitted_deltas) >= 2 - first_delta = emitted_deltas[0][1] + assert len(emitted_deltas) == 1 + len(expected) + first_token, first_delta = emitted_deltas[0] + assert first_token == token assert first_delta[State.get_full_name()].pop("router" + FIELD_MARKER) is not None assert first_delta == exp_is_hydrated(State, False) - # Find the delta containing the test handler's state change - handler_deltas = [ - d - for _, d in emitted_deltas - if test_state.get_full_name() in d - and "num" + FIELD_MARKER in d[test_state.get_full_name()] - ] - assert len(handler_deltas) >= 1 - assert handler_deltas[0][test_state.get_full_name()]["num" + FIELD_MARKER] == 1 - - # Find the delta that sets is_hydrated back to True - hydrated_deltas = [ - d - for _, d in emitted_deltas - if State.get_full_name() in d - and d[State.get_full_name()].get(CompileVars.IS_HYDRATED + FIELD_MARKER) is True - ] - assert len(hydrated_deltas) == 1 + # Find the deltas containing the test handler's state change + for (delta_token, actual_delta), expected_delta in zip( + emitted_deltas[1:], expected, strict=True + ): + assert delta_token == token + assert actual_delta == expected_delta @pytest.mark.asyncio @@ -3413,7 +3422,8 @@ async def test_setvar( TestState.setvar("num2", "4.2"), ]) async with mock_base_state_event_processor as processor: - await processor.enqueue(token, *events) + for fut in asyncio.as_completed(await processor.enqueue_many(token, *events)): + await fut await processor.join(1) if environment.REFLEX_OPLOCK_ENABLED.get(): @@ -3427,7 +3437,7 @@ async def test_setvar( # Set Var in parent state events = Event.from_event_type([GrandchildState.setvar("array", [43])]) async with mock_base_state_event_processor as processor: - await processor.enqueue(token, *events) + await (await processor.enqueue(token, events[0])) if environment.REFLEX_OPLOCK_ENABLED.get(): await state_manager.close() From b3c81789ae6de377d00c1f9c8697c501956894cc Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 2 Apr 2026 00:13:16 -0700 Subject: [PATCH 56/81] attach the registration_context_middleware in App.__call__ We don't want to set it from a lifespan task, because once the middleware is added, it cannot be readded again. So if the same ASGI gets started and stopped, it will throw an error. Does this happen much in real life? No. But one AppHarness test was hitting it, and this is technically more correct. --- reflex/app.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index e3a97c5b014..a5c3826973a 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -602,11 +602,6 @@ async def registration_context_middleware( @contextlib.asynccontextmanager async def _setup_event_processor(self) -> AsyncIterator[None]: - # Make sure the RegistrationContext is attached. - if self._api is not None: - self._api.add_middleware( - self._registration_context_middleware, - ) # Create the event processor. self._event_processor = BaseStateEventProcessor( middleware=self, backend_exception_handler=self.backend_exception_handler @@ -657,6 +652,8 @@ def __call__(self) -> ASGIApp: raise ValueError(msg) asgi_app = self._api + # Make sure the RegistrationContext is attached. + asgi_app.add_middleware(self._registration_context_middleware) if environment.REFLEX_MOUNT_FRONTEND_COMPILED_APP.get(): asgi_app.mount( From 83ba6eea2f1ff717167f26c31f11337bd2a4b54c Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 2 Apr 2026 10:25:51 -0700 Subject: [PATCH 57/81] test_connection_banner: use CDP to simulate network offline Instead of killing and restarting the server, which is slower and more annoying, just temporarily restrict network traffic in the browser. --- tests/integration/test_connection_banner.py | 83 ++++++++++++++------- 1 file changed, 57 insertions(+), 26 deletions(-) diff --git a/tests/integration/test_connection_banner.py b/tests/integration/test_connection_banner.py index f165ae655cd..b2264315837 100644 --- a/tests/integration/test_connection_banner.py +++ b/tests/integration/test_connection_banner.py @@ -1,7 +1,9 @@ """Test case for displaying the connection banner when the websocket drops.""" +import asyncio +import contextlib import pickle -from collections.abc import AsyncGenerator, Generator +from collections.abc import AsyncGenerator, Generator, Iterator import pytest import pytest_asyncio @@ -95,6 +97,40 @@ def connection_banner( yield harness +@contextlib.contextmanager +def browser_offline(driver: WebDriver) -> Iterator[None]: + """Context manager that takes the browser offline via CDP and restores it on exit. + + Args: + driver: Selenium WebDriver instance (must support execute_cdp_cmd). + + Yields: + None + """ + driver.execute_cdp_cmd("Network.enable", {}) + driver.execute_cdp_cmd( + "Network.emulateNetworkConditions", + { + "offline": True, + "downloadThroughput": -1, + "uploadThroughput": -1, + "latency": 0, + }, + ) + try: + yield + finally: + driver.execute_cdp_cmd( + "Network.emulateNetworkConditions", + { + "offline": False, + "downloadThroughput": -1, + "uploadThroughput": -1, + "latency": 0, + }, + ) + + CONNECTION_ERROR_XPATH = "//*[ contains(text(), 'Cannot connect to server') ]" @@ -170,7 +206,8 @@ async def redis( redis = get_redis() yield redis if redis is not None: - await redis.close() + with contextlib.suppress(Exception, asyncio.CancelledError): + await redis.aclose() @pytest.mark.asyncio @@ -206,34 +243,28 @@ async def test_connection_banner(connection_banner: AppHarness, redis: Redis | N increment_button.click() assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1" - # Start an long event before killing the backend, to mark event_processing=true + # Start a long event before blocking the network, to mark event_processing=true delay_button.click() - # Get the backend port - backend_port = connection_banner._poll_for_servers().getsockname()[1] + with browser_offline(driver): + # Error modal should now be displayed + AppHarness.expect(lambda: has_error_modal(driver)) - # Kill the backend - connection_banner.backend.should_exit = True - if connection_banner.backend_thread is not None: - connection_banner.backend_thread.join() - - # Error modal should now be displayed - AppHarness.expect(lambda: has_error_modal(driver)) - - # The token association should have been removed when the server exited. - assert token not in app_token_manager.token_to_sid - if redis is not None: - assert isinstance(app_token_manager, RedisTokenManager) - assert await redis.get(app_token_manager._get_redis_key(token)) is None - - # Increment the counter with backend down - increment_button.click() - assert connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1" - - # Bring the backend back up - connection_banner._start_backend(port=backend_port) + # The token association should be removed once the websocket closes on the server. + assert connection_banner._poll_for( + lambda: token not in app_token_manager.token_to_sid + ) + if redis is not None: + assert isinstance(app_token_manager, RedisTokenManager) + assert await redis.get(app_token_manager._get_redis_key(token)) is None + + # Increment the counter while disconnected + increment_button.click() + assert ( + connection_banner.poll_for_value(counter_element, exp_not_equal="0") == "1" + ) - # Banner should be gone now + # Banner should be gone now (network restored on context manager exit) AppHarness.expect(lambda: not has_error_modal(driver)) # After reconnecting, the token association should be re-established. From 5b36147de046bfa49201a85cb07e031ad24544bb Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 2 Apr 2026 10:41:44 -0700 Subject: [PATCH 58/81] reflex_core.event: provide BaseState as a namespace property Avoid other weird circular import issues that occur when the `sys.modules` record for the module is not actually the module namespace. --- packages/reflex-core/src/reflex_core/event.py | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/event.py b/packages/reflex-core/src/reflex_core/event.py index c89008c6901..13a47865bd5 100644 --- a/packages/reflex-core/src/reflex_core/event.py +++ b/packages/reflex-core/src/reflex_core/event.py @@ -2812,13 +2812,22 @@ def wrapper( run_script = staticmethod(run_script) __file__ = __file__ + @property + def BaseState(self) -> "type[BaseState]": # noqa: N802 + """Get the BaseState class. + + A reference to BaseState is needed for doc generation when resolving + type hints, so add it to the namespace late to avoid circular import + issues. + + Returns: + The BaseState class. + """ + from reflex.state import BaseState + + return BaseState + event = EventNamespace event.event = event # pyright: ignore[reportAttributeAccessIssue] sys.modules[__name__] = event # pyright: ignore[reportArgumentType] - -# A reference to BaseState is needed for doc generation when resolving type -# hints, so add it to the namespace late to avoid circular import issues. -from reflex.state import BaseState # noqa: E402 - -event.BaseState = BaseState # pyright: ignore[reportAttributeAccessIssue] From 26160e4f964b11b8fac57428accd860f6f414307 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 2 Apr 2026 16:44:04 -0700 Subject: [PATCH 59/81] Track EventFuture children Ensure yielded event execution ordering, by waiting for sibling EventFuture for non-background tasks before processing. --- .../event/processor/base_state_processor.py | 9 +- .../event/processor/event_processor.py | 117 +++++++--- .../_internal/event/processor/future.py | 56 +++-- .../event/processor/test_event_processor.py | 162 ++++++++++++++ .../_internal/event/processor/test_future.py | 204 ++++++++++++++---- tests/units/test_state.py | 39 ++-- 6 files changed, 477 insertions(+), 110 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py index 1bab71a481c..51f0571e36f 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/base_state_processor.py @@ -308,20 +308,19 @@ async def _rehydrate(self, root_state: BaseState): root_state=root_state, ) - async def _process_event_queue_entry( + async def _execute_event( self, *, entry: EventQueueEntry, registered_handler: RegisteredEventHandler ) -> None: - """Process a single event queue entry. + """Execute the handler for a single event queue entry with full state management. - This function runs in a new task for each event. + The ``EventContext`` has already been set by ``_process_event_queue_entry`` + before this method is called. Args: entry: The event queue entry to process. registered_handler: The registered handler for the event. """ - # Set up the event context for this task. ctx = entry.ctx - EventContext.set(ctx) event = entry.event router_data = event.router_data or {} # Get the state for the session exclusively. diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py index 28175ce188b..0220bff9615 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py @@ -341,15 +341,29 @@ async def enqueue( msg = "Event processor is not running, call .start(...) first." raise RuntimeError(msg) from le queue = self._ensure_queue_task() - tracked = EventFuture.create() + # Determine whether sequential ordering is required for this event. + is_background = False + try: + registered = RegistrationContext.get().event_handlers.get(event.name) + if registered is not None and registered.handler.is_background: + is_background = True + except LookupError: + pass txid = ev_ctx.txid + parent_future = ( + self._futures.get(ev_ctx.parent_txid) + if ev_ctx.parent_txid is not None + else None + ) + tracked = EventFuture( + sequential=not is_background, parent=parent_future, txid=txid + ) self._futures[txid] = tracked - tracked.add_done_callback(lambda f: self._on_future_done(txid, f)) + tracked.add_done_callback(self._try_clean_future) + tracked.add_done_callback(self._on_future_done) # If this context has a parent, register as a child of the parent's future. - if ev_ctx.parent_txid is not None: - parent_tracked = self._futures.get(ev_ctx.parent_txid) - if parent_tracked is not None: - parent_tracked.add_child(tracked) + if parent_future is not None: + parent_future.add_child(tracked) await queue.put(EventQueueEntry(event=event, ctx=ev_ctx)) return tracked @@ -434,7 +448,29 @@ async def _emit_delta_impl( # Raise any exceptions for the caller, waiting for all chained events. await task_future.wait_all() - def _on_future_done(self, txid: str, future: asyncio.Future) -> None: + def _try_clean_future(self, future: EventFuture) -> None: # type: ignore[override] + """Pop a future from _futures when it and all immediate children are done. + + After popping, cascade the check upward: if the parent future is also + done and all its immediate children are done, pop the parent as well. + + This keeps parent futures alive in ``_futures`` while any child still + needs to look up its siblings for sequential ordering. + + Args: + future: The EventFuture to check. + """ + if not future.done(): + return + # Not checking future.all_done() to avoid waiting for grandchildren here. + if not all(c.done() for c in future.children): + return + parent = future.parent + self._futures.pop(future.txid, None) + if parent is not None and parent.txid: + self._try_clean_future(parent) + + def _on_future_done(self, future: EventFuture) -> None: # type: ignore[override] """Callback invoked when an enqueued future completes. If the future was cancelled externally, cancel the running task @@ -443,33 +479,53 @@ def _on_future_done(self, txid: str, future: asyncio.Future) -> None: entry is dequeued. Args: - txid: The transaction id associated with the future. - future: The future that completed. + future: The EventFuture that completed. """ if not future.cancelled(): return # Cascade cancellation to all child futures. - tracked = self._futures.get(txid) - if tracked is not None: - for child in tracked.children: - child.cancel() - task = self._tasks.get(txid) + for child in future.children: + child.cancel() + task = self._tasks.get(future.txid) if task is not None: task.cancel() + async def _execute_event( + self, *, entry: EventQueueEntry, registered_handler: RegisteredEventHandler + ) -> None: + """Execute the handler for a single event queue entry. + + This method contains the actual event-processing logic. The base + implementation simply invokes the registered handler function with the + event payload. Subclasses (e.g. ``BaseStateEventProcessor``) override + this method to add state management, delta emission, and middleware. + + ``_process_event_queue_entry`` is responsible for setting up the + ``EventContext`` and ensuring sequential ordering *before* calling this + method. + + Args: + entry: The event queue entry to process. + registered_handler: The registered handler for the event. + """ + event = entry.event + result = registered_handler.handler.fn(**event.payload) + if inspect.isawaitable(result): + await result + async def _process_event_queue_entry( self, *, entry: EventQueueEntry, registered_handler: RegisteredEventHandler ) -> None: """Process a single event queue entry. - This function runs in a new task for each event. - - The default implementation just calls the registered handler function - with the event payload as keyword arguments. + This function runs in a new task for each event. It sets up the + ``EventContext``, enforces sequential ordering for non-background + events, and then delegates to ``_execute_event`` for the actual + handler invocation. - Subclasses, such as BaseStateEventProcessor, can override this function - to provide additional functionality such as state management, event - chaining, and delta calculation. + Subclasses should override ``_execute_event`` rather than this method + so that the shared context setup and sequential-ordering logic is + always applied. Args: entry: The event queue entry to process. @@ -478,10 +534,14 @@ async def _process_event_queue_entry( # Set up the event context for this task. ctx = entry.ctx EventContext.set(ctx) - event = entry.event - result = registered_handler.handler.fn(**event.payload) - if inspect.isawaitable(result): - await result + # For sequential (non-background) events, wait for the previous sibling + # to finish before proceeding so that chained events run in order. + if current_future := self._futures.get(ctx.txid): + await current_future.wait_for_predecessor() + print( + f"Processing event {entry.event} [txid={ctx.txid}] with handler {registered_handler.handler.fn.__name__}" + ) + await self._execute_event(entry=entry, registered_handler=registered_handler) async def _process_queue(self): """Process events from the queue in a task.""" @@ -494,7 +554,7 @@ async def _process_queue(self): if ( future := self._futures.get(entry.ctx.txid) ) is not None and future.cancelled(): - self._futures.pop(entry.ctx.txid, None) + self._try_clean_future(future) queue.task_done() continue try: @@ -520,6 +580,9 @@ async def _process_queue(self): task._event_ctx = entry.ctx # pyright: ignore[reportAttributeAccessIssue] self._tasks[entry.ctx.txid] = task task.add_done_callback(self._finish_task) + print( + f"Enqueued task {task.get_name()} for event {entry.event.name} [txid={entry.ctx.txid}]" + ) except Exception: # Log the error and continue processing the next events. console.error( @@ -563,7 +626,7 @@ def _finish_task(self, task: asyncio.Task): else: task_ctx = task.get_context().run(EventContext.get) self._tasks.pop(task_ctx.txid, None) - future = self._futures.pop(task_ctx.txid, None) + future = self._futures.get(task_ctx.txid) if task.done(): try: result = task.result() diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py index fc4ee6e84e8..974f9491843 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py @@ -4,9 +4,11 @@ import asyncio import contextlib +import dataclasses from typing import Any +@dataclasses.dataclass(kw_only=True, slots=True, eq=False) class EventFuture(asyncio.Future): """A future that tracks child futures for hierarchical event processing. @@ -14,25 +16,28 @@ class EventFuture(asyncio.Future): futures are tracked so callers can wait for the entire chain to complete. """ - children: list[EventFuture] + # The transaction id associated with this future. + txid: str - def __init__(self, *, loop: asyncio.AbstractEventLoop | None = None) -> None: - super().__init__(loop=loop) - self.children = [] + # If sequential is True, sibling events will be processed sequentially. + # Background events should set sequential=False to run immediately and + # without affecting non-background event ordering. + sequential: bool = True - @classmethod - def create(cls, loop: asyncio.AbstractEventLoop | None = None) -> EventFuture: - """Create a new EventFuture on the given or running event loop. + # Child futures spawned by this future, if any. + children: list[EventFuture] = dataclasses.field(default_factory=list) - Args: - loop: The event loop to use. Defaults to the running loop. + # The parent future that spawned this one, or None if this future was + # enqueued directly from the queue rather than chained from another event. + parent: EventFuture | None = dataclasses.field(default=None, repr=False) - Returns: - A new EventFuture instance. - """ - if loop is None: - loop = asyncio.get_running_loop() - return cls(loop=loop) + # The event loop that this future is running on. + loop: asyncio.AbstractEventLoop = dataclasses.field( + default_factory=asyncio.get_running_loop, repr=False + ) + + def __post_init__(self) -> None: + super().__init__(loop=self.loop) def add_child(self, child: EventFuture) -> None: """Add a child future to this tracked future. @@ -93,6 +98,27 @@ def cancel(self, msg: object = None) -> bool: child.cancel(msg) return result + async def wait_for_predecessor(self) -> None: + """Wait for the immediately preceding sequential sibling to complete. + + If this future is not sequential, has no parent, is the first child, + or cannot be found in the parent's children list, this is a no-op. + """ + if not self.sequential or self.parent is None: + return + children = self.parent.children + try: + idx = children.index(self) + except ValueError: + return + if idx == 0: + return # First child: no predecessor to wait for. + with contextlib.suppress(Exception, asyncio.CancelledError): + print( + f"EventFuture {self.txid} waiting for predecessor {children[idx - 1].txid}" + ) + await children[idx - 1] + __all__ = [ "EventFuture", diff --git a/tests/units/reflex_core/_internal/event/processor/test_event_processor.py b/tests/units/reflex_core/_internal/event/processor/test_event_processor.py index 6b2b5dbbf47..2d9adefbf7c 100644 --- a/tests/units/reflex_core/_internal/event/processor/test_event_processor.py +++ b/tests/units/reflex_core/_internal/event/processor/test_event_processor.py @@ -1,6 +1,7 @@ """Tests for EventProcessor lifecycle, task management, and error handling.""" import asyncio +import contextlib from typing import Any import pytest @@ -9,6 +10,7 @@ EventProcessor, QueueShutDown, ) +from reflex_core._internal.event.processor.future import EventFuture from reflex_core._internal.registry import RegistrationContext from reflex.event import Event, EventHandler @@ -66,6 +68,52 @@ async def _multi_delta_handler(): await asyncio.sleep(0.01) +async def _slow_logging_handler(value: str = "default"): + """A slow logging handler that pauses before recording. + + Args: + value: The value to log. + """ + await asyncio.sleep(0.05) + _CALL_LOG.append({"value": value}) + + +async def _multi_chaining_handler(): + """A handler that enqueues three slow logging events in sequence.""" + ctx = EventContext.get() + for label in ("first", "second", "third"): + await ctx.enqueue( + Event.from_event_type(slow_logging_event(label))[0], + ) + + +async def _background_then_normal_handler(): + """A handler that enqueues a background event followed by a normal slow event.""" + ctx = EventContext.get() + await ctx.enqueue(Event.from_event_type(background_slow_logging_event("bg"))[0]) + await ctx.enqueue(Event.from_event_type(slow_logging_event("normal"))[0]) + + +async def _error_then_logging_handler(): + """A handler that enqueues an error event followed by a logging event.""" + ctx = EventContext.get() + await ctx.enqueue(Event.from_event_type(error_event())[0]) + await ctx.enqueue(Event.from_event_type(logging_event("after_chain_error"))[0]) + + +async def _background_slow_logging_handler(value: str = "default"): + """A background version of the slow logging handler. + + Args: + value: The value to log. + """ + await asyncio.sleep(0.05) + _CALL_LOG.append({"value": value}) + + +_background_slow_logging_handler._reflex_background_task = True # type: ignore[attr-defined] + + noop_event = EventHandler(fn=_noop_handler) slow_event = EventHandler(fn=_slow_handler) error_event = EventHandler(fn=_error_handler) @@ -73,6 +121,11 @@ async def _multi_delta_handler(): chaining_event = EventHandler(fn=_chaining_handler) delta_event = EventHandler(fn=_delta_handler) multi_delta_event = EventHandler(fn=_multi_delta_handler) +slow_logging_event = EventHandler(fn=_slow_logging_handler) +multi_chaining_event = EventHandler(fn=_multi_chaining_handler) +background_slow_logging_event = EventHandler(fn=_background_slow_logging_handler) +background_then_normal_event = EventHandler(fn=_background_then_normal_handler) +error_then_logging_event = EventHandler(fn=_error_then_logging_handler) @pytest.fixture(autouse=True) @@ -91,6 +144,11 @@ def _register_handlers(forked_registration_context: RegistrationContext): chaining_event, delta_event, multi_delta_event, + slow_logging_event, + multi_chaining_event, + background_slow_logging_event, + background_then_normal_event, + error_then_logging_event, ): RegistrationContext.register_event_handler(handler) @@ -462,3 +520,107 @@ async def test_stream_delta_not_configured_raises(): with pytest.raises(RuntimeError, match="not configured"): async for _ in ep.enqueue_stream_delta("tok", Event(name="x", payload={})): pass + + +async def test_sequential_chained_events_run_in_order(token: str): + """Chained events enqueued by a handler run in the order they were enqueued. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + future = await ep.enqueue( + token, Event.from_event_type(multi_chaining_event())[0] + ) + await future.wait_all() + assert [entry["value"] for entry in _CALL_LOG] == ["first", "second", "third"] + + +async def test_sequential_chained_futures_are_sequential(token: str): + """EventFutures created for normal (non-background) events have sequential=True. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + future = await ep.enqueue(token, Event.from_event_type(logging_event())[0]) + assert isinstance(future, EventFuture) + assert future.sequential is True + + +async def test_background_event_future_is_not_sequential(token: str): + """EventFutures created for background events have sequential=False. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + future = await ep.enqueue( + token, Event.from_event_type(background_slow_logging_event())[0] + ) + assert isinstance(future, EventFuture) + assert future.sequential is False + + +async def test_futures_cleaned_up_after_chained_events(token: str): + """All futures are removed from _futures after chained events complete. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + future = await ep.enqueue( + token, Event.from_event_type(multi_chaining_event())[0] + ) + await future.wait_all() + assert ep._futures == {} + + +async def test_background_event_does_not_block_sequential_sibling(token: str): + """A background event enqueued before a sequential sibling does not delay it. + + The background event (sequential=False) should execute concurrently while + the normal sibling is free to start without waiting for the background + event to finish first. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + future = await ep.enqueue( + token, Event.from_event_type(background_then_normal_event())[0] + ) + await future.wait_all() + # Both events should have been processed regardless of order. + values = {entry["value"] for entry in _CALL_LOG} + assert values == {"bg", "normal"} + + +async def test_sequential_chain_continues_after_error(token: str): + """A sequential chained event still runs when the preceding sibling raised an exception. + + The error in the first chained event must not block the second chained + event from executing. + + Args: + token: The client token. + """ + ep = EventProcessor(graceful_shutdown_timeout=2) + ep.configure() + async with ep: + future = await ep.enqueue( + token, Event.from_event_type(error_then_logging_event())[0] + ) + with contextlib.suppress(Exception): + await future.wait_all() + assert _CALL_LOG == [{"value": "after_chain_error"}] diff --git a/tests/units/reflex_core/_internal/event/processor/test_future.py b/tests/units/reflex_core/_internal/event/processor/test_future.py index 6597b8c5673..2cde70feef5 100644 --- a/tests/units/reflex_core/_internal/event/processor/test_future.py +++ b/tests/units/reflex_core/_internal/event/processor/test_future.py @@ -1,16 +1,38 @@ """Tests for EventFuture.""" import asyncio +import contextlib +from collections.abc import AsyncGenerator import pytest from reflex_core._internal.event.processor.future import EventFuture +@contextlib.asynccontextmanager +async def assert_no_loop_yield() -> AsyncGenerator[None, None]: + """Assert the body never yields control to the event loop. + + Schedules a sentinel task before the body runs and asserts it has not + executed by the time the body returns. Because asyncio is cooperative, + the sentinel can only run if the body awaited something that suspended it. + """ + sentinel_ran = False + + async def _sentinel() -> None: # noqa: RUF029 + nonlocal sentinel_ran + sentinel_ran = True + + task = asyncio.create_task(_sentinel()) + yield + assert not sentinel_ran, "Event loop was unexpectedly yielded to" + await task + + @pytest.mark.asyncio async def test_create_uses_running_loop(): # noqa: RUF029 - """EventFuture.create() defaults to the running event loop.""" + """EventFuture() defaults to the running event loop.""" running_loop = asyncio.get_running_loop() - f = EventFuture.create() + f = EventFuture(txid="f") assert isinstance(f, EventFuture) assert f.get_loop() is running_loop assert f.children == [] @@ -19,10 +41,10 @@ async def test_create_uses_running_loop(): # noqa: RUF029 @pytest.mark.asyncio async def test_create_with_explicit_loop(): # noqa: RUF029 - """EventFuture.create(loop=...) uses the given (non-default) loop.""" + """EventFuture(loop=...) uses the given (non-default) loop.""" other_loop = asyncio.new_event_loop() try: - f = EventFuture.create(loop=other_loop) + f = EventFuture(txid="f", loop=other_loop) assert isinstance(f, EventFuture) assert f.get_loop() is other_loop assert f.get_loop() is not asyncio.get_running_loop() @@ -33,8 +55,8 @@ async def test_create_with_explicit_loop(): # noqa: RUF029 @pytest.mark.asyncio async def test_add_child_multiple(): # noqa: RUF029 """add_child can be called multiple times.""" - parent = EventFuture.create() - children = [EventFuture.create() for _ in range(3)] + parent = EventFuture(txid="parent") + children = [EventFuture(txid=f"c{i}") for i in range(3)] for c in children: parent.add_child(c) assert parent.children == children @@ -43,9 +65,9 @@ async def test_add_child_multiple(): # noqa: RUF029 @pytest.mark.asyncio async def test_add_child_to_done_future_raises(): # noqa: RUF029 """add_child raises RuntimeError if the parent future is already done.""" - parent = EventFuture.create() + parent = EventFuture(txid="parent") parent.set_result(None) - child = EventFuture.create() + child = EventFuture(txid="child") with pytest.raises(RuntimeError, match="already done"): parent.add_child(child) @@ -53,9 +75,9 @@ async def test_add_child_to_done_future_raises(): # noqa: RUF029 @pytest.mark.asyncio async def test_add_child_to_cancelled_future_raises(): # noqa: RUF029 """add_child raises RuntimeError if the parent future is cancelled.""" - parent = EventFuture.create() + parent = EventFuture(txid="parent") parent.cancel() - child = EventFuture.create() + child = EventFuture(txid="child") with pytest.raises(RuntimeError, match="already done"): parent.add_child(child) @@ -63,7 +85,7 @@ async def test_add_child_to_cancelled_future_raises(): # noqa: RUF029 @pytest.mark.asyncio async def test_all_done_no_children(): # noqa: RUF029 """all_done is True when the future is resolved and has no children.""" - f = EventFuture.create() + f = EventFuture(txid="f") assert not f.all_done() f.set_result(42) assert f.all_done() @@ -72,8 +94,8 @@ async def test_all_done_no_children(): # noqa: RUF029 @pytest.mark.asyncio async def test_all_done_with_pending_child(): # noqa: RUF029 """all_done is False when a child is still pending.""" - parent = EventFuture.create() - child = EventFuture.create() + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") parent.add_child(child) parent.set_result(None) assert not parent.all_done() @@ -84,9 +106,9 @@ async def test_all_done_with_pending_child(): # noqa: RUF029 @pytest.mark.asyncio async def test_all_done_nested(): # noqa: RUF029 """all_done checks the full descendant tree.""" - root = EventFuture.create() - child = EventFuture.create() - grandchild = EventFuture.create() + root = EventFuture(txid="root") + child = EventFuture(txid="child") + grandchild = EventFuture(txid="grandchild") root.add_child(child) child.add_child(grandchild) @@ -102,8 +124,8 @@ async def test_all_done_nested(): # noqa: RUF029 @pytest.mark.asyncio async def test_all_done_with_cancelled_child(): # noqa: RUF029 """all_done is True when all children are cancelled (done).""" - parent = EventFuture.create() - child = EventFuture.create() + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") parent.add_child(child) parent.set_result(None) child.cancel() @@ -113,8 +135,8 @@ async def test_all_done_with_cancelled_child(): # noqa: RUF029 @pytest.mark.asyncio async def test_all_done_with_exception_child(): # noqa: RUF029 """all_done is True when a child has an exception (still done).""" - parent = EventFuture.create() - child = EventFuture.create() + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") parent.add_child(child) parent.set_result(None) child.set_exception(ValueError("boom")) @@ -124,7 +146,7 @@ async def test_all_done_with_exception_child(): # noqa: RUF029 @pytest.mark.asyncio async def test_wait_all_returns_result(): """wait_all returns the result of the root future.""" - f = EventFuture.create() + f = EventFuture(txid="f") f.set_result(42) result = await f.wait_all() assert result == 42 @@ -133,8 +155,8 @@ async def test_wait_all_returns_result(): @pytest.mark.asyncio async def test_wait_all_waits_for_children(): """wait_all waits for all children to complete.""" - parent = EventFuture.create() - child = EventFuture.create() + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") parent.add_child(child) async def resolve_later(): @@ -152,9 +174,9 @@ async def resolve_later(): @pytest.mark.asyncio async def test_wait_all_waits_for_nested_children(): """wait_all waits for grandchildren too.""" - root = EventFuture.create() - child = EventFuture.create() - grandchild = EventFuture.create() + root = EventFuture(txid="root") + child = EventFuture(txid="child") + grandchild = EventFuture(txid="grandchild") root.add_child(child) child.add_child(grandchild) @@ -175,8 +197,8 @@ async def resolve_chain(): @pytest.mark.asyncio async def test_wait_all_suppresses_child_exceptions(): """wait_all suppresses exceptions from children.""" - parent = EventFuture.create() - child = EventFuture.create() + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") parent.add_child(child) parent.set_result("ok") @@ -190,8 +212,8 @@ async def test_wait_all_suppresses_child_exceptions(): @pytest.mark.asyncio async def test_wait_all_suppresses_child_cancellation(): """wait_all suppresses CancelledError from children.""" - parent = EventFuture.create() - child = EventFuture.create() + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") parent.add_child(child) parent.set_result("ok") @@ -204,14 +226,14 @@ async def test_wait_all_suppresses_child_cancellation(): @pytest.mark.asyncio async def test_wait_all_children_added_during_iteration(): """wait_all picks up children added while iterating (index-based walk).""" - parent = EventFuture.create() - child1 = EventFuture.create() + parent = EventFuture(txid="parent") + child1 = EventFuture(txid="child1") parent.add_child(child1) parent.set_result("done") # child2 will be added to child1 after child1 resolves, # simulating a chained event that enqueues more events. - child2 = EventFuture.create() + child2 = EventFuture(txid="child2") async def resolve_and_chain(): await asyncio.sleep(0.01) @@ -229,7 +251,7 @@ async def resolve_and_chain(): @pytest.mark.asyncio async def test_cancel_no_children(): # noqa: RUF029 """Cancel cancels the future itself.""" - f = EventFuture.create() + f = EventFuture(txid="f") assert f.cancel() assert f.cancelled() @@ -237,9 +259,9 @@ async def test_cancel_no_children(): # noqa: RUF029 @pytest.mark.asyncio async def test_cancel_cascades_to_children(): # noqa: RUF029 """Cancel propagates to all children.""" - parent = EventFuture.create() - child1 = EventFuture.create() - child2 = EventFuture.create() + parent = EventFuture(txid="parent") + child1 = EventFuture(txid="child1") + child2 = EventFuture(txid="child2") parent.add_child(child1) parent.add_child(child2) @@ -252,9 +274,9 @@ async def test_cancel_cascades_to_children(): # noqa: RUF029 @pytest.mark.asyncio async def test_cancel_cascades_to_grandchildren(): # noqa: RUF029 """Cancel propagates through the full descendant tree.""" - root = EventFuture.create() - child = EventFuture.create() - grandchild = EventFuture.create() + root = EventFuture(txid="root") + child = EventFuture(txid="child") + grandchild = EventFuture(txid="grandchild") root.add_child(child) child.add_child(grandchild) @@ -265,8 +287,8 @@ async def test_cancel_cascades_to_grandchildren(): # noqa: RUF029 @pytest.mark.asyncio async def test_cancel_with_message(): # noqa: RUF029 """Cancel passes the message to children.""" - parent = EventFuture.create() - child = EventFuture.create() + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") parent.add_child(child) parent.cancel("shutting down") @@ -278,11 +300,103 @@ async def test_cancel_with_message(): # noqa: RUF029 child.result() +@pytest.mark.asyncio +async def test_wait_for_predecessor_first_child_is_noop(): + """wait_for_predecessor is a no-op for the first child in the parent.""" + parent = EventFuture(txid="parent") + child = EventFuture(txid="c0", parent=parent) + parent.add_child(child) + + async with assert_no_loop_yield(): + await child.wait_for_predecessor() + + +@pytest.mark.asyncio +async def test_wait_for_predecessor_no_parent_is_noop(): + """wait_for_predecessor is a no-op when the future has no parent.""" + f = EventFuture(txid="root") + + async with assert_no_loop_yield(): + await f.wait_for_predecessor() + + +@pytest.mark.asyncio +async def test_wait_for_predecessor_not_sequential_is_noop(): + """wait_for_predecessor is a no-op for non-sequential (background) futures.""" + parent = EventFuture(txid="parent") + sib = EventFuture(txid="sib", parent=parent) + bg = EventFuture(txid="bg", parent=parent, sequential=False) + parent.add_child(sib) + parent.add_child(bg) + + # bg is not sequential so it should not wait for sib. + async with assert_no_loop_yield(): + await bg.wait_for_predecessor() + assert not sib.done() # sib was never resolved; bg did not wait for it + + +@pytest.mark.asyncio +async def test_wait_for_predecessor_waits_for_previous_sibling(): + """wait_for_predecessor waits until the preceding sibling is done.""" + parent = EventFuture(txid="parent") + first = EventFuture(txid="first", parent=parent) + second = EventFuture(txid="second", parent=parent) + parent.add_child(first) + parent.add_child(second) + + resolved = [] + + async def _run_second(): + await second.wait_for_predecessor() + resolved.append("second") + + task = asyncio.create_task(_run_second()) + await asyncio.sleep(0) # let the task start and block + assert resolved == [] # second is still waiting + + first.set_result(None) + await task + assert resolved == ["second"] + + +@pytest.mark.asyncio +async def test_wait_for_predecessor_continues_after_sibling_exception(): + """wait_for_predecessor continues even if the preceding sibling raised.""" + parent = EventFuture(txid="parent") + first = EventFuture(txid="first", parent=parent) + second = EventFuture(txid="second", parent=parent) + parent.add_child(first) + parent.add_child(second) + + first.set_exception(RuntimeError("boom")) + + # Should not raise; exception is suppressed. + # first was already done, so wait_for_predecessor returned without suspending. + async with assert_no_loop_yield(): + await second.wait_for_predecessor() + + +@pytest.mark.asyncio +async def test_wait_for_predecessor_continues_after_sibling_cancel(): + """wait_for_predecessor continues even if the preceding sibling was cancelled.""" + parent = EventFuture(txid="parent") + first = EventFuture(txid="first", parent=parent) + second = EventFuture(txid="second", parent=parent) + parent.add_child(first) + parent.add_child(second) + + first.cancel() + + # first was already done (cancelled), so wait_for_predecessor returned without suspending. + async with assert_no_loop_yield(): + await second.wait_for_predecessor() + + @pytest.mark.asyncio async def test_cancel_already_done_child(): # noqa: RUF029 """Cancel on a parent does not fail if a child is already resolved.""" - parent = EventFuture.create() - child = EventFuture.create() + parent = EventFuture(txid="parent") + child = EventFuture(txid="child") parent.add_child(child) child.set_result("already done") @@ -296,6 +410,6 @@ async def test_cancel_already_done_child(): # noqa: RUF029 @pytest.mark.asyncio async def test_cancel_already_done_parent_returns_false(): # noqa: RUF029 """Cancel returns False if the parent is already resolved.""" - f = EventFuture.create() + f = EventFuture(txid="f") f.set_result(None) assert not f.cancel() diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 14a2447687c..35124c9770d 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2937,18 +2937,21 @@ class OnLoadState2(State): num: int = 0 name: str + @rx.event def test_handler(self): """Test handler that calls another handler. - Returns: - Chain of EventHandlers + Yields: + EventHandler to change name. """ self.num += 1 - return type(self).change_name + yield type(self).change_name + yield type(self).change_name("other") - def change_name(self): + @rx.event + def change_name(self, name: str = "default"): """Test handler to change name.""" - self.name = "random" + self.name = name class OnLoadState3(State): @@ -2976,8 +2979,9 @@ async def test_handler(self): OnLoadState2, [ {OnLoadState2.get_full_name(): {"num" + FIELD_MARKER: 1}}, + {OnLoadState2.get_full_name(): {"name" + FIELD_MARKER: "default"}}, exp_is_hydrated(State, True), - {OnLoadState2.get_full_name(): {"name" + FIELD_MARKER: "random"}}, + {OnLoadState2.get_full_name(): {"name" + FIELD_MARKER: "other"}}, ], ), ( @@ -3026,19 +3030,18 @@ def index(): ) async with mock_base_state_event_processor as processor: - await ( - await processor.enqueue( - token, - Event( - name=on_load_internal_name, - router_data={ - RouteVar.PATH: "/", - RouteVar.ORIGIN: "/", - RouteVar.QUERY: {}, - }, - ), - ) + on_load_future = await processor.enqueue( + token, + Event( + name=on_load_internal_name, + router_data={ + RouteVar.PATH: "/", + RouteVar.ORIGIN: "/", + RouteVar.QUERY: {}, + }, + ), ) + await on_load_future.wait_all() # The processor chains all events: on_load_internal sets is_hydrated=False, # then the on_load handler runs, then set_is_hydrated(True) runs. From 6fc9d6ad4024b76e4e9608473eab976181cfd386 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Thu, 2 Apr 2026 17:26:43 -0700 Subject: [PATCH 60/81] use py3.11 compatible super() for dataclasses with slots --- .../src/reflex_core/_internal/event/processor/future.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py index 974f9491843..b54ac2fdb25 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py @@ -37,7 +37,7 @@ class EventFuture(asyncio.Future): ) def __post_init__(self) -> None: - super().__init__(loop=self.loop) + super(EventFuture, self).__init__(loop=self.loop) def add_child(self, child: EventFuture) -> None: """Add a child future to this tracked future. From 1a04463734dab6b415f3612e7bb7bb58d2dae049 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 3 Apr 2026 09:11:38 -0700 Subject: [PATCH 61/81] Only process one non-backend event per token --- .../event/processor/event_processor.py | 136 +++++++++++++----- .../_internal/event/processor/future.py | 26 ---- tests/integration/test_event_chain.py | 56 +++----- .../event/processor/test_event_processor.py | 31 ---- .../_internal/event/processor/test_future.py | 114 --------------- tests/units/test_state.py | 2 +- 6 files changed, 124 insertions(+), 241 deletions(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py index 0220bff9615..40998bfa5b8 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/event_processor.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import collections import contextlib import dataclasses import inspect @@ -84,6 +85,10 @@ class EventProcessor: _futures: dict[str, EventFuture] = dataclasses.field( default_factory=dict, init=False ) + _token_queues: dict[ + str, + collections.deque[tuple[EventQueueEntry, RegisteredEventHandler]], + ] = dataclasses.field(default_factory=dict, init=False) def configure( self, @@ -269,6 +274,8 @@ async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: ) ) self._queue_task = None + # Discard any pending per-token queue entries. + self._token_queues.clear() # Cancel any remaining unresolved futures. for future in self._futures.values(): if not future.done(): @@ -341,23 +348,13 @@ async def enqueue( msg = "Event processor is not running, call .start(...) first." raise RuntimeError(msg) from le queue = self._ensure_queue_task() - # Determine whether sequential ordering is required for this event. - is_background = False - try: - registered = RegistrationContext.get().event_handlers.get(event.name) - if registered is not None and registered.handler.is_background: - is_background = True - except LookupError: - pass txid = ev_ctx.txid parent_future = ( self._futures.get(ev_ctx.parent_txid) if ev_ctx.parent_txid is not None else None ) - tracked = EventFuture( - sequential=not is_background, parent=parent_future, txid=txid - ) + tracked = EventFuture(parent=parent_future, txid=txid) self._futures[txid] = tracked tracked.add_done_callback(self._try_clean_future) tracked.add_done_callback(self._on_future_done) @@ -455,7 +452,7 @@ def _try_clean_future(self, future: EventFuture) -> None: # type: ignore[overri done and all its immediate children are done, pop the parent as well. This keeps parent futures alive in ``_futures`` while any child still - needs to look up its siblings for sequential ordering. + needs them for ``wait_all`` and cleanup. Args: future: The EventFuture to check. @@ -532,17 +529,82 @@ async def _process_event_queue_entry( registered_handler: The registered handler for the event. """ # Set up the event context for this task. - ctx = entry.ctx - EventContext.set(ctx) - # For sequential (non-background) events, wait for the previous sibling - # to finish before proceeding so that chained events run in order. - if current_future := self._futures.get(ctx.txid): - await current_future.wait_for_predecessor() - print( - f"Processing event {entry.event} [txid={ctx.txid}] with handler {registered_handler.handler.fn.__name__}" - ) + EventContext.set(entry.ctx) await self._execute_event(entry=entry, registered_handler=registered_handler) + def _create_event_task( + self, + *, + entry: EventQueueEntry, + registered_handler: RegisteredEventHandler, + ) -> asyncio.Task: + """Create and register an asyncio task for processing a single event. + + Args: + entry: The event queue entry to process. + registered_handler: The registered handler for the event. + + Returns: + The created asyncio.Task. + """ + task = asyncio.create_task( + self._process_event_queue_entry( + entry=entry, registered_handler=registered_handler + ), + name=f"reflex_event|{entry.event.name}|{entry.ctx.token}|{time.time()}", + ) + if sys.version_info < (3, 12): + task._event_ctx = entry.ctx # pyright: ignore[reportAttributeAccessIssue] + self._tasks[entry.ctx.txid] = task + task.add_done_callback(self._finish_task) + return task + + def _enqueue_for_token( + self, + *, + entry: EventQueueEntry, + registered_handler: RegisteredEventHandler, + ) -> None: + """Append an event to the per-token queue and dispatch if idle. + + If no queue exists for the token yet, one is created. If this is + the first (and therefore only) entry, a task is dispatched + immediately. + + Args: + entry: The event queue entry to enqueue. + registered_handler: The registered handler for the event. + """ + token = entry.ctx.token + token_queue = self._token_queues.get(token) + if token_queue is None: + token_queue = self._token_queues[token] = collections.deque() + token_queue.append((entry, registered_handler)) + if len(token_queue) == 1: + self._dispatch_next_for_token(token) + + def _dispatch_next_for_token(self, token: str) -> None: + """Create a task for the front entry in the per-token queue. + + Args: + token: The client token whose queue to dispatch from. + """ + token_queue = self._token_queues.get(token) + if not token_queue: + return + entry, registered_handler = token_queue[0] + # Skip cancelled futures. + future = self._futures.get(entry.ctx.txid) + if future is not None and future.cancelled(): + self._try_clean_future(future) + token_queue.popleft() + if token_queue: + self._dispatch_next_for_token(token) + else: + del self._token_queues[token] + return + self._create_event_task(entry=entry, registered_handler=registered_handler) + async def _process_queue(self): """Process events from the queue in a task.""" if (queue := self._queue) is None: @@ -567,22 +629,16 @@ async def _process_queue(self): f"No registered handler found for event: {entry.event.name}" ) raise KeyError(msg) from ke - # Create a new task to handle this event. - task = asyncio.create_task( - self._process_event_queue_entry( + if registered_handler.handler.is_background: + # Background events run immediately, bypassing per-token ordering. + self._create_event_task( entry=entry, registered_handler=registered_handler - ), - name=( - f"reflex_event|{entry.event.name}|{entry.ctx.token}|{time.time()}" - ), - ) - if sys.version_info < (3, 12): - task._event_ctx = entry.ctx # pyright: ignore[reportAttributeAccessIssue] - self._tasks[entry.ctx.txid] = task - task.add_done_callback(self._finish_task) - print( - f"Enqueued task {task.get_name()} for event {entry.event.name} [txid={entry.ctx.txid}]" - ) + ) + else: + # Sequential events go through the per-token queue. + self._enqueue_for_token( + entry=entry, registered_handler=registered_handler + ) except Exception: # Log the error and continue processing the next events. console.error( @@ -626,6 +682,14 @@ def _finish_task(self, task: asyncio.Task): else: task_ctx = task.get_context().run(EventContext.get) self._tasks.pop(task_ctx.txid, None) + # Chain the next sequential event for this token if applicable. + token_queue = self._token_queues.get(task_ctx.token) + if token_queue and token_queue[0][0].ctx.txid == task_ctx.txid: + token_queue.popleft() + if token_queue: + self._dispatch_next_for_token(task_ctx.token) + else: + del self._token_queues[task_ctx.token] future = self._futures.get(task_ctx.txid) if task.done(): try: diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py index b54ac2fdb25..cc3a57a302f 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py @@ -19,11 +19,6 @@ class EventFuture(asyncio.Future): # The transaction id associated with this future. txid: str - # If sequential is True, sibling events will be processed sequentially. - # Background events should set sequential=False to run immediately and - # without affecting non-background event ordering. - sequential: bool = True - # Child futures spawned by this future, if any. children: list[EventFuture] = dataclasses.field(default_factory=list) @@ -98,27 +93,6 @@ def cancel(self, msg: object = None) -> bool: child.cancel(msg) return result - async def wait_for_predecessor(self) -> None: - """Wait for the immediately preceding sequential sibling to complete. - - If this future is not sequential, has no parent, is the first child, - or cannot be found in the parent's children list, this is a no-op. - """ - if not self.sequential or self.parent is None: - return - children = self.parent.children - try: - idx = children.index(self) - except ValueError: - return - if idx == 0: - return # First child: no predecessor to wait for. - with contextlib.suppress(Exception, asyncio.CancelledError): - print( - f"EventFuture {self.txid} waiting for predecessor {children[idx - 1].txid}" - ) - await children[idx - 1] - __all__ = [ "EventFuture", diff --git a/tests/integration/test_event_chain.py b/tests/integration/test_event_chain.py index eb65c2fe1d2..4dcd6c322e6 100644 --- a/tests/integration/test_event_chain.py +++ b/tests/integration/test_event_chain.py @@ -536,13 +536,11 @@ def test_event_chain_on_load( "/on-mount-yield-chain", [ "on_load_yield_chain", - { - "event_arg:4", - "event_arg:5", - "event_arg:6", - "event_arg:mount", - "event_no_args", - }, + "event_arg:mount", + "event_arg:4", + "event_arg:5", + "event_arg:6", + "event_no_args", ], ), ], @@ -585,18 +583,14 @@ def test_event_chain_on_mount( "/on-mount-return-chain", [ "on_load_return_chain", - { - "event_arg:unmount", - "on_load_return_chain", - "event_arg:1", - "event_arg:2", - "event_arg:3", - }, - { - "event_arg:1", - "event_arg:2", - "event_arg:3", - }, + "event_arg:unmount", + "on_load_return_chain", + "event_arg:1", + "event_arg:2", + "event_arg:3", + "event_arg:1", + "event_arg:2", + "event_arg:3", "event_arg:unmount", ], ), @@ -604,20 +598,16 @@ def test_event_chain_on_mount( "/on-mount-yield-chain", [ "on_load_yield_chain", - { - "event_arg:4", - "event_arg:5", - "event_arg:6", - "event_arg:mount", - "event_no_args", - "on_load_yield_chain", - }, - { - "event_arg:mount", - "event_arg:4", - "event_arg:5", - "event_arg:6", - }, + "event_arg:mount", + "event_no_args", + "on_load_yield_chain", + "event_arg:mount", + "event_arg:4", + "event_arg:5", + "event_arg:6", + "event_arg:4", + "event_arg:5", + "event_arg:6", "event_no_args", ], ), diff --git a/tests/units/reflex_core/_internal/event/processor/test_event_processor.py b/tests/units/reflex_core/_internal/event/processor/test_event_processor.py index 2d9adefbf7c..d3a158952fa 100644 --- a/tests/units/reflex_core/_internal/event/processor/test_event_processor.py +++ b/tests/units/reflex_core/_internal/event/processor/test_event_processor.py @@ -10,7 +10,6 @@ EventProcessor, QueueShutDown, ) -from reflex_core._internal.event.processor.future import EventFuture from reflex_core._internal.registry import RegistrationContext from reflex.event import Event, EventHandler @@ -538,36 +537,6 @@ async def test_sequential_chained_events_run_in_order(token: str): assert [entry["value"] for entry in _CALL_LOG] == ["first", "second", "third"] -async def test_sequential_chained_futures_are_sequential(token: str): - """EventFutures created for normal (non-background) events have sequential=True. - - Args: - token: The client token. - """ - ep = EventProcessor(graceful_shutdown_timeout=2) - ep.configure() - async with ep: - future = await ep.enqueue(token, Event.from_event_type(logging_event())[0]) - assert isinstance(future, EventFuture) - assert future.sequential is True - - -async def test_background_event_future_is_not_sequential(token: str): - """EventFutures created for background events have sequential=False. - - Args: - token: The client token. - """ - ep = EventProcessor(graceful_shutdown_timeout=2) - ep.configure() - async with ep: - future = await ep.enqueue( - token, Event.from_event_type(background_slow_logging_event())[0] - ) - assert isinstance(future, EventFuture) - assert future.sequential is False - - async def test_futures_cleaned_up_after_chained_events(token: str): """All futures are removed from _futures after chained events complete. diff --git a/tests/units/reflex_core/_internal/event/processor/test_future.py b/tests/units/reflex_core/_internal/event/processor/test_future.py index 2cde70feef5..225baab126f 100644 --- a/tests/units/reflex_core/_internal/event/processor/test_future.py +++ b/tests/units/reflex_core/_internal/event/processor/test_future.py @@ -1,33 +1,11 @@ """Tests for EventFuture.""" import asyncio -import contextlib -from collections.abc import AsyncGenerator import pytest from reflex_core._internal.event.processor.future import EventFuture -@contextlib.asynccontextmanager -async def assert_no_loop_yield() -> AsyncGenerator[None, None]: - """Assert the body never yields control to the event loop. - - Schedules a sentinel task before the body runs and asserts it has not - executed by the time the body returns. Because asyncio is cooperative, - the sentinel can only run if the body awaited something that suspended it. - """ - sentinel_ran = False - - async def _sentinel() -> None: # noqa: RUF029 - nonlocal sentinel_ran - sentinel_ran = True - - task = asyncio.create_task(_sentinel()) - yield - assert not sentinel_ran, "Event loop was unexpectedly yielded to" - await task - - @pytest.mark.asyncio async def test_create_uses_running_loop(): # noqa: RUF029 """EventFuture() defaults to the running event loop.""" @@ -300,98 +278,6 @@ async def test_cancel_with_message(): # noqa: RUF029 child.result() -@pytest.mark.asyncio -async def test_wait_for_predecessor_first_child_is_noop(): - """wait_for_predecessor is a no-op for the first child in the parent.""" - parent = EventFuture(txid="parent") - child = EventFuture(txid="c0", parent=parent) - parent.add_child(child) - - async with assert_no_loop_yield(): - await child.wait_for_predecessor() - - -@pytest.mark.asyncio -async def test_wait_for_predecessor_no_parent_is_noop(): - """wait_for_predecessor is a no-op when the future has no parent.""" - f = EventFuture(txid="root") - - async with assert_no_loop_yield(): - await f.wait_for_predecessor() - - -@pytest.mark.asyncio -async def test_wait_for_predecessor_not_sequential_is_noop(): - """wait_for_predecessor is a no-op for non-sequential (background) futures.""" - parent = EventFuture(txid="parent") - sib = EventFuture(txid="sib", parent=parent) - bg = EventFuture(txid="bg", parent=parent, sequential=False) - parent.add_child(sib) - parent.add_child(bg) - - # bg is not sequential so it should not wait for sib. - async with assert_no_loop_yield(): - await bg.wait_for_predecessor() - assert not sib.done() # sib was never resolved; bg did not wait for it - - -@pytest.mark.asyncio -async def test_wait_for_predecessor_waits_for_previous_sibling(): - """wait_for_predecessor waits until the preceding sibling is done.""" - parent = EventFuture(txid="parent") - first = EventFuture(txid="first", parent=parent) - second = EventFuture(txid="second", parent=parent) - parent.add_child(first) - parent.add_child(second) - - resolved = [] - - async def _run_second(): - await second.wait_for_predecessor() - resolved.append("second") - - task = asyncio.create_task(_run_second()) - await asyncio.sleep(0) # let the task start and block - assert resolved == [] # second is still waiting - - first.set_result(None) - await task - assert resolved == ["second"] - - -@pytest.mark.asyncio -async def test_wait_for_predecessor_continues_after_sibling_exception(): - """wait_for_predecessor continues even if the preceding sibling raised.""" - parent = EventFuture(txid="parent") - first = EventFuture(txid="first", parent=parent) - second = EventFuture(txid="second", parent=parent) - parent.add_child(first) - parent.add_child(second) - - first.set_exception(RuntimeError("boom")) - - # Should not raise; exception is suppressed. - # first was already done, so wait_for_predecessor returned without suspending. - async with assert_no_loop_yield(): - await second.wait_for_predecessor() - - -@pytest.mark.asyncio -async def test_wait_for_predecessor_continues_after_sibling_cancel(): - """wait_for_predecessor continues even if the preceding sibling was cancelled.""" - parent = EventFuture(txid="parent") - first = EventFuture(txid="first", parent=parent) - second = EventFuture(txid="second", parent=parent) - parent.add_child(first) - parent.add_child(second) - - first.cancel() - - # first was already done (cancelled), so wait_for_predecessor returned without suspending. - async with assert_no_loop_yield(): - await second.wait_for_predecessor() - - @pytest.mark.asyncio async def test_cancel_already_done_child(): # noqa: RUF029 """Cancel on a parent does not fail if a child is already resolved.""" diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 35124c9770d..5382b1f6a35 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -2979,8 +2979,8 @@ async def test_handler(self): OnLoadState2, [ {OnLoadState2.get_full_name(): {"num" + FIELD_MARKER: 1}}, - {OnLoadState2.get_full_name(): {"name" + FIELD_MARKER: "default"}}, exp_is_hydrated(State, True), + {OnLoadState2.get_full_name(): {"name" + FIELD_MARKER: "default"}}, {OnLoadState2.get_full_name(): {"name" + FIELD_MARKER: "other"}}, ], ), From 73e2bbed4461320dd283feb82dab780c57a66b5f Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 3 Apr 2026 09:44:18 -0700 Subject: [PATCH 62/81] Fix event order assertions in test_event_chain --- tests/integration/test_event_chain.py | 135 ++++++++++++++++++-------- tests/integration/utils.py | 80 +++++++++++++++ 2 files changed, 173 insertions(+), 42 deletions(-) diff --git a/tests/integration/test_event_chain.py b/tests/integration/test_event_chain.py index 4dcd6c322e6..942b70994e6 100644 --- a/tests/integration/test_event_chain.py +++ b/tests/integration/test_event_chain.py @@ -9,7 +9,10 @@ from selenium.webdriver.common.by import By from reflex.testing import AppHarness, WebDriver -from tests.integration.utils import poll_assert_event_order +from tests.integration.utils import ( + poll_assert_event_order, + poll_assert_relative_event_order, +) MANY_EVENTS = 50 @@ -520,27 +523,45 @@ def test_event_chain_on_load( @pytest.mark.parametrize( - ("uri", "exp_event_order"), + ("uri", "expected_counts", "ordering_rules"), [ ( "/on-mount-return-chain", + { + "on_load_return_chain": 1, + "event_arg:1": 1, + "event_arg:2": 1, + "event_arg:3": 1, + "event_arg:unmount": 1, + }, [ - "on_load_return_chain", - "event_arg:1", - "event_arg:2", - "event_arg:3", - "event_arg:unmount", + # on_load before chain and unmount + (("on_load_return_chain", 0), ("event_arg:1", 0)), + (("on_load_return_chain", 0), ("event_arg:unmount", 0)), + # Chain in order + (("event_arg:1", 0), ("event_arg:2", 0)), + (("event_arg:2", 0), ("event_arg:3", 0)), ], ), ( "/on-mount-yield-chain", + { + "on_load_yield_chain": 1, + "event_arg:4": 1, + "event_arg:5": 1, + "event_arg:6": 1, + "event_arg:mount": 1, + "event_no_args": 1, + }, [ - "on_load_yield_chain", - "event_arg:mount", - "event_arg:4", - "event_arg:5", - "event_arg:6", - "event_no_args", + # on_load before chain and mount + (("on_load_yield_chain", 0), ("event_arg:4", 0)), + (("on_load_yield_chain", 0), ("event_arg:mount", 0)), + # Chain in order + (("event_arg:4", 0), ("event_arg:5", 0)), + (("event_arg:5", 0), ("event_arg:6", 0)), + # mount before event_no_args + (("event_arg:mount", 0), ("event_no_args", 0)), ], ), ], @@ -549,7 +570,8 @@ def test_event_chain_on_mount( event_chain: AppHarness, driver: WebDriver, uri: str, - exp_event_order: list[str | set[str]], + expected_counts: dict[str, int], + ordering_rules: list, ): """Load the URI, assert that the events are handled in the correct order. @@ -562,7 +584,8 @@ def test_event_chain_on_mount( event_chain: AppHarness for the event_chain app driver: selenium WebDriver open to the app uri: the page to load - exp_event_order: the expected events recorded in the State + expected_counts: mapping of event name to expected occurrence count + ordering_rules: relative ordering constraints between event occurrences """ assert event_chain.frontend_url is not None driver.get(event_chain.frontend_url.removesuffix("/") + uri) @@ -573,42 +596,67 @@ def test_event_chain_on_mount( assert_token(event_chain, driver) unmount_button.click() - poll_assert_event_order(driver, exp_event_order) + poll_assert_relative_event_order(driver, expected_counts, ordering_rules) @pytest.mark.parametrize( - ("uri", "exp_event_order"), + ("uri", "expected_counts", "ordering_rules"), [ ( "/on-mount-return-chain", + { + "on_load_return_chain": 2, + "event_arg:1": 2, + "event_arg:2": 2, + "event_arg:3": 2, + "event_arg:unmount": 2, + }, [ - "on_load_return_chain", - "event_arg:unmount", - "on_load_return_chain", - "event_arg:1", - "event_arg:2", - "event_arg:3", - "event_arg:1", - "event_arg:2", - "event_arg:3", - "event_arg:unmount", + # First on_load before first chain and first unmount + (("on_load_return_chain", 0), ("event_arg:1", 0)), + (("on_load_return_chain", 0), ("event_arg:unmount", 0)), + # First chain in order + (("event_arg:1", 0), ("event_arg:2", 0)), + (("event_arg:2", 0), ("event_arg:3", 0)), + # First unmount before second on_load + (("event_arg:unmount", 0), ("on_load_return_chain", 1)), + # Second on_load before second chain and second unmount + (("on_load_return_chain", 1), ("event_arg:1", 1)), + (("on_load_return_chain", 1), ("event_arg:unmount", 1)), + # Second chain in order + (("event_arg:1", 1), ("event_arg:2", 1)), + (("event_arg:2", 1), ("event_arg:3", 1)), ], ), ( "/on-mount-yield-chain", + { + "on_load_yield_chain": 2, + "event_arg:4": 2, + "event_arg:5": 2, + "event_arg:6": 2, + "event_arg:mount": 2, + "event_no_args": 2, + }, [ - "on_load_yield_chain", - "event_arg:mount", - "event_no_args", - "on_load_yield_chain", - "event_arg:mount", - "event_arg:4", - "event_arg:5", - "event_arg:6", - "event_arg:4", - "event_arg:5", - "event_arg:6", - "event_no_args", + # First on_load before first chain and first mount + (("on_load_yield_chain", 0), ("event_arg:4", 0)), + (("on_load_yield_chain", 0), ("event_arg:mount", 0)), + # First chain in order + (("event_arg:4", 0), ("event_arg:5", 0)), + (("event_arg:5", 0), ("event_arg:6", 0)), + # First mount before first event_no_args + (("event_arg:mount", 0), ("event_no_args", 0)), + # First event_no_args before second on_load + (("event_no_args", 0), ("on_load_yield_chain", 1)), + # Second on_load before second chain and second mount + (("on_load_yield_chain", 1), ("event_arg:4", 1)), + (("on_load_yield_chain", 1), ("event_arg:mount", 1)), + # Second chain in order + (("event_arg:4", 1), ("event_arg:5", 1)), + (("event_arg:5", 1), ("event_arg:6", 1)), + # Second mount before second event_no_args + (("event_arg:mount", 1), ("event_no_args", 1)), ], ), ], @@ -617,7 +665,8 @@ def test_event_chain_on_mount_strict( event_chain_strict: AppHarness, driver_strict: WebDriver, uri: str, - exp_event_order: list[str | set[str]], + expected_counts: dict[str, int], + ordering_rules: list, ): """Run the test_event_chain_on_mount test with strict mode enabled. @@ -625,13 +674,15 @@ def test_event_chain_on_mount_strict( event_chain_strict: AppHarness for the event_chain app with strict mode enabled driver_strict: selenium WebDriver open to the app with strict mode enabled uri: the page to load - exp_event_order: the expected events recorded in the State + expected_counts: mapping of event name to expected occurrence count + ordering_rules: relative ordering constraints between event occurrences """ test_event_chain_on_mount( event_chain=event_chain_strict, driver=driver_strict, uri=uri, - exp_event_order=exp_event_order, + expected_counts=expected_counts, + ordering_rules=ordering_rules, ) diff --git a/tests/integration/utils.py b/tests/integration/utils.py index 5c851f94b65..dd70cb4ad71 100644 --- a/tests/integration/utils.py +++ b/tests/integration/utils.py @@ -108,6 +108,86 @@ def _has_number_of_expected_events(): assert_event_order([elem.text for elem in event_elements], exp_event_order) +# Type alias for an ordering rule: ((event_a, occurrence_a), (event_b, occurrence_b)). +OrderingRule = tuple[tuple[str, int], tuple[str, int]] + + +def assert_relative_event_order( + actual: list[str], + expected_counts: dict[str, int], + ordering_rules: list[OrderingRule], +) -> None: + """Assert that events satisfy relative ordering constraints. + + Instead of requiring an exact event sequence, this checks that: + 1. Each event appears the expected number of times. + 2. Specific occurrences of events appear before other specific occurrences. + + Args: + actual: the actual events recorded. + expected_counts: mapping of event name to expected occurrence count. + ordering_rules: list of ((event_a, occ_a), (event_b, occ_b)) meaning + the occ_a-th occurrence (0-indexed) of event_a must appear before + the occ_b-th occurrence (0-indexed) of event_b in the actual list. + + Raises: + AssertionError: if any constraint is violated. + """ + from collections import Counter + + actual_counts = Counter(actual) + for event, count in expected_counts.items(): + assert actual_counts[event] == count, ( + f"Expected {count} occurrences of '{event}', got {actual_counts[event]}. Actual: {actual}" + ) + assert sum(expected_counts.values()) == len(actual), ( + f"Expected {sum(expected_counts.values())} total events, got {len(actual)}. Actual: {actual}" + ) + + # Build occurrence index: (event, occ) -> position in actual list + occurrence_indices: dict[tuple[str, int], int] = {} + event_counters: dict[str, int] = {} + for i, event in enumerate(actual): + occ = event_counters.get(event, 0) + occurrence_indices[event, occ] = i + event_counters[event] = occ + 1 + + for (event_a, occ_a), (event_b, occ_b) in ordering_rules: + idx_a = occurrence_indices[event_a, occ_a] + idx_b = occurrence_indices[event_b, occ_b] + assert idx_a < idx_b, ( + f"Expected '{event_a}'[{occ_a}] (pos {idx_a}) before " + f"'{event_b}'[{occ_b}] (pos {idx_b}). Actual: {actual}" + ) + + +def poll_assert_relative_event_order( + driver: WebDriver, + expected_counts: dict[str, int], + ordering_rules: list[OrderingRule], + xpath: str = '//*[@id="event_order"]/p', +) -> None: + """Poll until the expected number of events appear, then assert relative ordering. + + Args: + driver: WebDriver instance. + expected_counts: mapping of event name to expected occurrence count. + ordering_rules: ordering constraints (see assert_relative_event_order). + xpath: The XPath to the event order elements. + """ + n_exp = sum(expected_counts.values()) + + def _has_number_of_expected_events(): + return len(driver.find_elements(By.XPATH, xpath)) == n_exp + + AppHarness._poll_for(_has_number_of_expected_events) + + event_elements = driver.find_elements(By.XPATH, xpath) + assert_relative_event_order( + [elem.text for elem in event_elements], expected_counts, ordering_rules + ) + + class LocalStorage: """Class to access local storage. From 5b056ad6af84ad76b805136608281b51771fee26 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 3 Apr 2026 09:48:28 -0700 Subject: [PATCH 63/81] py3.11 super() fix again --- .../src/reflex_core/_internal/event/processor/future.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py b/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py index cc3a57a302f..dd76f485800 100644 --- a/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py +++ b/packages/reflex-core/src/reflex_core/_internal/event/processor/future.py @@ -88,7 +88,7 @@ def cancel(self, msg: object = None) -> bool: Returns: True if the future was successfully cancelled. """ - result = super().cancel(msg) + result = super(EventFuture, self).cancel(msg) for child in self.children: child.cancel(msg) return result From fe0bc5b179c916b9904fc2e357ae7d95a40a5914 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 3 Apr 2026 14:48:17 -0700 Subject: [PATCH 64/81] Move _registration_context_middleware to top of asgi stack And don't register it as middleware, because it's not middleware technically (even though it worked the way it was). --- reflex/app.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index a5c3826973a..e1d6958324e 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -652,8 +652,6 @@ def __call__(self) -> ASGIApp: raise ValueError(msg) asgi_app = self._api - # Make sure the RegistrationContext is attached. - asgi_app.add_middleware(self._registration_context_middleware) if environment.REFLEX_MOUNT_FRONTEND_COMPILED_APP.get(): asgi_app.mount( @@ -687,8 +685,8 @@ def __call__(self) -> ASGIApp: top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks) top_asgi_app.mount("", asgi_app) App._add_cors(top_asgi_app) - - return top_asgi_app + # Make sure the RegistrationContext is attached. + return self._registration_context_middleware(top_asgi_app) def _add_default_endpoints(self): """Add default api endpoints (ping).""" From f6dc0026c00e9f5f4106c2c480ed34dc5944a4ec Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 3 Apr 2026 15:19:03 -0700 Subject: [PATCH 65/81] Move registration context middle to not quite the top level app. The returned top level app should always continue to be a Starlette instance for compatibility. --- reflex/app.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/reflex/app.py b/reflex/app.py index e1d6958324e..31710d242f9 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -683,10 +683,13 @@ def __call__(self) -> ASGIApp: asgi_app = api_transformer(asgi_app) top_asgi_app = Starlette(lifespan=self._run_lifespan_tasks) - top_asgi_app.mount("", asgi_app) - App._add_cors(top_asgi_app) # Make sure the RegistrationContext is attached. - return self._registration_context_middleware(top_asgi_app) + top_asgi_app.mount( + "", + self._registration_context_middleware(asgi_app), + ) + App._add_cors(top_asgi_app) + return top_asgi_app def _add_default_endpoints(self): """Add default api endpoints (ping).""" From 3a0f532d9eb39877eb21187781f17b2b5e906bc4 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 3 Apr 2026 15:19:57 -0700 Subject: [PATCH 66/81] Add cache_key and lock_key attributes to StateToken --- reflex/istate/manager/memory.py | 41 ++++++++++++++++++--------------- reflex/istate/manager/token.py | 30 ++++++++++++++++++++++++ 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index 456df5841b9..f3dcb50ba53 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -36,8 +36,8 @@ class StateManagerMemory(StateManager): init=False, ) - # The latest expiration deadline for each token. - _token_expires_at: dict[str, float] = dataclasses.field( + # The latest expiration deadline and token for each cache key. + _token_expires_at: dict[str, tuple[float, StateToken]] = dataclasses.field( default_factory=dict, init=False, ) @@ -53,7 +53,7 @@ def _get_or_create_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: Returns: The state for the token. """ - key = token.ident if isinstance(token, BaseStateToken) else str(token) + key = token.cache_key if key not in self.states: if isinstance(token, BaseStateToken): self.states[key] = token.cls.get_root_state()( @@ -65,15 +65,21 @@ def _get_or_create_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: def _track_token(self, token: StateToken): """Refresh the expiration deadline for an active token.""" - self._token_expires_at[token.ident] = time.time() + self.token_expiration + self._token_expires_at[token.cache_key] = ( + time.time() + self.token_expiration, + token, + ) self._ensure_expiration_task() def _purge_token(self, token: StateToken): - """Remove a token from in-memory state bookkeeping.""" - key = token.ident if isinstance(token, BaseStateToken) else str(token) - self._token_expires_at.pop(token.ident, None) - self.states.pop(key, None) - self._states_locks.pop(token.ident, None) + """Remove a token from in-memory state bookkeeping. + + Args: + token: The token to purge. + """ + self._token_expires_at.pop(token.cache_key, None) + self._states_locks.pop(token.lock_key, None) + self.states.pop(token.cache_key, None) def _purge_expired_tokens(self) -> float | None: """Purge expired in-memory state entries and return the next deadline. @@ -86,15 +92,13 @@ def _purge_expired_tokens(self) -> float | None: token_expires_at = self._token_expires_at state_locks = self._states_locks - for token, expires_at in list(token_expires_at.items()): + for _cache_key, (expires_at, token) in list(token_expires_at.items()): if ( - state_lock := state_locks.get(token) + state_lock := state_locks.get(token.lock_key) ) is not None and state_lock.locked(): continue if expires_at <= now: - self._purge_token( - BaseStateToken(ident=token, cls=type(self.states[token])) - ) + self._purge_token(token) continue if next_expires_at is None or expires_at < next_expires_at: next_expires_at = expires_at @@ -110,12 +114,12 @@ async def _get_state_lock(self, token: StateToken) -> asyncio.Lock: Returns: The lock protecting the token's state. """ - state_lock = self._states_locks.get(token.ident) + state_lock = self._states_locks.get(token.lock_key) if state_lock is None: async with self._state_manager_lock: - state_lock = self._states_locks.get(token.ident) + state_lock = self._states_locks.get(token.lock_key) if state_lock is None: - state_lock = self._states_locks[token.ident] = asyncio.Lock() + state_lock = self._states_locks[token.lock_key] = asyncio.Lock() return state_lock async def _expire_states(self): @@ -166,8 +170,7 @@ async def set_state( state: The state to set. context: The state modification context. """ - key = token.ident if isinstance(token, BaseStateToken) else str(token) - self.states[key] = state + self.states[token.cache_key] = state self._track_token(token) @override diff --git a/reflex/istate/manager/token.py b/reflex/istate/manager/token.py index 31b3a8a80db..b376f76e6d9 100644 --- a/reflex/istate/manager/token.py +++ b/reflex/istate/manager/token.py @@ -37,6 +37,24 @@ def with_cls(self, cls: type[TOKEN_TYPE]) -> Self: """ return dataclasses.replace(self, cls=cls) + @property + def cache_key(self) -> str: + """The key used for caching state instances in the StateManager. + + Returns: + A string key combining ident and class path. + """ + return str(self) + + @property + def lock_key(self) -> str: + """The key used for locking and session-level bookkeeping. + + Returns: + The token ident. + """ + return self.ident + def __str__(self) -> str: """The key used in the underlying StateManager store. @@ -109,6 +127,18 @@ class BaseStateToken(StateToken["BaseState"]): This token type implies subtree hierarchy population and other semantic checks. """ + @property + def cache_key(self) -> str: + """The key used for caching state instances in the StateManager. + + BaseState tokens use just the ident because the entire state hierarchy + lives under a single root state instance per session. + + Returns: + The token ident. + """ + return self.ident + def with_cls(self, cls: type[BaseState]) -> Self: """Return a new token with the cls field updated to the provided class. From 46232ac279d6df30dbd5d7a58b5af56f57b11901 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Fri, 3 Apr 2026 16:08:39 -0700 Subject: [PATCH 67/81] Use cache_key and lock_key in StateManagerDisk and StateManagerRedis --- reflex/istate/manager/disk.py | 27 +++++------- reflex/istate/manager/redis.py | 80 ++++++++++++++++------------------ 2 files changed, 50 insertions(+), 57 deletions(-) diff --git a/reflex/istate/manager/disk.py b/reflex/istate/manager/disk.py index 504b0061343..3f6c50573e2 100644 --- a/reflex/istate/manager/disk.py +++ b/reflex/istate/manager/disk.py @@ -168,12 +168,8 @@ async def get_state( Returns: The state for the token. """ - if isinstance(token, BaseStateToken): - root_state = self.states.get(token.ident) - self._token_last_touched[token.ident] = time.time() - else: - root_state = self.states.get(str(token)) - self._token_last_touched[str(token)] = time.time() + root_state = self.states.get(token.cache_key) + self._token_last_touched[token.cache_key] = time.time() if root_state is not None: # Retrieved state from memory. return root_state @@ -194,13 +190,13 @@ async def get_state( # Ensure all substates exist, even if they were not serialized previously. root_state.substates = fresh_root_state.substates await self.populate_substates(token, root_state, root_state) - self.states[token.ident] = root_state + self.states[token.cache_key] = root_state return cast(TOKEN_TYPE, root_state) # For non-BaseState tokens, if the deserialized state is None, we create a new instance using the token's cls. state = await self.load_state(token) if state is None: state = token.cls() - self.states[str(token)] = state + self.states[token.cache_key] = state return cast(TOKEN_TYPE, state) async def set_state_for_substate( @@ -275,10 +271,10 @@ async def _process_write_queue(self): token, self._write_queue.pop(token).state ) # Check for expired states to purge. - for token_ident, last_touched in list(self._token_last_touched.items()): + for cache_key, last_touched in list(self._token_last_touched.items()): if now - last_touched > self.token_expiration: - self._token_last_touched.pop(token_ident) - self.states.pop(token_ident, None) + self._token_last_touched.pop(cache_key) + self.states.pop(cache_key, None) await run_in_thread(self._purge_expired_states) await self._process_write_queue_delay() except asyncio.CancelledError: # noqa: PERF203 @@ -363,12 +359,13 @@ async def modify_state( The state for the token. """ # Disk state manager ignores the substate suffix and always returns the top-level state. - if token.ident not in self._states_locks: + lock_key = token.lock_key + if lock_key not in self._states_locks: async with self._state_manager_lock: - if token.ident not in self._states_locks: - self._states_locks[token.ident] = asyncio.Lock() + if lock_key not in self._states_locks: + self._states_locks[lock_key] = asyncio.Lock() - async with self._states_locks[token.ident]: + async with self._states_locks[lock_key]: state = await self.get_state(token) yield state await self.set_state(token, state, **context) diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index 8083525b0f8..3d739f92f46 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -362,7 +362,7 @@ async def set_state( Args: token: The token to set the state for. state: The state to set. - lock_id: If provided, the lock_key must be set to this value to set the state. + lock_id: If provided, the lock must be held with this value to set the state. context: The event context. Raises: @@ -397,9 +397,9 @@ async def set_state( base_state = cast(BaseState, state) - client_token = token.ident + lock_key = token.lock_key - if lock_id is not None and client_token not in self._local_leases: + if lock_id is not None and lock_key not in self._local_leases: time_taken = ( self.lock_expiration - (await self.redis.pttl(self._lock_key(token))) ) / 1000 @@ -424,7 +424,7 @@ async def set_state( lock_id=lock_id, **context, ), - name=f"reflex_set_state|{client_token}|{substate.get_full_name()}", + name=f"reflex_set_state|{lock_key}|{substate.get_full_name()}", ) for substate in base_state.substates.values() ] @@ -472,7 +472,7 @@ async def _try_modify_state( return # Opportunistic locking is enabled, so try to hold the lock across multiple calls. - client_token = token.ident + lock_key = token.lock_key lock_held_ctx = contextlib.AsyncExitStack() try: lock_id = await lock_held_ctx.enter_async_context( @@ -484,12 +484,12 @@ async def _try_modify_state( else: # Do not create a lease break task when multiple instances are waiting. if ( - not await self._get_local_lease(client_token) + not await self._get_local_lease(lock_key) and await self._n_lock_contenders(self._lock_key(token)) > 0 ): if self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} has contention, not leasing" + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} has contention, not leasing" ) async with lock_held_ctx: state = await self.get_state(token) @@ -503,11 +503,11 @@ async def _try_modify_state( token, lock_id, cleanup_ctx=lock_held_ctx, **context ) ) is ( - current_lease_task := await self._get_local_lease(client_token) + current_lease_task := await self._get_local_lease(lock_key) ) and new_lease_task is not None: if self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} obtained lock {lock_id.decode()}." + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} obtained lock {lock_id.decode()}." ) elif current_lease_task is None: # Check if we still have the redis lock, then just try to send this one update and release it. @@ -515,7 +515,7 @@ async def _try_modify_state( if await self.redis.get(self._lock_key(token)) == lock_id: if self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} holding lock {lock_id.decode()}, {new_lease_task=} already exited, doing single update..." + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} holding lock {lock_id.decode()}, {new_lease_task=} already exited, doing single update..." ) async with lock_held_ctx: state = await self.get_state(token) @@ -524,7 +524,7 @@ async def _try_modify_state( return elif self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lock {lock_id.decode()} expired while waiting for lease task to exit..." + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lock {lock_id.decode()} expired while waiting for lease task to exit..." ) # Have to retry getting the state, but now it's probably cached. yield None @@ -561,17 +561,15 @@ async def _get_state_cached( Yields: The cached state for the token, or None if not cached/uncachable. """ - client_token = token.ident + lock_key = token.lock_key # Opportunistically reuse existing lock. if ( - client_token in self._local_leases - and (state_lock := self._cached_states_locks.get(client_token)) is not None + lock_key in self._local_leases + and (state_lock := self._cached_states_locks.get(lock_key)) is not None ): async with state_lock: - if await self._get_local_lease(client_token) is not None: - if ( - cached_state := self._cached_states.get(client_token) - ) is not None: + if await self._get_local_lease(lock_key) is not None: + if (cached_state := self._cached_states.get(lock_key)) is not None: if isinstance(token, BaseStateToken): # Make sure we have the substate cached (or fetch it from redis). state_path = token.cls.get_full_name() @@ -592,11 +590,11 @@ async def _get_state_cached( return elif self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease task found, lock held, but no cached state" + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease task found, lock held, but no cached state" ) elif self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} no active lease task found" + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} no active lease task found" ) yield None @@ -636,32 +634,32 @@ async def _create_lease_break_task( """ self._ensure_lock_task() - client_token = token.ident + lock_key = token.lock_key async def do_flush() -> None: - if (state_lock := self._cached_states_locks.get(client_token)) is None: + if (state_lock := self._cached_states_locks.get(lock_key)) is None: # If we lost the lock, we can't write the state, something went wrong. console.warn( - f"State lock for {client_token} missing while finalizing lease." + f"State lock for {lock_key} missing while finalizing lease." ) return async with state_lock: # Write the state to redis while no one else can modify the cached copy. - state = self._cached_states.pop(client_token, None) + state = self._cached_states.pop(lock_key, None) try: if state: if self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} flushing state" + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} flushing state" ) await self.set_state(token, state, lock_id=lock_id, **context) finally: - if (current_lease := self._local_leases.get(client_token)) is task: - self._local_leases.pop(client_token, None) + if (current_lease := self._local_leases.get(lock_key)) is task: + self._local_leases.pop(lock_key, None) # TODO: clean up the cached states locks periodically elif self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} cleanup of {task=} found different task in _local_leases {current_lease=}." + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} cleanup of {task=} found different task in _local_leases {current_lease=}." ) async def lease_breaker(): @@ -670,7 +668,7 @@ async def lease_breaker(): lease_break_time = self.oplock_hold_time_ms / 1000 if self._debug_enabled: console.debug( - f"{SMR} [{time.monotonic() - start:.3f}] {client_token} lease breaker {lock_id.decode()} started, sleeping for {lease_break_time}s" + f"{SMR} [{time.monotonic() - start:.3f}] {lock_key} lease breaker {lock_id.decode()} started, sleeping for {lease_break_time}s" ) try: await asyncio.sleep(lease_break_time) @@ -679,7 +677,7 @@ async def lease_breaker(): # We got cancelled so if someone is holding the lock, # extend the timeout so they get the full time to complete. if ( - state_lock := self._cached_states_locks[client_token] + state_lock := self._cached_states_locks[lock_key] ) is not None and state_lock.locked(): await self._try_extend_lock(self._lock_key(token)) try: @@ -698,10 +696,10 @@ async def lease_breaker(): if cancelled_error is not None: raise cancelled_error - if (state_lock := self._cached_states_locks.get(client_token)) is not None: + if (state_lock := self._cached_states_locks.get(lock_key)) is not None: # We have an existing lock, so lets see if we have an existing lease to cancel. async with state_lock: - if (existing_task := self._local_leases.get(client_token)) is not None: + if (existing_task := self._local_leases.get(lock_key)) is not None: # There's already a lease break task, so cancel it to clear it out. existing_task.cancel() if existing_task is not None: @@ -709,25 +707,23 @@ async def lease_breaker(): await existing_task # Now we might need to create a new lock. - if (state_lock := self._cached_states_locks.get(client_token)) is None: + if (state_lock := self._cached_states_locks.get(lock_key)) is None: async with self._state_manager_lock: - if (state_lock := self._cached_states_locks.get(client_token)) is None: - state_lock = self._cached_states_locks[client_token] = ( - asyncio.Lock() - ) + if (state_lock := self._cached_states_locks.get(lock_key)) is None: + state_lock = self._cached_states_locks[lock_key] = asyncio.Lock() async with state_lock: # Create the task now if one didn't sneak past us. if ( - client_token not in self._local_leases + lock_key not in self._local_leases and await self._n_lock_contenders(self._lock_key(token)) == 0 ): - self._local_leases[client_token] = task = asyncio.create_task( + self._local_leases[lock_key] = task = asyncio.create_task( lease_breaker(), - name=f"reflex_lease_breaker|{client_token}|{lock_id.decode()}", + name=f"reflex_lease_breaker|{lock_key}|{lock_id.decode()}", ) # Fetch the requested state into the cache. - self._cached_states[client_token] = await self.get_state(token) + self._cached_states[lock_key] = await self.get_state(token) return task return None @@ -741,7 +737,7 @@ def _lock_key(token: StateToken[Any]) -> bytes: Returns: The redis lock key for the token. """ - return f"{token.ident}_lock".encode() + return f"{token.lock_key}_lock".encode() async def _try_extend_lock(self, lock_key: bytes) -> bool | None: """Extends the current lock for another lock_expiration period. From b7afcef1f7cd19871a25c8443feff7ee875d4090 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 09:26:59 -0700 Subject: [PATCH 68/81] update pyi_hashes --- pyi_hashes.json | 119 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/pyi_hashes.json b/pyi_hashes.json index 7b2a6e0adec..f6db8e53593 100644 --- a/pyi_hashes.json +++ b/pyi_hashes.json @@ -1,4 +1,123 @@ { + "packages/reflex-components-code/src/reflex_components_code/code.pyi": "2797061144c4199f57848f6673a05a7f", + "packages/reflex-components-code/src/reflex_components_code/shiki_code_block.pyi": "db0de2879d57870831a030a69b5282b7", + "packages/reflex-components-core/src/reflex_components_core/__init__.pyi": "82b29d23f2490161d42fd21021bd39c3", + "packages/reflex-components-core/src/reflex_components_core/base/__init__.pyi": "7009187aaaf191814d031e5462c48318", + "packages/reflex-components-core/src/reflex_components_core/base/app_wrap.pyi": "e7dfa98f5df5e30cb6d01d61b6974bef", + "packages/reflex-components-core/src/reflex_components_core/base/body.pyi": "0f98a7c1247e35059b76ae2985b7c81b", + "packages/reflex-components-core/src/reflex_components_core/base/document.pyi": "80a3090e5b7a46de6daa8e97e68e8638", + "packages/reflex-components-core/src/reflex_components_core/base/error_boundary.pyi": "f36f27e580041af842d348adbddcd600", + "packages/reflex-components-core/src/reflex_components_core/base/fragment.pyi": "39abed241f2def793dd0c59328bb0470", + "packages/reflex-components-core/src/reflex_components_core/base/link.pyi": "05d96de8a1d5f7be08de831b99663e67", + "packages/reflex-components-core/src/reflex_components_core/base/meta.pyi": "b83e94900f988ef5d2fdf121b01be7fa", + "packages/reflex-components-core/src/reflex_components_core/base/script.pyi": "cfb0d5bcfe67f7c2b40868cdf3a5f7c1", + "packages/reflex-components-core/src/reflex_components_core/base/strict_mode.pyi": "8a69093c8d40b10b1f0b1c4e851e9d53", + "packages/reflex-components-core/src/reflex_components_core/core/__init__.pyi": "dd5142b3c9087bf2bf22651adf6f2724", + "packages/reflex-components-core/src/reflex_components_core/core/auto_scroll.pyi": "29f5c106b98ddac94cf7c1244a02cfb1", + "packages/reflex-components-core/src/reflex_components_core/core/banner.pyi": "9af2721b01868b24a48c7899ad6b1c69", + "packages/reflex-components-core/src/reflex_components_core/core/clipboard.pyi": "20a3f4f500d44ac4365b6d831c6816ff", + "packages/reflex-components-core/src/reflex_components_core/core/debounce.pyi": "eb606cf8151e6769df7f2443ece739cd", + "packages/reflex-components-core/src/reflex_components_core/core/helmet.pyi": "5e28d554d2b4d7fae1ba35809c24f4fc", + "packages/reflex-components-core/src/reflex_components_core/core/html.pyi": "28bd59898f0402b33c34e14f3eef1282", + "packages/reflex-components-core/src/reflex_components_core/core/sticky.pyi": "4b34eca0e7338ec80ac5985345717bc9", + "packages/reflex-components-core/src/reflex_components_core/core/upload.pyi": "6f3cdef9956dbe5c917edeefdffd1b0e", + "packages/reflex-components-core/src/reflex_components_core/core/window_events.pyi": "28e901ee970bec806ee766d0d126d739", + "packages/reflex-components-core/src/reflex_components_core/datadisplay/__init__.pyi": "c96fed4da42a13576d64f84e3c7cb25c", + "packages/reflex-components-core/src/reflex_components_core/el/__init__.pyi": "f09129ddefb57ab4c7769c86dc9a3153", + "packages/reflex-components-core/src/reflex_components_core/el/element.pyi": "1a8824cdd243efc876157b97f9f1b714", + "packages/reflex-components-core/src/reflex_components_core/el/elements/__init__.pyi": "e6c845f2f29eb079697a2e31b0c2f23a", + "packages/reflex-components-core/src/reflex_components_core/el/elements/base.pyi": "7c74980207dc1a5cac14083f2edd31ba", + "packages/reflex-components-core/src/reflex_components_core/el/elements/forms.pyi": "da7ef00fd67699eeeb55e33279c2eb8d", + "packages/reflex-components-core/src/reflex_components_core/el/elements/inline.pyi": "0ea0058ea7b6ae03138c7c85df963c32", + "packages/reflex-components-core/src/reflex_components_core/el/elements/media.pyi": "97f7f6c66533bb3947a43ceefe160d49", + "packages/reflex-components-core/src/reflex_components_core/el/elements/metadata.pyi": "7ea09671a42d75234a0464fc3601577c", + "packages/reflex-components-core/src/reflex_components_core/el/elements/other.pyi": "869dca86b783149f9c59e1ae0d2900c1", + "packages/reflex-components-core/src/reflex_components_core/el/elements/scripts.pyi": "c3a5a4f2d0594414a160fe59b13ccc26", + "packages/reflex-components-core/src/reflex_components_core/el/elements/sectioning.pyi": "b2acdc964feabe78154be141dc978555", + "packages/reflex-components-core/src/reflex_components_core/el/elements/tables.pyi": "e75fbe0454df06abf462ab579b698897", + "packages/reflex-components-core/src/reflex_components_core/el/elements/typography.pyi": "f88089a2f4270b981a28e385d07460b5", + "packages/reflex-components-core/src/reflex_components_core/react_router/dom.pyi": "c5ac8ba14fdce557063a832a79f43f68", + "packages/reflex-components-dataeditor/src/reflex_components_dataeditor/dataeditor.pyi": "e10210239ce7dc18980e70eec19b9353", + "packages/reflex-components-gridjs/src/reflex_components_gridjs/datatable.pyi": "2a93782c63e82a6939411273fe2486d9", + "packages/reflex-components-lucide/src/reflex_components_lucide/icon.pyi": "f654cc9cb305712b485fcd676935c0c1", + "packages/reflex-components-markdown/src/reflex_components_markdown/markdown.pyi": "9c11bca2c4c5b722f55aba969f383e74", + "packages/reflex-components-moment/src/reflex_components_moment/moment.pyi": "ad4b084d94e50311f761d69b3173e357", + "packages/reflex-components-plotly/src/reflex_components_plotly/plotly.pyi": "241b80584f3e029145e6e003d1c476f2", + "packages/reflex-components-radix/src/reflex_components_radix/__init__.pyi": "b2f485bfde4978047b7b944cf15d92cb", + "packages/reflex-components-radix/src/reflex_components_radix/primitives/__init__.pyi": "5404a8da97e8b5129133d7f300e3f642", + "packages/reflex-components-radix/src/reflex_components_radix/primitives/accordion.pyi": "18ed34323f671fcf655639dc78d7c549", + "packages/reflex-components-radix/src/reflex_components_radix/primitives/base.pyi": "9c80e740d177b4a805dee3038d580941", + "packages/reflex-components-radix/src/reflex_components_radix/primitives/dialog.pyi": "b47313aefc9a740851ee332656446afd", + "packages/reflex-components-radix/src/reflex_components_radix/primitives/drawer.pyi": "d6a4f88f2988fa50fbed8a9026f5ef8b", + "packages/reflex-components-radix/src/reflex_components_radix/primitives/form.pyi": "00c0e0b6c8190f2db7fd847a25b5c03d", + "packages/reflex-components-radix/src/reflex_components_radix/primitives/progress.pyi": "577ec9714a4d8bc9f7dd7eca22fe5252", + "packages/reflex-components-radix/src/reflex_components_radix/primitives/slider.pyi": "bc69b9443d04ae7856c0a411a90755a9", + "packages/reflex-components-radix/src/reflex_components_radix/themes/__init__.pyi": "b433b9a099dc5b0ab008d02c85d38059", + "packages/reflex-components-radix/src/reflex_components_radix/themes/base.pyi": "90a182a1444b73c006e52ea67c2b3db1", + "packages/reflex-components-radix/src/reflex_components_radix/themes/color_mode.pyi": "3a419f78071b0dd6be55dc55e7334a1b", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/__init__.pyi": "f10f0169f81c78290333da831915762f", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/alert_dialog.pyi": "2b8c68239c9e9646e71ef8e81d7b5f69", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/aspect_ratio.pyi": "0f981ee0589f5501ab3c57e0aec01316", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/avatar.pyi": "d30f1bfb42198177ea08d7d358e99339", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/badge.pyi": "c3bb335b309177ff03d2cadcaf623744", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/button.pyi": "6a01812d601e8bf3dcd30dcccc75cb79", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/callout.pyi": "9b853e851805addacc2fcd995119f857", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/card.pyi": "67a71ec6ed4945a9ce270bd51d40b94e", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/checkbox.pyi": "0c975a4812efc267c87119f10880e1a9", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/checkbox_cards.pyi": "6425aae44ffe78f48699910906d16285", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/checkbox_group.pyi": "d0029ee04a971d8a51be0c99e414a139", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/context_menu.pyi": "1ee25c7dd27fece9881800226e322d6b", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/data_list.pyi": "924addbc155a178709f5fd38af4eb547", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/dialog.pyi": "e315e9779663f2f2fc9c2ca322a5645f", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/dropdown_menu.pyi": "ec6cb8830971b2a04bebe7459c059b15", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/hover_card.pyi": "28384945a53620ad6075797f8ada7354", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/icon_button.pyi": "6a3a37bdc9136f8c19fb3a7f55e76d64", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/inset.pyi": "05cfece835e2660bbc1b096529dfdec0", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/popover.pyi": "3033070773e8e32de283ad917367b386", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/progress.pyi": "798eadec25895a56e36d23203a4e0444", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/radio.pyi": "f6140dbf7ad4c25595c6983dcacc2a60", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/radio_cards.pyi": "e16ca79a2ad4c2919f56efb54830c1ef", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/radio_group.pyi": "473703616ed18d983dda3600899710a5", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/scroll_area.pyi": "12eb86d24886764bf1a5815e87ea519c", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/segmented_control.pyi": "6319f89d046b0fce8e9efb51e50dda9f", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/select.pyi": "c6da1db236da70dc40815a404d2e29b3", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/separator.pyi": "d2dabb895d7fc63a556d3c3220e38b4d", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/skeleton.pyi": "55b003f62cc3e5c85c90c82f8f595bc6", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/slider.pyi": "c204f30612bfa35a62cb9f525a913f77", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/spinner.pyi": "faeddfd0e3dc0e3bbcfdeaa6e42cb755", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/switch.pyi": "70f1d8fc55398d3cbb01f157c768419e", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/table.pyi": "a4c3052bc449924a630dad911f975e26", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/tabs.pyi": "ec4e4ed03bd892c6f7d50ae4b490adb9", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/text_area.pyi": "06549c800759ae541cc3c3a74240af59", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/text_field.pyi": "dcb6a8ff4668082fc9406579098abf87", + "packages/reflex-components-radix/src/reflex_components_radix/themes/components/tooltip.pyi": "69e4ce4eeaa60ac90ef120331cb8601c", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/__init__.pyi": "dcbb1dc8e860379188924c15dd21605b", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/base.pyi": "28e6cd3869c9cbad886b69b339e3ecf6", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/box.pyi": "004cae8160c3a91ae6c12b54205f5112", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/center.pyi": "9dbe595eddc2ec731beeb3a98743be36", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/container.pyi": "1fb9d0ce37de9c64f681ad70375b9e42", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/flex.pyi": "a729044bfe2d82404de07c4570262b55", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/grid.pyi": "74b017b63728ce328e110bc64f20a205", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/list.pyi": "3a595ec7faf95645ab52bdad1bf9dc4a", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/section.pyi": "f3e44e291f3d96d06850d262de5d43a8", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/spacer.pyi": "a0a59ca93ea1e3a0e5136b9692a68d18", + "packages/reflex-components-radix/src/reflex_components_radix/themes/layout/stack.pyi": "6ab750e790f0687b735d7464fa289c1f", + "packages/reflex-components-radix/src/reflex_components_radix/themes/typography/__init__.pyi": "de7ee994f66a4c1d1a6ac2ad3370c30e", + "packages/reflex-components-radix/src/reflex_components_radix/themes/typography/blockquote.pyi": "3dd8bc1d7117b4e2b3b38438b4d6631a", + "packages/reflex-components-radix/src/reflex_components_radix/themes/typography/code.pyi": "a71f56a8c51e9b00f953d87b16724bdb", + "packages/reflex-components-radix/src/reflex_components_radix/themes/typography/heading.pyi": "47a5f03dc4c85c473026069d23b6c531", + "packages/reflex-components-radix/src/reflex_components_radix/themes/typography/link.pyi": "ced137b2820a5e156cd1846ff113cfc9", + "packages/reflex-components-radix/src/reflex_components_radix/themes/typography/text.pyi": "014444973b21272cf8c572b2913dfdf5", + "packages/reflex-components-react-player/src/reflex_components_react_player/audio.pyi": "2c3c398ec0cc1476995f316cf8d0d271", + "packages/reflex-components-react-player/src/reflex_components_react_player/react_player.pyi": "9f8631e66d64f8bed90cbfd63615a97a", + "packages/reflex-components-react-player/src/reflex_components_react_player/video.pyi": "d0efeacb8b4162e9ace79f99c03e4368", + "packages/reflex-components-recharts/src/reflex_components_recharts/__init__.pyi": "7b8b69840a3637c1f1cac45ba815cccf", + "packages/reflex-components-recharts/src/reflex_components_recharts/cartesian.pyi": "9e99f951112c86ec7991bc80985a76b1", + "packages/reflex-components-recharts/src/reflex_components_recharts/charts.pyi": "5730b770af97f8c67d6d2d50e84fe14d", + "packages/reflex-components-recharts/src/reflex_components_recharts/general.pyi": "4097350ca05011733ce998898c6aefe7", + "packages/reflex-components-recharts/src/reflex_components_recharts/polar.pyi": "db5298160144f23ae7abcaac68e845c7", + "packages/reflex-components-recharts/src/reflex_components_recharts/recharts.pyi": "75150b01510bdacf2c97fca347c86c59", + "packages/reflex-components-sonner/src/reflex_components_sonner/toast.pyi": "dc43e142b089b1158588e999505444f6", "reflex/__init__.pyi": "5de3d4af8ea86e9755f622510b868196", "reflex/components/__init__.pyi": "f39a2af77f438fa243c58c965f19d42e", "reflex/experimental/memo.pyi": "c10cbc554fe2ffdb3a008b59bc503936" From db87434b2da77cf1e51049d15d3c200a39331138 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 09:34:20 -0700 Subject: [PATCH 69/81] make reflex_base.event a package --- .../reflex-base/src/reflex_base/{event.py => event/__init__.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename packages/reflex-base/src/reflex_base/{event.py => event/__init__.py} (100%) diff --git a/packages/reflex-base/src/reflex_base/event.py b/packages/reflex-base/src/reflex_base/event/__init__.py similarity index 100% rename from packages/reflex-base/src/reflex_base/event.py rename to packages/reflex-base/src/reflex_base/event/__init__.py From 0fa3f7e611561f63b40be9409939d00e00231e84 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 09:41:39 -0700 Subject: [PATCH 70/81] Get rid of reflex_base._internal namespace --- .../src/reflex_base/_internal/__init__.py | 1 - .../reflex_base/_internal/event/__init__.py | 1 - .../_internal/event/processor/__init__.py | 19 ------------------- .../src/reflex_base/components/component.py | 2 +- .../{_internal => }/context/__init__.py | 0 .../{_internal => }/context/base.py | 2 ++ .../src/reflex_base/event/__init__.py | 6 +++++- .../{_internal => }/event/context.py | 2 +- .../reflex_base/event/processor/__init__.py | 14 ++++++++++++++ .../event/processor/base_state_processor.py | 9 +++------ .../{_internal => }/event/processor/compat.py | 0 .../event/processor/event_processor.py | 10 +++++----- .../{_internal => }/event/processor/future.py | 1 + .../event/processor/timeout.py | 0 .../reflex_base/{_internal => }/registry.py | 2 +- .../reflex_components_core/core/_upload.py | 2 +- reflex/app.py | 7 ++----- reflex/istate/manager/__init__.py | 2 +- reflex/istate/proxy.py | 2 +- reflex/state.py | 10 +++++----- reflex/testing.py | 2 +- tests/benchmarks/test_event_processing.py | 4 ++-- tests/units/conftest.py | 9 +++------ tests/units/istate/test_proxy.py | 2 +- .../middleware/test_hydrate_middleware.py | 2 +- .../{_internal => context}/__init__.py | 0 .../{_internal => }/context/test_base.py | 2 +- .../{_internal/context => event}/__init__.py | 0 .../processor/test_base_state_processor.py | 6 +++--- .../event/processor/test_event_processor.py | 9 +++------ .../event/processor/test_future.py | 2 +- .../event/processor/test_timeout.py | 2 +- .../{_internal => }/event/test_context.py | 2 +- .../{_internal => }/test_registry.py | 2 +- tests/units/test_app.py | 6 +++--- tests/units/test_state.py | 4 ++-- 36 files changed, 67 insertions(+), 79 deletions(-) delete mode 100644 packages/reflex-base/src/reflex_base/_internal/__init__.py delete mode 100644 packages/reflex-base/src/reflex_base/_internal/event/__init__.py delete mode 100644 packages/reflex-base/src/reflex_base/_internal/event/processor/__init__.py rename packages/reflex-base/src/reflex_base/{_internal => }/context/__init__.py (100%) rename packages/reflex-base/src/reflex_base/{_internal => }/context/base.py (97%) rename packages/reflex-base/src/reflex_base/{_internal => }/event/context.py (98%) create mode 100644 packages/reflex-base/src/reflex_base/event/processor/__init__.py rename packages/reflex-base/src/reflex_base/{_internal => }/event/processor/base_state_processor.py (98%) rename packages/reflex-base/src/reflex_base/{_internal => }/event/processor/compat.py (100%) rename packages/reflex-base/src/reflex_base/{_internal => }/event/processor/event_processor.py (98%) rename packages/reflex-base/src/reflex_base/{_internal => }/event/processor/future.py (98%) rename packages/reflex-base/src/reflex_base/{_internal => }/event/processor/timeout.py (100%) rename packages/reflex-base/src/reflex_base/{_internal => }/registry.py (98%) rename tests/units/reflex_base/{_internal => context}/__init__.py (100%) rename tests/units/reflex_base/{_internal => }/context/test_base.py (97%) rename tests/units/reflex_base/{_internal/context => event}/__init__.py (100%) rename tests/units/reflex_base/{_internal => }/event/processor/test_base_state_processor.py (96%) rename tests/units/reflex_base/{_internal => }/event/processor/test_event_processor.py (98%) rename tests/units/reflex_base/{_internal => }/event/processor/test_future.py (99%) rename tests/units/reflex_base/{_internal => }/event/processor/test_timeout.py (90%) rename tests/units/reflex_base/{_internal => }/event/test_context.py (97%) rename tests/units/reflex_base/{_internal => }/test_registry.py (97%) diff --git a/packages/reflex-base/src/reflex_base/_internal/__init__.py b/packages/reflex-base/src/reflex_base/_internal/__init__.py deleted file mode 100644 index af1fcdf80fb..00000000000 --- a/packages/reflex-base/src/reflex_base/_internal/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Reflex internals: subject to change 🐉.""" diff --git a/packages/reflex-base/src/reflex_base/_internal/event/__init__.py b/packages/reflex-base/src/reflex_base/_internal/event/__init__.py deleted file mode 100644 index b2e670ff4b3..00000000000 --- a/packages/reflex-base/src/reflex_base/_internal/event/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Internal event processing.""" diff --git a/packages/reflex-base/src/reflex_base/_internal/event/processor/__init__.py b/packages/reflex-base/src/reflex_base/_internal/event/processor/__init__.py deleted file mode 100644 index 1ebe5b4e130..00000000000 --- a/packages/reflex-base/src/reflex_base/_internal/event/processor/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Procedures for handling events.""" - -from reflex_base._internal.event.processor.base_state_processor import ( - BaseStateEventProcessor, -) -from reflex_base._internal.event.processor.event_processor import ( - EventProcessor, - EventQueueEntry, -) -from reflex_base._internal.event.processor.future import EventFuture -from reflex_base._internal.event.processor.timeout import DrainTimeoutManager - -__all__ = [ - "BaseStateEventProcessor", - "DrainTimeoutManager", - "EventFuture", - "EventProcessor", - "EventQueueEntry", -] diff --git a/packages/reflex-base/src/reflex_base/components/component.py b/packages/reflex-base/src/reflex_base/components/component.py index 0950477f900..d7903c5c9e2 100644 --- a/packages/reflex-base/src/reflex_base/components/component.py +++ b/packages/reflex-base/src/reflex_base/components/component.py @@ -2419,7 +2419,7 @@ def create(cls, component: Component) -> StatefulComponent | None: """ from reflex_components_core.core.foreach import Foreach - from reflex_base._internal.registry import RegistrationContext + from reflex_base.registry import RegistrationContext if component._memoization_mode.disposition == MemoizationDisposition.NEVER: # Never memoize this component. diff --git a/packages/reflex-base/src/reflex_base/_internal/context/__init__.py b/packages/reflex-base/src/reflex_base/context/__init__.py similarity index 100% rename from packages/reflex-base/src/reflex_base/_internal/context/__init__.py rename to packages/reflex-base/src/reflex_base/context/__init__.py diff --git a/packages/reflex-base/src/reflex_base/_internal/context/base.py b/packages/reflex-base/src/reflex_base/context/base.py similarity index 97% rename from packages/reflex-base/src/reflex_base/_internal/context/base.py rename to packages/reflex-base/src/reflex_base/context/base.py index 9d75b05933b..7bb28d4864c 100644 --- a/packages/reflex-base/src/reflex_base/_internal/context/base.py +++ b/packages/reflex-base/src/reflex_base/context/base.py @@ -1,3 +1,5 @@ +"""Shared contextvars wrapper for contextual globals.""" + from __future__ import annotations import dataclasses diff --git a/packages/reflex-base/src/reflex_base/event/__init__.py b/packages/reflex-base/src/reflex_base/event/__init__.py index dd82e2a7b3f..9bb63f6d9a5 100644 --- a/packages/reflex-base/src/reflex_base/event/__init__.py +++ b/packages/reflex-base/src/reflex_base/event/__init__.py @@ -81,7 +81,7 @@ class Event: @property def state_cls(self) -> "type[BaseState]": """The state class for the event.""" - from reflex_base._internal.registry import RegistrationContext + from reflex_base.registry import RegistrationContext substate_name = self.name.rpartition(".")[0] return RegistrationContext.get().base_states[substate_name] @@ -2830,4 +2830,8 @@ def BaseState(self) -> "type[BaseState]": # noqa: N802 event = EventNamespace event.event = event # pyright: ignore[reportAttributeAccessIssue] +_this = sys.modules[__name__] +event.__path__ = _this.__path__ # pyright: ignore[reportAttributeAccessIssue] +event.__spec__ = _this.__spec__ # pyright: ignore[reportAttributeAccessIssue] +event.__package__ = _this.__package__ # pyright: ignore[reportAttributeAccessIssue] sys.modules[__name__] = event # pyright: ignore[reportArgumentType] diff --git a/packages/reflex-base/src/reflex_base/_internal/event/context.py b/packages/reflex-base/src/reflex_base/event/context.py similarity index 98% rename from packages/reflex-base/src/reflex_base/_internal/event/context.py rename to packages/reflex-base/src/reflex_base/event/context.py index 17afaeac7fa..df3f0200e62 100644 --- a/packages/reflex-base/src/reflex_base/_internal/event/context.py +++ b/packages/reflex-base/src/reflex_base/event/context.py @@ -8,7 +8,7 @@ from collections.abc import Callable, Mapping from typing import TYPE_CHECKING, Any, Protocol -from reflex_base._internal.context.base import BaseContext +from reflex_base.context.base import BaseContext from reflex_base.utils.format import to_snake_case if TYPE_CHECKING: diff --git a/packages/reflex-base/src/reflex_base/event/processor/__init__.py b/packages/reflex-base/src/reflex_base/event/processor/__init__.py new file mode 100644 index 00000000000..f72058483d2 --- /dev/null +++ b/packages/reflex-base/src/reflex_base/event/processor/__init__.py @@ -0,0 +1,14 @@ +"""Procedures for handling events.""" + +from reflex_base.event.processor.base_state_processor import BaseStateEventProcessor +from reflex_base.event.processor.event_processor import EventProcessor, EventQueueEntry +from reflex_base.event.processor.future import EventFuture +from reflex_base.event.processor.timeout import DrainTimeoutManager + +__all__ = [ + "BaseStateEventProcessor", + "DrainTimeoutManager", + "EventFuture", + "EventProcessor", + "EventQueueEntry", +] diff --git a/packages/reflex-base/src/reflex_base/_internal/event/processor/base_state_processor.py b/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py similarity index 98% rename from packages/reflex-base/src/reflex_base/_internal/event/processor/base_state_processor.py rename to packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py index b67bb48531a..30e7c9fe315 100644 --- a/packages/reflex-base/src/reflex_base/_internal/event/processor/base_state_processor.py +++ b/packages/reflex-base/src/reflex_base/event/processor/base_state_processor.py @@ -15,12 +15,9 @@ from reflex.istate.manager.token import BaseStateToken from reflex.istate.proxy import StateProxy from reflex.utils import console, types -from reflex_base._internal.event.context import EventContext -from reflex_base._internal.event.processor.event_processor import ( - EventProcessor, - EventQueueEntry, -) -from reflex_base._internal.registry import RegisteredEventHandler +from reflex_base.event.context import EventContext +from reflex_base.event.processor.event_processor import EventProcessor, EventQueueEntry +from reflex_base.registry import RegisteredEventHandler from reflex_base.utils.format import format_event_handler if TYPE_CHECKING: diff --git a/packages/reflex-base/src/reflex_base/_internal/event/processor/compat.py b/packages/reflex-base/src/reflex_base/event/processor/compat.py similarity index 100% rename from packages/reflex-base/src/reflex_base/_internal/event/processor/compat.py rename to packages/reflex-base/src/reflex_base/event/processor/compat.py diff --git a/packages/reflex-base/src/reflex_base/_internal/event/processor/event_processor.py b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py similarity index 98% rename from packages/reflex-base/src/reflex_base/_internal/event/processor/event_processor.py rename to packages/reflex-base/src/reflex_base/event/processor/event_processor.py index a5ce1423295..776c4b5dce1 100644 --- a/packages/reflex-base/src/reflex_base/_internal/event/processor/event_processor.py +++ b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py @@ -20,11 +20,11 @@ from reflex.app_mixins.middleware import MiddlewareMixin from reflex.istate.manager import StateManager from reflex.utils import console -from reflex_base._internal.event.context import EventContext -from reflex_base._internal.event.processor.compat import as_completed -from reflex_base._internal.event.processor.future import EventFuture -from reflex_base._internal.event.processor.timeout import DrainTimeoutManager -from reflex_base._internal.registry import RegisteredEventHandler, RegistrationContext +from reflex_base.event.context import EventContext +from reflex_base.event.processor.compat import as_completed +from reflex_base.event.processor.future import EventFuture +from reflex_base.event.processor.timeout import DrainTimeoutManager +from reflex_base.registry import RegisteredEventHandler, RegistrationContext if TYPE_CHECKING: from reflex.app import EventNamespace diff --git a/packages/reflex-base/src/reflex_base/_internal/event/processor/future.py b/packages/reflex-base/src/reflex_base/event/processor/future.py similarity index 98% rename from packages/reflex-base/src/reflex_base/_internal/event/processor/future.py rename to packages/reflex-base/src/reflex_base/event/processor/future.py index dd76f485800..01d27fbdcef 100644 --- a/packages/reflex-base/src/reflex_base/_internal/event/processor/future.py +++ b/packages/reflex-base/src/reflex_base/event/processor/future.py @@ -32,6 +32,7 @@ class EventFuture(asyncio.Future): ) def __post_init__(self) -> None: + """Call Future.__init__ for the EventFuture.""" super(EventFuture, self).__init__(loop=self.loop) def add_child(self, child: EventFuture) -> None: diff --git a/packages/reflex-base/src/reflex_base/_internal/event/processor/timeout.py b/packages/reflex-base/src/reflex_base/event/processor/timeout.py similarity index 100% rename from packages/reflex-base/src/reflex_base/_internal/event/processor/timeout.py rename to packages/reflex-base/src/reflex_base/event/processor/timeout.py diff --git a/packages/reflex-base/src/reflex_base/_internal/registry.py b/packages/reflex-base/src/reflex_base/registry.py similarity index 98% rename from packages/reflex-base/src/reflex_base/_internal/registry.py rename to packages/reflex-base/src/reflex_base/registry.py index ab787dd4191..8caa1d2b2c3 100644 --- a/packages/reflex-base/src/reflex_base/_internal/registry.py +++ b/packages/reflex-base/src/reflex_base/registry.py @@ -7,7 +7,7 @@ from typing_extensions import Self -from reflex_base._internal.context.base import BaseContext +from reflex_base.context.base import BaseContext from reflex_base.utils.exceptions import StateValueError if TYPE_CHECKING: diff --git a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py index 08651c6eeb5..2e8d80b1052 100644 --- a/packages/reflex-components-core/src/reflex_components_core/core/_upload.py +++ b/packages/reflex-components-core/src/reflex_components_core/core/_upload.py @@ -621,11 +621,11 @@ async def upload_file(request: Request): UploadTypeError: If a non-streaming upload is wired to a background task. HTTPException: when the request does not include token / handler headers. """ - from reflex_base._internal.registry import RegistrationContext from reflex_base.event import ( resolve_upload_chunk_handler_param, resolve_upload_handler_param, ) + from reflex_base.registry import RegistrationContext token, handler_name = _require_upload_headers(request) registered_event_handler = RegistrationContext.get().event_handlers[ diff --git a/reflex/app.py b/reflex/app.py index 933c3ef6527..e3e66aec45d 100644 --- a/reflex/app.py +++ b/reflex/app.py @@ -24,11 +24,6 @@ from typing import TYPE_CHECKING, Any, ParamSpec, overload from reflex_base import constants -from reflex_base._internal.event.processor import ( - BaseStateEventProcessor, - EventProcessor, -) -from reflex_base._internal.registry import RegistrationContext from reflex_base.components.component import ( CUSTOM_COMPONENTS, Component, @@ -45,6 +40,8 @@ IndividualEventType, noop, ) +from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor +from reflex_base.registry import RegistrationContext from reflex_base.utils import console from reflex_base.utils.imports import ImportVar from reflex_base.utils.types import ASGIApp, Message, Receive, Scope, Send diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index e0b1cfbd12f..c2cf5e44b2e 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -176,6 +176,6 @@ def get_state_manager() -> StateManager: Returns: The state manager. """ - from reflex_base._internal.event.context import EventContext + from reflex_base.event.context import EventContext return EventContext.get().state_manager diff --git a/reflex/istate/proxy.py b/reflex/istate/proxy.py index 8f0d01e88b0..ce9aa3c8618 100644 --- a/reflex/istate/proxy.py +++ b/reflex/istate/proxy.py @@ -15,8 +15,8 @@ from typing import TYPE_CHECKING, Any, SupportsIndex, TypeVar import wrapt -from reflex_base._internal.event.context import EventContext from reflex_base.event import Event +from reflex_base.event.context import EventContext from reflex_base.utils.exceptions import ImmutableStateError from reflex_base.utils.serializers import can_serialize, serialize, serializer from reflex_base.vars.base import Var diff --git a/reflex/state.py b/reflex/state.py index f58fe0bf7ed..413e20c0978 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -496,7 +496,7 @@ def __init_subclass__(cls, mixin: bool = False, **kwargs): Raises: StateValueError: If a substate class shadows another. """ - from reflex_base._internal.registry import RegistrationContext + from reflex_base.registry import RegistrationContext from reflex_base.utils.exceptions import StateValueError super().__init_subclass__(**kwargs) @@ -960,7 +960,7 @@ def get_substates(cls) -> set[type[BaseState]]: Returns: The substates of the state. """ - from reflex_base._internal.registry import RegistrationContext + from reflex_base.registry import RegistrationContext return RegistrationContext.get().get_substates(cls) @@ -1145,7 +1145,7 @@ def _create_event_handler( Returns: The event handler. """ - from reflex_base._internal.registry import RegistrationContext + from reflex_base.registry import RegistrationContext # Check if function has stored event_actions from decorator event_actions = getattr(fn, EVENT_ACTIONS_MARKER, {}) @@ -2214,7 +2214,7 @@ async def _get_state_from_redis(self, state_cls: type[T_STATE]) -> T_STATE: @event async def hydrate(self) -> None: """Send the full state to the frontend to synchronize it with the backend.""" - from reflex_base._internal.event.context import EventContext + from reflex_base.event.context import EventContext # Clear client storage, to respect clearing cookies self._reset_client_storage() @@ -2569,7 +2569,7 @@ def reload_state_module( state: Recursive argument for the state class to reload. """ - from reflex_base._internal.registry import RegistrationContext + from reflex_base.registry import RegistrationContext # Reset the _app_ref of OnLoadInternalState to avoid stale references. if state is OnLoadInternalState: diff --git a/reflex/testing.py b/reflex/testing.py index 5ac37a5e0b4..c144b22d7c6 100644 --- a/reflex/testing.py +++ b/reflex/testing.py @@ -28,10 +28,10 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypeVar import uvicorn -from reflex_base._internal.registry import RegistrationContext from reflex_base.components.component import CUSTOM_COMPONENTS, CustomComponent from reflex_base.config import get_config from reflex_base.environment import environment +from reflex_base.registry import RegistrationContext from reflex_base.utils.types import ASGIApp from typing_extensions import Self diff --git a/tests/benchmarks/test_event_processing.py b/tests/benchmarks/test_event_processing.py index 24a5832650d..6990482cbdb 100644 --- a/tests/benchmarks/test_event_processing.py +++ b/tests/benchmarks/test_event_processing.py @@ -14,9 +14,9 @@ import pytest import pytest_asyncio from pytest_codspeed import BenchmarkFixture -from reflex_base._internal.event.context import EventContext -from reflex_base._internal.event.processor import BaseStateEventProcessor from reflex_base.event import Event +from reflex_base.event.context import EventContext +from reflex_base.event.processor import BaseStateEventProcessor from reflex_base.utils.format import format_event_handler from reflex.istate.manager.memory import StateManagerMemory diff --git a/tests/units/conftest.py b/tests/units/conftest.py index c88beb53eba..36baee0ec8e 100644 --- a/tests/units/conftest.py +++ b/tests/units/conftest.py @@ -10,14 +10,11 @@ import pytest import pytest_asyncio -from reflex_base._internal.event.context import EventContext -from reflex_base._internal.event.processor import ( - BaseStateEventProcessor, - EventProcessor, -) -from reflex_base._internal.registry import RegistrationContext from reflex_base.components.component import CUSTOM_COMPONENTS from reflex_base.event import Event, EventSpec +from reflex_base.event.context import EventContext +from reflex_base.event.processor import BaseStateEventProcessor, EventProcessor +from reflex_base.registry import RegistrationContext from reflex.app import App from reflex.experimental.memo import EXPERIMENTAL_MEMOS diff --git a/tests/units/istate/test_proxy.py b/tests/units/istate/test_proxy.py index 5946df37199..ec9a923b52a 100644 --- a/tests/units/istate/test_proxy.py +++ b/tests/units/istate/test_proxy.py @@ -6,7 +6,7 @@ from contextlib import asynccontextmanager import pytest -from reflex_base._internal.event.context import EventContext +from reflex_base.event.context import EventContext import reflex as rx from reflex.istate.proxy import MutableProxy, StateProxy diff --git a/tests/units/middleware/test_hydrate_middleware.py b/tests/units/middleware/test_hydrate_middleware.py index 435c4263d57..ed437e7e4e0 100644 --- a/tests/units/middleware/test_hydrate_middleware.py +++ b/tests/units/middleware/test_hydrate_middleware.py @@ -1,7 +1,7 @@ from __future__ import annotations import pytest -from reflex_base._internal.registry import RegistrationContext +from reflex_base.registry import RegistrationContext from reflex.app import App from reflex.middleware.hydrate_middleware import HydrateMiddleware diff --git a/tests/units/reflex_base/_internal/__init__.py b/tests/units/reflex_base/context/__init__.py similarity index 100% rename from tests/units/reflex_base/_internal/__init__.py rename to tests/units/reflex_base/context/__init__.py diff --git a/tests/units/reflex_base/_internal/context/test_base.py b/tests/units/reflex_base/context/test_base.py similarity index 97% rename from tests/units/reflex_base/_internal/context/test_base.py rename to tests/units/reflex_base/context/test_base.py index bd7211d2b43..11db7963159 100644 --- a/tests/units/reflex_base/_internal/context/test_base.py +++ b/tests/units/reflex_base/context/test_base.py @@ -3,7 +3,7 @@ import dataclasses import pytest -from reflex_base._internal.context.base import BaseContext +from reflex_base.context.base import BaseContext @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) diff --git a/tests/units/reflex_base/_internal/context/__init__.py b/tests/units/reflex_base/event/__init__.py similarity index 100% rename from tests/units/reflex_base/_internal/context/__init__.py rename to tests/units/reflex_base/event/__init__.py diff --git a/tests/units/reflex_base/_internal/event/processor/test_base_state_processor.py b/tests/units/reflex_base/event/processor/test_base_state_processor.py similarity index 96% rename from tests/units/reflex_base/_internal/event/processor/test_base_state_processor.py rename to tests/units/reflex_base/event/processor/test_base_state_processor.py index 32e155ec744..e414369a40e 100644 --- a/tests/units/reflex_base/_internal/event/processor/test_base_state_processor.py +++ b/tests/units/reflex_base/event/processor/test_base_state_processor.py @@ -6,11 +6,11 @@ import pytest import pytest_asyncio -from reflex_base._internal.event.context import EventContext -from reflex_base._internal.event.processor import BaseStateEventProcessor -from reflex_base._internal.registry import RegistrationContext from reflex_base.constants import CompileVars from reflex_base.constants.state import FIELD_MARKER +from reflex_base.event.context import EventContext +from reflex_base.event.processor import BaseStateEventProcessor +from reflex_base.registry import RegistrationContext from reflex import event from reflex.app import App diff --git a/tests/units/reflex_base/_internal/event/processor/test_event_processor.py b/tests/units/reflex_base/event/processor/test_event_processor.py similarity index 98% rename from tests/units/reflex_base/_internal/event/processor/test_event_processor.py rename to tests/units/reflex_base/event/processor/test_event_processor.py index 2e9706c1219..de1ea4dcb23 100644 --- a/tests/units/reflex_base/_internal/event/processor/test_event_processor.py +++ b/tests/units/reflex_base/event/processor/test_event_processor.py @@ -5,12 +5,9 @@ from typing import Any import pytest -from reflex_base._internal.event.context import EventContext -from reflex_base._internal.event.processor.event_processor import ( - EventProcessor, - QueueShutDown, -) -from reflex_base._internal.registry import RegistrationContext +from reflex_base.event.context import EventContext +from reflex_base.event.processor.event_processor import EventProcessor, QueueShutDown +from reflex_base.registry import RegistrationContext from reflex.event import Event, EventHandler diff --git a/tests/units/reflex_base/_internal/event/processor/test_future.py b/tests/units/reflex_base/event/processor/test_future.py similarity index 99% rename from tests/units/reflex_base/_internal/event/processor/test_future.py rename to tests/units/reflex_base/event/processor/test_future.py index 925304d59cf..dec13eb7cd8 100644 --- a/tests/units/reflex_base/_internal/event/processor/test_future.py +++ b/tests/units/reflex_base/event/processor/test_future.py @@ -3,7 +3,7 @@ import asyncio import pytest -from reflex_base._internal.event.processor.future import EventFuture +from reflex_base.event.processor.future import EventFuture @pytest.mark.asyncio diff --git a/tests/units/reflex_base/_internal/event/processor/test_timeout.py b/tests/units/reflex_base/event/processor/test_timeout.py similarity index 90% rename from tests/units/reflex_base/_internal/event/processor/test_timeout.py rename to tests/units/reflex_base/event/processor/test_timeout.py index 530c405c834..54f0827ea4a 100644 --- a/tests/units/reflex_base/_internal/event/processor/test_timeout.py +++ b/tests/units/reflex_base/event/processor/test_timeout.py @@ -2,7 +2,7 @@ import time -from reflex_base._internal.event.processor.timeout import DrainTimeoutManager +from reflex_base.event.processor.timeout import DrainTimeoutManager def test_drain_timeout_no_timeout(): diff --git a/tests/units/reflex_base/_internal/event/test_context.py b/tests/units/reflex_base/event/test_context.py similarity index 97% rename from tests/units/reflex_base/_internal/event/test_context.py rename to tests/units/reflex_base/event/test_context.py index 90c1856a46d..484ec696e21 100644 --- a/tests/units/reflex_base/_internal/event/test_context.py +++ b/tests/units/reflex_base/event/test_context.py @@ -2,7 +2,7 @@ from unittest import mock -from reflex_base._internal.event.context import EventContext +from reflex_base.event.context import EventContext def test_fork_creates_child(mock_root_event_context: EventContext): diff --git a/tests/units/reflex_base/_internal/test_registry.py b/tests/units/reflex_base/test_registry.py similarity index 97% rename from tests/units/reflex_base/_internal/test_registry.py rename to tests/units/reflex_base/test_registry.py index a457269fedf..474acf874c8 100644 --- a/tests/units/reflex_base/_internal/test_registry.py +++ b/tests/units/reflex_base/test_registry.py @@ -1,7 +1,7 @@ """Tests for RegistrationContext.""" import pytest -from reflex_base._internal.registry import RegisteredEventHandler, RegistrationContext +from reflex_base.registry import RegisteredEventHandler, RegistrationContext from reflex_base.utils.exceptions import StateValueError diff --git a/tests/units/test_app.py b/tests/units/test_app.py index 65d86eaeb09..f331dd6c8b9 100644 --- a/tests/units/test_app.py +++ b/tests/units/test_app.py @@ -16,12 +16,12 @@ import pytest from pytest_mock import MockerFixture -from reflex_base._internal.event.context import EventContext -from reflex_base._internal.event.processor import BaseStateEventProcessor -from reflex_base._internal.registry import RegistrationContext from reflex_base.components.component import Component from reflex_base.constants.state import FIELD_MARKER from reflex_base.event import Event +from reflex_base.event.context import EventContext +from reflex_base.event.processor import BaseStateEventProcessor +from reflex_base.registry import RegistrationContext from reflex_base.style import Style from reflex_base.utils import console, exceptions, format from reflex_base.vars.base import computed_var diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 4dec98248e2..900dad6ee29 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -22,11 +22,11 @@ from pydantic import BaseModel as Base from pytest_mock import MockerFixture from reflex_base import constants -from reflex_base._internal.event.context import EventContext -from reflex_base._internal.event.processor import BaseStateEventProcessor from reflex_base.constants import CompileVars, RouteVar from reflex_base.constants.state import FIELD_MARKER from reflex_base.event import Event, EventHandler +from reflex_base.event.context import EventContext +from reflex_base.event.processor import BaseStateEventProcessor from reflex_base.utils import format, types from reflex_base.utils.exceptions import ( InvalidLockWarningThresholdError, From e2c1ed15255f87d0833d4dffe5f1680df42b2899 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 11:17:26 -0700 Subject: [PATCH 71/81] test_upload: extend sleep before cancellation increase chances of actually getting a single chunk processed before the cancel takes place. --- tests/integration/test_upload.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_upload.py b/tests/integration/test_upload.py index 5957dd9cc10..4c7f8e26997 100644 --- a/tests/integration/test_upload.py +++ b/tests/integration/test_upload.py @@ -670,10 +670,10 @@ async def test_cancel_upload_chunk( upload_box.send_keys(str(target_file)) upload_button.click() - await asyncio.sleep(1) + await asyncio.sleep(2) cancel_button.click() - await asyncio.sleep(12) + await asyncio.sleep(11) # But there should never be a final progress record for a cancelled upload. for p in driver.find_elements(By.XPATH, "//*[@id='stream_progress_dicts']/p"): From 4ae61e24d7c9e8e9b3b863b6abc9653fa892e0b3 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 15:22:20 -0700 Subject: [PATCH 72/81] re-add fix_events token param --- packages/reflex-base/src/reflex_base/event/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/reflex-base/src/reflex_base/event/__init__.py b/packages/reflex-base/src/reflex_base/event/__init__.py index 9bb63f6d9a5..ca0347745f9 100644 --- a/packages/reflex-base/src/reflex_base/event/__init__.py +++ b/packages/reflex-base/src/reflex_base/event/__init__.py @@ -2102,6 +2102,7 @@ def get_handler_args( def fix_events( events: list[EventSpec | EventHandler] | None, + token: str | None = None, router_data: dict[str, Any] | None = None, ) -> list[Event]: """Fix a list of events returned by an event handler. @@ -2110,6 +2111,7 @@ def fix_events( Args: events: The events to fix. + token: Deprecated, ignored. Kept for backward compatibility. router_data: The optional router data to set in the event. Returns: From 80542a3f8c3c85edfe5b2729a46261051c407d77 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 15:32:16 -0700 Subject: [PATCH 73/81] Add StateManager.state property as a compat shim --- reflex/istate/manager/__init__.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index c2cf5e44b2e..eed0be7f642 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -27,11 +27,27 @@ class StateModificationContext(TypedDict, total=False): @dataclasses.dataclass class StateManager(ABC): - """A class to manage many client states. + """A class to manage many client states.""" - Attributes: - state: The state class to use. - """ + @property + def state(self): + """Get the state class. + + Deprecated: the state manager no longer holds a reference to the state class. + + Returns: + The State class. + """ + console.deprecate( + feature_name="StateManager.state", + reason="The state manager no longer holds a reference to the state class. " + "Use reflex.state.State directly instead.", + deprecation_version="0.9.0", + removal_version="1.0", + ) + from reflex.state import State + + return State @classmethod def create(cls): From e8644effe378ce2fdd151ac07187804b7d9eccf9 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 15:37:25 -0700 Subject: [PATCH 74/81] deprecate StateUpdate.final (instead of removal) --- reflex/state.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/reflex/state.py b/reflex/state.py index 413e20c0978..463a1057e50 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -2530,6 +2530,19 @@ class StateUpdate: # Events to be added to the event queue. events: list[Event] = dataclasses.field(default_factory=list) + # Deprecated: previously indicated whether the event processing is complete. + final: bool | None = dataclasses.field(default=None, repr=False) + + def __post_init__(self): + """Warn if the deprecated `final` attribute is supplied.""" + if self.final is not None: + console.deprecate( + feature_name="StateUpdate.final", + reason="The final attribute is no longer used.", + deprecation_version="0.9.0", + removal_version="1.0", + ) + @serializer(to=dict) def serialize_state_update(update: StateUpdate) -> dict: From 27f09a65de772635fd96b94cb410cf8fb1c6d614 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 16:05:39 -0700 Subject: [PATCH 75/81] Support legacy token format in StateManager implementations --- reflex/istate/manager/__init__.py | 18 ++++++++++++++++++ reflex/istate/manager/disk.py | 3 +++ reflex/istate/manager/memory.py | 3 +++ reflex/istate/manager/redis.py | 3 +++ 4 files changed, 27 insertions(+) diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index eed0be7f642..febaa09a237 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -88,6 +88,23 @@ def create(cls): msg = f"Expected one of: DISK, MEMORY, REDIS, got {config.state_manager_mode}" raise InvalidStateManagerModeError(msg) + @staticmethod + def _coerce_token(token: StateToken[TOKEN_TYPE] | str) -> StateToken[TOKEN_TYPE]: + """Convert a legacy string token to a StateToken if needed. + + Args: + token: The token, either a StateToken or legacy string. + + Returns: + The coerced StateToken. + """ + if isinstance(token, str): + from reflex.istate.manager.token import BaseStateToken + from reflex.state import State + + return BaseStateToken.from_legacy_token(token, root_state=State) # type: ignore[return-value] + return token + @abstractmethod async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: """Get the state for a token. @@ -149,6 +166,7 @@ async def modify_state_with_links( """ from reflex.state import BaseState + token = self._coerce_token(token) async with self.modify_state(token, **context) as root_state: if ( isinstance(root_state, BaseState) diff --git a/reflex/istate/manager/disk.py b/reflex/istate/manager/disk.py index d61ef5e758a..9b75533ad93 100644 --- a/reflex/istate/manager/disk.py +++ b/reflex/istate/manager/disk.py @@ -168,6 +168,7 @@ async def get_state( Returns: The state for the token. """ + token = self._coerce_token(token) root_state = self.states.get(token.cache_key) self._token_last_touched[token.cache_key] = time.time() if root_state is not None: @@ -330,6 +331,7 @@ async def set_state( state: The state to set. context: The state modification context. """ + token = self._coerce_token(token) if self._write_debounce_seconds > 0: # Deferred write to reduce disk IO overhead. if token not in self._write_queue: @@ -358,6 +360,7 @@ async def modify_state( Yields: The state for the token. """ + token = self._coerce_token(token) # Disk state manager ignores the substate suffix and always returns the top-level state. lock_key = token.lock_key if lock_key not in self._states_locks: diff --git a/reflex/istate/manager/memory.py b/reflex/istate/manager/memory.py index f3dcb50ba53..07d4dc27926 100644 --- a/reflex/istate/manager/memory.py +++ b/reflex/istate/manager/memory.py @@ -152,6 +152,7 @@ async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: Returns: The state for the token. """ + token = self._coerce_token(token) state = self._get_or_create_state(token) self._track_token(token) return state @@ -170,6 +171,7 @@ async def set_state( state: The state to set. context: The state modification context. """ + token = self._coerce_token(token) self.states[token.cache_key] = state self._track_token(token) @@ -187,6 +189,7 @@ async def modify_state( Yields: The state for the token. """ + token = self._coerce_token(token) state_lock = await self._get_state_lock(token) try: diff --git a/reflex/istate/manager/redis.py b/reflex/istate/manager/redis.py index 3bd708e7eb5..3a98e16c0d4 100644 --- a/reflex/istate/manager/redis.py +++ b/reflex/istate/manager/redis.py @@ -281,6 +281,7 @@ async def get_state( Raises: RuntimeError: when the parent state for a requested state was not fetched. """ + token = self._coerce_token(token) if not isinstance(token, BaseStateToken): # Non-BaseState token: simple single-key fetch. redis_data = await self.redis.get(str(token)) @@ -369,6 +370,7 @@ async def set_state( LockExpiredError: If lock_id is provided and the lock for the token is not held by that ID. RuntimeError: If the state instance doesn't match the state name in the token. """ + token = self._coerce_token(token) # Check that we're holding the lock. if ( lock_id is not None @@ -543,6 +545,7 @@ async def modify_state( Yields: The state for the token. """ + token = self._coerce_token(token) while True: async with self._try_modify_state(token, **context) as state_instance: if state_instance is not None: From b4a8dad4bca96751aebd032b8ac523fb3232e4a2 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 16:28:20 -0700 Subject: [PATCH 76/81] add test case for StateManager legacy str tokens --- tests/units/test_state.py | 69 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 900dad6ee29..8627d4127cf 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1802,6 +1802,75 @@ async def _coro(): assert lock is None or not lock.locked() +@pytest.mark.asyncio +async def test_state_manager_legacy_token(state_manager: StateManager, token: str): + """Test that passing a legacy string token to the state manager works with a deprecation warning. + + Args: + state_manager: A state manager instance. + token: A token. + """ + from unittest.mock import patch + + import reflex_base.utils.console as _base_console + + from reflex.istate.manager import token as _token_mod + + console = _token_mod.console + + from reflex.state import State + + legacy_token = f"{token}_{OnLoadState.get_full_name()}" + + def _clear_dedupe(): + _base_console._EMITTED_DEPRECATION_WARNINGS -= { + k + for k in _base_console._EMITTED_DEPRECATION_WARNINGS + if "Passing a string to modify_state" in k + } + + _clear_dedupe() + + with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: + # modify_state should accept a legacy string token and emit a deprecation warning. + async with state_manager.modify_state(legacy_token) as state: # pyright: ignore [reportArgumentType] + assert isinstance(state, State) + # The substate targeted by the token should be prepopulated. + assert OnLoadState.get_name() in state.substates + mock_deprecate.assert_called() + assert ( + mock_deprecate.call_args.kwargs["feature_name"] + == "Passing a string to modify_state" + ) + mock_deprecate.reset_mock() + + _clear_dedupe() + + with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: + # get_state should also accept a legacy string token. + retrieved = await state_manager.get_state(legacy_token) # pyright: ignore [reportArgumentType] + assert isinstance(retrieved, State) + assert OnLoadState.get_name() in retrieved.substates + mock_deprecate.assert_called() + mock_deprecate.reset_mock() + + _clear_dedupe() + + with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: + # set_state should also accept a legacy string token. + await state_manager.set_state(legacy_token, retrieved) # pyright: ignore [reportArgumentType] + mock_deprecate.assert_called() + mock_deprecate.reset_mock() + + _clear_dedupe() + + with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: + final = await state_manager.get_state(legacy_token) # pyright: ignore [reportArgumentType] + assert isinstance(final, State) + assert OnLoadState.get_name() in final.substates + mock_deprecate.assert_called() + + @pytest_asyncio.fixture(loop_scope="function", scope="function") async def state_manager_redis() -> AsyncGenerator[StateManager, None]: """Instance of state manager for redis only. From 0d5ed8560f7b46e91329c92d635b2f507ae8934e Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 16:29:05 -0700 Subject: [PATCH 77/81] ignore QueueShutDown when shutting down queue --- .../src/reflex_base/event/processor/event_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py index 776c4b5dce1..4fc3f47dd8f 100644 --- a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py +++ b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py @@ -264,7 +264,7 @@ async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: self._queue_task.cancel() try: await self._queue_task - except (asyncio.CancelledError, RuntimeError): + except (asyncio.CancelledError, asyncio.QueueShutDown, RuntimeError): pass except Exception as ex: telemetry.send_error(ex, context="backend") From 78d7040717ad065f94af1868bc4e9ea7ef175265 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 16:29:35 -0700 Subject: [PATCH 78/81] optimize test_event_processing benchmark avoid overhead in stop() and join() paths --- tests/benchmarks/test_event_processing.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/benchmarks/test_event_processing.py b/tests/benchmarks/test_event_processing.py index 6990482cbdb..15acf8094d4 100644 --- a/tests/benchmarks/test_event_processing.py +++ b/tests/benchmarks/test_event_processing.py @@ -88,11 +88,10 @@ async def run_events(num_events: int, num_expected_deltas: int) -> None: emitted_deltas.clear() async with processor as p: - for _ in range(num_events): - await p.enqueue(token, event) - # Wait for the processor to drain all events. - await p.join(timeout=10) - + async for _ in asyncio.as_completed([ + await p.enqueue(token, event) for _ in range(num_events) + ]): + pass assert len(emitted_deltas) == num_expected_deltas yield run_events From a34eab4335852f279ff3e4b15725974e97a77f3b Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 16:35:24 -0700 Subject: [PATCH 79/81] asyncio.QueueShutDown was only added in 3.13+ --- .../src/reflex_base/event/processor/event_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py index 4fc3f47dd8f..f22c4fb78ba 100644 --- a/packages/reflex-base/src/reflex_base/event/processor/event_processor.py +++ b/packages/reflex-base/src/reflex_base/event/processor/event_processor.py @@ -264,7 +264,7 @@ async def stop(self, graceful_shutdown_timeout: float | None = None) -> None: self._queue_task.cancel() try: await self._queue_task - except (asyncio.CancelledError, asyncio.QueueShutDown, RuntimeError): + except (asyncio.CancelledError, QueueShutDown, RuntimeError): pass except Exception as ex: telemetry.send_error(ex, context="backend") From 98241e0ef7c5ba464b121337809c8fd38b9ff747 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 16:47:57 -0700 Subject: [PATCH 80/81] Add deprecated typing hints for passing str token to StateManager --- reflex/istate/manager/__init__.py | 70 ++++++++++++++++++++++++++++--- tests/units/test_state.py | 8 ++-- 2 files changed, 68 insertions(+), 10 deletions(-) diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index febaa09a237..46e64382d2f 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -4,13 +4,13 @@ import dataclasses from abc import ABC, abstractmethod from collections.abc import AsyncIterator -from typing import TypedDict +from typing import TYPE_CHECKING, TypedDict, overload from reflex_base import constants from reflex_base.config import get_config from reflex_base.event import Event from reflex_base.utils.exceptions import InvalidStateManagerModeError -from typing_extensions import ReadOnly, Unpack +from typing_extensions import ReadOnly, Unpack, deprecated from reflex.istate.manager.token import TOKEN_TYPE, StateToken from reflex.utils import console, prerequisites @@ -105,8 +105,64 @@ def _coerce_token(token: StateToken[TOKEN_TYPE] | str) -> StateToken[TOKEN_TYPE] return BaseStateToken.from_legacy_token(token, root_state=State) # type: ignore[return-value] return token + if TYPE_CHECKING: + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + async def get_state(self, token: str) -> TOKEN_TYPE: ... + + @overload + async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: ... + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + async def set_state( + self, + token: str, + state: TOKEN_TYPE, + **context: Unpack[StateModificationContext], + ) -> None: ... + + @overload + async def set_state( + self, + token: StateToken[TOKEN_TYPE], + state: TOKEN_TYPE, + **context: Unpack[StateModificationContext], + ) -> None: ... + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + def modify_state( + self, token: str, **context: Unpack[StateModificationContext] + ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... + + @overload + def modify_state( + self, + token: StateToken[TOKEN_TYPE], + **context: Unpack[StateModificationContext], + ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + def modify_state_with_links( + self, + token: str, + previous_dirty_vars: dict[str, set[str]] | None = None, + **context: Unpack[StateModificationContext], + ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... + + @overload + def modify_state_with_links( + self, + token: StateToken[TOKEN_TYPE], + previous_dirty_vars: dict[str, set[str]] | None = None, + **context: Unpack[StateModificationContext], + ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... + @abstractmethod - async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: + async def get_state(self, token: StateToken[TOKEN_TYPE] | str) -> TOKEN_TYPE: """Get the state for a token. Args: @@ -119,7 +175,7 @@ async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: @abstractmethod async def set_state( self, - token: StateToken[TOKEN_TYPE], + token: StateToken[TOKEN_TYPE] | str, state: TOKEN_TYPE, **context: Unpack[StateModificationContext], ): @@ -134,7 +190,9 @@ async def set_state( @abstractmethod @contextlib.asynccontextmanager async def modify_state( - self, token: StateToken[TOKEN_TYPE], **context: Unpack[StateModificationContext] + self, + token: StateToken[TOKEN_TYPE] | str, + **context: Unpack[StateModificationContext], ) -> AsyncIterator[TOKEN_TYPE]: """Modify the state for a token while holding exclusive lock. @@ -150,7 +208,7 @@ async def modify_state( @contextlib.asynccontextmanager async def modify_state_with_links( self, - token: StateToken[TOKEN_TYPE], + token: StateToken[TOKEN_TYPE] | str, previous_dirty_vars: dict[str, set[str]] | None = None, **context: Unpack[StateModificationContext], ) -> AsyncIterator[TOKEN_TYPE]: diff --git a/tests/units/test_state.py b/tests/units/test_state.py index 8627d4127cf..a61add0a435 100644 --- a/tests/units/test_state.py +++ b/tests/units/test_state.py @@ -1833,7 +1833,7 @@ def _clear_dedupe(): with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: # modify_state should accept a legacy string token and emit a deprecation warning. - async with state_manager.modify_state(legacy_token) as state: # pyright: ignore [reportArgumentType] + async with state_manager.modify_state(legacy_token) as state: assert isinstance(state, State) # The substate targeted by the token should be prepopulated. assert OnLoadState.get_name() in state.substates @@ -1848,7 +1848,7 @@ def _clear_dedupe(): with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: # get_state should also accept a legacy string token. - retrieved = await state_manager.get_state(legacy_token) # pyright: ignore [reportArgumentType] + retrieved = await state_manager.get_state(legacy_token) assert isinstance(retrieved, State) assert OnLoadState.get_name() in retrieved.substates mock_deprecate.assert_called() @@ -1858,14 +1858,14 @@ def _clear_dedupe(): with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: # set_state should also accept a legacy string token. - await state_manager.set_state(legacy_token, retrieved) # pyright: ignore [reportArgumentType] + await state_manager.set_state(legacy_token, retrieved) mock_deprecate.assert_called() mock_deprecate.reset_mock() _clear_dedupe() with patch.object(console, "deprecate", wraps=console.deprecate) as mock_deprecate: - final = await state_manager.get_state(legacy_token) # pyright: ignore [reportArgumentType] + final = await state_manager.get_state(legacy_token) assert isinstance(final, State) assert OnLoadState.get_name() in final.substates mock_deprecate.assert_called() From 62deadf8136050d79597fe0ae05340f7173f0d91 Mon Sep 17 00:00:00 2001 From: Masen Furer Date: Mon, 6 Apr 2026 16:54:46 -0700 Subject: [PATCH 81/81] Move overloads outside of if TYPE_CHECKING --- reflex/istate/manager/__init__.py | 111 +++++++++++++++--------------- 1 file changed, 56 insertions(+), 55 deletions(-) diff --git a/reflex/istate/manager/__init__.py b/reflex/istate/manager/__init__.py index 46e64382d2f..5e3d71ee0e6 100644 --- a/reflex/istate/manager/__init__.py +++ b/reflex/istate/manager/__init__.py @@ -15,6 +15,9 @@ from reflex.istate.manager.token import TOKEN_TYPE, StateToken from reflex.utils import console, prerequisites +if TYPE_CHECKING: + from reflex.state import BaseState + class StateModificationContext(TypedDict, total=False): """The context for modifying state.""" @@ -105,61 +108,59 @@ def _coerce_token(token: StateToken[TOKEN_TYPE] | str) -> StateToken[TOKEN_TYPE] return BaseStateToken.from_legacy_token(token, root_state=State) # type: ignore[return-value] return token - if TYPE_CHECKING: - - @overload - @deprecated("pass token as rx.BaseStateToken instead of str") - async def get_state(self, token: str) -> TOKEN_TYPE: ... - - @overload - async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: ... - - @overload - @deprecated("pass token as rx.BaseStateToken instead of str") - async def set_state( - self, - token: str, - state: TOKEN_TYPE, - **context: Unpack[StateModificationContext], - ) -> None: ... - - @overload - async def set_state( - self, - token: StateToken[TOKEN_TYPE], - state: TOKEN_TYPE, - **context: Unpack[StateModificationContext], - ) -> None: ... - - @overload - @deprecated("pass token as rx.BaseStateToken instead of str") - def modify_state( - self, token: str, **context: Unpack[StateModificationContext] - ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... - - @overload - def modify_state( - self, - token: StateToken[TOKEN_TYPE], - **context: Unpack[StateModificationContext], - ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... - - @overload - @deprecated("pass token as rx.BaseStateToken instead of str") - def modify_state_with_links( - self, - token: str, - previous_dirty_vars: dict[str, set[str]] | None = None, - **context: Unpack[StateModificationContext], - ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... - - @overload - def modify_state_with_links( - self, - token: StateToken[TOKEN_TYPE], - previous_dirty_vars: dict[str, set[str]] | None = None, - **context: Unpack[StateModificationContext], - ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + async def get_state(self, token: str) -> "BaseState": ... + + @overload + async def get_state(self, token: StateToken[TOKEN_TYPE]) -> TOKEN_TYPE: ... + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + async def set_state( + self, + token: str, + state: "BaseState", + **context: Unpack[StateModificationContext], + ) -> None: ... + + @overload + async def set_state( + self, + token: StateToken[TOKEN_TYPE], + state: TOKEN_TYPE, + **context: Unpack[StateModificationContext], + ) -> None: ... + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + def modify_state( + self, token: str, **context: Unpack[StateModificationContext] + ) -> contextlib.AbstractAsyncContextManager["BaseState"]: ... + + @overload + def modify_state( + self, + token: StateToken[TOKEN_TYPE], + **context: Unpack[StateModificationContext], + ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... + + @overload + @deprecated("pass token as rx.BaseStateToken instead of str") + def modify_state_with_links( + self, + token: str, + previous_dirty_vars: dict[str, set[str]] | None = None, + **context: Unpack[StateModificationContext], + ) -> contextlib.AbstractAsyncContextManager["BaseState"]: ... + + @overload + def modify_state_with_links( + self, + token: StateToken[TOKEN_TYPE], + previous_dirty_vars: dict[str, set[str]] | None = None, + **context: Unpack[StateModificationContext], + ) -> contextlib.AbstractAsyncContextManager[TOKEN_TYPE]: ... @abstractmethod async def get_state(self, token: StateToken[TOKEN_TYPE] | str) -> TOKEN_TYPE: