Skip to content

Commit 00c618c

Browse files
author
tretyak-rd
committed
Fix type hints and tests behaviour
1 parent 6307e0b commit 00c618c

File tree

6 files changed

+58
-43
lines changed

6 files changed

+58
-43
lines changed

test/test_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def test_sa_text(self, connection):
2323
DECLARE :data AS List<Struct<x:Int64, y:Int64>>;
2424
SELECT x, y FROM AS_TABLE(:data)
2525
"""
26-
), [{"data": [{"x": 2, "y": 1}, {"x": 3, "y": 2}]}]
26+
),
27+
[{"data": [{"x": 2, "y": 1}, {"x": 3, "y": 2}]}],
2728
)
2829
assert set(rs.fetchall()) == {(2, 1), (3, 2)}
2930

test/test_suite.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sqlalchemy.testing.suite import * # noqa: F401, F403
66

77
from sqlalchemy.testing import is_true, is_false
8-
from sqlalchemy.testing.suite import eq_, testing, inspect, provide_metadata, config, requirements
8+
from sqlalchemy.testing.suite import eq_, testing, inspect, provide_metadata, config, requirements, fixtures
99
from sqlalchemy.testing.suite import func, column, literal_column, select, exists
1010
from sqlalchemy.testing.suite import MetaData, Column, Table, Integer, String
1111

@@ -21,7 +21,6 @@
2121
CompositeKeyReflectionTest as _CompositeKeyReflectionTest,
2222
ComponentReflectionTestExtra as _ComponentReflectionTestExtra,
2323
QuotedNameArgumentTest as _QuotedNameArgumentTest,
24-
BizarroCharacterFKResolutionTest as _BizarroCharacterFKResolutionTest,
2524
)
2625
from sqlalchemy.testing.suite.test_types import (
2726
IntegerTest as _IntegerTest,
@@ -33,6 +32,10 @@
3332
NativeUUIDTest as _NativeUUIDTest,
3433
TimeMicrosecondsTest as _TimeMicrosecondsTest,
3534
DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest,
35+
DateTest as _DateTest,
36+
DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest,
37+
DateTimeTest as _DateTimeTest,
38+
TimestampMicrosecondsTest as _TimestampMicrosecondsTest,
3639
)
3740
from sqlalchemy.testing.suite.test_dialect import (
3841
EscapingTest as _EscapingTest,
@@ -64,11 +67,6 @@ def column_getter(*args, **kwargs):
6467
test_types_suite.Column = column_getter
6568

6669

67-
@pytest.mark.skip("foreign keys unsupported")
68-
class BizarroCharacterFKResolutionTest(_BizarroCharacterFKResolutionTest):
69-
pass
70-
71-
7270
class ComponentReflectionTest(_ComponentReflectionTest):
7371
def _check_list(self, result, exp, req_keys=None, msg=None):
7472
try:
@@ -121,7 +119,6 @@ def define_reflected_tables(cls, metadata, schema):
121119
Column("id", sa.Integer, primary_key=True, comment="id comment"),
122120
Column("data", sa.String(20), comment="data % comment"),
123121
Column("d2", sa.String(20), comment=r"""Comment types type speedily ' " \ '' Fun!"""),
124-
Column("d3", sa.String(42), comment="Comment\nwith\rescapes"),
125122
schema=schema,
126123
comment=r"""the test % ' " \ table comment""",
127124
)
@@ -422,6 +419,22 @@ def test_no_results_for_non_returning_insert(self, connection):
422419
pass
423420

424421

422+
class DateTest(_DateTest):
423+
run_dispose_bind = "once"
424+
425+
426+
class DateTimeMicrosecondsTest(_DateTimeMicrosecondsTest):
427+
run_dispose_bind = "once"
428+
429+
430+
class DateTimeTest(_DateTimeTest):
431+
run_dispose_bind = "once"
432+
433+
434+
class TimestampMicrosecondsTest(_TimestampMicrosecondsTest):
435+
run_dispose_bind = "once"
436+
437+
425438
@pytest.mark.skip("unsupported Time data type")
426439
class TimeTest(_TimeTest):
427440
pass

test_dbapi/test_dbapi.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@ def test_connection(connection):
1212

1313
cur = connection.cursor()
1414
with suppress(dbapi.DatabaseError):
15-
cur.execute(dbapi.YdbOperation("DROP TABLE foo", is_ddl=True))
15+
cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True))
1616

1717
assert not connection.check_exists("/local/foo")
1818
with pytest.raises(dbapi.ProgrammingError):
1919
connection.describe("/local/foo")
2020

21-
cur.execute(dbapi.YdbOperation("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True))
21+
cur.execute(dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True))
2222

2323
assert connection.check_exists("/local/foo")
2424

2525
col = connection.describe("/local/foo").columns[0]
2626
assert col.name == "id"
2727
assert col.type == ydb.PrimitiveType.Int64
2828

29-
cur.execute(dbapi.YdbOperation("DROP TABLE foo", is_ddl=True))
29+
cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True))
3030
cur.close()
3131

3232

@@ -35,18 +35,17 @@ def test_cursor_raw_query(connection):
3535
assert cur
3636

3737
with suppress(dbapi.DatabaseError):
38-
cur.execute(dbapi.YdbOperation("DROP TABLE test", is_ddl=True))
38+
cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True))
3939

40-
cur.execute(dbapi.YdbOperation("CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", is_ddl=True))
40+
cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", is_ddl=True))
4141

4242
cur.execute(
43-
dbapi.YdbOperation(
43+
dbapi.YdbQuery(
4444
"""
4545
DECLARE $data AS List<Struct<id:Int64, text: Utf8>>;
4646
4747
INSERT INTO test SELECT id, text FROM AS_TABLE($data);
4848
""",
49-
is_ddl=False,
5049
parameters_types={
5150
"$data": ydb.ListType(
5251
ydb.StructType()
@@ -63,7 +62,7 @@ def test_cursor_raw_query(connection):
6362
},
6463
)
6564

66-
cur.execute(dbapi.YdbOperation("DROP TABLE test", is_ddl=True))
65+
cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True))
6766

6867
cur.close()
6968

@@ -75,25 +74,25 @@ def test_errors(connection):
7574
cur = connection.cursor()
7675

7776
with suppress(dbapi.DatabaseError):
78-
cur.execute(dbapi.YdbOperation("DROP TABLE test", is_ddl=True))
77+
cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True))
7978

8079
with pytest.raises(dbapi.DataError):
81-
cur.execute(dbapi.YdbOperation("SELECT 18446744073709551616", is_ddl=False))
80+
cur.execute(dbapi.YdbQuery("SELECT 18446744073709551616"))
8281

8382
with pytest.raises(dbapi.DataError):
84-
cur.execute(dbapi.YdbOperation("SELECT * FROM 拉屎", is_ddl=False))
83+
cur.execute(dbapi.YdbQuery("SELECT * FROM 拉屎"))
8584

8685
with pytest.raises(dbapi.DataError):
87-
cur.execute(dbapi.YdbOperation("SELECT floor(5 / 2)", is_ddl=False))
86+
cur.execute(dbapi.YdbQuery("SELECT floor(5 / 2)"))
8887

8988
with pytest.raises(dbapi.ProgrammingError):
90-
cur.execute(dbapi.YdbOperation("SELECT * FROM test", is_ddl=False))
89+
cur.execute(dbapi.YdbQuery("SELECT * FROM test"))
9190

92-
cur.execute(dbapi.YdbOperation("CREATE TABLE test(id Int64, PRIMARY KEY (id))", is_ddl=True))
91+
cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64, PRIMARY KEY (id))", is_ddl=True))
9392

94-
cur.execute(dbapi.YdbOperation("INSERT INTO test(id) VALUES(1)", is_ddl=False))
93+
cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)"))
9594
with pytest.raises(dbapi.IntegrityError):
96-
cur.execute(dbapi.YdbOperation("INSERT INTO test(id) VALUES(1)", is_ddl=False))
95+
cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)"))
9796

98-
cur.execute(dbapi.YdbOperation("DROP TABLE test", is_ddl=True))
97+
cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True))
9998
cur.close()

ydb_sqlalchemy/dbapi/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .connection import Connection
2-
from .cursor import YdbOperation, Cursor
2+
from .cursor import Cursor, YdbQuery # noqa: F401
33
from .errors import (
44
Warning,
55
Error,

ydb_sqlalchemy/dbapi/cursor.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import itertools
33
import logging
44

5-
from typing import Any, Mapping, Optional, Sequence, Union
5+
from typing import Any, Mapping, Optional, Sequence, Union, Dict
66

77
import ydb
88
from .errors import (
@@ -24,10 +24,12 @@ def get_column_type(type_obj: Any) -> str:
2424

2525

2626
@dataclasses.dataclass
27-
class YdbOperation:
27+
class YdbQuery:
2828
yql_text: str
29-
is_ddl: bool
30-
parameters_types: dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]] = dataclasses.field(default_factory=dict)
29+
parameters_types: Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]] = dataclasses.field(
30+
default_factory=dict
31+
)
32+
is_ddl: bool = False
3133

3234

3335
class Cursor(object):
@@ -38,7 +40,7 @@ def __init__(self, connection):
3840
self.rows = None
3941
self._rows_prefetched = None
4042

41-
def execute(self, operation: YdbOperation, parameters: Optional[Mapping[str, Any]] = None):
43+
def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = None):
4244
self.description = None
4345

4446
if operation.is_ddl or not operation.parameters_types:
@@ -129,7 +131,7 @@ def _ensure_prefetched(self):
129131
self.rows = iter(self._rows_prefetched)
130132
return self._rows_prefetched
131133

132-
def executemany(self, operation: YdbOperation, seq_of_parameters: Optional[Sequence[Mapping[str, Any]]]):
134+
def executemany(self, operation: YdbQuery, seq_of_parameters: Optional[Sequence[Mapping[str, Any]]]):
133135
for parameters in seq_of_parameters:
134136
self.execute(operation, parameters)
135137

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
from sqlalchemy.engine.default import StrCompileDialect, DefaultExecutionContext
2323
from sqlalchemy.util.compat import inspect_getfullargspec
2424

25-
from typing import Any, Union, Mapping, Sequence, Optional, Tuple
25+
from typing import Any, Union, Mapping, Sequence, Optional, Tuple, List, Dict
2626

27-
from .types import UInt32, UInt64
27+
from . import types
2828

2929
STR_QUOTE_MAP = {
3030
"'": "\\'",
@@ -136,7 +136,7 @@ def get_ydb_type(
136136
ydb_type = ydb.PrimitiveType.Int64
137137
elif isinstance(type_, sa.JSON):
138138
ydb_type = ydb.PrimitiveType.Json
139-
elif isinstance(type_, (sa.DateTime, sa.TIMESTAMP)):
139+
elif isinstance(type_, sa.DateTime):
140140
ydb_type = ydb.PrimitiveType.Timestamp
141141
elif isinstance(type_, sa.Date):
142142
ydb_type = ydb.PrimitiveType.Date
@@ -288,7 +288,7 @@ def _guess_bound_variable_type_by_parameters(
288288

289289
return bind_type
290290

291-
def _get_expanding_bind_names(self, bind_name: str, parameters_values: Mapping[str, list[Any]]) -> list[Any]:
291+
def _get_expanding_bind_names(self, bind_name: str, parameters_values: Mapping[str, List[Any]]) -> List[Any]:
292292
expanding_bind_names = []
293293
for parameter_name in parameters_values:
294294
parameter_bind_name = "_".join(parameter_name.split("_")[:-1])
@@ -298,7 +298,7 @@ def _get_expanding_bind_names(self, bind_name: str, parameters_values: Mapping[s
298298

299299
def get_bind_types(
300300
self, post_compile_parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]
301-
) -> dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]]:
301+
) -> Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]]:
302302
"""
303303
This method extracts information about bound variables from the table definition and parameters.
304304
"""
@@ -357,8 +357,8 @@ def upsert(table):
357357
ydb.PrimitiveType.Int64: sa.INTEGER,
358358
ydb.PrimitiveType.Uint8: sa.INTEGER,
359359
ydb.PrimitiveType.Uint16: sa.INTEGER,
360-
ydb.PrimitiveType.Uint32: UInt32,
361-
ydb.PrimitiveType.Uint64: UInt64,
360+
ydb.PrimitiveType.Uint32: types.UInt32,
361+
ydb.PrimitiveType.Uint64: types.UInt64,
362362
ydb.PrimitiveType.Float: sa.FLOAT,
363363
ydb.PrimitiveType.Double: sa.FLOAT,
364364
ydb.PrimitiveType.String: sa.BINARY,
@@ -528,17 +528,17 @@ def _make_ydb_operation(
528528
context: Optional[DefaultExecutionContext] = None,
529529
parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] = None,
530530
execute_many: bool = False,
531-
) -> Tuple[dbapi.YdbOperation, Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]:
531+
) -> Tuple[dbapi.YdbQuery, Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]:
532532
is_ddl = context.isddl if context is not None else False
533533

534534
if not is_ddl and parameters:
535535
parameters_types = context.compiled.get_bind_types(parameters)
536536
parameters_types = {f"${k}": v for k, v in parameters_types.items()}
537537
statement, parameters = self._format_variables(statement, parameters, execute_many)
538-
return dbapi.YdbOperation(is_ddl=is_ddl, yql_text=statement, parameters_types=parameters_types), parameters
538+
return dbapi.YdbQuery(yql_text=statement, parameters_types=parameters_types, is_ddl=is_ddl), parameters
539539

540540
statement, parameters = self._format_variables(statement, parameters, execute_many)
541-
return dbapi.YdbOperation(is_ddl=is_ddl, yql_text=statement), parameters
541+
return dbapi.YdbQuery(yql_text=statement, is_ddl=is_ddl), parameters
542542

543543
def do_ping(self, dbapi_connection: dbapi.Connection) -> bool:
544544
cursor = dbapi_connection.cursor()

0 commit comments

Comments
 (0)