11import json
2- from typing import Any , Dict , List , Type
2+
3+ from typing import Any , Type
34
45import pydantic
56from fastapi import HTTPException
67from pydantic import create_model
78from pydantic ._internal ._model_construction import ModelMetaclass
89from sqlalchemy_to_pydantic import sqlalchemy_to_pydantic
9- from sqlalchemy import select
10+ from sqlalchemy import select , inspect
1011from sqlalchemy .orm import InstrumentedAttribute , DeclarativeMeta
1112from sqlalchemy .sql .elements import BinaryExpression , UnaryExpression
1213from sqlalchemy .sql .expression import and_ , or_
1314from starlette import status
1415from sqlalchemy .sql import Select
1516
17+ from fastapi_sa_orm_filter .exceptions import SAFilterOrmException
1618from fastapi_sa_orm_filter .operators import Operators as ops
1719from fastapi_sa_orm_filter .parsers import _FilterQueryParser , _OrderByQueryParser
1820
@@ -24,7 +26,10 @@ class FilterCore:
2426 """
2527
2628 def __init__ (
27- self , model : Type [DeclarativeMeta ], allowed_filters : Dict [str , List [ops ]]
29+ self ,
30+ model : Type [DeclarativeMeta ],
31+ allowed_filters : dict [str , list [ops ]],
32+ select_query_part : Select [Any ] = None
2833 ) -> None :
2934 """
3035 Produce a class:`FilterCore` object against a function
@@ -38,8 +43,10 @@ def __init__(
3843 }
3944 """
4045 self .model = model
46+ self .relationships = inspect (self .model ).relationships .items ()
4147 self ._allowed_filters = allowed_filters
42- self ._model_serializer = self ._create_pydantic_serializer ()
48+ self ._model_serializers = self ._create_pydantic_serializers ()
49+ self .select_query_part = select_query_part
4350
4451 def get_query (self , custom_filter : str ) -> Select [Any ]:
4552 """
@@ -63,58 +70,60 @@ def get_query(self, custom_filter: str) -> Select[Any]:
6370 ).order_by(model.id.desc())
6471 """
6572 split_query = self .split_by_order_by (custom_filter )
66- if len (split_query ) == 1 :
67- complete_query = self .get_complete_query (split_query [0 ])
68- return complete_query
69- filter_query_str , order_by_query_str = split_query
70- complete_query = self .get_complete_query (filter_query_str , order_by_query_str )
73+ try :
74+ complete_query = self .get_complete_query (* split_query )
75+ except SAFilterOrmException as e :
76+ raise HTTPException (status_code = status .HTTP_400_BAD_REQUEST , detail = e .args [0 ])
7177 return complete_query
7278
7379 def get_complete_query (self , filter_query_str : str , order_by_query_str : str | None = None ) -> Select [Any ]:
7480 select_query_part = self .get_select_query_part ()
7581 filter_query_part = self .get_filter_query_part (filter_query_str )
7682 complete_query = select_query_part .filter (* filter_query_part )
7783 group_query_part = self .get_group_by_query_part ()
78- if group_query_part != [] :
84+ if group_query_part :
7985 complete_query = complete_query .group_by (* group_query_part )
8086 if order_by_query_str is not None :
8187 order_by_query = self .get_order_by_query_part (order_by_query_str )
8288 complete_query = complete_query .order_by (* order_by_query )
8389 return complete_query
8490
8591 def get_select_query_part (self ) -> Select [Any ]:
92+ if self .select_query_part :
93+ return self .select_query_part
8694 return select (self .model )
8795
88- def get_filter_query_part (self , filter_query_str : str ) -> List [Any ]:
96+ def get_filter_query_part (self , filter_query_str : str ) -> list [Any ]:
8997 conditions = self ._get_filter_query (filter_query_str )
90- if conditions == [] :
98+ if len ( conditions ) == 0 :
9199 return conditions
92100 return [or_ (* conditions )]
93101
94- def get_group_by_query_part (self ):
102+ def get_group_by_query_part (self ) -> list :
95103 return []
96104
97- def get_order_by_query_part (self , order_by_query_str : str ) -> List [UnaryExpression ]:
105+ def get_order_by_query_part (self , order_by_query_str : str ) -> list [UnaryExpression ]:
98106 order_by_parser = _OrderByQueryParser (self .model )
99107 return order_by_parser .get_order_by_query (order_by_query_str )
100108
101- def _get_filter_query (self , custom_filter : str ) -> List [BinaryExpression ]:
109+ def _get_filter_query (self , custom_filter : str ) -> list [BinaryExpression ]:
102110 filter_conditions = []
103- if custom_filter == '' :
111+ if custom_filter == "" :
104112 return filter_conditions
105113 query_parser = _FilterQueryParser (custom_filter , self .model , self ._allowed_filters )
114+
106115 for and_expressions in query_parser .get_parsed_query ():
107116 and_condition = []
108117 for expression in and_expressions :
109- column , operator , value = expression
110- serialized_dict = self ._format_expression (column , operator , value )
118+ table , column , operator , value = expression
119+ serialized_dict = self ._format_expression (table , column , operator , value )
111120 value = serialized_dict [column .name ]
112121 param = self ._get_orm_for_field (column , operator , value )
113122 and_condition .append (param )
114123 filter_conditions .append (and_ (* and_condition ))
115124 return filter_conditions
116125
117- def _create_pydantic_serializer (self ) -> Dict [str , ModelMetaclass ]:
126+ def _create_pydantic_serializers (self ) -> dict [str , ModelMetaclass ]:
118127 """
119128 Create two pydantic models (optional and list field types)
120129 for value: str serialization
@@ -128,10 +137,23 @@ class model.__name__(BaseModel):
128137 field: Optional[List[type]]
129138 }
130139 """
131- pydantic_serializer = sqlalchemy_to_pydantic (self .model )
132- optional_model = self ._get_optional_pydantic_model (pydantic_serializer )
133- optional_list_model = self ._get_optional_pydantic_model (pydantic_serializer , is_list = True )
134- return {"optional_model" : optional_model , "optional_list_model" : optional_list_model }
140+
141+ models = [self .model ]
142+ models .extend (self .get_relations ())
143+
144+ serializers = {}
145+
146+ for model in models :
147+ pydantic_serializer = sqlalchemy_to_pydantic (model )
148+ optional_model = self ._get_optional_pydantic_model (model , pydantic_serializer )
149+ optional_list_model = self ._get_optional_pydantic_model (model , pydantic_serializer , is_list = True )
150+
151+ serializers [model .__tablename__ ] = {"optional_model" : optional_model , "optional_list_model" :optional_list_model }
152+
153+ return serializers
154+
155+ def get_relations (self ) -> list :
156+ return [relation [1 ].mapper .class_ for relation in self .relationships ]
135157
136158 def _get_orm_for_field (
137159 self , column : InstrumentedAttribute , operator : str , value : Any
@@ -146,7 +168,7 @@ def _get_orm_for_field(
146168 return param
147169
148170 def _format_expression (
149- self , column : InstrumentedAttribute , operator : str , value : str
171+ self , table : str , column : InstrumentedAttribute , operator : str , value : str
150172 ) -> dict [str , Any ]:
151173 """
152174 Serialize expression value from string to python type value,
@@ -158,41 +180,29 @@ def _format_expression(
158180 try :
159181 if operator not in [ops .between , ops .in_ ]:
160182 value = value [0 ]
161- serialized_dict = self ._model_serializer ["optional_model" ](
162- ** {column .name : value }
163- ).model_dump (exclude_none = True )
164- return serialized_dict
165- serialized_dict = self ._model_serializer ["optional_list_model" ](
166- ** {column .name : value }
167- ).model_dump (exclude_none = True )
168- return serialized_dict
183+ model_serializer = self ._model_serializers [table ]["optional_model" ]
184+ else :
185+ model_serializer = self ._model_serializers [table ]["optional_list_model" ]
186+ return model_serializer (** {column .name : value }).model_dump (exclude_none = True )
169187 except pydantic .ValidationError as e :
170- raise HTTPException (
171- status_code = status .HTTP_400_BAD_REQUEST , detail = json .loads (e .json ())
172- )
188+ raise SAFilterOrmException (json .loads (e .json ()))
173189 except ValueError :
174- raise HTTPException (
175- status_code = status .HTTP_400_BAD_REQUEST ,
176- detail = f"Incorrect filter value '{ value } '" ,
177- )
190+ raise SAFilterOrmException (f"Incorrect filter value '{ value } '" )
178191
179192 @staticmethod
180- def split_by_order_by (query ):
193+ def split_by_order_by (query ) -> list :
181194 split_query = [query_part .strip ("&" ) for query_part in query .split ("order_by=" )]
182195 if len (split_query ) > 2 :
183- raise HTTPException (
184- status_code = status .HTTP_400_BAD_REQUEST ,
185- detail = "Use only one order_by directive" ,
186- )
196+ raise SAFilterOrmException ("Use only one order_by directive" )
187197 return split_query
188198
189- def _get_optional_pydantic_model (self , pydantic_serializer , is_list : bool = False ):
199+ def _get_optional_pydantic_model (self , model , pydantic_serializer , is_list : bool = False ):
190200 fields = {}
191201 for k , v in pydantic_serializer .model_fields .items ():
192202 origin_annotation = getattr (v , 'annotation' )
193203 if is_list :
194- fields [k ] = (List [origin_annotation ], None )
204+ fields [k ] = (list [origin_annotation ], None )
195205 else :
196206 fields [k ] = (origin_annotation , None )
197- pydantic_model = create_model (self . model .__name__ , ** fields )
207+ pydantic_model = create_model (model .__name__ , ** fields )
198208 return pydantic_model
0 commit comments