Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ def is_dirty(self) -> bool:
def get_dirty(self) -> dict:
return {key: value for key, value in self.get_attributes().items() if not self.original_is_equivalent(key)}

def delete_attribute(self, key: str):
"""Remove a transient attribute that was added during query processing."""
self._attributes.pop(key, None)
self._dirty_attributes.pop(key, None)

def get_attributes_for_insert(self) -> dict:
# _dirty_attributes already went through set_attribute (casts applied on assignment).
# _attributes is set raw via new_model_instance, so apply set casts here.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def with_(self, *eagers) -> "QueryBuilder":
def get_table_name(self) -> str:
return self._table

def table(self, table: str) -> "QueryBuilder":
self._table = table
return self

def where_in(self, column: str, values) -> "QueryBuilder":
if hasattr(values, "_items"):
values = values._items
Expand Down Expand Up @@ -136,6 +140,10 @@ def run_scopes(self) -> "QueryBuilder":
scope(self)
return self

def without_global_scopes(self) -> "QueryBuilder":
self._global_scopes = {}
return self

def get_grammar(self):
return self.grammar(
columns=self._columns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ def __init__(
attribute="pivot",
with_fields=[],
):
fn_str = fn
if isinstance(fn, str):
self.fn = self.fn = lambda x: registry.Registry.resolve(fn)
self.fn = lambda: registry.Registry.resolve(fn_str)

self.local_key = local_foreign_key
self.foreign_key = other_foreign_key
Expand Down Expand Up @@ -134,16 +135,11 @@ async def apply_query(self, query, owner):
pivot_data.update({field: getattr(model, field)})
model.delete_attribute(field)

model.__original_attributes__.update(
{
self._as: (
Pivot.on(query.connection)
.table(self._table)
.hydrate(pivot_data)
.activate_timestamps(self.with_timestamps)
)
}
)
pivot_model = Pivot()
pivot_model.__table__ = self._table
pivot_model.__timestamps__ = self.with_timestamps
pivot_model.set_raw_attributes(pivot_data, True)
model._attributes[self._as] = pivot_model

return result

Expand Down Expand Up @@ -266,16 +262,11 @@ async def get_related(self, query, relation, eagers=None, callback=None):
pivot_data.update({field: getattr(model, field)})
model.delete_attribute(field)

model.__original_attributes__.update(
{
self._as: (
Pivot.on(builder.connection_name)
.table(self._table)
.hydrate(pivot_data)
.activate_timestamps(self.with_timestamps)
)
}
)
pivot_model = Pivot()
pivot_model.__table__ = self._table
pivot_model.__timestamps__ = self.with_timestamps
pivot_model.set_raw_attributes(pivot_data, True)
model._attributes[self._as] = pivot_model

return final_result

Expand Down Expand Up @@ -487,7 +478,9 @@ def attach(self, current_model, related_record):
self.foreign_key: getattr(related_record, self.other_owner_key),
}

self._table = self._table or self.get_pivot_table_name(current_model, related_record)
self._table = self._table or self.get_pivot_table_name(
current_model.get_builder(), related_record.get_builder()
)

if self.with_timestamps:
data.update(
Expand All @@ -497,20 +490,22 @@ def attach(self, current_model, related_record):
}
)

return Pivot.on(current_model.get_builder().connection).table(self._table).without_global_scopes().create(data)
return current_model.get_builder().connection.query().table(self._table).insert(data)

def detach(self, current_model, related_record):
data = {
self.local_key: getattr(current_model, self.local_owner_key),
self.foreign_key: getattr(related_record, self.other_owner_key),
}

self._table = self._table or self.get_pivot_table_name(current_model, related_record)
self._table = self._table or self.get_pivot_table_name(
current_model.get_builder(), related_record.get_builder()
)

return (
Pivot.on(current_model.get_builder().connection)
current_model.get_builder()
.connection.query()
.table(self._table)
.without_global_scopes()
.where(data)
.delete()
)
Expand All @@ -521,7 +516,9 @@ def attach_related(self, current_model, related_record):
self.foreign_key: getattr(related_record, self.other_owner_key),
}

self._table = self._table or self.get_pivot_table_name(current_model, related_record)
self._table = self._table or self.get_pivot_table_name(
current_model.get_builder(), related_record.get_builder()
)

if self.with_timestamps:
data.update(
Expand All @@ -531,20 +528,17 @@ def attach_related(self, current_model, related_record):
}
)

return (
Pivot.table(self._table)
.on(current_model.get_builder().connection_name)
.without_global_scopes()
.create(data)
)
return current_model.get_builder().connection.query().table(self._table).insert(data)

def detach_related(self, current_model, related_record):
data = {
self.local_key: getattr(current_model, self.local_owner_key),
self.foreign_key: getattr(related_record, self.other_owner_key),
}

self._table = self._table or self.get_pivot_table_name(current_model, related_record)
self._table = self._table or self.get_pivot_table_name(
current_model.get_builder(), related_record.get_builder()
)

if self.with_timestamps:
data.update(
Expand All @@ -555,9 +549,9 @@ def detach_related(self, current_model, related_record):
)

return (
Pivot.on(current_model.get_builder().connection_name)
current_model.get_builder()
.connection.query()
.table(self._table)
.without_global_scopes()
.where(data)
.delete()
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@ def set_keys(self, owner, attribute):
return self

def __get__(self, instance, owner):
if instance is None:
return self

attribute = self.fn.__name__
self._related_builder = instance.builder
self.polymorphic_builder = self.fn(self)()
self._related_builder = instance.get_builder()
self.polymorphic_builder = self.fn(self).query()
self.set_keys(owner, self.fn)

if not instance.is_loaded():
Expand All @@ -32,8 +35,7 @@ def __get__(self, instance, owner):
return self.apply_query(self._related_builder, instance)

def __getattr__(self, attribute):
relationship = self.fn(self)()
return getattr(relationship.builder, attribute)
return getattr(self.fn(self).query(), attribute)

def apply_query(self, builder, instance):
"""Apply the query and return a dictionary to be hydrated
Expand All @@ -45,7 +47,7 @@ def apply_query(self, builder, instance):
Returns:
dict -- A dictionary of data which will be hydrated.
"""
polymorphic_key = self.get_record_key_lookup(builder._model)
polymorphic_key = self.get_record_key_lookup(instance)
polymorphic_builder = self.polymorphic_builder
return (
polymorphic_builder.where(self.morph_key, polymorphic_key)
Expand Down Expand Up @@ -115,13 +117,7 @@ def morph_map(self):
return registry.Registry.get_morph_map()

def get_record_key_lookup(self, relation):
record_type = None
for record_type_loop, model in self.morph_map().items():
if model == relation.__class__:
record_type = record_type_loop
break

if not record_type:
morph_name = registry.Registry._reverse_map.get(relation.__class__)
if morph_name is None:
raise ValueError(f"Could not find the record type key for the {relation} class")

return record_type
return morph_name
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def register_related(self, key, model, collection):
model.add_relation({key: related})

def morph_map(self):
return load_config().DB._morph_map
from fastapi_startkit.masoniteorm.models import registry

return registry.Registry.get_morph_map()

def get_record_key_lookup(self, relation):
record_type = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,6 @@ def register_related(self, key, model, collection):
model.add_relation({key: related})

def morph_map(self):
return load_config().DB._morph_map
from fastapi_startkit.masoniteorm.models import registry

return registry.Registry.get_morph_map()
Empty file.
Loading
Loading