Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions src/pyoptinterface/_src/tupledict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,44 @@ def __init__(self, *args, **kwargs):
self.__select_cache = None
self.__scalar = False

def __invalidate_cache(self):
self.__select_cache = None

def __setitem__(self, key, value):
super().__setitem__(key, value)
self.__select_cache = None
self.__invalidate_cache()

def __delitem__(self, key):
super().__delitem__(key)
self.__select_cache = None
self.__invalidate_cache()

def update(self, *args, **kwargs):
super().update(*args, **kwargs)
self.__invalidate_cache()

def pop(self, *args):
result = super().pop(*args)
self.__invalidate_cache()
return result

def popitem(self):
result = super().popitem()
self.__invalidate_cache()
return result

def clear(self):
super().clear()
self.__invalidate_cache()

def setdefault(self, key, default=None):
if key not in self:
self.__invalidate_cache()
return super().setdefault(key, default)

def __ior__(self, other):
result = super().__ior__(other)
self.__invalidate_cache()
return result

def __check_key_length(self):
if len(self) == 0:
Expand All @@ -40,6 +71,10 @@ def select(self, *keys, with_key=False):
if len(keys) == 0:
yield from ()
return
# Ensure key type is detected before branching
if self.__select_cache is None:
self.__check_key_length()
self.__select_cache = dict()
if self.__scalar:
if len(keys) != 1:
raise ValueError(
Expand All @@ -60,9 +95,6 @@ def select(self, *keys, with_key=False):
else:
yield from ()
else:
if self.__select_cache is None:
self.__check_key_length()
self.__select_cache = dict()
key_len = self.__key_len
if len(keys) > key_len:
raise ValueError(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_tupledict.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,33 @@ def test_tupledict_map():
assert list(td_m.values()) == [i**2 for i in range(10)]

assert list(td_m.keys()) == list(td.keys())


def test_tupledict_cache_invalidation():
# update() should invalidate select cache
td = tupledict({(1, 2): "a", (1, 3): "b", (2, 2): "c"})
assert list(td.select(1, WILDCARD)) == ["a", "b"]
td.update({(3, 2): "e"})
assert list(td.select(3, WILDCARD)) == ["e"]

# pop() should invalidate select cache
td2 = tupledict({(1, 2): "a", (1, 3): "b"})
assert list(td2.select(1, WILDCARD)) == ["a", "b"]
td2.pop((1, 3))
assert list(td2.select(1, WILDCARD)) == ["a"]

# clear() should invalidate select cache
td3 = tupledict({(1, 2): "a"})
list(td3.select(1, WILDCARD)) # build cache
td3.clear()
td3[(1, 2)] = "x"
assert list(td3.select(1, WILDCARD)) == ["x"]


def test_tupledict_scalar_key_select():
# tupledict created via __init__ with scalar keys should work with select()
td = tupledict({1: "a", 2: "b", 3: "c"})
assert list(td.select(1)) == ["a"]
assert sorted(td.select(WILDCARD)) == ["a", "b", "c"]
assert list(td.select(1, with_key=True)) == [(1, "a")]
assert list(td.select(99)) == []
Loading