|
24 | 24 | from sqlalchemy.engine.default import StrCompileDialect, DefaultExecutionContext |
25 | 25 | from sqlalchemy.util.compat import inspect_getfullargspec |
26 | 26 |
|
27 | | -from typing import Any, Union, Mapping, Sequence, Optional, Tuple, List, Dict |
| 27 | +from typing import Any, Union, Mapping, Sequence, Optional, Tuple, List, Dict, Type |
28 | 28 |
|
29 | 29 | from . import types |
30 | 30 |
|
@@ -87,15 +87,30 @@ def visit_FLOAT(self, type_: sa.FLOAT, **kw): |
87 | 87 | def visit_BOOLEAN(self, type_: sa.BOOLEAN, **kw): |
88 | 88 | return "BOOL" |
89 | 89 |
|
| 90 | + def visit_uint64(self, type_: types.UInt64, **kw): |
| 91 | + return "UInt64" |
| 92 | + |
90 | 93 | def visit_uint32(self, type_: types.UInt32, **kw): |
91 | 94 | return "UInt32" |
92 | 95 |
|
93 | | - def visit_uint64(self, type_: types.UInt64, **kw): |
94 | | - return "UInt64" |
| 96 | + def visit_uint16(self, type_: types.UInt16, **kw): |
| 97 | + return "UInt16" |
95 | 98 |
|
96 | 99 | def visit_uint8(self, type_: types.UInt8, **kw): |
97 | 100 | return "UInt8" |
98 | 101 |
|
| 102 | + def visit_int64(self, type_: types.UInt64, **kw): |
| 103 | + return "Int64" |
| 104 | + |
| 105 | + def visit_int32(self, type_: types.UInt32, **kw): |
| 106 | + return "Int32" |
| 107 | + |
| 108 | + def visit_int16(self, type_: types.UInt16, **kw): |
| 109 | + return "Int16" |
| 110 | + |
| 111 | + def visit_int8(self, type_: types.UInt8, **kw): |
| 112 | + return "Int8" |
| 113 | + |
99 | 114 | def visit_INTEGER(self, type_: sa.INTEGER, **kw): |
100 | 115 | return "Int64" |
101 | 116 |
|
@@ -134,8 +149,28 @@ def get_ydb_type( |
134 | 149 |
|
135 | 150 | if isinstance(type_, (sa.Text, sa.String, sa.Uuid)): |
136 | 151 | ydb_type = ydb.PrimitiveType.Utf8 |
| 152 | + |
| 153 | + # Integers |
| 154 | + elif isinstance(type_, types.UInt64): |
| 155 | + ydb_type = ydb.PrimitiveType.Uint64 |
| 156 | + elif isinstance(type_, types.UInt32): |
| 157 | + ydb_type = ydb.PrimitiveType.Uint32 |
| 158 | + elif isinstance(type_, types.UInt16): |
| 159 | + ydb_type = ydb.PrimitiveType.Uint16 |
| 160 | + elif isinstance(type_, types.UInt8): |
| 161 | + ydb_type = ydb.PrimitiveType.Uint8 |
| 162 | + elif isinstance(type_, types.Int64): |
| 163 | + ydb_type = ydb.PrimitiveType.Int64 |
| 164 | + elif isinstance(type_, types.Int32): |
| 165 | + ydb_type = ydb.PrimitiveType.Int32 |
| 166 | + elif isinstance(type_, types.Int16): |
| 167 | + ydb_type = ydb.PrimitiveType.Int16 |
| 168 | + elif isinstance(type_, types.Int8): |
| 169 | + ydb_type = ydb.PrimitiveType.Int8 |
137 | 170 | elif isinstance(type_, sa.Integer): |
138 | 171 | ydb_type = ydb.PrimitiveType.Int64 |
| 172 | + # Integers |
| 173 | + |
139 | 174 | elif isinstance(type_, sa.JSON): |
140 | 175 | ydb_type = ydb.PrimitiveType.Json |
141 | 176 | elif isinstance(type_, sa.DateTime): |
@@ -188,6 +223,32 @@ def group_by_clause(self, select, **kw): |
188 | 223 | kw.update(within_columns_clause=True) |
189 | 224 | return super(YqlCompiler, self).group_by_clause(select, **kw) |
190 | 225 |
|
| 226 | + def limit_clause(self, select, **kw): |
| 227 | + text = "" |
| 228 | + if select._limit_clause is not None: |
| 229 | + limit_clause = self._maybe_cast( |
| 230 | + select._limit_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) |
| 231 | + ) |
| 232 | + text += "\n LIMIT " + self.process(limit_clause, **kw) |
| 233 | + if select._offset_clause is not None: |
| 234 | + offset_clause = self._maybe_cast( |
| 235 | + select._offset_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) |
| 236 | + ) |
| 237 | + if select._limit_clause is None: |
| 238 | + text += "\n LIMIT NULL" |
| 239 | + text += " OFFSET " + self.process(offset_clause, **kw) |
| 240 | + return text |
| 241 | + |
| 242 | + def _maybe_cast( |
| 243 | + self, |
| 244 | + element: Any, |
| 245 | + cast_to: Type[sa.types.TypeEngine], |
| 246 | + skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None, |
| 247 | + ) -> Any: |
| 248 | + if not hasattr(element, "type") or not isinstance(element.type, skip_types): |
| 249 | + return sa.Cast(element, cast_to) |
| 250 | + return element |
| 251 | + |
191 | 252 | def render_literal_value(self, value, type_): |
192 | 253 | if isinstance(value, str): |
193 | 254 | value = "".join(STR_QUOTE_MAP.get(x, x) for x in value) |
@@ -277,16 +338,14 @@ def _is_bound_to_nullable_column(self, bind_name: str) -> bool: |
277 | 338 | def _guess_bound_variable_type_by_parameters( |
278 | 339 | self, bind: sa.BindParameter, post_compile_bind_values: list |
279 | 340 | ) -> Optional[sa.types.TypeEngine]: |
280 | | - if not bind.expanding: |
281 | | - if isinstance(bind.type, sa.types.NullType): |
282 | | - return None |
283 | | - bind_type = bind.type |
284 | | - else: |
| 341 | + bind_type = bind.type |
| 342 | + if bind.expanding or (isinstance(bind.type, sa.types.NullType) and post_compile_bind_values): |
285 | 343 | not_null_values = [v for v in post_compile_bind_values if v is not None] |
286 | 344 | if not_null_values: |
287 | 345 | bind_type = sa.BindParameter("", not_null_values[0]).type |
288 | | - else: |
289 | | - return None |
| 346 | + |
| 347 | + if isinstance(bind_type, sa.types.NullType): |
| 348 | + return None |
290 | 349 |
|
291 | 350 | return bind_type |
292 | 351 |
|
|
0 commit comments