Skip to content

Commit db10250

Browse files
authored
Merge pull request #27 Specific integer types and limit-offset support from LuckySting/limit-clause-implementation
2 parents 755e968 + 4064f82 commit db10250

File tree

5 files changed

+126
-74
lines changed

5 files changed

+126
-74
lines changed

test/test_core.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,21 @@ def test_select_types(self, connection):
216216
row = connection.execute(sa.select(tb)).fetchone()
217217
assert row == (1, "Hello World!", 3.5, True, now, today)
218218

219+
def test_integer_types(self, connection):
220+
stmt = sa.Select(
221+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint8", 8, types.UInt8))),
222+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint16", 16, types.UInt16))),
223+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint32", 32, types.UInt32))),
224+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint64", 64, types.UInt64))),
225+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int8", -8, types.Int8))),
226+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int16", -16, types.Int16))),
227+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int32", -32, types.Int32))),
228+
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int64", -64, types.Int64))),
229+
)
230+
231+
result = connection.execute(stmt).fetchone()
232+
assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64")
233+
219234

220235
class TestWithClause(TablesTest):
221236
__backend__ = True

test/test_suite.py

Lines changed: 15 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,11 @@
2121
requirements,
2222
select,
2323
testing,
24+
union,
2425
)
2526
from sqlalchemy.testing.suite.test_ddl import (
2627
LongNameBlowoutTest as _LongNameBlowoutTest,
2728
)
28-
from sqlalchemy.testing.suite.test_deprecations import (
29-
DeprecatedCompoundSelectTest as _DeprecatedCompoundSelectTest,
30-
)
3129
from sqlalchemy.testing.suite.test_dialect import (
3230
DifficultParametersTest as _DifficultParametersTest,
3331
)
@@ -50,9 +48,6 @@
5048
QuotedNameArgumentTest as _QuotedNameArgumentTest,
5149
)
5250
from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest
53-
from sqlalchemy.testing.suite.test_select import (
54-
CompoundSelectTest as _CompoundSelectTest,
55-
)
5651
from sqlalchemy.testing.suite.test_select import ExistsTest as _ExistsTest
5752
from sqlalchemy.testing.suite.test_select import (
5853
FetchLimitOffsetTest as _FetchLimitOffsetTest,
@@ -325,20 +320,6 @@ def test_not_regexp_match(self):
325320
self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10, 11})
326321

327322

328-
class CompoundSelectTest(_CompoundSelectTest):
329-
@pytest.mark.skip("limit don't work")
330-
def test_distinct_selectable_in_unions(self):
331-
pass
332-
333-
@pytest.mark.skip("limit don't work")
334-
def test_limit_offset_in_unions_from_alias(self):
335-
pass
336-
337-
@pytest.mark.skip("limit don't work")
338-
def test_limit_offset_aliased_selectable_in_unions(self):
339-
pass
340-
341-
342323
class EscapingTest(_EscapingTest):
343324
@provide_metadata
344325
def test_percent_sign_round_trip(self):
@@ -395,45 +376,23 @@ def test_group_by_composed(self):
395376

396377

397378
class FetchLimitOffsetTest(_FetchLimitOffsetTest):
398-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
399-
def test_bound_limit(self, connection):
400-
pass
401-
402-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
403-
def test_bound_limit_offset(self, connection):
404-
pass
405-
406-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
407-
def test_bound_offset(self, connection):
408-
pass
409-
410-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
411-
def test_expr_limit_simple_offset(self, connection):
412-
pass
413-
414-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
415379
def test_limit_render_multiple_times(self, connection):
416-
pass
417-
418-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
419-
def test_simple_limit(self, connection):
420-
pass
421-
422-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
423-
def test_simple_limit_offset(self, connection):
424-
pass
425-
426-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
427-
def test_simple_offset(self, connection):
428-
pass
380+
"""
381+
YQL does not support scalar subquery, so test was refiled with simple subquery
382+
"""
383+
table = self.tables.some_table
384+
stmt = select(table.c.id).limit(1).subquery()
429385

430-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
431-
def test_simple_offset_zero(self, connection):
432-
pass
386+
u = union(select(stmt), select(stmt)).subquery().select()
433387

434-
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
435-
def test_simple_limit_expr_offset(self, connection):
436-
pass
388+
self._assert_result(
389+
connection,
390+
u,
391+
[
392+
(1,),
393+
(1,),
394+
],
395+
)
437396

438397

439398
class InsertBehaviorTest(_InsertBehaviorTest):
@@ -570,8 +529,3 @@ class RowFetchTest(_RowFetchTest):
570529
@pytest.mark.skip("scalar subquery unsupported")
571530
def test_row_w_scalar_select(self, connection):
572531
pass
573-
574-
575-
@pytest.mark.skip("TODO: try it after limit/offset tests would fixed")
576-
class DeprecatedCompoundSelectTest(_DeprecatedCompoundSelectTest):
577-
pass

ydb_sqlalchemy/dbapi/errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, List
1+
from typing import List, Optional
22

33
import ydb
44
from google.protobuf.message import Message

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55
import collections
66
import collections.abc
7-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
7+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
88

99
import sqlalchemy as sa
1010
import ydb
@@ -87,15 +87,30 @@ def visit_FLOAT(self, type_: sa.FLOAT, **kw):
8787
def visit_BOOLEAN(self, type_: sa.BOOLEAN, **kw):
8888
return "BOOL"
8989

90+
def visit_uint64(self, type_: types.UInt64, **kw):
91+
return "UInt64"
92+
9093
def visit_uint32(self, type_: types.UInt32, **kw):
9194
return "UInt32"
9295

93-
def visit_uint64(self, type_: types.UInt64, **kw):
94-
return "UInt64"
96+
def visit_uint16(self, type_: types.UInt16, **kw):
97+
return "UInt16"
9598

9699
def visit_uint8(self, type_: types.UInt8, **kw):
97100
return "UInt8"
98101

102+
def visit_int64(self, type_: types.Int64, **kw):
103+
return "Int64"
104+
105+
def visit_int32(self, type_: types.Int32, **kw):
106+
return "Int32"
107+
108+
def visit_int16(self, type_: types.Int16, **kw):
109+
return "Int16"
110+
111+
def visit_int8(self, type_: types.Int8, **kw):
112+
return "Int8"
113+
99114
def visit_INTEGER(self, type_: sa.INTEGER, **kw):
100115
return "Int64"
101116

@@ -134,8 +149,28 @@ def get_ydb_type(
134149

135150
if isinstance(type_, (sa.Text, sa.String, sa.Uuid)):
136151
ydb_type = ydb.PrimitiveType.Utf8
152+
153+
# Integers
154+
elif isinstance(type_, types.UInt64):
155+
ydb_type = ydb.PrimitiveType.Uint64
156+
elif isinstance(type_, types.UInt32):
157+
ydb_type = ydb.PrimitiveType.Uint32
158+
elif isinstance(type_, types.UInt16):
159+
ydb_type = ydb.PrimitiveType.Uint16
160+
elif isinstance(type_, types.UInt8):
161+
ydb_type = ydb.PrimitiveType.Uint8
162+
elif isinstance(type_, types.Int64):
163+
ydb_type = ydb.PrimitiveType.Int64
164+
elif isinstance(type_, types.Int32):
165+
ydb_type = ydb.PrimitiveType.Int32
166+
elif isinstance(type_, types.Int16):
167+
ydb_type = ydb.PrimitiveType.Int16
168+
elif isinstance(type_, types.Int8):
169+
ydb_type = ydb.PrimitiveType.Int8
137170
elif isinstance(type_, sa.Integer):
138171
ydb_type = ydb.PrimitiveType.Int64
172+
# Integers
173+
139174
elif isinstance(type_, sa.JSON):
140175
ydb_type = ydb.PrimitiveType.Json
141176
elif isinstance(type_, sa.DateTime):
@@ -188,6 +223,36 @@ def group_by_clause(self, select, **kw):
188223
kw.update(within_columns_clause=True)
189224
return super(YqlCompiler, self).group_by_clause(select, **kw)
190225

226+
def limit_clause(self, select, **kw):
227+
text = ""
228+
if select._limit_clause is not None:
229+
limit_clause = self._maybe_cast(
230+
select._limit_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8)
231+
)
232+
text += "\n LIMIT " + self.process(limit_clause, **kw)
233+
if select._offset_clause is not None:
234+
offset_clause = self._maybe_cast(
235+
select._offset_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8)
236+
)
237+
if select._limit_clause is None:
238+
text += "\n LIMIT NULL"
239+
text += " OFFSET " + self.process(offset_clause, **kw)
240+
return text
241+
242+
def _maybe_cast(
243+
self,
244+
element: Any,
245+
cast_to: Type[sa.types.TypeEngine],
246+
skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None,
247+
) -> Any:
248+
if not skip_types:
249+
skip_types = (cast_to,)
250+
if cast_to not in skip_types:
251+
skip_types = (*skip_types, cast_to)
252+
if not hasattr(element, "type") or not isinstance(element.type, skip_types):
253+
return sa.Cast(element, cast_to)
254+
return element
255+
191256
def render_literal_value(self, value, type_):
192257
if isinstance(value, str):
193258
value = "".join(STR_QUOTE_MAP.get(x, x) for x in value)
@@ -277,16 +342,14 @@ def _is_bound_to_nullable_column(self, bind_name: str) -> bool:
277342
def _guess_bound_variable_type_by_parameters(
278343
self, bind: sa.BindParameter, post_compile_bind_values: list
279344
) -> Optional[sa.types.TypeEngine]:
280-
if not bind.expanding:
281-
if isinstance(bind.type, sa.types.NullType):
282-
return None
283-
bind_type = bind.type
284-
else:
345+
bind_type = bind.type
346+
if bind.expanding or (isinstance(bind.type, sa.types.NullType) and post_compile_bind_values):
285347
not_null_values = [v for v in post_compile_bind_values if v is not None]
286348
if not_null_values:
287349
bind_type = sa.BindParameter("", not_null_values[0]).type
288-
else:
289-
return None
350+
351+
if isinstance(bind_type, sa.types.NullType):
352+
return None
290353

291354
return bind_type
292355

ydb_sqlalchemy/sqlalchemy/types.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,38 @@
44
from sqlalchemy.sql import type_api
55

66

7+
class UInt64(types.Integer):
8+
__visit_name__ = "uint64"
9+
10+
711
class UInt32(types.Integer):
812
__visit_name__ = "uint32"
913

1014

11-
class UInt64(types.Integer):
12-
__visit_name__ = "uint64"
15+
class UInt16(types.Integer):
16+
__visit_name__ = "uint16"
1317

1418

1519
class UInt8(types.Integer):
1620
__visit_name__ = "uint8"
1721

1822

23+
class Int64(types.Integer):
24+
__visit_name__ = "int64"
25+
26+
27+
class Int32(types.Integer):
28+
__visit_name__ = "int32"
29+
30+
31+
class Int16(types.Integer):
32+
__visit_name__ = "int32"
33+
34+
35+
class Int8(types.Integer):
36+
__visit_name__ = "int8"
37+
38+
1939
class ListType(ARRAY):
2040
__visit_name__ = "list_type"
2141

0 commit comments

Comments
 (0)