Skip to content

Commit cf7a7aa

Browse files
committed
update object mapper
1 parent 61bc6e4 commit cf7a7aa

File tree

1 file changed

+19
-15
lines changed

1 file changed

+19
-15
lines changed

dictdatabase/object_mapper.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ def fill_object_from_dict_using_type_hints(obj, cls, data: dict):
3535
if var_name not in data:
3636
nullable = var_type_origin is UnionType and NoneType in var_type_args
3737
if not nullable:
38-
raise RuntimeError(f"Missing variable '{var_name}' in {cls.__name__}.")
39-
continue
38+
raise KeyError(f"Missing variable '{var_name}' in {cls.__name__}.")
4039
# When it is a list, fill the list with the items
4140
if var_type_origin is list and len(var_type_args) == 1:
4241
item_type = var_type_args[0]
@@ -68,7 +67,15 @@ class FileDictModel(ABC, Generic[T]):
6867
"""
6968

7069
__file__ = None
71-
__item_model__: Type[T]
70+
71+
@classmethod
72+
def _get_item_model(cls):
73+
for base in cls.__orig_bases__:
74+
for type_args in get_args(base):
75+
if issubclass(type_args, FileDictItemModel):
76+
return type_args
77+
raise AttributeError("FileDictModel must specify a FileDictItemModel")
78+
7279

7380
@classmethod
7481
def get_at_key(cls, key) -> T:
@@ -77,20 +84,17 @@ def get_at_key(cls, key) -> T:
7784
The data is partially read from the __file__.
7885
"""
7986
data = io_safe.partial_read(cls.__file__, key)
80-
res: T = cls.__item_model__.from_key_value(key, data)
87+
res: T = cls._get_item_model().from_key_value(key, data)
8188
return res
8289

8390
@classmethod
8491
def session_at_key(cls, key):
85-
return cls.__item_model__.session(key)
92+
return cls._get_item_model().session(key)
8693

8794
@classmethod
88-
def items(cls) -> list[Tuple[str, T]]:
89-
"""
90-
Gets all items as a list of tuples (key, ORM model of value).
91-
"""
95+
def get_all(cls) -> dict[str, T]:
9296
data = io_safe.read(cls.__file__)
93-
return [(k, cls.__item_model__.from_key_value(k, v)) for k, v in data.items()]
97+
return {k: cls._get_item_model().from_key_value(k, v) for k, v in data.items()}
9498

9599
@classmethod
96100
def session(cls):
@@ -101,18 +105,18 @@ def session(cls):
101105
def make_session_obj_from_dict(data):
102106
sess_obj = {}
103107
for k, v in data.items():
104-
sess_obj[k] = cls.__item_model__.from_key_value(k, v)
108+
sess_obj[k] = cls._get_item_model().from_key_value(k, v)
105109
return sess_obj
106110
return SessionFileFull(cls.__file__, make_session_obj_from_dict)
107111

108112

109113
@classmethod
110-
def get_where(cls, where: callable) -> list[Tuple[str, T]]:
114+
def get_where(cls, where: callable[str, T]) -> dict[str, T]:
111115
"""
112-
Gets all items where the where function returns True.
113-
The where function takes an object of type __item_model__.
116+
Return a dictionary of all the items for which the where function returns True.
117+
Where takes the key and the value's model object as arguments.
114118
"""
115-
return [(k, v) for k, v in cls.items() if where(v)]
119+
return {k: v for k, v in cls.get_all().items() if where(k, v)}
116120

117121

118122

0 commit comments

Comments
 (0)