1010from .session_timeseries import TimeSeries
1111from .session_counters import DocumentCounters
1212from typing import Dict , List
13+ from collections import MutableSet
1314
1415
1516class _SaveChangesData (object ):
@@ -19,6 +20,129 @@ def __init__(self, commands, deferred_command_count, entities=None):
1920 self .deferred_command_count = deferred_command_count
2021
2122
23+ class _RefEq :
24+ def __init__ (self , ref ):
25+ if isinstance (ref , _RefEq ):
26+ self .ref = ref .ref
27+ return
28+ self .ref = ref
29+
30+ # As we split the hashable and unhashable items into separate collections, we only compare _RefEq to other _RefEq
31+ def __eq__ (self , other ):
32+ if isinstance (other , _RefEq ):
33+ return id (self .ref ) == id (other .ref )
34+ raise TypeError ("Expected _RefEq type object" )
35+
36+ def __hash__ (self ):
37+ return id (self .ref )
38+
39+
40+ class _RefEqEntityHolder (object ):
41+ def __init__ (self ):
42+ self .unhashable_items = dict ()
43+
44+ def __len__ (self ):
45+ return len (self .unhashable_items )
46+
47+ def __contains__ (self , item ):
48+ return _RefEq (item ) in self .unhashable_items
49+
50+ def __delitem__ (self , key ):
51+ del self .unhashable_items [_RefEq (key )]
52+
53+ def __setitem__ (self , key , value ):
54+ self .unhashable_items [_RefEq (key )] = value
55+
56+ def __getitem__ (self , key ):
57+ return self .unhashable_items [_RefEq (key )]
58+
59+ def __getattribute__ (self , item ):
60+ if item == "unhashable_items" :
61+ return super ().__getattribute__ (item )
62+ return self .unhashable_items .__getattribute__ (item )
63+
64+
65+ class _DocumentsByEntityHolder (object ):
66+ def __init__ (self ):
67+ self ._hashable_items = dict ()
68+ self ._unhashable_items = _RefEqEntityHolder ()
69+
70+ def __repr__ (self ):
71+ return f"{ self .__class__ .__name__ } : { [item for item in self .__iter__ ()]} "
72+
73+ def __len__ (self ):
74+ return len (self ._hashable_items ) + len (self ._unhashable_items )
75+
76+ def __contains__ (self , item ):
77+ try :
78+ return item in self ._hashable_items
79+ except TypeError as e :
80+ if str (e .args [0 ]).startswith ("unhashable type" ):
81+ return item in self ._unhashable_items
82+ raise e
83+
84+ def __setitem__ (self , key , value ):
85+ try :
86+ self ._hashable_items [key ] = value
87+ except TypeError as e :
88+ if str (e .args [0 ]).startswith ("unhashable type" ):
89+ self ._unhashable_items [key ] = value
90+ return
91+ raise e
92+
93+ def __getitem__ (self , key ):
94+ try :
95+ return self ._hashable_items [key ]
96+ except (TypeError , KeyError ):
97+ return self ._unhashable_items [key ]
98+
99+ def __iter__ (self ):
100+ d = list (map (lambda x : x .ref , self ._unhashable_items .keys ()))
101+ if len (self ._hashable_items ) > 0 :
102+ d .extend (self ._hashable_items .keys ())
103+ return (item for item in d )
104+
105+ def get (self , key , default = None ):
106+ return self [key ] if key in self else default
107+
108+ def pop (self , key , default_value = None ):
109+ result = self ._hashable_items .pop (key , None )
110+ if result is not None :
111+ return result
112+ return self ._unhashable_items .pop (_RefEq (key ), default_value )
113+
114+ def clear (self ):
115+ self ._hashable_items .clear ()
116+ self ._unhashable_items .clear ()
117+
118+
119+ class _DeletedEntitiesHolder (MutableSet ):
120+ def __init__ (self , items = None ):
121+ if items is None :
122+ items = []
123+ self .items = set (map (_RefEq , items ))
124+
125+ def __getattribute__ (self , item ):
126+ if item in ["add" , "discard" , "items" ]:
127+ return super ().__getattribute__ (item )
128+ return self .items .__getattribute__ (item )
129+
130+ def __contains__ (self , item : object ) -> bool :
131+ return _RefEq (item ) in self .items
132+
133+ def __len__ (self ) -> int :
134+ return len (self .items )
135+
136+ def __iter__ (self ):
137+ return (item .ref for item in self .items )
138+
139+ def add (self , element : object ) -> None :
140+ return self .items .add (_RefEq (element ))
141+
142+ def discard (self , element : object ) -> None :
143+ return self .items .discard (_RefEq (element ))
144+
145+
22146class DocumentSession (object ):
23147 def __init__ (self , database , document_store , requests_executor , session_id , ** kwargs ):
24148 """
@@ -33,8 +157,8 @@ def __init__(self, database, document_store, requests_executor, session_id, **kw
33157 self ._requests_executor = requests_executor
34158 self ._documents_by_id = {}
35159 self ._included_documents_by_id = {}
36- self ._deleted_entities = set ()
37- self ._documents_by_entity = {}
160+ self ._deleted_entities = _DeletedEntitiesHolder ()
161+ self ._documents_by_entity = _DocumentsByEntityHolder ()
38162 self ._timeseries_defer_commands = {}
39163 self ._time_series_by_document_id = {}
40164 self ._counters_defer_commands = {}
0 commit comments