|
4 | 4 | """ |
5 | 5 | import collections |
6 | 6 | import collections.abc |
7 | | -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union |
| 7 | +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union |
8 | 8 |
|
9 | 9 | import sqlalchemy as sa |
10 | 10 | import ydb |
@@ -87,15 +87,30 @@ def visit_FLOAT(self, type_: sa.FLOAT, **kw): |
87 | 87 | def visit_BOOLEAN(self, type_: sa.BOOLEAN, **kw): |
88 | 88 | return "BOOL" |
89 | 89 |
|
| 90 | + def visit_uint64(self, type_: types.UInt64, **kw): |
| 91 | + return "UInt64" |
| 92 | + |
90 | 93 | def visit_uint32(self, type_: types.UInt32, **kw): |
91 | 94 | return "UInt32" |
92 | 95 |
|
93 | | - def visit_uint64(self, type_: types.UInt64, **kw): |
94 | | - return "UInt64" |
| 96 | + def visit_uint16(self, type_: types.UInt16, **kw): |
| 97 | + return "UInt16" |
95 | 98 |
|
96 | 99 | def visit_uint8(self, type_: types.UInt8, **kw): |
97 | 100 | return "UInt8" |
98 | 101 |
|
| 102 | + def visit_int64(self, type_: types.Int64, **kw): |
| 103 | + return "Int64" |
| 104 | + |
| 105 | + def visit_int32(self, type_: types.Int32, **kw): |
| 106 | + return "Int32" |
| 107 | + |
| 108 | + def visit_int16(self, type_: types.Int16, **kw): |
| 109 | + return "Int16" |
| 110 | + |
| 111 | + def visit_int8(self, type_: types.Int8, **kw): |
| 112 | + return "Int8" |
| 113 | + |
99 | 114 | def visit_INTEGER(self, type_: sa.INTEGER, **kw): |
100 | 115 | return "Int64" |
101 | 116 |
|
@@ -134,8 +149,28 @@ def get_ydb_type( |
134 | 149 |
|
135 | 150 | if isinstance(type_, (sa.Text, sa.String, sa.Uuid)): |
136 | 151 | ydb_type = ydb.PrimitiveType.Utf8 |
| 152 | + |
| 153 | + # Integers |
| 154 | + elif isinstance(type_, types.UInt64): |
| 155 | + ydb_type = ydb.PrimitiveType.Uint64 |
| 156 | + elif isinstance(type_, types.UInt32): |
| 157 | + ydb_type = ydb.PrimitiveType.Uint32 |
| 158 | + elif isinstance(type_, types.UInt16): |
| 159 | + ydb_type = ydb.PrimitiveType.Uint16 |
| 160 | + elif isinstance(type_, types.UInt8): |
| 161 | + ydb_type = ydb.PrimitiveType.Uint8 |
| 162 | + elif isinstance(type_, types.Int64): |
| 163 | + ydb_type = ydb.PrimitiveType.Int64 |
| 164 | + elif isinstance(type_, types.Int32): |
| 165 | + ydb_type = ydb.PrimitiveType.Int32 |
| 166 | + elif isinstance(type_, types.Int16): |
| 167 | + ydb_type = ydb.PrimitiveType.Int16 |
| 168 | + elif isinstance(type_, types.Int8): |
| 169 | + ydb_type = ydb.PrimitiveType.Int8 |
137 | 170 | elif isinstance(type_, sa.Integer): |
138 | 171 | ydb_type = ydb.PrimitiveType.Int64 |
| 172 | + # Integers |
| 173 | + |
139 | 174 | elif isinstance(type_, sa.JSON): |
140 | 175 | ydb_type = ydb.PrimitiveType.Json |
141 | 176 | elif isinstance(type_, sa.DateTime): |
@@ -188,6 +223,36 @@ def group_by_clause(self, select, **kw): |
188 | 223 | kw.update(within_columns_clause=True) |
189 | 224 | return super(YqlCompiler, self).group_by_clause(select, **kw) |
190 | 225 |
|
| 226 | + def limit_clause(self, select, **kw): |
| 227 | + text = "" |
| 228 | + if select._limit_clause is not None: |
| 229 | + limit_clause = self._maybe_cast( |
| 230 | + select._limit_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) |
| 231 | + ) |
| 232 | + text += "\n LIMIT " + self.process(limit_clause, **kw) |
| 233 | + if select._offset_clause is not None: |
| 234 | + offset_clause = self._maybe_cast( |
| 235 | + select._offset_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) |
| 236 | + ) |
| 237 | + if select._limit_clause is None: |
| 238 | + text += "\n LIMIT NULL" |
| 239 | + text += " OFFSET " + self.process(offset_clause, **kw) |
| 240 | + return text |
| 241 | + |
| 242 | + def _maybe_cast( |
| 243 | + self, |
| 244 | + element: Any, |
| 245 | + cast_to: Type[sa.types.TypeEngine], |
| 246 | + skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None, |
| 247 | + ) -> Any: |
| 248 | + if not skip_types: |
| 249 | + skip_types = (cast_to,) |
| 250 | + if cast_to not in skip_types: |
| 251 | + skip_types = (*skip_types, cast_to) |
| 252 | + if not hasattr(element, "type") or not isinstance(element.type, skip_types): |
| 253 | + return sa.Cast(element, cast_to) |
| 254 | + return element |
| 255 | + |
191 | 256 | def render_literal_value(self, value, type_): |
192 | 257 | if isinstance(value, str): |
193 | 258 | value = "".join(STR_QUOTE_MAP.get(x, x) for x in value) |
@@ -277,16 +342,14 @@ def _is_bound_to_nullable_column(self, bind_name: str) -> bool: |
277 | 342 | def _guess_bound_variable_type_by_parameters( |
278 | 343 | self, bind: sa.BindParameter, post_compile_bind_values: list |
279 | 344 | ) -> Optional[sa.types.TypeEngine]: |
280 | | - if not bind.expanding: |
281 | | - if isinstance(bind.type, sa.types.NullType): |
282 | | - return None |
283 | | - bind_type = bind.type |
284 | | - else: |
| 345 | + bind_type = bind.type |
| 346 | + if bind.expanding or (isinstance(bind.type, sa.types.NullType) and post_compile_bind_values): |
285 | 347 | not_null_values = [v for v in post_compile_bind_values if v is not None] |
286 | 348 | if not_null_values: |
287 | 349 | bind_type = sa.BindParameter("", not_null_values[0]).type |
288 | | - else: |
289 | | - return None |
| 350 | + |
| 351 | + if isinstance(bind_type, sa.types.NullType): |
| 352 | + return None |
290 | 353 |
|
291 | 354 | return bind_type |
292 | 355 |
|
|
0 commit comments