Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def is_dirty(self) -> bool:
def get_dirty(self) -> dict:
return {key: value for key, value in self.get_attributes().items() if not self.original_is_equivalent(key)}

def delete_attribute(self, key: str):
"""Remove a transient attribute that was added during query processing."""
self._attributes.pop(key, None)
self._dirty_attributes.pop(key, None)

def get_attributes_for_insert(self) -> dict:
# _dirty_attributes already went through set_attribute (casts applied on assignment).
# _attributes is set raw via new_model_instance, so apply set casts here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def with_(self, *eagers) -> "QueryBuilder":
def get_table_name(self) -> str:
return self._table

def table(self, table: str) -> "QueryBuilder":
self._table = table
return self

def where_in(self, column: str, values) -> "QueryBuilder":
if hasattr(values, "_items"):
values = values._items
Expand Down Expand Up @@ -136,6 +140,10 @@ def run_scopes(self) -> "QueryBuilder":
scope(self)
return self

def without_global_scopes(self) -> "QueryBuilder":
self._global_scopes = {}
return self

def get_grammar(self):
return self.grammar(
columns=self._columns,
Expand Down
17 changes: 13 additions & 4 deletions fastapi_startkit/src/fastapi_startkit/masoniteorm/models/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations
import inflection

from typing import TYPE_CHECKING

import inflection

from fastapi_startkit.carbon import Carbon
from fastapi_startkit.masoniteorm.collection import Collection
from fastapi_startkit.masoniteorm.models.fields import CreatedAtField, UpdatedAtField
from fastapi_startkit.masoniteorm.models.registry import Registry
from fastapi_startkit.masoniteorm.observers import ObservesEvents
from fastapi_startkit.masoniteorm.connections.manager import DatabaseManager
from fastapi_startkit.masoniteorm.models.attribute import Attribute
from fastapi_startkit.masoniteorm.models.fields import CreatedAtField, UpdatedAtField
from fastapi_startkit.masoniteorm.models.registry import Registry
from fastapi_startkit.masoniteorm.models.relationship import Relationship
from fastapi_startkit.masoniteorm.observers import ObservesEvents

if TYPE_CHECKING:
from fastapi_startkit.masoniteorm.models.builder import QueryBuilder
Expand Down Expand Up @@ -131,6 +132,14 @@ async def all(cls):
async def count(cls, column: str = "*"):
return await cls.query().count(column)

def table(self, table: str):
self.__table__ = table
return self

def timestamps(self, timestamps: bool = True):
self.__timestamps__ = timestamps
return self

def set_connection(self, connection: str):
self.connection = connection

Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import pendulum
from inflection import singularize

from fastapi_startkit.masoniteorm.models import registry
from .BaseRelationship import BaseRelationship
from ..collection import Collection
from ..models.pivot import Pivot
from .BaseRelationship import BaseRelationship
from fastapi_startkit.masoniteorm.models import registry


class BelongsToMany(BaseRelationship):
"""Has Many Relationship Class."""

def __init__(
self,
fn=None,
fn: str,
local_foreign_key=None,
other_foreign_key=None,
local_owner_key=None,
Expand All @@ -23,8 +23,7 @@ def __init__(
attribute="pivot",
with_fields=[],
):
if isinstance(fn, str):
self.fn = self.fn = lambda x: registry.Registry.resolve(fn)
self.fn = lambda: registry.Registry.resolve(fn)

self.local_key = local_foreign_key
self.foreign_key = other_foreign_key
Expand Down Expand Up @@ -134,15 +133,15 @@ async def apply_query(self, query, owner):
pivot_data.update({field: getattr(model, field)})
model.delete_attribute(field)

model.__original_attributes__.update(
{
self._as: (
Pivot.on(query.connection)
.table(self._table)
.hydrate(pivot_data)
.activate_timestamps(self.with_timestamps)
)
}
model.set_attribute(
self._as,
(
Pivot()
.on(query.connection)
.table(self._table)
.set_raw_attributes(pivot_data, True)
.timestamps(self.with_timestamps)
),
)

return result
Expand All @@ -151,11 +150,6 @@ def table(self, table):
self._table = table
return self

def make_builder(self, eagers=None):
builder = self.get_builder().with_(eagers)

return builder

async def make_query(self, query, relation, eagers=None, callback=None):
"""Used during eager loading a relationship

Expand Down Expand Up @@ -239,7 +233,6 @@ async def make_query(self, query, relation, eagers=None, callback=None):

async def get_related(self, query, relation, eagers=None, callback=None):
final_result = await self.make_query(query, relation, eagers=eagers, callback=callback)
builder = self.make_builder(eagers)

for model in final_result:
pivot_data = {
Expand All @@ -266,15 +259,15 @@ async def get_related(self, query, relation, eagers=None, callback=None):
pivot_data.update({field: getattr(model, field)})
model.delete_attribute(field)

model.__original_attributes__.update(
{
self._as: (
Pivot.on(builder.connection_name)
.table(self._table)
.hydrate(pivot_data)
.activate_timestamps(self.with_timestamps)
)
}
model.set_attribute(
self._as,
(
Pivot()
.on(query.connection)
.table(self._table)
.set_raw_attributes(pivot_data, True)
.timestamps(self.with_timestamps)
),
)

return final_result
Expand Down Expand Up @@ -497,7 +490,7 @@ def attach(self, current_model, related_record):
}
)

return Pivot.on(current_model.get_builder().connection).table(self._table).without_global_scopes().create(data)
return current_model.get_builder().connection.query().table(self._table).insert(data)

def detach(self, current_model, related_record):
data = {
Expand All @@ -508,9 +501,10 @@ def detach(self, current_model, related_record):
self._table = self._table or self.get_pivot_table_name(current_model, related_record)

return (
Pivot.on(current_model.get_builder().connection)
.table(self._table)
current_model.get_builder()
.connection.query()
.without_global_scopes()
.table(self._table)
.where(data)
.delete()
)
Expand All @@ -531,12 +525,7 @@ def attach_related(self, current_model, related_record):
}
)

return (
Pivot.table(self._table)
.on(current_model.get_builder().connection_name)
.without_global_scopes()
.create(data)
)
return current_model.get_builder().connection.query().table(self._table).insert(data)

def detach_related(self, current_model, related_record):
data = {
Expand All @@ -554,10 +543,4 @@ def detach_related(self, current_model, related_record):
}
)

return (
Pivot.on(current_model.get_builder().connection_name)
.table(self._table)
.without_global_scopes()
.where(data)
.delete()
)
return current_model.get_builder().connection.query().table(self._table).where(data).delete()
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from ...fixtures.model import Product, Store
from ..test_case import TestCase


class TestBelongsToManyRelationship(TestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
self.store = await Store.create({"name": "Test Store"})
self.product1 = await Product.create({"name": "Widget"})
self.product2 = await Product.create({"name": "Gadget"})

async def test_attach_creates_pivot_record(self):
await Store.products.attach(self.store, self.product1)
store = await Store.where("id", self.store.id).first()
products = await store.products
self.assertEqual(len(products), 1)
self.assertEqual(products[0].name, "Widget")

async def test_attach_multiple_products(self):
await Store.products.attach(self.store, self.product1)
await Store.products.attach(self.store, self.product2)
store = await Store.where("id", self.store.id).first()
products = await store.products
self.assertEqual(len(products), 2)

async def test_detach_removes_pivot_record(self):
await Store.products.attach(self.store, self.product1)
await Store.products.attach(self.store, self.product2)
await Store.products.detach(self.store, self.product1)
store = await Store.where("id", self.store.id).first()
products = await store.products
self.assertEqual(len(products), 1)
self.assertEqual(products[0].name, "Gadget")

async def test_eager_load_belongs_to_many(self):
await Store.products.attach(self.store, self.product1)
stores = await Store.with_("products").get()
store = stores.where("id", self.store.id).first()
self.assertIsNotNone(store)
self.assertEqual(len(store.products), 1)

async def test_eager_load_empty_relationship(self):
stores = await Store.with_("products").get()
store = stores.where("id", self.store.id).first()
self.assertIsNotNone(store)
# Empty BelongsToMany eager load returns None (consistent with other relationships)
self.assertIsNone(store.products)

async def test_pivot_access_after_eager_load(self):
await Store.products.attach(self.store, self.product1)
stores = await Store.with_("products").get()
store = stores.where("id", self.store.id).first()
product = store.products[0]
pivot = product.pivot
self.assertIsNotNone(pivot)

async def test_with_timestamps_in_pivot(self):
# Store.products uses with_timestamps=True
await Store.products.attach(self.store, self.product1)
stores = await Store.with_("products").get()
store = stores.where("id", self.store.id).first()
product = store.products[0]
self.assertIsNotNone(product.pivot)

async def test_explicit_table_relationship(self):
# Store.products_table uses table="product_table"
await Store.products_table.attach(self.store, self.product1)
store = await Store.where("id", self.store.id).first()
products = await store.products_table
self.assertEqual(len(products), 1)

async def test_attach_related_creates_pivot_record(self):
await Store.products.attach_related(self.store, self.product1)
store = await Store.where("id", self.store.id).first()
products = await store.products
self.assertEqual(len(products), 1)

async def test_detach_related_removes_pivot_record(self):
await Store.products.attach_related(self.store, self.product1)
await Store.products.detach_related(self.store, self.product1)
store = await Store.where("id", self.store.id).first()
products = await store.products
self.assertEqual(len(products), 0)

async def test_get_pivot_table_name(self):
# Test the helper method directly using builder proxies
rel = Store.products
# Manually set the pivot table name to test the method
rel._table = None
name = rel.get_pivot_table_name(self.store.get_builder(), self.product1.get_builder())
self.assertEqual(name, "product_store")

async def test_map_related_returns_result(self):
results = [self.product1, self.product2]
rel = Store.products
mapped = rel.map_related(results)
self.assertEqual(mapped, results)

async def test_register_related_groups_by_owner_key(self):
await Store.products.attach(self.store, self.product1)
stores = await Store.with_("products").get()
store = stores.where("id", self.store.id).first()
# If register_related works, the products collection is populated
self.assertEqual(len(store.products), 1)

async def test_query_has_filters_stores_with_products(self):
store2 = await Store.create({"name": "Empty Store"})
await Store.products.attach(self.store, self.product1)

stores_with_products = await Store.where_has("products").get()
store_ids = [s.id for s in stores_with_products]

self.assertIn(self.store.id, store_ids)
self.assertNotIn(store2.id, store_ids)

async def test_query_has_returns_builder(self):
# query_has should add a where_exists clause to the builder
builder = Store.query()
result = Store.products.query_has(builder, method="where_exists")
# The builder has a where clause appended (we just verify no error is raised)
self.assertIsNotNone(builder)
Loading
Loading