Skip to content

Commit a69f4b9

Browse files
refactor code base, update dependencies
1 parent 8204abb commit a69f4b9

File tree

8 files changed

+593
-539
lines changed

8 files changed

+593
-539
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ v 0.1.5
3131
```shell
3232
from fastapi import FastAPI
3333
from fastapi.params import Query
34-
from fastapi_sa_orm_filter import FilterCore, ops
34+
from fastapi_sa_orm_filter import FilterCore, Operators as ops
3535

3636
from db.base import get_session
3737
from db.models import MyModel

fastapi_sa_orm_filter/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,7 @@
22
from fastapi_sa_orm_filter.main import FilterCore # noqa
33
from fastapi_sa_orm_filter.operators import Operators as ops # noqa
44

5-
__version__ = "0.2.2"
5+
__version__ = "0.2.3"
6+
7+
from .main import FilterCore as FilterCore # noqa
8+
from .operators import Operators as Operators # noqa
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from abc import ABC, abstractmethod
2+
3+
from fastapi_sa_orm_filter.dto import ParsedFilter
4+
from fastapi_sa_orm_filter.operators import Operators as ops
5+
6+
7+
class QueryParser(ABC):
8+
9+
@abstractmethod
10+
def __init__(self, custom_filter: str, allowed_filters: dict[str, list[ops]]) -> None:
11+
self.custom_filter = custom_filter
12+
self.allowed_filters = allowed_filters
13+
14+
@abstractmethod
15+
def get_parsed_filter(self) -> tuple[list[list[ParsedFilter]], list[str]] | tuple[list, list]:
16+
pass

fastapi_sa_orm_filter/main.py

Lines changed: 38 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from starlette import status
99
from sqlalchemy.sql import Select
1010

11+
from fastapi_sa_orm_filter.dto import ParsedFilter
1112
from fastapi_sa_orm_filter.exceptions import SAFilterOrmException
13+
from fastapi_sa_orm_filter.interfaces import QueryParser
1214
from fastapi_sa_orm_filter.operators import Operators as ops
13-
from fastapi_sa_orm_filter.parsers import FilterQueryParser, OrderByQueryParser
14-
from fastapi_sa_orm_filter.sa_expression_builder import SAFilterExpressionBuilder
15+
from fastapi_sa_orm_filter.parsers import StringQueryParser
16+
from fastapi_sa_orm_filter.sa_expression_builder import SAFilterExpressionBuilder, SAOrderByExpressionBuilder
1517

1618

1719
class FilterCore:
@@ -36,10 +38,11 @@ def __init__(
3638
'field_name': [startswith, eq, in_],
3739
'field_name': [contains, like]
3840
}
41+
:param select_query_part: custom select query part (select(model).join(model1))
3942
"""
4043
self.model = model
4144
self._allowed_filters = allowed_filters
42-
self.select_query_part = select_query_part
45+
self.select_sql_query = select_query_part
4346

4447
def get_query(self, custom_filter: str) -> Select[Any]:
4548
"""
@@ -50,7 +53,6 @@ def get_query(self, custom_filter: str) -> Select[Any]:
5053
created_at__between=2023-05-01,2023-05-05|
5154
category__eq=Medicine&
5255
order_by=-id
53-
:param select_query_part: custom select query part (select(model).join(model1))
5456
5557
:return:
5658
select(model)
@@ -63,56 +65,49 @@ def get_query(self, custom_filter: str) -> Select[Any]:
6365
model.category == 'Medicine'
6466
).order_by(model.id.desc())
6567
"""
66-
split_query = self._split_by_order_by(custom_filter)
6768
try:
68-
complete_query = self._get_complete_query(*split_query)
69+
query_parser = self._get_query_parser(custom_filter)
70+
filter_query, order_by_query = query_parser.get_parsed_filter()
71+
complete_query = self._get_complete_query(filter_query, order_by_query)
6972
except SAFilterOrmException as e:
7073
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=e.args[0])
7174
return complete_query
7275

73-
def _get_complete_query(self, filter_query_str: str, order_by_query_str: str | None = None) -> Select[Any]:
74-
select_query_part = self.get_select_query_part()
75-
filter_query_part = self._get_filter_query_part(filter_query_str)
76-
complete_query = select_query_part.filter(*filter_query_part)
77-
group_query_part = self.get_group_by_query_part()
78-
if group_query_part:
79-
complete_query = complete_query.group_by(*group_query_part)
80-
if order_by_query_str is not None:
81-
order_by_query = self.get_order_by_query_part(order_by_query_str)
82-
complete_query = complete_query.order_by(*order_by_query)
83-
return complete_query
76+
def _get_complete_query(
77+
self, filter_query: list[list[ParsedFilter]] | list, order_by_query: list[str] | list
78+
) -> Select[Any]:
79+
select_sa_query = self.get_select_query_part()
80+
filter_sa_query = self._get_filter_sa_query(filter_query)
81+
group_by_sa_query = self._get_group_by_sa_query()
82+
order_by_sa_query = self._get_order_by_sa_query(order_by_query)
83+
return select_sa_query.filter(*filter_sa_query).group_by(*group_by_sa_query).order_by(*order_by_sa_query)
8484

8585
def get_select_query_part(self) -> Select[Any]:
86-
if self.select_query_part is not None:
87-
return self.select_query_part
86+
if self.select_sql_query is not None:
87+
return self.select_sql_query
8888
return select(self.model)
8989

90-
def _get_filter_query_part(self, filter_query_str: str) -> list[Any]:
91-
conditions = self._get_filter_query(filter_query_str)
92-
if len(conditions) == 0:
93-
return conditions
90+
def _get_filter_sa_query(self, filter_query: list[list[ParsedFilter]] | list) -> list[BinaryExpression] | list:
91+
if len(filter_query) == 0:
92+
return []
93+
sa_builder = SAFilterExpressionBuilder(self.model)
94+
conditions = sa_builder.get_expressions(filter_query)
9495
return [or_(*conditions)]
9596

96-
def get_group_by_query_part(self) -> list:
97-
return []
97+
def _get_order_by_sa_query(self, order_by_query: list[str] | list) -> list[UnaryExpression]:
98+
if len(order_by_query) == 0:
99+
return []
100+
sa_builder = SAOrderByExpressionBuilder(self.model)
101+
return sa_builder.get_order_by_query(order_by_query)
98102

99-
def get_order_by_query_part(self, order_by_query_str: str) -> list[UnaryExpression]:
100-
order_by_parser = OrderByQueryParser(self.model)
101-
return order_by_parser.get_order_by_query(order_by_query_str)
103+
def _get_group_by_sa_query(self) -> list[BinaryExpression] | list:
104+
group_query_part = self.get_group_by_query_part()
105+
if len(group_query_part) == 0:
106+
return []
107+
return group_query_part
102108

103-
def _get_filter_query(self, custom_filter: str) -> list[BinaryExpression]:
104-
filter_conditions = []
105-
if custom_filter == "":
106-
return filter_conditions
109+
def get_group_by_query_part(self) -> list:
110+
return []
107111

108-
parser = FilterQueryParser(custom_filter, self._allowed_filters)
109-
parsed_filters = parser.get_parsed_query()
110-
sa_builder = SAFilterExpressionBuilder(self.model)
111-
return sa_builder.get_expressions(parsed_filters)
112-
113-
@staticmethod
114-
def _split_by_order_by(query) -> list:
115-
split_query = [query_part.strip("&") for query_part in query.split("order_by=")]
116-
if len(split_query) > 2:
117-
raise SAFilterOrmException("Use only one order_by directive")
118-
return split_query
112+
def _get_query_parser(self, custom_filter: str) -> QueryParser:
113+
return StringQueryParser(custom_filter, self._allowed_filters)

fastapi_sa_orm_filter/parsers.py

Lines changed: 48 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,73 @@
1-
from sqlalchemy.orm import DeclarativeBase
2-
from sqlalchemy.sql.elements import UnaryExpression
3-
41
from fastapi_sa_orm_filter.dto import ParsedFilter
52
from fastapi_sa_orm_filter.exceptions import SAFilterOrmException
3+
from fastapi_sa_orm_filter.interfaces import QueryParser
64
from fastapi_sa_orm_filter.operators import Operators as ops
7-
from fastapi_sa_orm_filter.operators import OrderSequence
85

96

10-
class OrderByQueryParser:
7+
class StringQueryParser(QueryParser):
8+
9+
def __init__(self, custom_filter: str, allowed_filters: dict[str, list[ops]]) -> None:
10+
self.custom_filter = custom_filter
11+
self.allowed_filters = allowed_filters
12+
13+
def get_parsed_filter(self) -> tuple[list[list[ParsedFilter]], list[str]] | tuple[list, list]:
14+
parsed_filter = []
15+
parsed_order_by = []
16+
17+
if self.custom_filter == "":
18+
return parsed_filter, parsed_order_by
19+
20+
split_query = [query_part.strip("&") for query_part in self.custom_filter.split("order_by=")]
21+
22+
if len(split_query) > 2:
23+
raise SAFilterOrmException("Use only one order_by directive")
24+
25+
parsed_filter = self._get_filter_query_part(split_query[0])
26+
27+
if len(split_query) == 2:
28+
parsed_order_by = self._get_order_by_query_part(split_query[1])
29+
30+
return parsed_filter, parsed_order_by
31+
32+
def _get_filter_query_part(self, filter_query_str: str) -> list[list[ParsedFilter]] | list:
33+
if filter_query_str == "":
34+
return []
35+
filter_parser = StringFilterQueryParser(self.allowed_filters)
36+
return filter_parser.get_parsed_query(filter_query_str)
37+
38+
def _get_order_by_query_part(self, order_by_query_str: str) -> list[str] | list:
39+
if order_by_query_str == "":
40+
return []
41+
order_by_parser = StringOrderByQueryParser()
42+
return order_by_parser.get_order_by_query(order_by_query_str)
43+
44+
45+
class StringOrderByQueryParser:
1146
"""
1247
Class parse order by part of request query string.
1348
"""
14-
def __init__(self, model: type[DeclarativeBase]) -> None:
15-
self._model = model
16-
17-
def get_order_by_query(self, order_by_query_str: str) -> list[UnaryExpression]:
18-
order_by_fields = self._validate_order_by_fields(order_by_query_str)
19-
order_by_query = []
20-
for field in order_by_fields:
21-
if '-' in field:
22-
column = getattr(self._model, field.strip('-'))
23-
order_by_query.append(getattr(column, OrderSequence.desc)())
24-
else:
25-
column = getattr(self._model, field.strip('+'))
26-
order_by_query.append(getattr(column, OrderSequence.asc)())
27-
return order_by_query
28-
29-
def _validate_order_by_fields(self, order_by_query_str: str) -> list[str]:
30-
"""
31-
:return:
32-
[
33-
+field_name,
34-
-field_name
35-
]
36-
"""
37-
order_by_fields = order_by_query_str.split(",")
38-
model_fields = self._model.__table__.columns.keys()
39-
for field in order_by_fields:
40-
field = field.strip('+').strip('-')
41-
if field in model_fields:
42-
continue
43-
raise SAFilterOrmException(f"Incorrect order_by field name {field} for model {self._model.__name__}")
44-
return order_by_fields
49+
def get_order_by_query(self, order_by_query_str: str) -> list[str]:
50+
return order_by_query_str.split(",")
4551

4652

47-
class FilterQueryParser:
53+
class StringFilterQueryParser:
4854
"""
4955
Class parse filter part of request query string.
5056
"""
5157

5258
def __init__(
53-
self, query: str,
54-
allowed_filters: dict[str, list[ops]]
59+
self, allowed_filters: dict[str, list[ops]]
5560
) -> None:
56-
self._query = query
5761
self._allowed_filters = allowed_filters
5862

59-
def get_parsed_query(self) -> list[list[ParsedFilter]]:
63+
def get_parsed_query(self, filter_query_str: str) -> list[list[ParsedFilter]]:
6064
"""
6165
:return:
6266
[
6367
[ParsedFilter, ParsedFilter, ParsedFilter]
6468
]
6569
"""
66-
and_blocks = self._parse_by_conjunctions()
70+
and_blocks = self._parse_by_conjunctions(filter_query_str)
6771
parsed_query = []
6872
for and_block in and_blocks:
6973
parsed_and_blocks = []
@@ -74,7 +78,7 @@ def get_parsed_query(self) -> list[list[ParsedFilter]]:
7478
parsed_query.append(parsed_and_blocks)
7579
return parsed_query
7680

77-
def _parse_by_conjunctions(self) -> list[list[str]]:
81+
def _parse_by_conjunctions(self, filter_query_str: str) -> list[list[str]]:
7882
"""
7983
Split request query string by 'OR' and 'AND' conjunctions
8084
to divide query string to field's conditions
@@ -84,7 +88,7 @@ def _parse_by_conjunctions(self) -> list[list[str]]:
8488
['field_name__operator=value']
8589
]
8690
"""
87-
and_blocks = [block.split("&") for block in self._query.split("|")]
91+
and_blocks = [block.split("&") for block in filter_query_str.split("|")]
8892
return and_blocks
8993

9094
def _parse_expression(

fastapi_sa_orm_filter/sa_expression_builder.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
from typing import Any
33

44
import pydantic
5-
from pydantic import create_model
5+
from pydantic import create_model, BaseModel
66
from pydantic._internal._model_construction import ModelMetaclass
7-
from sqlalchemy import inspect, BinaryExpression, and_
7+
from sqlalchemy import inspect, BinaryExpression, and_, UnaryExpression
88
from sqlalchemy.orm import DeclarativeBase, InstrumentedAttribute
99
from sqlalchemy_to_pydantic import sqlalchemy_to_pydantic
1010

1111
from fastapi_sa_orm_filter.exceptions import SAFilterOrmException
12-
from fastapi_sa_orm_filter.operators import Operators as ops
12+
from fastapi_sa_orm_filter.operators import Operators as ops, OrderSequence
1313

1414

1515
class SAFilterExpressionBuilder:
@@ -32,8 +32,10 @@ def get_expressions(self, parsed_filters) -> list[BinaryExpression]:
3232
model = self.get_relation_model(and_filter.relation)
3333
table = model.__tablename__
3434
column = self.get_column(model, and_filter.field_name)
35-
serialized_dict = self.serialize_expression_value(table, column, and_filter.operator, and_filter.value)
36-
value = serialized_dict[column.name]
35+
serialized_dict = self.serialize_expression_value(
36+
table, and_filter.field_name, and_filter.operator, and_filter.value
37+
)
38+
value = serialized_dict[and_filter.field_name]
3739
expr = self.get_orm_for_field(column, and_filter.operator, value)
3840
and_expr.append(expr)
3941
or_expr.append(and_(*and_expr))
@@ -83,7 +85,7 @@ class model.__name__(BaseModel):
8385

8486
return serializers
8587

86-
def get_relations_classes(self) -> list:
88+
def get_relations_classes(self) -> list[type[DeclarativeBase]]:
8789
return [relation[1].mapper.class_ for relation in self._relationships]
8890

8991
def get_orm_for_field(
@@ -97,7 +99,7 @@ def get_orm_for_field(
9799
return getattr(column, ops[operator].value)(value)
98100

99101
def serialize_expression_value(
100-
self, table: str, column: InstrumentedAttribute, operator: str, value: str
102+
self, table: str, field_name: str, operator: str, value: str
101103
) -> dict[str, Any]:
102104
"""
103105
Serialize expression value from string to python type value,
@@ -112,14 +114,14 @@ def serialize_expression_value(
112114
model_serializer = self._model_serializers[table]["optional_model"]
113115
else:
114116
model_serializer = self._model_serializers[table]["optional_list_model"]
115-
return model_serializer(**{column.name: value}).model_dump(exclude_none=True)
117+
return model_serializer(**{field_name: value}).model_dump(exclude_none=True)
116118
except pydantic.ValidationError as e:
117119
raise SAFilterOrmException(json.loads(e.json()))
118120
except ValueError:
119121
raise SAFilterOrmException(f"Incorrect filter value '{value}'")
120122

121123
@staticmethod
122-
def get_optional_pydantic_model(model, pydantic_serializer, is_list: bool = False):
124+
def get_optional_pydantic_model(model, pydantic_serializer, is_list: bool = False) -> BaseModel:
123125
fields = {}
124126
for k, v in pydantic_serializer.model_fields.items():
125127
origin_annotation = getattr(v, 'annotation')
@@ -129,3 +131,37 @@ def get_optional_pydantic_model(model, pydantic_serializer, is_list: bool = Fals
129131
fields[k] = (origin_annotation, None)
130132
pydantic_model = create_model(model.__name__, **fields)
131133
return pydantic_model
134+
135+
136+
class SAOrderByExpressionBuilder:
137+
138+
def __init__(self, model: type[DeclarativeBase]) -> None:
139+
self._model = model
140+
141+
def get_order_by_query(self, order_by_query: list[str]) -> list[UnaryExpression]:
142+
order_by_fields = self._validate_order_by_fields(order_by_query)
143+
order_by_sql_query = []
144+
for field in order_by_fields:
145+
if '-' in field:
146+
column = getattr(self._model, field.strip('-'))
147+
order_by_sql_query.append(getattr(column, OrderSequence.desc)())
148+
else:
149+
column = getattr(self._model, field.strip('+'))
150+
order_by_sql_query.append(getattr(column, OrderSequence.asc)())
151+
return order_by_sql_query
152+
153+
def _validate_order_by_fields(self, order_by_fields: list[str]) -> list[str]:
154+
"""
155+
:return:
156+
[
157+
+field_name,
158+
-field_name
159+
]
160+
"""
161+
model_fields = self._model.__table__.columns.keys()
162+
for field in order_by_fields:
163+
field = field.strip('+').strip('-')
164+
if field in model_fields:
165+
continue
166+
raise SAFilterOrmException(f"Incorrect order_by field name {field} for model {self._model.__name__}")
167+
return order_by_fields

0 commit comments

Comments
 (0)