Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions news/6689.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support using mutable state proxies as async context managers.
114 changes: 107 additions & 7 deletions reflex/istate/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections.abc import Callable, Sequence
from importlib.util import find_spec
from types import MethodType
from typing import TYPE_CHECKING, Any, SupportsIndex, TypeVar
from typing import TYPE_CHECKING, Any, Literal, NoReturn, SupportsIndex, TypeVar

import wrapt
from reflex_base.event import Event
Expand All @@ -29,6 +29,13 @@

T_STATE = TypeVar("T_STATE", bound="BaseState")
T = TypeVar("T")
AccessSpec = (
tuple[Literal["attr"], str]
| tuple[Literal["item"], Any]
| tuple[Literal["iter"], None]
)
_ITER_ACCESS_SPEC: AccessSpec = ("iter", None)
_ROOT_ITER_PATH: tuple[AccessSpec, ...] = (_ITER_ACCESS_SPEC,)

# Cached filename of the dataclasses module, used to detect reads originating
# from `dataclasses.asdict`/`astuple` internals on the proxy read hot-path.
Expand Down Expand Up @@ -446,18 +453,27 @@ def __new__(cls, wrapped: Any, *args, **kwargs) -> MutableProxy:
cls = cls.__dataclass_proxies__[wrapper_cls_name]
return super().__new__(cls) # pyright: ignore[reportArgumentType]

def __init__(self, wrapped: Any, state: BaseState, field_name: str):
def __init__(
self,
wrapped: Any,
state: BaseState,
field_name: str,
path: tuple[AccessSpec, ...] | None = None,
):
"""Create a proxy for a mutable object that tracks changes.

Args:
wrapped: The object to proxy.
state: The state to mark dirty when the object is changed.
field_name: The name of the field on the state associated with the
wrapped object.
path: Access path from the state field to this wrapped object.
"""
super().__init__(wrapped)
self._self_state = state
self._self_field_name = field_name
self._self_path = path or ()
self._self_actx_state = None

def __repr__(self) -> str:
"""Get the representation of the wrapped object.
Expand All @@ -467,6 +483,73 @@ def __repr__(self) -> str:
"""
return f"{type(self).__name__}({self.__wrapped__})"

async def __aenter__(self) -> Self:
"""Enter the async context manager protocol through the bound state.

Returns:
This proxy refreshed from the current state field.
"""
if self._self_actx_state is not None:
msg = (
"Mutable proxy is already mutable. Do not nest `async with proxy` "
"blocks."
)
raise RuntimeError(msg)
context_state = self._self_state
self._self_actx_state = context_state
Comment thread
harsh21234i marked this conversation as resolved.
aenter_ok = False
try:
state = await context_state.__aenter__()
aenter_ok = True
refreshed_value = getattr(state, self._self_field_name)
for access_spec in self._self_path:
match access_spec:
case ("attr", access_key):
refreshed_value = getattr(refreshed_value, access_key)
case ("item", access_key):
refreshed_value = refreshed_value[access_key]
case _:
self._raise_refresh_error()
if (
isinstance(refreshed_value, MutableProxy)
and self._self_field_name == refreshed_value._self_field_name
):
super().__setattr__("__wrapped__", refreshed_value.__wrapped__)
self._self_state = refreshed_value._self_state
Comment thread
harsh21234i marked this conversation as resolved.
self._self_path = refreshed_value._self_path
else:
self._raise_refresh_error()
except (Exception, asyncio.CancelledError):
try:
if aenter_ok:
await context_state.__aexit__(*sys.exc_info())
finally:
self._self_actx_state = None
raise
Comment thread
greptile-apps[bot] marked this conversation as resolved.
return self
Comment thread
greptile-apps[bot] marked this conversation as resolved.

def _raise_refresh_error(self) -> NoReturn:
"""Raise when this proxy cannot be refreshed from its state field."""
msg = (
"Unable to refresh mutable proxy from state field "
f"`{self._self_field_name}`."
)
raise RuntimeError(msg)

async def __aexit__(self, *exc_info: Any) -> None:
"""Exit the async context manager protocol through the bound state.

Args:
exc_info: The exception info tuple.
"""
context_state = self._self_actx_state
if context_state is None:
return
try:
await context_state.__aexit__(*exc_info)
finally:
self._self_actx_state = None

def _mark_dirty(
self,
wrapped: Callable | None = None,
Expand Down Expand Up @@ -513,11 +596,14 @@ def _is_called_from_dataclasses_internal() -> bool:
return True
return False

def _wrap_recursive(self, value: Any) -> Any:
def _wrap_recursive(
self, value: Any, new_path_segment: AccessSpec | None = None
) -> Any:
"""Wrap a value recursively if it is mutable.

Args:
value: The value to wrap.
new_path_segment: Access path segment from this proxy to the value.

Returns:
The wrapped value.
Expand All @@ -532,10 +618,18 @@ def _wrap_recursive(self, value: Any) -> Any:
# Recursively wrap mutable types.
if is_mutable_type(type(value)):
base_cls = globals()[self.__base_proxy__]
path = self._self_path
if new_path_segment is not None:
path = (
_ROOT_ITER_PATH
if new_path_segment is _ITER_ACCESS_SPEC and not path
else (*path, new_path_segment)
)
return base_cls(
wrapped=value,
state=self._self_state,
field_name=self._self_field_name,
path=path,
)
return value

Expand Down Expand Up @@ -592,10 +686,12 @@ def __getattr__(self, __name: str) -> Any:
if is_mutable_type(type(value)) and __name not in (
"__wrapped__",
"_self_state",
"_self_path",
Comment thread
harsh21234i marked this conversation as resolved.
"_self_actx_state",
"__dict__",
):
# Recursively wrap mutable attribute values retrieved through this proxy.
return self._wrap_recursive(value)
return self._wrap_recursive(value, ("attr", __name))

return value

Expand All @@ -610,9 +706,13 @@ def __getitem__(self, key: Any) -> Any:
"""
value = super().__getitem__(key) # pyright: ignore[reportAttributeAccessIssue]
if isinstance(key, slice) and isinstance(value, list):
return [self._wrap_recursive(item) for item in value]
indices = range(len(self.__wrapped__))[key]
return [
self._wrap_recursive(item, ("item", index))
for item, index in zip(value, indices, strict=False)
]
# Recursively wrap mutable items retrieved through this proxy.
return self._wrap_recursive(value)
return self._wrap_recursive(value, ("item", key))

def __iter__(self) -> Any:
"""Iterate over the proxied object and return a proxy if mutable.
Expand All @@ -622,7 +722,7 @@ def __iter__(self) -> Any:
"""
for value in super().__iter__(): # pyright: ignore[reportAttributeAccessIssue]
# Recursively wrap mutable items retrieved through this proxy.
yield self._wrap_recursive(value)
yield self._wrap_recursive(value, _ITER_ACCESS_SPEC)

def __delattr__(self, name: str):
"""Delete the attribute on the proxied object and mark state dirty.
Expand Down
179 changes: 179 additions & 0 deletions tests/units/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4954,6 +4954,12 @@ class MutableProxyState(BaseState):
data: dict[str, list[int]] = {"a": [1], "b": [2]}


class IterableMutableProxyState(BaseState):
"""A test state with mutable values that can be accessed by iteration."""

data: list[list[int]] = [[1]]


@pytest.mark.asyncio
async def test_rebind_mutable_proxy(
token: str, attached_mock_event_context: EventContext
Expand Down Expand Up @@ -5006,6 +5012,179 @@ async def test_rebind_mutable_proxy(
assert state.data["b"] == [2, 3]


@pytest.mark.asyncio
async def test_immutable_mutable_proxy_async_context_manager(
token: str, attached_mock_event_context: EventContext
) -> None:
"""Mutable state proxies can enter the owning StateProxy context."""
state_manager = attached_mock_event_context.state_manager

async with state_manager.modify_state(
BaseStateToken(ident=token, cls=MutableProxyState)
) as state:
state.router = RouterData.from_router_data({
"query": {},
"token": token,
"sid": "test_sid",
})
state_proxy = StateProxy(state)
data_proxy = state_proxy.data
items_proxy = data_proxy["a"]

assert isinstance(data_proxy, ImmutableMutableProxy)
assert isinstance(items_proxy, ImmutableMutableProxy)
await data_proxy.__aexit__(None, None, None)
with pytest.raises(ImmutableStateError):
data_proxy["a"].append(3)
with pytest.raises(ImmutableStateError):
items_proxy.append(3)

async with state_manager.modify_state(
BaseStateToken(ident=token, cls=MutableProxyState)
) as state:
assert isinstance(state, MutableProxyState)
state.data["a"].append(2)

async with data_proxy as mutable_data:
assert mutable_data is data_proxy
with pytest.raises(RuntimeError, match="already mutable"):
async with data_proxy:
pass
data_proxy["a"].append(3)
mutable_data["b"].append(4)

async with items_proxy as mutable_items:
assert mutable_items is items_proxy
mutable_items.append(5)
assert items_proxy.__wrapped__ == [1, 2, 3, 5]

with pytest.raises(ImmutableStateError):
data_proxy["a"].append(6)
with pytest.raises(ImmutableStateError):
items_proxy.append(6)

async with state_manager.modify_state(
BaseStateToken(ident=token, cls=MutableProxyState)
) as state:
assert isinstance(state, MutableProxyState)
assert state.data["a"] == [1, 2, 3, 5]
assert state.data["b"] == [2, 4]


@pytest.mark.asyncio
async def test_immutable_mutable_proxy_async_context_rejects_iter_proxy(
token: str, attached_mock_event_context: EventContext
) -> None:
"""Iteration-sourced mutable proxies fail clearly as async context managers."""
state_manager = attached_mock_event_context.state_manager

async with state_manager.modify_state(
BaseStateToken(ident=token, cls=IterableMutableProxyState)
) as state:
state.router = RouterData.from_router_data({
"query": {},
"token": token,
"sid": "test_sid",
})
state_proxy = StateProxy(state)
[items_proxy] = state_proxy.data

assert isinstance(items_proxy, ImmutableMutableProxy)
with pytest.raises(RuntimeError, match="Unable to refresh mutable proxy"):
async with items_proxy:
pass

async with state_proxy:
state_proxy.data[0].append(2)

async with state_manager.modify_state(
BaseStateToken(ident=token, cls=IterableMutableProxyState)
) as state:
assert isinstance(state, IterableMutableProxyState)
assert state.data == [[1, 2]]


@pytest.mark.asyncio
async def test_immutable_mutable_proxy_async_context_recovers_from_enter_failure(
token: str,
attached_mock_event_context: EventContext,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A failed owning StateProxy enter does not permanently block the proxy."""
state_manager = attached_mock_event_context.state_manager

async with state_manager.modify_state(
BaseStateToken(ident=token, cls=MutableProxyState)
) as state:
state.router = RouterData.from_router_data({
"query": {},
"token": token,
"sid": "test_sid",
})
state_proxy = StateProxy(state)
data_proxy = state_proxy.data

original_aenter = StateProxy.__aenter__
fail_once = True

async def fail_first_enter(self: StateProxy) -> StateProxy:
nonlocal fail_once
if fail_once:
fail_once = False
raise asyncio.CancelledError
return await original_aenter(self)

monkeypatch.setattr(StateProxy, "__aenter__", fail_first_enter)

with pytest.raises(asyncio.CancelledError):
async with data_proxy:
pass

assert data_proxy._self_actx_state is None
async with data_proxy as mutable_data:
mutable_data["a"].append(2)

async with state_manager.modify_state(
BaseStateToken(ident=token, cls=MutableProxyState)
) as state:
assert isinstance(state, MutableProxyState)
assert state.data["a"] == [1, 2]


@pytest.mark.asyncio
async def test_immutable_mutable_proxy_async_context_clears_state_when_cleanup_fails(
token: str,
attached_mock_event_context: EventContext,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""A failed cleanup while entering does not leave the proxy marked mutable."""
state_manager = attached_mock_event_context.state_manager

async with state_manager.modify_state(
BaseStateToken(ident=token, cls=IterableMutableProxyState)
) as state:
state.router = RouterData.from_router_data({
"query": {},
"token": token,
"sid": "test_sid",
})
state_proxy = StateProxy(state)
[items_proxy] = state_proxy.data

async def fail_exit(self: StateProxy, *exc_info: Any) -> None:
await asyncio.sleep(0)
msg = "cleanup failed"
raise RuntimeError(msg)

monkeypatch.setattr(StateProxy, "__aexit__", fail_exit)

with pytest.raises(RuntimeError, match="cleanup failed"):
async with items_proxy:
pass

assert items_proxy._self_actx_state is None


def test_override_base_method_skips_event_handler_wrapping():
"""A method marked with __override_base_method__ should not be wrapped as an EventHandler."""
from reflex.state import _override_base_method
Expand Down