diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/attribute.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/attribute.py index a0249f1e..dd280a57 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/attribute.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/attribute.py @@ -93,6 +93,11 @@ def is_dirty(self) -> bool: def get_dirty(self) -> dict: return {key: value for key, value in self.get_attributes().items() if not self.original_is_equivalent(key)} + def delete_attribute(self, key: str): + """Remove a transient attribute that was added during query processing.""" + self._attributes.pop(key, None) + self._dirty_attributes.pop(key, None) + def get_attributes_for_insert(self) -> dict: # _dirty_attributes already went through set_attribute (casts applied on assignment). # _attributes is set raw via new_model_instance, so apply set casts here. diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py index 90b74497..01ba4636 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/models/builder.py @@ -63,6 +63,10 @@ def with_(self, *eagers) -> "QueryBuilder": def get_table_name(self) -> str: return self._table + def table(self, table: str) -> "QueryBuilder": + self._table = table + return self + def where_in(self, column: str, values) -> "QueryBuilder": if hasattr(values, "_items"): values = values._items @@ -136,6 +140,10 @@ def run_scopes(self) -> "QueryBuilder": scope(self) return self + def without_global_scopes(self) -> "QueryBuilder": + self._global_scopes = {} + return self + def get_grammar(self): return self.grammar( columns=self._columns, diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/BelongsToMany.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/BelongsToMany.py index 61c6a3c9..57a6054d 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/BelongsToMany.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/BelongsToMany.py @@ -23,8 +23,9 @@ def __init__( attribute="pivot", with_fields=[], ): + fn_str = fn if isinstance(fn, str): - self.fn = self.fn = lambda x: registry.Registry.resolve(fn) + self.fn = lambda: registry.Registry.resolve(fn_str) self.local_key = local_foreign_key self.foreign_key = other_foreign_key @@ -134,16 +135,11 @@ async def apply_query(self, query, owner): pivot_data.update({field: getattr(model, field)}) model.delete_attribute(field) - model.__original_attributes__.update( - { - self._as: ( - Pivot.on(query.connection) - .table(self._table) - .hydrate(pivot_data) - .activate_timestamps(self.with_timestamps) - ) - } - ) + pivot_model = Pivot() + pivot_model.__table__ = self._table + pivot_model.__timestamps__ = self.with_timestamps + pivot_model.set_raw_attributes(pivot_data, True) + model._attributes[self._as] = pivot_model return result @@ -266,16 +262,11 @@ async def get_related(self, query, relation, eagers=None, callback=None): pivot_data.update({field: getattr(model, field)}) model.delete_attribute(field) - model.__original_attributes__.update( - { - self._as: ( - Pivot.on(builder.connection_name) - .table(self._table) - .hydrate(pivot_data) - .activate_timestamps(self.with_timestamps) - ) - } - ) + pivot_model = Pivot() + pivot_model.__table__ = self._table + pivot_model.__timestamps__ = self.with_timestamps + pivot_model.set_raw_attributes(pivot_data, True) + model._attributes[self._as] = pivot_model return final_result @@ -487,7 +478,9 @@ def attach(self, current_model, related_record): self.foreign_key: getattr(related_record, self.other_owner_key), } - self._table = self._table or self.get_pivot_table_name(current_model, related_record) + self._table = self._table or self.get_pivot_table_name( + current_model.get_builder(), related_record.get_builder() + ) if self.with_timestamps: data.update( @@ -497,7 +490,7 @@ def attach(self, current_model, related_record): } ) - return Pivot.on(current_model.get_builder().connection).table(self._table).without_global_scopes().create(data) + return current_model.get_builder().connection.query().table(self._table).insert(data) def detach(self, current_model, related_record): data = { @@ -505,12 +498,14 @@ def detach(self, current_model, related_record): self.foreign_key: getattr(related_record, self.other_owner_key), } - self._table = self._table or self.get_pivot_table_name(current_model, related_record) + self._table = self._table or self.get_pivot_table_name( + current_model.get_builder(), related_record.get_builder() + ) return ( - Pivot.on(current_model.get_builder().connection) + current_model.get_builder() + .connection.query() .table(self._table) - .without_global_scopes() .where(data) .delete() ) @@ -521,7 +516,9 @@ def attach_related(self, current_model, related_record): self.foreign_key: getattr(related_record, self.other_owner_key), } - self._table = self._table or self.get_pivot_table_name(current_model, related_record) + self._table = self._table or self.get_pivot_table_name( + current_model.get_builder(), related_record.get_builder() + ) if self.with_timestamps: data.update( @@ -531,12 +528,7 @@ def attach_related(self, current_model, related_record): } ) - return ( - Pivot.table(self._table) - .on(current_model.get_builder().connection_name) - .without_global_scopes() - .create(data) - ) + return current_model.get_builder().connection.query().table(self._table).insert(data) def detach_related(self, current_model, related_record): data = { @@ -544,7 +536,9 @@ def detach_related(self, current_model, related_record): self.foreign_key: getattr(related_record, self.other_owner_key), } - self._table = self._table or self.get_pivot_table_name(current_model, related_record) + self._table = self._table or self.get_pivot_table_name( + current_model.get_builder(), related_record.get_builder() + ) if self.with_timestamps: data.update( @@ -555,9 +549,9 @@ def detach_related(self, current_model, related_record): ) return ( - Pivot.on(current_model.get_builder().connection_name) + current_model.get_builder() + .connection.query() .table(self._table) - .without_global_scopes() .where(data) .delete() ) diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphMany.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphMany.py index af655ed5..985d455c 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphMany.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphMany.py @@ -18,9 +18,12 @@ def set_keys(self, owner, attribute): return self def __get__(self, instance, owner): + if instance is None: + return self + attribute = self.fn.__name__ - self._related_builder = instance.builder - self.polymorphic_builder = self.fn(self)() + self._related_builder = instance.get_builder() + self.polymorphic_builder = self.fn(self).query() self.set_keys(owner, self.fn) if not instance.is_loaded(): @@ -32,8 +35,7 @@ def __get__(self, instance, owner): return self.apply_query(self._related_builder, instance) def __getattr__(self, attribute): - relationship = self.fn(self)() - return getattr(relationship.builder, attribute) + return getattr(self.fn(self).query(), attribute) def apply_query(self, builder, instance): """Apply the query and return a dictionary to be hydrated @@ -45,7 +47,7 @@ def apply_query(self, builder, instance): Returns: dict -- A dictionary of data which will be hydrated. """ - polymorphic_key = self.get_record_key_lookup(builder._model) + polymorphic_key = self.get_record_key_lookup(instance) polymorphic_builder = self.polymorphic_builder return ( polymorphic_builder.where(self.morph_key, polymorphic_key) @@ -115,13 +117,7 @@ def morph_map(self): return registry.Registry.get_morph_map() def get_record_key_lookup(self, relation): - record_type = None - for record_type_loop, model in self.morph_map().items(): - if model == relation.__class__: - record_type = record_type_loop - break - - if not record_type: + morph_name = registry.Registry._reverse_map.get(relation.__class__) + if morph_name is None: raise ValueError(f"Could not find the record type key for the {relation} class") - - return record_type + return morph_name diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphOne.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphOne.py index 72ce6b48..abac605f 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphOne.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphOne.py @@ -131,7 +131,9 @@ def register_related(self, key, model, collection): model.add_relation({key: related}) def morph_map(self): - return load_config().DB._morph_map + from fastapi_startkit.masoniteorm.models import registry + + return registry.Registry.get_morph_map() def get_record_key_lookup(self, relation): record_type = None diff --git a/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphToMany.py b/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphToMany.py index f647f250..dd1ebefd 100644 --- a/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphToMany.py +++ b/fastapi_startkit/src/fastapi_startkit/masoniteorm/relationships/MorphToMany.py @@ -100,4 +100,6 @@ def register_related(self, key, model, collection): model.add_relation({key: related}) def morph_map(self): - return load_config().DB._morph_map + from fastapi_startkit.masoniteorm.models import registry + + return registry.Registry.get_morph_map() diff --git a/fastapi_startkit/tests/masoniteorm/processors/__init__.py b/fastapi_startkit/tests/masoniteorm/processors/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastapi_startkit/tests/masoniteorm/processors/test_post_processors.py b/fastapi_startkit/tests/masoniteorm/processors/test_post_processors.py new file mode 100644 index 00000000..493cf117 --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/processors/test_post_processors.py @@ -0,0 +1,211 @@ +from unittest import TestCase +from unittest.mock import MagicMock + +from fastapi_startkit.masoniteorm.query.processors.MSSQLPostProcessor import MSSQLPostProcessor +from fastapi_startkit.masoniteorm.query.processors.MySQLPostProcessor import MySQLPostProcessor +from fastapi_startkit.masoniteorm.query.processors.PostgresPostProcessor import PostgresPostProcessor +from fastapi_startkit.masoniteorm.query.processors.SQLitePostProcessor import SQLitePostProcessor + + +class TestMySQLPostProcessor(TestCase): + def setUp(self): + self.processor = MySQLPostProcessor() + + def test_process_insert_get_id_when_id_already_in_results(self): + builder = MagicMock() + results = {"id": 5, "name": "test"} + result = self.processor.process_insert_get_id(builder, results, "id") + self.assertEqual(result["id"], 5) + builder._connection.get_cursor.assert_not_called() + + def test_process_insert_get_id_when_id_not_in_results(self): + builder = MagicMock() + builder._connection.get_cursor.return_value.lastrowid = 42 + results = {"name": "test"} + result = self.processor.process_insert_get_id(builder, results, "id") + self.assertEqual(result["id"], 42) + builder._connection.get_cursor.assert_called_once() + + def test_process_insert_get_id_custom_key(self): + builder = MagicMock() + builder._connection.get_cursor.return_value.lastrowid = 99 + results = {"email": "test@example.com"} + result = self.processor.process_insert_get_id(builder, results, "user_id") + self.assertEqual(result["user_id"], 99) + + def test_get_column_value_with_id_key_and_value(self): + builder = MagicMock() + mock_new_builder = MagicMock() + builder.select.return_value = mock_new_builder + mock_new_builder.first.return_value = {"name": "Alice"} + + result = self.processor.get_column_value(builder, "name", {}, "id", 1) + self.assertEqual(result, "Alice") + builder.select.assert_called_once_with("name") + mock_new_builder.where.assert_called_once_with("id", 1) + + def test_get_column_value_without_id_key(self): + builder = MagicMock() + result = self.processor.get_column_value(builder, "name", {}, None, None) + self.assertEqual(result, {}) + + def test_get_column_value_without_id_value(self): + builder = MagicMock() + result = self.processor.get_column_value(builder, "name", {}, "id", None) + self.assertEqual(result, {}) + + +class TestSQLitePostProcessor(TestCase): + def setUp(self): + self.processor = SQLitePostProcessor() + + def test_process_insert_get_id_when_id_already_in_results(self): + builder = MagicMock() + results = {"id": 7, "name": "test"} + result = self.processor.process_insert_get_id(builder, results, "id") + self.assertEqual(result["id"], 7) + builder.get_connection.assert_not_called() + + def test_process_insert_get_id_when_id_not_in_results(self): + builder = MagicMock() + builder.get_connection.return_value.get_last_row_id.return_value = 10 + results = {"name": "test"} + result = self.processor.process_insert_get_id(builder, results, "id") + self.assertEqual(result["id"], 10) + builder.get_connection.assert_called_once() + + def test_process_insert_get_id_default_key(self): + builder = MagicMock() + builder.get_connection.return_value.get_last_row_id.return_value = 3 + results = {} + result = self.processor.process_insert_get_id(builder, results) + self.assertEqual(result["id"], 3) + + def test_get_column_value_with_id_key_and_value(self): + builder = MagicMock() + mock_new_builder = MagicMock() + builder.select.return_value = mock_new_builder + mock_new_builder.first.return_value = {"email": "bob@example.com"} + + result = self.processor.get_column_value(builder, "email", {}, "id", 2) + self.assertEqual(result, "bob@example.com") + + def test_get_column_value_without_id(self): + builder = MagicMock() + result = self.processor.get_column_value(builder, "email", {}, None, None) + self.assertEqual(result, {}) + + def test_get_column_value_without_id_value(self): + builder = MagicMock() + result = self.processor.get_column_value(builder, "email", {}, "id", None) + self.assertEqual(result, {}) + + +class TestPostgresPostProcessor(TestCase): + def setUp(self): + self.processor = PostgresPostProcessor() + + def test_process_insert_get_id_with_lastval_result(self): + builder = MagicMock() + results = {"lastval": 99} + result = self.processor.process_insert_get_id(builder, results, "id") + self.assertEqual(result, {"id": 99}) + + def test_process_insert_get_id_with_regular_results(self): + builder = MagicMock() + results = {"id": 5, "name": "test"} + result = self.processor.process_insert_get_id(builder, results, "id") + self.assertEqual(result, {"id": 5, "name": "test"}) + + def test_process_insert_get_id_with_non_dict_result(self): + builder = MagicMock() + results = {"id": 5, "extra": "field"} + result = self.processor.process_insert_get_id(builder, results, "id") + self.assertEqual(result["id"], 5) + + def test_process_insert_get_id_lastval_uses_correct_key(self): + builder = MagicMock() + results = {"lastval": 77} + result = self.processor.process_insert_get_id(builder, results, "user_id") + self.assertEqual(result, {"user_id": 77}) + + def test_get_column_value_when_column_in_results(self): + builder = MagicMock() + results = {"name": "Charlie", "email": "charlie@example.com"} + result = self.processor.get_column_value(builder, "name", results, "id", 1) + self.assertEqual(result, "Charlie") + builder.select.assert_not_called() + + def test_get_column_value_when_column_not_in_results_with_id(self): + builder = MagicMock() + mock_new_builder = MagicMock() + builder.select.return_value = mock_new_builder + mock_new_builder.first.return_value = {"name": "Dave"} + + result = self.processor.get_column_value(builder, "name", {}, "id", 3) + self.assertEqual(result, "Dave") + + def test_get_column_value_without_id(self): + builder = MagicMock() + result = self.processor.get_column_value(builder, "name", {}, None, None) + self.assertEqual(result, {}) + + def test_get_column_value_without_id_value(self): + builder = MagicMock() + result = self.processor.get_column_value(builder, "name", {}, "id", None) + self.assertEqual(result, {}) + + +class TestMSSQLPostProcessor(TestCase): + def setUp(self): + self.processor = MSSQLPostProcessor() + + def test_process_insert_get_id_with_integer_id(self): + builder = MagicMock() + builder.new_connection.return_value.query.return_value = {"id": "42"} + results = {"name": "test"} + result = self.processor.process_insert_get_id(builder, results, "id") + self.assertEqual(result["id"], 42) + self.assertIsInstance(result["id"], int) + + def test_process_insert_get_id_with_string_id(self): + builder = MagicMock() + builder.new_connection.return_value.query.return_value = {"id": "abc-123"} + results = {"name": "test"} + result = self.processor.process_insert_get_id(builder, results, "id") + self.assertEqual(result["id"], "abc-123") + self.assertIsInstance(result["id"], str) + + def test_process_insert_get_id_custom_key(self): + builder = MagicMock() + builder.new_connection.return_value.query.return_value = {"id": "7"} + results = {} + result = self.processor.process_insert_get_id(builder, results, "record_id") + self.assertEqual(result["record_id"], 7) + + def test_process_insert_get_id_calls_select_identity(self): + builder = MagicMock() + builder.new_connection.return_value.query.return_value = {"id": "1"} + self.processor.process_insert_get_id(builder, {}, "id") + builder.new_connection.return_value.query.assert_called_once_with( + "SELECT @@Identity as [id]", results=1 + ) + + def test_get_column_value_with_id_key_and_value(self): + builder = MagicMock() + mock_new_builder = MagicMock() + builder.select.return_value = mock_new_builder + mock_new_builder.first.return_value = {"score": 95} + + result = self.processor.get_column_value(builder, "score", {}, "id", 1) + self.assertEqual(result, 95) + + def test_get_column_value_without_id(self): + builder = MagicMock() + result = self.processor.get_column_value(builder, "score", {}, None, None) + self.assertEqual(result, {}) + + def test_get_column_value_without_id_value(self): + builder = MagicMock() + result = self.processor.get_column_value(builder, "score", {}, "id", None) + self.assertEqual(result, {}) diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_belongs_to_many.py b/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_belongs_to_many.py new file mode 100644 index 00000000..48b44ea3 --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_belongs_to_many.py @@ -0,0 +1,121 @@ +from ...fixtures.model import Product, Store +from ..test_case import TestCase + + +class TestBelongsToManyRelationship(TestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.store = await Store.create({"name": "Test Store"}) + self.product1 = await Product.create({"name": "Widget"}) + self.product2 = await Product.create({"name": "Gadget"}) + + async def test_attach_creates_pivot_record(self): + await Store.products.attach(self.store, self.product1) + store = await Store.where("id", self.store.id).first() + products = await store.products + self.assertEqual(len(products), 1) + self.assertEqual(products[0].name, "Widget") + + async def test_attach_multiple_products(self): + await Store.products.attach(self.store, self.product1) + await Store.products.attach(self.store, self.product2) + store = await Store.where("id", self.store.id).first() + products = await store.products + self.assertEqual(len(products), 2) + + async def test_detach_removes_pivot_record(self): + await Store.products.attach(self.store, self.product1) + await Store.products.attach(self.store, self.product2) + await Store.products.detach(self.store, self.product1) + store = await Store.where("id", self.store.id).first() + products = await store.products + self.assertEqual(len(products), 1) + self.assertEqual(products[0].name, "Gadget") + + async def test_eager_load_belongs_to_many(self): + await Store.products.attach(self.store, self.product1) + stores = await Store.with_("products").get() + store = stores.where("id", self.store.id).first() + self.assertIsNotNone(store) + self.assertEqual(len(store.products), 1) + + async def test_eager_load_empty_relationship(self): + stores = await Store.with_("products").get() + store = stores.where("id", self.store.id).first() + self.assertIsNotNone(store) + # Empty BelongsToMany eager load returns None (consistent with other relationships) + self.assertIsNone(store.products) + + async def test_pivot_access_after_eager_load(self): + await Store.products.attach(self.store, self.product1) + stores = await Store.with_("products").get() + store = stores.where("id", self.store.id).first() + product = store.products[0] + pivot = product.pivot + self.assertIsNotNone(pivot) + + async def test_with_timestamps_in_pivot(self): + # Store.products uses with_timestamps=True + await Store.products.attach(self.store, self.product1) + stores = await Store.with_("products").get() + store = stores.where("id", self.store.id).first() + product = store.products[0] + self.assertIsNotNone(product.pivot) + + async def test_explicit_table_relationship(self): + # Store.products_table uses table="product_table" + await Store.products_table.attach(self.store, self.product1) + store = await Store.where("id", self.store.id).first() + products = await store.products_table + self.assertEqual(len(products), 1) + + async def test_attach_related_creates_pivot_record(self): + await Store.products.attach_related(self.store, self.product1) + store = await Store.where("id", self.store.id).first() + products = await store.products + self.assertEqual(len(products), 1) + + async def test_detach_related_removes_pivot_record(self): + await Store.products.attach_related(self.store, self.product1) + await Store.products.detach_related(self.store, self.product1) + store = await Store.where("id", self.store.id).first() + products = await store.products + self.assertEqual(len(products), 0) + + async def test_get_pivot_table_name(self): + # Test the helper method directly using builder proxies + rel = Store.products + # Manually set the pivot table name to test the method + rel._table = None + name = rel.get_pivot_table_name(self.store.get_builder(), self.product1.get_builder()) + self.assertEqual(name, "product_store") + + async def test_map_related_returns_result(self): + results = [self.product1, self.product2] + rel = Store.products + mapped = rel.map_related(results) + self.assertEqual(mapped, results) + + async def test_register_related_groups_by_owner_key(self): + await Store.products.attach(self.store, self.product1) + stores = await Store.with_("products").get() + store = stores.where("id", self.store.id).first() + # If register_related works, the products collection is populated + self.assertEqual(len(store.products), 1) + + async def test_query_has_filters_stores_with_products(self): + store2 = await Store.create({"name": "Empty Store"}) + await Store.products.attach(self.store, self.product1) + + stores_with_products = await Store.where_has("products").get() + store_ids = [s.id for s in stores_with_products] + + self.assertIn(self.store.id, store_ids) + self.assertNotIn(store2.id, store_ids) + + async def test_query_has_returns_builder(self): + # query_has should add a where_exists clause to the builder + builder = Store.query() + result = Store.products.query_has(builder, method="where_exists") + # The builder has a where clause appended (we just verify no error is raised) + self.assertIsNotNone(builder) diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_morph_many.py b/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_morph_many.py new file mode 100644 index 00000000..ddda9e7c --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_morph_many.py @@ -0,0 +1,131 @@ +from fastapi_startkit.masoniteorm import Model +from fastapi_startkit.masoniteorm.models.registry import Registry +from fastapi_startkit.masoniteorm.relationships import MorphMany +from ...fixtures.model import Like, Product +from ..test_case import TestCase + + +def likes(self): + return Like + + +class ArticleModel(Model): + __table__ = "articles" + likes = MorphMany(likes, morph_key="likeable_type", morph_id="likeable_id") + + +Registry.morph_map({"article": ArticleModel, "product": Product}) + + +class TestMorphManyRelationship(TestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.article = await ArticleModel.create( + { + "title": "Test Article", + "user_id": 1, + "published_date": "2024-01-01 00:00:00", + } + ) + + async def test_morph_many_init_stores_keys(self): + rel = MorphMany(likes, morph_key="likeable_type", morph_id="likeable_id") + self.assertEqual(rel.morph_key, "likeable_type") + self.assertEqual(rel.morph_id, "likeable_id") + self.assertEqual(rel.fn, likes) + + async def test_morph_many_set_keys_defaults(self): + rel = MorphMany(likes) + rel.set_keys(None, None) + self.assertEqual(rel.morph_key, "record_type") + self.assertEqual(rel.morph_id, "record_id") + + async def test_morph_many_set_keys_keeps_existing(self): + rel = MorphMany(likes, morph_key="likeable_type", morph_id="likeable_id") + rel.set_keys(None, None) + self.assertEqual(rel.morph_key, "likeable_type") + self.assertEqual(rel.morph_id, "likeable_id") + + async def test_morph_map_returns_registry(self): + rel = MorphMany(likes, morph_key="likeable_type", morph_id="likeable_id") + morph_map = rel.morph_map() + self.assertIn("article", morph_map) + self.assertIn("product", morph_map) + + async def test_get_record_key_lookup_returns_key(self): + rel = MorphMany(likes, morph_key="likeable_type", morph_id="likeable_id") + key = rel.get_record_key_lookup(self.article) + self.assertEqual(key, "article") + + async def test_apply_query_returns_likes_for_article(self): + await Like.create({"likeable_type": "article", "likeable_id": self.article.id}) + article = await ArticleModel.where("id", self.article.id).first() + likes_result = await article.likes + self.assertEqual(len(likes_result), 1) + + async def test_apply_query_returns_empty_for_article_without_likes(self): + # Don't create any likes for this article + article = await ArticleModel.where("id", self.article.id).first() + likes_result = await article.likes + self.assertEqual(len(likes_result), 0) + + async def test_eager_load_morph_many(self): + await Like.create({"likeable_type": "article", "likeable_id": self.article.id}) + articles = await ArticleModel.with_("likes").get() + article = articles.where("id", self.article.id).first() + self.assertIsNotNone(article) + self.assertEqual(len(article.likes), 1) + + async def test_eager_load_morph_many_multiple(self): + await Like.create({"likeable_type": "article", "likeable_id": self.article.id}) + await Like.create({"likeable_type": "article", "likeable_id": self.article.id}) + articles = await ArticleModel.with_("likes").get() + article = articles.where("id", self.article.id).first() + self.assertEqual(len(article.likes), 2) + + async def test_get_related_with_single_model(self): + await Like.create({"likeable_type": "article", "likeable_id": self.article.id}) + rel = ArticleModel.likes + rel._related_builder = self.article.get_builder() + rel.polymorphic_builder = Like.query() + + result = rel.get_related(None, self.article) + likes_result = await result + self.assertEqual(len(likes_result), 1) + + async def test_get_related_with_collection(self): + await Like.create({"likeable_type": "article", "likeable_id": self.article.id}) + articles = await ArticleModel.get() + rel = ArticleModel.likes + rel._related_builder = self.article.get_builder() + rel.polymorphic_builder = Like.query() + + result = rel.get_related(None, articles) + likes_result = await result + self.assertGreaterEqual(len(likes_result), 1) + + async def test_get_related_with_callback(self): + await Like.create({"likeable_type": "article", "likeable_id": self.article.id}) + rel = ArticleModel.likes + rel._related_builder = self.article.get_builder() + rel.polymorphic_builder = Like.query() + + result = rel.get_related(None, self.article, callback=lambda q: q) + likes_result = await result + self.assertEqual(len(likes_result), 1) + + async def test_register_related_adds_relation(self): + await Like.create({"likeable_type": "article", "likeable_id": self.article.id}) + await Like.create({"likeable_type": "product", "likeable_id": 999}) + + all_likes = await Like.get() + article = await ArticleModel.where("id", self.article.id).first() + + rel = ArticleModel.likes + rel.register_related("likes", article, all_likes) + + self.assertIn("likes", article._relationships) + article_likes = article._relationships["likes"] + # Only likes for this article type are included + for like in article_likes: + self.assertEqual(like.likeable_type, "article") diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_morph_one.py b/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_morph_one.py new file mode 100644 index 00000000..ca06e36b --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_morph_one.py @@ -0,0 +1,139 @@ +from fastapi_startkit.masoniteorm import Model +from fastapi_startkit.masoniteorm.models.registry import Registry +from fastapi_startkit.masoniteorm.relationships import MorphOne +from ...fixtures.model import Like, Product +from ..test_case import TestCase + + +def first_like(self): + return Like + + +class ArticleModelMorphOne(Model): + __table__ = "articles" + first_like = MorphOne(first_like, morph_key="likeable_type", morph_id="likeable_id") + + +Registry.morph_map({"article_one": ArticleModelMorphOne, "product": Product}) + + +class TestMorphOneRelationship(TestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.article = await ArticleModelMorphOne.create( + { + "title": "Test Article", + "user_id": 1, + "published_date": "2024-01-01 00:00:00", + } + ) + + async def test_morph_one_init_with_function(self): + rel = MorphOne(first_like, morph_key="likeable_type", morph_id="likeable_id") + self.assertEqual(rel.morph_key, "likeable_type") + self.assertEqual(rel.morph_id, "likeable_id") + self.assertEqual(rel.fn, first_like) + + async def test_morph_one_init_with_string(self): + # When a string is passed, it's used as morph_key and second arg as morph_id + rel = MorphOne("likeable_type", "likeable_id") + self.assertIsNone(rel.fn) + self.assertEqual(rel.morph_key, "likeable_type") + self.assertEqual(rel.morph_id, "likeable_id") + + async def test_morph_one_set_keys_defaults(self): + rel = MorphOne(first_like) + rel.set_keys(None, None) + self.assertEqual(rel.morph_key, "record_type") + self.assertEqual(rel.morph_id, "record_id") + + async def test_morph_one_set_keys_keeps_existing(self): + rel = MorphOne(first_like, morph_key="likeable_type", morph_id="likeable_id") + rel.set_keys(None, None) + self.assertEqual(rel.morph_key, "likeable_type") + self.assertEqual(rel.morph_id, "likeable_id") + + async def test_morph_map_uses_registry(self): + rel = MorphOne(first_like, morph_key="likeable_type", morph_id="likeable_id") + morph_map = rel.morph_map() + self.assertIsInstance(morph_map, dict) + self.assertIn("article_one", morph_map) + + async def test_get_record_key_lookup_returns_key(self): + rel = MorphOne(first_like, morph_key="likeable_type", morph_id="likeable_id") + key = rel.get_record_key_lookup(self.article) + self.assertEqual(key, "article_one") + + async def test_get_record_key_lookup_raises_for_unknown(self): + rel = MorphOne(first_like, morph_key="likeable_type", morph_id="likeable_id") + + class UnknownModel(Model): + __table__ = "articles" + + unknown = UnknownModel() + with self.assertRaises(ValueError): + rel.get_record_key_lookup(unknown) + + async def test_apply_query_returns_single_like(self): + await Like.create({"likeable_type": "article_one", "likeable_id": self.article.id}) + await Like.create({"likeable_type": "article_one", "likeable_id": self.article.id}) + article = await ArticleModelMorphOne.where("id", self.article.id).first() + result = await article.first_like + # MorphOne returns a single record (first()) + self.assertIsNotNone(result) + self.assertIsInstance(result, Like) + + async def test_apply_query_returns_none_when_no_likes(self): + article = await ArticleModelMorphOne.where("id", self.article.id).first() + result = await article.first_like + self.assertIsNone(result) + + async def test_eager_load_morph_one(self): + await Like.create({"likeable_type": "article_one", "likeable_id": self.article.id}) + articles = await ArticleModelMorphOne.with_("first_like").get() + article = articles.where("id", self.article.id).first() + self.assertIsNotNone(article) + self.assertIsInstance(article.first_like, Like) + + async def test_get_related_with_single_model(self): + await Like.create({"likeable_type": "article_one", "likeable_id": self.article.id}) + rel = ArticleModelMorphOne.first_like + rel._related_builder = self.article.builder + rel.polymorphic_builder = first_like(rel)() + + result = rel.get_related(None, self.article) + like = await result + self.assertIsNotNone(like) + self.assertIsInstance(like, Like) + + async def test_get_related_with_collection(self): + await Like.create({"likeable_type": "article_one", "likeable_id": self.article.id}) + articles = await ArticleModelMorphOne.get() + rel = ArticleModelMorphOne.first_like + rel._related_builder = self.article.builder + rel.polymorphic_builder = first_like(rel)() + + result = rel.get_related(None, articles) + # With Collection, returns .get() (a Collection) + likes_result = await result + self.assertGreaterEqual(len(likes_result), 1) + + async def test_get_related_with_callback(self): + await Like.create({"likeable_type": "article_one", "likeable_id": self.article.id}) + rel = ArticleModelMorphOne.first_like + rel._related_builder = self.article.builder + rel.polymorphic_builder = first_like(rel)() + + result = rel.get_related(None, self.article, callback=lambda q: q) + like = await result + self.assertIsNotNone(like) + + async def test_register_related_adds_first(self): + await Like.create({"likeable_type": "article_one", "likeable_id": self.article.id}) + all_likes = await Like.get() + article = await ArticleModelMorphOne.where("id", self.article.id).first() + + rel = ArticleModelMorphOne.first_like + rel.register_related("first_like", article, all_likes) + + self.assertIn("first_like", article._relationships) diff --git a/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_morph_to_many.py b/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_morph_to_many.py new file mode 100644 index 00000000..3aa0a126 --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/sqlite/relationships/test_morph_to_many.py @@ -0,0 +1,134 @@ +from fastapi_startkit.masoniteorm import Model +from fastapi_startkit.masoniteorm.collection import Collection +from fastapi_startkit.masoniteorm.models.registry import Registry +from fastapi_startkit.masoniteorm.relationships import MorphToMany +from ...fixtures.model import Articles, Product +from ..test_case import TestCase + + +def record(self): + return None + + +class LikeModelMorphToMany(Model): + __table__ = "likes" + record = MorphToMany(record, morph_key="likeable_type", morph_id="likeable_id") + + +# Register article/product so morph_map resolves them +Registry.morph_map({"article_m2m": Articles, "product_m2m": Product}) + + +class TestMorphToManyRelationship(TestCase): + async def asyncSetUp(self): + await super().asyncSetUp() + self.article = await Articles.create( + { + "title": "M2M Article", + "user_id": 1, + "published_date": "2024-01-01 00:00:00", + } + ) + self.product = await Product.create({"name": "M2M Product"}) + + async def test_morph_to_many_init_with_function(self): + rel = MorphToMany(record, morph_key="likeable_type", morph_id="likeable_id") + self.assertEqual(rel.morph_key, "likeable_type") + self.assertEqual(rel.morph_id, "likeable_id") + self.assertEqual(rel.fn, record) + + async def test_morph_to_many_init_with_string(self): + rel = MorphToMany("likeable_type", "likeable_id") + self.assertIsNone(rel.fn) + self.assertEqual(rel.morph_key, "likeable_type") + self.assertEqual(rel.morph_id, "likeable_id") + + async def test_morph_to_many_set_keys_defaults(self): + rel = MorphToMany(record) + rel.set_keys(None, None) + self.assertEqual(rel.morph_key, "record_type") + self.assertEqual(rel.morph_id, "record_id") + + async def test_morph_to_many_set_keys_keeps_existing(self): + rel = MorphToMany(record, morph_key="likeable_type", morph_id="likeable_id") + rel.set_keys(None, None) + self.assertEqual(rel.morph_key, "likeable_type") + self.assertEqual(rel.morph_id, "likeable_id") + + async def test_morph_map_uses_registry(self): + rel = MorphToMany(record, morph_key="likeable_type", morph_id="likeable_id") + morph_map = rel.morph_map() + self.assertIsInstance(morph_map, dict) + self.assertIn("article_m2m", morph_map) + self.assertIn("product_m2m", morph_map) + + async def test_apply_query_resolves_article(self): + from fastapi_startkit.masoniteorm.models.model import Model as BaseModel + + like = await LikeModelMorphToMany.create( + {"likeable_type": "article_m2m", "likeable_id": self.article.id} + ) + like_loaded = await LikeModelMorphToMany.where("id", like.id).first() + + rel = LikeModelMorphToMany.record + rel.set_keys(LikeModelMorphToMany, rel.fn) + result = rel.apply_query(like_loaded.builder, like_loaded) + resolved = await result + self.assertIsNotNone(resolved) + self.assertIsInstance(resolved, Articles) + + async def test_apply_query_resolves_product(self): + like = await LikeModelMorphToMany.create( + {"likeable_type": "product_m2m", "likeable_id": self.product.id} + ) + like_loaded = await LikeModelMorphToMany.where("id", like.id).first() + + rel = LikeModelMorphToMany.record + rel.set_keys(LikeModelMorphToMany, rel.fn) + result = rel.apply_query(like_loaded.builder, like_loaded) + resolved = await result + self.assertIsNotNone(resolved) + self.assertIsInstance(resolved, Product) + + async def test_get_related_with_collection(self): + await LikeModelMorphToMany.create( + {"likeable_type": "article_m2m", "likeable_id": self.article.id} + ) + await LikeModelMorphToMany.create( + {"likeable_type": "product_m2m", "likeable_id": self.product.id} + ) + + likes = await LikeModelMorphToMany.get() + rel = LikeModelMorphToMany.record + rel.set_keys(LikeModelMorphToMany, rel.fn) + + resolved = await rel.get_related(None, likes) + self.assertIsInstance(resolved, Collection) + self.assertGreaterEqual(resolved.count(), 1) + + async def test_get_related_with_single_model_no_match(self): + like = await LikeModelMorphToMany.create( + {"likeable_type": "unknown_type", "likeable_id": 999} + ) + like_loaded = await LikeModelMorphToMany.where("id", like.id).first() + + rel = LikeModelMorphToMany.record + rel.set_keys(LikeModelMorphToMany, rel.fn) + result = await rel.get_related(None, like_loaded) + self.assertIsNone(result) + + async def test_register_related_maps_to_model(self): + await LikeModelMorphToMany.create( + {"likeable_type": "article_m2m", "likeable_id": self.article.id} + ) + await LikeModelMorphToMany.create( + {"likeable_type": "product_m2m", "likeable_id": self.product.id} + ) + + all_articles = await Articles.get() + like_article = await LikeModelMorphToMany.where("likeable_type", "article_m2m").first() + + rel = LikeModelMorphToMany.record + rel.register_related("record", like_article, all_articles) + + self.assertIn("record", like_article._relationships) diff --git a/fastapi_startkit/tests/masoniteorm/testing/__init__.py b/fastapi_startkit/tests/masoniteorm/testing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastapi_startkit/tests/masoniteorm/testing/test_transaction.py b/fastapi_startkit/tests/masoniteorm/testing/test_transaction.py new file mode 100644 index 00000000..1921d6fb --- /dev/null +++ b/fastapi_startkit/tests/masoniteorm/testing/test_transaction.py @@ -0,0 +1,114 @@ +from unittest import IsolatedAsyncioTestCase +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi_startkit.masoniteorm.testing.transaction import DatabaseTransaction, RefreshDatabase + + +class TestDatabaseTransaction(IsolatedAsyncioTestCase): + async def test_asyncStartTestRun_begins_transaction(self): + tx = DatabaseTransaction() + mock_db_manager = MagicMock() + mock_connection = AsyncMock() + mock_db_manager.connection.return_value = mock_connection + + # Model is imported locally inside asyncStartTestRun, so patch the source module + with patch("fastapi_startkit.masoniteorm.models.Model") as mock_model_cls: + mock_model_cls.db_manager = mock_db_manager + await tx.asyncStartTestRun() + + mock_db_manager.connection.assert_called_once_with(None) + mock_connection.begin_transaction.assert_awaited_once() + self.assertIs(tx.connection, mock_connection) + + async def test_asyncStopTestRun_rolls_back_transaction(self): + tx = DatabaseTransaction() + mock_connection = AsyncMock() + tx.connection = mock_connection + + await tx.asyncStopTestRun() + + mock_connection.rollback.assert_awaited_once() + + async def test_start_then_stop_sequence(self): + tx = DatabaseTransaction() + mock_db_manager = MagicMock() + mock_connection = AsyncMock() + mock_db_manager.connection.return_value = mock_connection + + with patch("fastapi_startkit.masoniteorm.models.Model") as mock_model_cls: + mock_model_cls.db_manager = mock_db_manager + await tx.asyncStartTestRun() + await tx.asyncStopTestRun() + + mock_connection.begin_transaction.assert_awaited_once() + mock_connection.rollback.assert_awaited_once() + + +class TestRefreshDatabase(IsolatedAsyncioTestCase): + async def asyncSetUp(self): + # Reset the class-level flag before each test + RefreshDatabase.migrated = False + + async def asyncTearDown(self): + # Reset again to avoid polluting other test suites + RefreshDatabase.migrated = False + + async def test_migrate_database_runs_migrator_on_first_call(self): + mock_migrator = AsyncMock() + + with ( + patch( + "fastapi_startkit.masoniteorm.migrations.Migrator", + return_value=mock_migrator, + ) as MockMigrator, + patch("fastapi_startkit.application.app") as mock_app_fn, + ): + mock_app_fn.return_value.use_base_path.return_value = "/fake/migrations" + + await RefreshDatabase.migrate_database() + + MockMigrator.assert_called_once() + mock_migrator.fresh.assert_awaited_once_with(ignore_fk=True) + self.assertTrue(RefreshDatabase.migrated) + + async def test_migrate_database_skips_on_second_call(self): + mock_migrator = AsyncMock() + + with ( + patch( + "fastapi_startkit.masoniteorm.migrations.Migrator", + return_value=mock_migrator, + ) as MockMigrator, + patch("fastapi_startkit.application.app") as mock_app_fn, + ): + mock_app_fn.return_value.use_base_path.return_value = "/fake/migrations" + + await RefreshDatabase.migrate_database() + await RefreshDatabase.migrate_database() # second call — should be a no-op + + # Migrator should only be instantiated once + MockMigrator.assert_called_once() + mock_migrator.fresh.assert_awaited_once() + + async def test_asyncStartTestRun_migrates_then_begins_transaction(self): + db = RefreshDatabase() + mock_db_manager = MagicMock() + mock_connection = AsyncMock() + mock_db_manager.connection.return_value = mock_connection + mock_migrator = AsyncMock() + + with ( + patch("fastapi_startkit.masoniteorm.models.Model") as mock_model_cls, + patch( + "fastapi_startkit.masoniteorm.migrations.Migrator", + return_value=mock_migrator, + ), + patch("fastapi_startkit.application.app") as mock_app_fn, + ): + mock_model_cls.db_manager = mock_db_manager + mock_app_fn.return_value.use_base_path.return_value = "/fake/migrations" + await db.asyncStartTestRun() + + mock_migrator.fresh.assert_awaited_once_with(ignore_fk=True) + mock_connection.begin_transaction.assert_awaited_once() + self.assertTrue(RefreshDatabase.migrated)