Skip to content

Commit b0ff350

Browse files
refactor filter parser
1 parent 0d9e35a commit b0ff350

File tree

10 files changed

+234
-166
lines changed

10 files changed

+234
-166
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ v 0.1.5
3131
```shell
3232
from fastapi import FastAPI
3333
from fastapi.params import Query
34-
from fastapi_sa_orm_filter.main import FilterCore
35-
from fastapi_sa_orm_filter.operators import Operators as ops
34+
from fastapi_sa_orm_filter import FilterCore, ops
3635

3736
from db.base import get_session
3837
from db.models import MyModel

fastapi_sa_orm_filter/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
"""FastAPI-SQLAlchemy filter, transform request query string to SQLAlchemy orm query"""
2+
from fastapi_sa_orm_filter.main import FilterCore # noqa
3+
from fastapi_sa_orm_filter.operators import Operators as ops # noqa
24

35
__version__ = "0.2.2"

fastapi_sa_orm_filter/dto.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from dataclasses import dataclass
2+
3+
4+
@dataclass
5+
class ParsedFilter:
6+
field_name: str
7+
operator: str
8+
value: str
9+
relation: str | None
10+
11+
@property
12+
def has_relation(self) -> bool:
13+
return bool(self.relation)

fastapi_sa_orm_filter/main.py

Lines changed: 10 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,17 @@
1-
import json
2-
31
from typing import Any, Type
42

5-
import pydantic
63
from fastapi import HTTPException
7-
from pydantic import create_model
8-
from pydantic._internal._model_construction import ModelMetaclass
9-
from sqlalchemy_to_pydantic import sqlalchemy_to_pydantic
10-
from sqlalchemy import select, inspect
11-
from sqlalchemy.orm import InstrumentedAttribute, DeclarativeBase
4+
from sqlalchemy import select
5+
from sqlalchemy.orm import DeclarativeBase
126
from sqlalchemy.sql.elements import BinaryExpression, UnaryExpression
13-
from sqlalchemy.sql.expression import and_, or_
7+
from sqlalchemy.sql.expression import or_
148
from starlette import status
159
from sqlalchemy.sql import Select
1610

1711
from fastapi_sa_orm_filter.exceptions import SAFilterOrmException
1812
from fastapi_sa_orm_filter.operators import Operators as ops
19-
from fastapi_sa_orm_filter.parsers import _FilterQueryParser, _OrderByQueryParser
13+
from fastapi_sa_orm_filter.parsers import FilterQueryParser, OrderByQueryParser
14+
from fastapi_sa_orm_filter.sa_expression_builder import SAFilterExpressionBuilder
2015

2116

2217
class FilterCore:
@@ -43,9 +38,7 @@ def __init__(
4338
}
4439
"""
4540
self.model = model
46-
self.relationships = inspect(self.model).relationships.items()
4741
self._allowed_filters = allowed_filters
48-
self._model_serializers = self._create_pydantic_serializers()
4942
self.select_query_part = select_query_part
5043

5144
def get_query(self, custom_filter: str) -> Select[Any]:
@@ -104,108 +97,22 @@ def get_group_by_query_part(self) -> list:
10497
return []
10598

10699
def get_order_by_query_part(self, order_by_query_str: str) -> list[UnaryExpression]:
107-
order_by_parser = _OrderByQueryParser(self.model)
100+
order_by_parser = OrderByQueryParser(self.model)
108101
return order_by_parser.get_order_by_query(order_by_query_str)
109102

110103
def _get_filter_query(self, custom_filter: str) -> list[BinaryExpression]:
111104
filter_conditions = []
112105
if custom_filter == "":
113106
return filter_conditions
114-
query_parser = _FilterQueryParser(custom_filter, self.model, self._allowed_filters)
115-
116-
for and_expressions in query_parser.get_parsed_query():
117-
and_condition = []
118-
for expression in and_expressions:
119-
table, column, operator, value = expression
120-
serialized_dict = self._format_expression(table, column, operator, value)
121-
value = serialized_dict[column.name]
122-
param = self._get_orm_for_field(column, operator, value)
123-
and_condition.append(param)
124-
filter_conditions.append(and_(*and_condition))
125-
return filter_conditions
126-
127-
def _create_pydantic_serializers(self) -> dict[str, dict[str, ModelMetaclass]]:
128-
"""
129-
Create two pydantic models (optional and list field types)
130-
for value: str serialization
131-
132-
:return: {
133-
'optional_model':
134-
class model.__name__(BaseModel):
135-
field: Optional[type]
136-
'list_model':
137-
class model.__name__(BaseModel):
138-
field: Optional[List[type]]
139-
}
140-
"""
141-
142-
models = [self.model]
143-
models.extend(self._get_relations())
144-
145-
serializers = {}
146-
147-
for model in models:
148-
pydantic_serializer = sqlalchemy_to_pydantic(model)
149-
optional_model = self._get_optional_pydantic_model(model, pydantic_serializer)
150-
optional_list_model = self._get_optional_pydantic_model(model, pydantic_serializer, is_list=True)
151-
152-
serializers[model.__tablename__] = {
153-
"optional_model": optional_model, "optional_list_model": optional_list_model
154-
}
155-
156-
return serializers
157107

158-
def _get_relations(self) -> list:
159-
return [relation[1].mapper.class_ for relation in self.relationships]
160-
161-
def _get_orm_for_field(
162-
self, column: InstrumentedAttribute, operator: str, value: Any
163-
) -> BinaryExpression:
164-
"""
165-
Create SQLAlchemy orm expression for the field
166-
"""
167-
if operator in [ops.between]:
168-
param = getattr(column, ops[operator].value)(*value)
169-
else:
170-
param = getattr(column, ops[operator].value)(value)
171-
return param
172-
173-
def _format_expression(
174-
self, table: str, column: InstrumentedAttribute, operator: str, value: str
175-
) -> dict[str, Any]:
176-
"""
177-
Serialize expression value from string to python type value,
178-
according to db model types
179-
180-
:return: {'field_name': [value, value]}
181-
"""
182-
value = value.split(",")
183-
try:
184-
if operator not in [ops.between, ops.in_]:
185-
value = value[0]
186-
model_serializer = self._model_serializers[table]["optional_model"]
187-
else:
188-
model_serializer = self._model_serializers[table]["optional_list_model"]
189-
return model_serializer(**{column.name: value}).model_dump(exclude_none=True)
190-
except pydantic.ValidationError as e:
191-
raise SAFilterOrmException(json.loads(e.json()))
192-
except ValueError:
193-
raise SAFilterOrmException(f"Incorrect filter value '{value}'")
108+
parser = FilterQueryParser(custom_filter, self._allowed_filters)
109+
parsed_filters = parser.get_parsed_query()
110+
sa_builder = SAFilterExpressionBuilder(self.model)
111+
return sa_builder.get_expressions(parsed_filters)
194112

195113
@staticmethod
196114
def _split_by_order_by(query) -> list:
197115
split_query = [query_part.strip("&") for query_part in query.split("order_by=")]
198116
if len(split_query) > 2:
199117
raise SAFilterOrmException("Use only one order_by directive")
200118
return split_query
201-
202-
def _get_optional_pydantic_model(self, model, pydantic_serializer, is_list: bool = False):
203-
fields = {}
204-
for k, v in pydantic_serializer.model_fields.items():
205-
origin_annotation = getattr(v, 'annotation')
206-
if is_list:
207-
fields[k] = (list[origin_annotation], None)
208-
else:
209-
fields[k] = (origin_annotation, None)
210-
pydantic_model = create_model(model.__name__, **fields)
211-
return pydantic_model

fastapi_sa_orm_filter/parsers.py

Lines changed: 22 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,20 @@
1-
from typing import Optional, Tuple, Union, List, Any, Type
2-
3-
from fastapi import HTTPException
4-
from sqlalchemy import inspect
5-
from sqlalchemy.orm import InstrumentedAttribute, DeclarativeBase
1+
from sqlalchemy.orm import DeclarativeBase
62
from sqlalchemy.sql.elements import UnaryExpression
73

4+
from fastapi_sa_orm_filter.dto import ParsedFilter
85
from fastapi_sa_orm_filter.exceptions import SAFilterOrmException
96
from fastapi_sa_orm_filter.operators import Operators as ops
107
from fastapi_sa_orm_filter.operators import Sequence
118

129

13-
class _OrderByQueryParser:
10+
class OrderByQueryParser:
1411
"""
1512
Class parse order by part of request query string.
1613
"""
17-
def __init__(self, model: Type[DeclarativeBase]) -> None:
14+
def __init__(self, model: type[DeclarativeBase]) -> None:
1815
self._model = model
1916

20-
def get_order_by_query(self, order_by_query_str: str) -> List[UnaryExpression]:
17+
def get_order_by_query(self, order_by_query_str: str) -> list[UnaryExpression]:
2118
order_by_fields = self._validate_order_by_fields(order_by_query_str)
2219
order_by_query = []
2320
for field in order_by_fields:
@@ -47,33 +44,33 @@ def _validate_order_by_fields(self, order_by_query_str: str) -> list[str]:
4744
return order_by_fields
4845

4946

50-
class _FilterQueryParser:
47+
class FilterQueryParser:
5148
"""
5249
Class parse filter part of request query string.
5350
"""
5451

55-
def __init__(self, query: str, model: Type[DeclarativeBase], allowed_filters: dict[str, list[ops]]) -> None:
52+
def __init__(
53+
self, query: str,
54+
allowed_filters: dict[str, list[ops]]
55+
) -> None:
5656
self._query = query
57-
self._model = model
58-
self._relationships = inspect(model).relationships.items()
5957
self._allowed_filters = allowed_filters
6058

61-
def get_parsed_query(self) -> list[list[Any]]:
59+
def get_parsed_query(self) -> list[list[ParsedFilter]]:
6260
"""
6361
:return:
6462
[
65-
[[column, operator, value], [column, operator, value]],
66-
[[column, operator, value]]
63+
[ParsedFilter, ParsedFilter, ParsedFilter]
6764
]
6865
"""
6966
and_blocks = self._parse_by_conjunctions()
7067
parsed_query = []
7168
for and_block in and_blocks:
7269
parsed_and_blocks = []
7370
for expression in and_block:
74-
table, column, operator, value = self._parse_expression(expression)
75-
self._validate_query_params(column.name, operator)
76-
parsed_and_blocks.append([table, column, operator, value])
71+
parsed_filter = self._parse_expression(expression)
72+
self._validate_query_params(parsed_filter.field_name, parsed_filter.operator)
73+
parsed_and_blocks.append(parsed_filter)
7774
parsed_query.append(parsed_and_blocks)
7875
return parsed_query
7976

@@ -92,13 +89,12 @@ def _parse_by_conjunctions(self) -> list[list[str]]:
9289

9390
def _parse_expression(
9491
self, expression: str
95-
) -> Union[Tuple[str, InstrumentedAttribute, str, str], HTTPException]:
96-
model = self._model
97-
table = self._model.__tablename__
92+
) -> ParsedFilter:
93+
relation = None
9894
try:
9995
field_name, condition = expression.split("__")
10096
if "." in field_name:
101-
model, table, field_name = self._get_relation_model(field_name)
97+
relation, field_name = self._get_relation_model(field_name)
10298
operator, value = condition.split("=")
10399
except ValueError:
104100
raise SAFilterOrmException(
@@ -108,23 +104,14 @@ def _parse_expression(
108104
"or '{relation}.{field_name}__{condition}={value}{conjunction}'",
109105
)
110106

111-
column = getattr(model, field_name, None)
112-
113-
if not column:
114-
raise SAFilterOrmException(f"DB model {model.__name__} doesn't have field '{field_name}'")
115-
return table, column, operator, value
107+
return ParsedFilter(field_name=field_name, operator=operator, value=value, relation=relation)
116108

117-
def _get_relation_model(self, field_name: str) -> tuple[DeclarativeBase, str, str]:
118-
relation, field_name = field_name.split(".")
119-
for relationship in self._relationships:
120-
if relationship[0] == relation:
121-
model = relationship[1].mapper.class_
122-
return model, model.__tablename__, field_name
123-
raise SAFilterOrmException(f"Can not find relation {relation} in {self._model.__name__} model")
109+
def _get_relation_model(self, field_name: str) -> list[str]:
110+
return field_name.split(".")
124111

125112
def _validate_query_params(
126113
self, field_name: str, operator: str
127-
) -> Optional[HTTPException]:
114+
) -> None:
128115
"""
129116
Check expression on valid and allowed field_name and operator
130117
"""

0 commit comments

Comments
 (0)