Skip to content

Commit 72c59b3

Browse files
committed
refactor: use fields name from dataclass in mutation callback for improved reliability
1 parent 38e12f6 commit 72c59b3

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

src/_algopy_testing/primitives/array.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
parameterize_type,
1515
)
1616

17+
if typing.TYPE_CHECKING:
18+
from _typeshed import DataclassInstance
19+
1720
_TArrayItem = typing.TypeVar("_TArrayItem")
1821
_TArrayLength = typing.TypeVar("_TArrayLength", bound=int)
1922
_T = typing.TypeVar("_T")
@@ -509,10 +512,15 @@ def from_bytes(cls, value: bytes, /) -> typing.Self:
509512
class Struct(Serializable, MutableBytes):
510513
"""Base class for Struct types."""
511514

515+
_field_names: typing.ClassVar[list[str]]
516+
512517
def __init_subclass__(cls, *args: typing.Any, **kwargs: dict[str, typing.Any]) -> None:
513518
# make implementation not frozen, so we can conditionally control behaviour
514519
dataclasses.dataclass(cls, *args, **{**kwargs, "frozen": False})
515520
frozen = kwargs.get("frozen", False)
521+
cls._field_names = [
522+
f.name for f in dataclasses.fields(typing.cast("type[DataclassInstance]", cls))
523+
]
516524
assert isinstance(frozen, bool)
517525

518526
def __post_init__(self) -> None:
@@ -527,11 +535,8 @@ def __getattribute__(self, name: str) -> typing.Any:
527535
def __setattr__(self, key: str, value: typing.Any) -> None:
528536
super().__setattr__(key, value)
529537
# don't update backing value until base class has been init'd
530-
if hasattr(self, "_on_mutate") and key not in {
531-
"_MutableBytes__value",
532-
"_on_mutate",
533-
"_value",
534-
}:
538+
539+
if hasattr(self, "_on_mutate") and key in self._field_names:
535540
self._update_backing_value()
536541

537542
def copy(self) -> typing.Self:

0 commit comments

Comments
 (0)