From 2c04360d794485630d11826b8b83bd31de963a40 Mon Sep 17 00:00:00 2001 From: CallumFu Date: Sat, 25 Apr 2026 11:29:21 +0800 Subject: [PATCH] Improve tupledict cache invalidation --- src/pyoptinterface/_src/tupledict.py | 42 ++++++++++++++++++++++++---- tests/test_tupledict.py | 30 ++++++++++++++++++++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/src/pyoptinterface/_src/tupledict.py b/src/pyoptinterface/_src/tupledict.py index ecf5fb69..77155efc 100644 --- a/src/pyoptinterface/_src/tupledict.py +++ b/src/pyoptinterface/_src/tupledict.py @@ -10,13 +10,44 @@ def __init__(self, *args, **kwargs): self.__select_cache = None self.__scalar = False + def __invalidate_cache(self): + self.__select_cache = None + def __setitem__(self, key, value): super().__setitem__(key, value) - self.__select_cache = None + self.__invalidate_cache() def __delitem__(self, key): super().__delitem__(key) - self.__select_cache = None + self.__invalidate_cache() + + def update(self, *args, **kwargs): + super().update(*args, **kwargs) + self.__invalidate_cache() + + def pop(self, *args): + result = super().pop(*args) + self.__invalidate_cache() + return result + + def popitem(self): + result = super().popitem() + self.__invalidate_cache() + return result + + def clear(self): + super().clear() + self.__invalidate_cache() + + def setdefault(self, key, default=None): + if key not in self: + self.__invalidate_cache() + return super().setdefault(key, default) + + def __ior__(self, other): + result = super().__ior__(other) + self.__invalidate_cache() + return result def __check_key_length(self): if len(self) == 0: @@ -40,6 +71,10 @@ def select(self, *keys, with_key=False): if len(keys) == 0: yield from () return + # Ensure key type is detected before branching + if self.__select_cache is None: + self.__check_key_length() + self.__select_cache = dict() if self.__scalar: if len(keys) != 1: raise ValueError( @@ -60,9 +95,6 @@ def select(self, *keys, with_key=False): else: yield from () else: - if self.__select_cache is None: - self.__check_key_length() - self.__select_cache = dict() key_len = self.__key_len if len(keys) > key_len: raise ValueError( diff --git a/tests/test_tupledict.py b/tests/test_tupledict.py index 4b5cd579..d4af466e 100644 --- a/tests/test_tupledict.py +++ b/tests/test_tupledict.py @@ -68,3 +68,33 @@ def test_tupledict_map(): assert list(td_m.values()) == [i**2 for i in range(10)] assert list(td_m.keys()) == list(td.keys()) + + +def test_tupledict_cache_invalidation(): + # update() should invalidate select cache + td = tupledict({(1, 2): "a", (1, 3): "b", (2, 2): "c"}) + assert list(td.select(1, WILDCARD)) == ["a", "b"] + td.update({(3, 2): "e"}) + assert list(td.select(3, WILDCARD)) == ["e"] + + # pop() should invalidate select cache + td2 = tupledict({(1, 2): "a", (1, 3): "b"}) + assert list(td2.select(1, WILDCARD)) == ["a", "b"] + td2.pop((1, 3)) + assert list(td2.select(1, WILDCARD)) == ["a"] + + # clear() should invalidate select cache + td3 = tupledict({(1, 2): "a"}) + list(td3.select(1, WILDCARD)) # build cache + td3.clear() + td3[(1, 2)] = "x" + assert list(td3.select(1, WILDCARD)) == ["x"] + + +def test_tupledict_scalar_key_select(): + # tupledict created via __init__ with scalar keys should work with select() + td = tupledict({1: "a", 2: "b", 3: "c"}) + assert list(td.select(1)) == ["a"] + assert sorted(td.select(WILDCARD)) == ["a", "b", "c"] + assert list(td.select(1, with_key=True)) == [(1, "a")] + assert list(td.select(99)) == []