Skip to content

Commit 52be132

Browse files
added join filter
1 parent f9f191f commit 52be132

File tree

11 files changed

+1133
-181
lines changed

11 files changed

+1133
-181
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
.idea
2+
__pychache__/
3+
.env
4+
dist

README.md

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Supported operators, datatypes and example of work you can find below.
1212
pip install fastapi-sa-orm-filter
1313
```
1414
### Compatibility
15-
v 0.2.0
15+
v 0.2.1
1616
- Python: >= 3.8
1717
- Fastapi: >= 0.100
1818
- Pydantic: >= 2.0.0
@@ -57,7 +57,7 @@ async def get_filtered_objects(
5757
return res.scalars().all()
5858
```
5959

60-
### Example of work
60+
### Examples of usage
6161

6262
```shell
6363

@@ -83,6 +83,47 @@ select(model)
8383
).order_by(model.id.desc(), model.category.asc())
8484
```
8585
86+
```shell
87+
# Filter by joined model
88+
89+
# Input query string
90+
'''vacancies.salary_from__gte=100'''
91+
92+
allowed_filter_fields = {
93+
"id": [ops.eq],
94+
"title": [ops.startswith, ops.eq, ops.contains],
95+
"salary_from": [ops.eq, ops.gt, ops.lte, ops.gte]
96+
}
97+
98+
company_filter = FilterCore(
99+
Company,
100+
allowed_filter_fields,
101+
select(Company).join(Vacancy).options(joinedload(Company.vacancies))
102+
)
103+
104+
@app.get("/")
105+
async def get_filtered_company(
106+
filter_query: str = "title__eq=MyCompany&vacancies.salary_from__gte=100",
107+
db: AsyncSession = Depends(get_session)
108+
) -> List[Company]:
109+
110+
query = company_filter.get_query(filter_query)
111+
res = await db.execute(query)
112+
return res.scalars().all()
113+
114+
# Returned SQLAlchemy query
115+
select(Company)
116+
.join(Vacancy)
117+
.options(joinedload(Company.vacancies))
118+
.where(
119+
and_(
120+
Company.title == "MyCompany",
121+
Vacancy.salary_from >= 100
122+
)
123+
)
124+
125+
```
126+
86127
### Supported query string format
87128
88129
* field_name__eq=value
@@ -93,7 +134,7 @@ select(model)
93134
### Modify query for custom selection
94135
```shell
95136
# Create a class inherited from FilterCore and rewrite 'get_unordered_query' method.
96-
# 0.2.0 Version
137+
# ^0.2.0 Version
97138

98139
class CustomFilter(FilterCore):
99140

fastapi_sa_orm_filter/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""FastAPI-SQLAlchemy filter, transform request query string to SQLAlchemy orm query"""
22

3-
__version__ = "0.2.0"
3+
__version__ = "0.2.1"
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
class SAFilterOrmException(Exception):
2+
pass

fastapi_sa_orm_filter/main.py

Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
11
import json
2-
from typing import Any, Dict, List, Type
2+
3+
from typing import Any, Type
34

45
import pydantic
56
from fastapi import HTTPException
67
from pydantic import create_model
78
from pydantic._internal._model_construction import ModelMetaclass
89
from sqlalchemy_to_pydantic import sqlalchemy_to_pydantic
9-
from sqlalchemy import select
10+
from sqlalchemy import select, inspect
1011
from sqlalchemy.orm import InstrumentedAttribute, DeclarativeMeta
1112
from sqlalchemy.sql.elements import BinaryExpression, UnaryExpression
1213
from sqlalchemy.sql.expression import and_, or_
1314
from starlette import status
1415
from sqlalchemy.sql import Select
1516

17+
from fastapi_sa_orm_filter.exceptions import SAFilterOrmException
1618
from fastapi_sa_orm_filter.operators import Operators as ops
1719
from fastapi_sa_orm_filter.parsers import _FilterQueryParser, _OrderByQueryParser
1820

@@ -24,7 +26,10 @@ class FilterCore:
2426
"""
2527

2628
def __init__(
27-
self, model: Type[DeclarativeMeta], allowed_filters: Dict[str, List[ops]]
29+
self,
30+
model: Type[DeclarativeMeta],
31+
allowed_filters: dict[str, list[ops]],
32+
select_query_part: Select[Any] = None
2833
) -> None:
2934
"""
3035
Produce a class:`FilterCore` object against a function
@@ -38,8 +43,10 @@ def __init__(
3843
}
3944
"""
4045
self.model = model
46+
self.relationships = inspect(self.model).relationships.items()
4147
self._allowed_filters = allowed_filters
42-
self._model_serializer = self._create_pydantic_serializer()
48+
self._model_serializers = self._create_pydantic_serializers()
49+
self.select_query_part = select_query_part
4350

4451
def get_query(self, custom_filter: str) -> Select[Any]:
4552
"""
@@ -63,58 +70,60 @@ def get_query(self, custom_filter: str) -> Select[Any]:
6370
).order_by(model.id.desc())
6471
"""
6572
split_query = self.split_by_order_by(custom_filter)
66-
if len(split_query) == 1:
67-
complete_query = self.get_complete_query(split_query[0])
68-
return complete_query
69-
filter_query_str, order_by_query_str = split_query
70-
complete_query = self.get_complete_query(filter_query_str, order_by_query_str)
73+
try:
74+
complete_query = self.get_complete_query(*split_query)
75+
except SAFilterOrmException as e:
76+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.args[0])
7177
return complete_query
7278

7379
def get_complete_query(self, filter_query_str: str, order_by_query_str: str | None = None) -> Select[Any]:
7480
select_query_part = self.get_select_query_part()
7581
filter_query_part = self.get_filter_query_part(filter_query_str)
7682
complete_query = select_query_part.filter(*filter_query_part)
7783
group_query_part = self.get_group_by_query_part()
78-
if group_query_part != []:
84+
if group_query_part:
7985
complete_query = complete_query.group_by(*group_query_part)
8086
if order_by_query_str is not None:
8187
order_by_query = self.get_order_by_query_part(order_by_query_str)
8288
complete_query = complete_query.order_by(*order_by_query)
8389
return complete_query
8490

8591
def get_select_query_part(self) -> Select[Any]:
92+
if self.select_query_part:
93+
return self.select_query_part
8694
return select(self.model)
8795

88-
def get_filter_query_part(self, filter_query_str: str) -> List[Any]:
96+
def get_filter_query_part(self, filter_query_str: str) -> list[Any]:
8997
conditions = self._get_filter_query(filter_query_str)
90-
if conditions == []:
98+
if len(conditions) == 0:
9199
return conditions
92100
return [or_(*conditions)]
93101

94-
def get_group_by_query_part(self):
102+
def get_group_by_query_part(self) -> list:
95103
return []
96104

97-
def get_order_by_query_part(self, order_by_query_str: str) -> List[UnaryExpression]:
105+
def get_order_by_query_part(self, order_by_query_str: str) -> list[UnaryExpression]:
98106
order_by_parser = _OrderByQueryParser(self.model)
99107
return order_by_parser.get_order_by_query(order_by_query_str)
100108

101-
def _get_filter_query(self, custom_filter: str) -> List[BinaryExpression]:
109+
def _get_filter_query(self, custom_filter: str) -> list[BinaryExpression]:
102110
filter_conditions = []
103-
if custom_filter == '':
111+
if custom_filter == "":
104112
return filter_conditions
105113
query_parser = _FilterQueryParser(custom_filter, self.model, self._allowed_filters)
114+
106115
for and_expressions in query_parser.get_parsed_query():
107116
and_condition = []
108117
for expression in and_expressions:
109-
column, operator, value = expression
110-
serialized_dict = self._format_expression(column, operator, value)
118+
table, column, operator, value = expression
119+
serialized_dict = self._format_expression(table, column, operator, value)
111120
value = serialized_dict[column.name]
112121
param = self._get_orm_for_field(column, operator, value)
113122
and_condition.append(param)
114123
filter_conditions.append(and_(*and_condition))
115124
return filter_conditions
116125

117-
def _create_pydantic_serializer(self) -> Dict[str, ModelMetaclass]:
126+
def _create_pydantic_serializers(self) -> dict[str, ModelMetaclass]:
118127
"""
119128
Create two pydantic models (optional and list field types)
120129
for value: str serialization
@@ -128,10 +137,23 @@ class model.__name__(BaseModel):
128137
field: Optional[List[type]]
129138
}
130139
"""
131-
pydantic_serializer = sqlalchemy_to_pydantic(self.model)
132-
optional_model = self._get_optional_pydantic_model(pydantic_serializer)
133-
optional_list_model = self._get_optional_pydantic_model(pydantic_serializer, is_list=True)
134-
return {"optional_model": optional_model, "optional_list_model": optional_list_model}
140+
141+
models = [self.model]
142+
models.extend(self.get_relations())
143+
144+
serializers = {}
145+
146+
for model in models:
147+
pydantic_serializer = sqlalchemy_to_pydantic(model)
148+
optional_model = self._get_optional_pydantic_model(model, pydantic_serializer)
149+
optional_list_model = self._get_optional_pydantic_model(model, pydantic_serializer, is_list=True)
150+
151+
serializers[model.__tablename__] = {"optional_model": optional_model, "optional_list_model":optional_list_model}
152+
153+
return serializers
154+
155+
def get_relations(self) -> list:
156+
return [relation[1].mapper.class_ for relation in self.relationships]
135157

136158
def _get_orm_for_field(
137159
self, column: InstrumentedAttribute, operator: str, value: Any
@@ -146,7 +168,7 @@ def _get_orm_for_field(
146168
return param
147169

148170
def _format_expression(
149-
self, column: InstrumentedAttribute, operator: str, value: str
171+
self, table: str, column: InstrumentedAttribute, operator: str, value: str
150172
) -> dict[str, Any]:
151173
"""
152174
Serialize expression value from string to python type value,
@@ -158,41 +180,29 @@ def _format_expression(
158180
try:
159181
if operator not in [ops.between, ops.in_]:
160182
value = value[0]
161-
serialized_dict = self._model_serializer["optional_model"](
162-
**{column.name: value}
163-
).model_dump(exclude_none=True)
164-
return serialized_dict
165-
serialized_dict = self._model_serializer["optional_list_model"](
166-
**{column.name: value}
167-
).model_dump(exclude_none=True)
168-
return serialized_dict
183+
model_serializer = self._model_serializers[table]["optional_model"]
184+
else:
185+
model_serializer = self._model_serializers[table]["optional_list_model"]
186+
return model_serializer(**{column.name: value}).model_dump(exclude_none=True)
169187
except pydantic.ValidationError as e:
170-
raise HTTPException(
171-
status_code=status.HTTP_400_BAD_REQUEST, detail=json.loads(e.json())
172-
)
188+
raise SAFilterOrmException(json.loads(e.json()))
173189
except ValueError:
174-
raise HTTPException(
175-
status_code=status.HTTP_400_BAD_REQUEST,
176-
detail=f"Incorrect filter value '{value}'",
177-
)
190+
raise SAFilterOrmException(f"Incorrect filter value '{value}'")
178191

179192
@staticmethod
180-
def split_by_order_by(query):
193+
def split_by_order_by(query) -> list:
181194
split_query = [query_part.strip("&") for query_part in query.split("order_by=")]
182195
if len(split_query) > 2:
183-
raise HTTPException(
184-
status_code=status.HTTP_400_BAD_REQUEST,
185-
detail="Use only one order_by directive",
186-
)
196+
raise SAFilterOrmException("Use only one order_by directive")
187197
return split_query
188198

189-
def _get_optional_pydantic_model(self, pydantic_serializer, is_list: bool = False):
199+
def _get_optional_pydantic_model(self, model, pydantic_serializer, is_list: bool = False):
190200
fields = {}
191201
for k, v in pydantic_serializer.model_fields.items():
192202
origin_annotation = getattr(v, 'annotation')
193203
if is_list:
194-
fields[k] = (List[origin_annotation], None)
204+
fields[k] = (list[origin_annotation], None)
195205
else:
196206
fields[k] = (origin_annotation, None)
197-
pydantic_model = create_model(self.model.__name__, **fields)
207+
pydantic_model = create_model(model.__name__, **fields)
198208
return pydantic_model

0 commit comments

Comments
 (0)