Skip to content

Commit bea552b

Browse files
committed
Add optional type casting to limit-offset clause
1 parent e2d74aa commit bea552b

File tree

3 files changed

+106
-70
lines changed

3 files changed

+106
-70
lines changed

test/test_suite.py

Lines changed: 15 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66

77
from sqlalchemy.testing import is_true, is_false
88
from sqlalchemy.testing.suite import eq_, testing, inspect, provide_metadata, config, requirements, fixtures
9-
from sqlalchemy.testing.suite import func, column, literal_column, select, exists
9+
from sqlalchemy.testing.suite import func, column, literal_column, select, exists, union
1010
from sqlalchemy.testing.suite import MetaData, Column, Table, Integer, String
1111

1212
from sqlalchemy.testing.suite.test_select import (
1313
ExistsTest as _ExistsTest,
1414
LikeFunctionsTest as _LikeFunctionsTest,
15-
CompoundSelectTest as _CompoundSelectTest,
1615
)
1716
from sqlalchemy.testing.suite.test_reflection import (
1817
HasTableTest as _HasTableTest,
@@ -49,7 +48,6 @@
4948
from sqlalchemy.testing.suite.test_insert import InsertBehaviorTest as _InsertBehaviorTest
5049
from sqlalchemy.testing.suite.test_ddl import LongNameBlowoutTest as _LongNameBlowoutTest
5150
from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest
52-
from sqlalchemy.testing.suite.test_deprecations import DeprecatedCompoundSelectTest as _DeprecatedCompoundSelectTest
5351

5452
from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types
5553

@@ -294,20 +292,6 @@ def test_not_regexp_match(self):
294292
self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10, 11})
295293

296294

297-
class CompoundSelectTest(_CompoundSelectTest):
298-
@pytest.mark.skip("limit don't work")
299-
def test_distinct_selectable_in_unions(self):
300-
pass
301-
302-
@pytest.mark.skip("limit don't work")
303-
def test_limit_offset_in_unions_from_alias(self):
304-
pass
305-
306-
@pytest.mark.skip("limit don't work")
307-
def test_limit_offset_aliased_selectable_in_unions(self):
308-
pass
309-
310-
311295
class EscapingTest(_EscapingTest):
312296
@provide_metadata
313297
def test_percent_sign_round_trip(self):
@@ -364,45 +348,23 @@ def test_group_by_composed(self):
364348

365349

366350
class FetchLimitOffsetTest(_FetchLimitOffsetTest):
367-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
368-
def test_bound_limit(self, connection):
369-
pass
370-
371-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
372-
def test_bound_limit_offset(self, connection):
373-
pass
374-
375-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
376-
def test_bound_offset(self, connection):
377-
pass
378-
379-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
380-
def test_expr_limit_simple_offset(self, connection):
381-
pass
382-
383-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
384351
def test_limit_render_multiple_times(self, connection):
385-
pass
386-
387-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
388-
def test_simple_limit(self, connection):
389-
pass
390-
391-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
392-
def test_simple_limit_offset(self, connection):
393-
pass
394-
395-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
396-
def test_simple_offset(self, connection):
397-
pass
352+
"""
353+
YQL does not support scalar subquery, so test was refiled with simple subquery
354+
"""
355+
table = self.tables.some_table
356+
stmt = select(table.c.id).limit(1).subquery()
398357

399-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
400-
def test_simple_offset_zero(self, connection):
401-
pass
358+
u = union(select(stmt), select(stmt)).subquery().select()
402359

403-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
404-
def test_simple_limit_expr_offset(self, connection):
405-
pass
360+
self._assert_result(
361+
connection,
362+
u,
363+
[
364+
(1,),
365+
(1,),
366+
],
367+
)
406368

407369

408370
class InsertBehaviorTest(_InsertBehaviorTest):
@@ -539,8 +501,3 @@ class RowFetchTest(_RowFetchTest):
539501
@pytest.mark.skip("scalar subquery unsupported")
540502
def test_row_w_scalar_select(self, connection):
541503
pass
542-
543-
544-
@pytest.mark.skip("TODO: try it after limit/offset tests would fixed")
545-
class DeprecatedCompoundSelectTest(_DeprecatedCompoundSelectTest):
546-
pass

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sqlalchemy.engine.default import StrCompileDialect, DefaultExecutionContext
2525
from sqlalchemy.util.compat import inspect_getfullargspec
2626

27-
from typing import Any, Union, Mapping, Sequence, Optional, Tuple, List, Dict
27+
from typing import Any, Union, Mapping, Sequence, Optional, Tuple, List, Dict, Type
2828

2929
from . import types
3030

@@ -87,15 +87,30 @@ def visit_FLOAT(self, type_: sa.FLOAT, **kw):
8787
def visit_BOOLEAN(self, type_: sa.BOOLEAN, **kw):
8888
return "BOOL"
8989

90+
def visit_uint64(self, type_: types.UInt64, **kw):
91+
return "UInt64"
92+
9093
def visit_uint32(self, type_: types.UInt32, **kw):
9194
return "UInt32"
9295

93-
def visit_uint64(self, type_: types.UInt64, **kw):
94-
return "UInt64"
96+
def visit_uint16(self, type_: types.UInt16, **kw):
97+
return "UInt16"
9598

9699
def visit_uint8(self, type_: types.UInt8, **kw):
97100
return "UInt8"
98101

102+
def visit_int64(self, type_: types.UInt64, **kw):
103+
return "Int64"
104+
105+
def visit_int32(self, type_: types.UInt32, **kw):
106+
return "Int32"
107+
108+
def visit_int16(self, type_: types.UInt16, **kw):
109+
return "Int16"
110+
111+
def visit_int8(self, type_: types.UInt8, **kw):
112+
return "Int8"
113+
99114
def visit_INTEGER(self, type_: sa.INTEGER, **kw):
100115
return "Int64"
101116

@@ -134,8 +149,28 @@ def get_ydb_type(
134149

135150
if isinstance(type_, (sa.Text, sa.String, sa.Uuid)):
136151
ydb_type = ydb.PrimitiveType.Utf8
152+
153+
# Integers
154+
elif isinstance(type_, types.UInt64):
155+
ydb_type = ydb.PrimitiveType.Uint64
156+
elif isinstance(type_, types.UInt32):
157+
ydb_type = ydb.PrimitiveType.Uint32
158+
elif isinstance(type_, types.UInt16):
159+
ydb_type = ydb.PrimitiveType.Uint16
160+
elif isinstance(type_, types.UInt8):
161+
ydb_type = ydb.PrimitiveType.Uint8
162+
elif isinstance(type_, types.Int64):
163+
ydb_type = ydb.PrimitiveType.Int64
164+
elif isinstance(type_, types.Int32):
165+
ydb_type = ydb.PrimitiveType.Int32
166+
elif isinstance(type_, types.Int16):
167+
ydb_type = ydb.PrimitiveType.Int16
168+
elif isinstance(type_, types.Int8):
169+
ydb_type = ydb.PrimitiveType.Int8
137170
elif isinstance(type_, sa.Integer):
138171
ydb_type = ydb.PrimitiveType.Int64
172+
# Integers
173+
139174
elif isinstance(type_, sa.JSON):
140175
ydb_type = ydb.PrimitiveType.Json
141176
elif isinstance(type_, sa.DateTime):
@@ -188,6 +223,32 @@ def group_by_clause(self, select, **kw):
188223
kw.update(within_columns_clause=True)
189224
return super(YqlCompiler, self).group_by_clause(select, **kw)
190225

226+
def limit_clause(self, select, **kw):
227+
text = ""
228+
if select._limit_clause is not None:
229+
limit_clause = self._maybe_cast(
230+
select._limit_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8)
231+
)
232+
text += "\n LIMIT " + self.process(limit_clause, **kw)
233+
if select._offset_clause is not None:
234+
offset_clause = self._maybe_cast(
235+
select._offset_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8)
236+
)
237+
if select._limit_clause is None:
238+
text += "\n LIMIT NULL"
239+
text += " OFFSET " + self.process(offset_clause, **kw)
240+
return text
241+
242+
def _maybe_cast(
243+
self,
244+
element: Any,
245+
cast_to: Type[sa.types.TypeEngine],
246+
skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None,
247+
) -> Any:
248+
if not hasattr(element, "type") or not isinstance(element.type, skip_types):
249+
return sa.Cast(element, cast_to)
250+
return element
251+
191252
def render_literal_value(self, value, type_):
192253
if isinstance(value, str):
193254
value = "".join(STR_QUOTE_MAP.get(x, x) for x in value)
@@ -277,16 +338,14 @@ def _is_bound_to_nullable_column(self, bind_name: str) -> bool:
277338
def _guess_bound_variable_type_by_parameters(
278339
self, bind: sa.BindParameter, post_compile_bind_values: list
279340
) -> Optional[sa.types.TypeEngine]:
280-
if not bind.expanding:
281-
if isinstance(bind.type, sa.types.NullType):
282-
return None
283-
bind_type = bind.type
284-
else:
341+
bind_type = bind.type
342+
if bind.expanding or (isinstance(bind.type, sa.types.NullType) and post_compile_bind_values):
285343
not_null_values = [v for v in post_compile_bind_values if v is not None]
286344
if not_null_values:
287345
bind_type = sa.BindParameter("", not_null_values[0]).type
288-
else:
289-
return None
346+
347+
if isinstance(bind_type, sa.types.NullType):
348+
return None
290349

291350
return bind_type
292351

ydb_sqlalchemy/sqlalchemy/types.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,38 @@
33
from typing import Mapping, Any, Union, Type
44

55

6+
class UInt64(types.Integer):
7+
__visit_name__ = "uint64"
8+
9+
610
class UInt32(types.Integer):
711
__visit_name__ = "uint32"
812

913

10-
class UInt64(types.Integer):
11-
__visit_name__ = "uint64"
14+
class UInt16(types.Integer):
15+
__visit_name__ = "uint32"
1216

1317

1418
class UInt8(types.Integer):
1519
__visit_name__ = "uint8"
1620

1721

22+
class Int64(types.Integer):
23+
__visit_name__ = "int64"
24+
25+
26+
class Int32(types.Integer):
27+
__visit_name__ = "int32"
28+
29+
30+
class Int16(types.Integer):
31+
__visit_name__ = "int32"
32+
33+
34+
class Int8(types.Integer):
35+
__visit_name__ = "int8"
36+
37+
1838
class ListType(ARRAY):
1939
__visit_name__ = "list_type"
2040

0 commit comments

Comments
 (0)