Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,56 @@ def __init__(self, *args, **kwargs):
if not self._is_null:
super(JsonObject, self).__init__(*args, **kwargs)

def __len__(self):
if self._is_null:
return 0
if self._is_array:
return len(self._array_value)
if self._is_scalar_value:
return 1
return super(JsonObject, self).__len__()

def __bool__(self):
if self._is_null:
return False
if self._is_array:
return bool(self._array_value)
if self._is_scalar_value:
return True
return len(self) > 0

def __iter__(self):
if self._is_array:
return iter(self._array_value)
if self._is_scalar_value:
raise TypeError(f"'{type(self._simple_value).__name__}' object is not iterable")
return super(JsonObject, self).__iter__()

def __getitem__(self, key):
if self._is_array:
return self._array_value[key]
if self._is_scalar_value:
raise TypeError(f"'{type(self._simple_value).__name__}' object is not subscriptable")
return super(JsonObject, self).__getitem__(key)

def __contains__(self, item):
if self._is_array:
return item in self._array_value
if self._is_scalar_value:
raise TypeError(f"argument of type '{type(self._simple_value).__name__}' is not iterable")
return super(JsonObject, self).__contains__(item)

def __eq__(self, other):
if isinstance(other, JsonObject):
return self.serialize() == other.serialize()
if self._is_array:
return self._array_value == other
if self._is_scalar_value:
return self._simple_value == other
if self._is_null:
return other is None or (isinstance(other, dict) and len(other) == 0)
return super(JsonObject, self).__eq__(other)
Comment on lines +99 to +108
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of __eq__ has several issues:

  1. Inconsistency: JsonObject(None) == {} returns True (line 107), but JsonObject(None) == JsonObject({}) returns False (line 101) because their serializations differ ("null" vs "{}"). This violates the transitivity of equality. A JSON null should not be equal to an empty object {}.
  2. Efficiency: Comparing two JsonObject instances by serializing them to strings is very expensive ($O(N)$ time and memory). It is much more efficient to compare their internal values directly.
  3. Redundancy: The check for isinstance(other, dict) in line 107 is partially redundant with the final super().__eq__ call, but it currently leads to the incorrect equality with empty dicts for null values.

I recommend refactoring the method to handle each variant explicitly and optimize the comparison between JsonObject instances.

Suggested change
def __eq__(self, other):
if isinstance(other, JsonObject):
return self.serialize() == other.serialize()
if self._is_array:
return self._array_value == other
if self._is_scalar_value:
return self._simple_value == other
if self._is_null:
return other is None or (isinstance(other, dict) and len(other) == 0)
return super(JsonObject, self).__eq__(other)
def __eq__(self, other):
if isinstance(other, JsonObject):
if self._is_null:
return other._is_null
if self._is_array:
return other._is_array and self._array_value == other._array_value
if self._is_scalar_value:
return other._is_scalar_value and self._simple_value == other._simple_value
# Both are dict variants
if other._is_null or other._is_array or other._is_scalar_value:
return False
return super(JsonObject, self).__eq__(other)
if self._is_array:
return self._array_value == other
if self._is_scalar_value:
return self._simple_value == other
if self._is_null:
return other is None
return super(JsonObject, self).__eq__(other)


def __repr__(self):
if self._is_array:
return str(self._array_value)
Expand Down
87 changes: 87 additions & 0 deletions packages/google-cloud-spanner/tests/unit/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,90 @@ def test_w_JsonObject_of_list_of_simple_JsonData(self):
expected = json.dumps(data, sort_keys=True, separators=(",", ":"))
data_jsonobject = JsonObject(JsonObject(data))
self.assertEqual(data_jsonobject.serialize(), expected)


class Test_JsonObject_dict_protocol(unittest.TestCase):
"""Verify that JsonObject behaves correctly with standard Python
operations (len, bool, iteration, indexing) for all JSON variants."""

def test_array_len(self):
obj = JsonObject([{"id": 1}, {"id": 2}])
self.assertEqual(len(obj), 2)

def test_array_bool_truthy(self):
obj = JsonObject([{"id": 1}])
self.assertTrue(obj)

def test_array_bool_empty(self):
obj = JsonObject([])
self.assertFalse(obj)

def test_array_iter(self):
data = [{"a": 1}, {"b": 2}]
obj = JsonObject(data)
self.assertEqual(list(obj), data)

def test_array_getitem(self):
data = [{"a": 1}, {"b": 2}]
obj = JsonObject(data)
self.assertEqual(obj[0], {"a": 1})
self.assertEqual(obj[1], {"b": 2})

def test_array_contains(self):
data = [1, 2, 3]
obj = JsonObject(data)
self.assertIn(2, obj)
self.assertNotIn(4, obj)

def test_array_eq(self):
data = [{"id": 1}]
obj = JsonObject(data)
self.assertEqual(obj, data)

def test_array_json_dumps(self):
data = [{"id": "m1", "content": "hello"}]
obj = JsonObject(data)
result = json.loads(json.dumps(list(obj)))
self.assertEqual(result, data)

def test_dict_len(self):
obj = JsonObject({"a": 1, "b": 2})
self.assertEqual(len(obj), 2)

def test_dict_bool(self):
obj = JsonObject({"a": 1})
self.assertTrue(obj)

def test_dict_iter(self):
obj = JsonObject({"a": 1, "b": 2})
self.assertEqual(sorted(obj), ["a", "b"])

def test_dict_getitem(self):
obj = JsonObject({"key": "value"})
self.assertEqual(obj["key"], "value")

def test_null_len(self):
obj = JsonObject(None)
self.assertEqual(len(obj), 0)

def test_null_bool(self):
obj = JsonObject(None)
self.assertFalse(obj)

def test_scalar_len(self):
obj = JsonObject(42)
self.assertEqual(len(obj), 1)

def test_scalar_bool(self):
obj = JsonObject(42)
self.assertTrue(obj)

def test_scalar_not_iterable(self):
obj = JsonObject(42)
with self.assertRaises(TypeError):
iter(obj)

def test_scalar_not_subscriptable(self):
obj = JsonObject(42)
with self.assertRaises(TypeError):
obj[0]
Loading