Skip to content

Commit fe0c91a

Browse files
authored
Merge branch 'main' into feat/ydb-secondary-indicies
2 parents e7ee419 + 7926d85 commit fe0c91a

File tree

5 files changed

+75
-5
lines changed

5 files changed

+75
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
## 0.0.1b13 ##
12
* Added declare for yql statement variables (opt in) - temporary flag
23

34
## 0.0.1b12 ##

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
setuptools.setup(
1515
name="ydb-sqlalchemy",
16-
version="0.0.1b12", # AUTOVERSION
16+
version="0.0.1b13", # AUTOVERSION
1717
description="YDB Dialect for SQLAlchemy",
1818
author="Yandex LLC",
1919
author_email="ydb@yandex-team.ru",

test/test_orm.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pytest
2+
import sqlalchemy as sa
3+
from types import MethodType
4+
from sqlalchemy import Column, Integer, Unicode
5+
from sqlalchemy.orm import declarative_base, sessionmaker
6+
from sqlalchemy.testing.fixtures import TablesTest, config
7+
8+
9+
class TestDirectories(TablesTest):
10+
__backend__ = True
11+
12+
def prepare_table(self, engine):
13+
base = declarative_base()
14+
15+
class Table(base):
16+
__tablename__ = "dir/test"
17+
id = Column(Integer, primary_key=True)
18+
text = Column(Unicode)
19+
20+
base.metadata.create_all(engine)
21+
session = sessionmaker(bind=engine)()
22+
session.add(Table(id=2, text="foo"))
23+
session.commit()
24+
return base, Table, session
25+
26+
def try_update(self, session, Table):
27+
row = session.query(Table).first()
28+
row.text = "bar"
29+
session.commit()
30+
return row
31+
32+
def drop_table(self, base, engine):
33+
base.metadata.drop_all(engine)
34+
35+
def bind_old_method_to_dialect(self, dialect):
36+
def _handle_column_name(self, variable):
37+
return variable
38+
39+
dialect._handle_column_name = MethodType(_handle_column_name, dialect)
40+
41+
def test_directories(self):
42+
engine_good = sa.create_engine(config.db_url)
43+
base, Table, session = self.prepare_table(engine_good)
44+
row = self.try_update(session, Table)
45+
assert row.id == 2
46+
assert row.text == "bar"
47+
self.drop_table(base, engine_good)
48+
49+
engine_bad = sa.create_engine(config.db_url)
50+
self.bind_old_method_to_dialect(engine_bad.dialect)
51+
base, Table, session = self.prepare_table(engine_bad)
52+
with pytest.raises(Exception) as excinfo:
53+
self.try_update(session, Table)
54+
assert "Unknown name: $dir" in str(excinfo.value)
55+
self.drop_table(base, engine_bad)

ydb_sqlalchemy/_version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
VERSION = "0.0.1b12"
1+
VERSION = "0.0.1b13"

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -625,7 +625,13 @@ class YqlDialect(StrCompileDialect):
625625
def import_dbapi(cls: Any):
626626
return dbapi.YdbDBApi()
627627

628-
def __init__(self, json_serializer=None, json_deserializer=None, _add_declare_for_yql_stmt_vars=False, **kwargs):
628+
def __init__(
629+
self,
630+
json_serializer=None,
631+
json_deserializer=None,
632+
_add_declare_for_yql_stmt_vars=False,
633+
**kwargs,
634+
):
629635
super().__init__(**kwargs)
630636

631637
self._json_deserializer = json_deserializer
@@ -728,6 +734,9 @@ def do_rollback(self, dbapi_connection: dbapi.Connection) -> None:
728734
def do_commit(self, dbapi_connection: dbapi.Connection) -> None:
729735
dbapi_connection.commit()
730736

737+
def _handle_column_name(self, variable):
738+
return "`" + variable + "`"
739+
731740
def _format_variables(
732741
self,
733742
statement: str,
@@ -749,15 +758,20 @@ def _format_variables(
749758
variable_names = set(parameters.keys())
750759
formatted_parameters = {f"${k}": v for k, v in parameters.items()}
751760

752-
formatted_variable_names = {variable_name: f"${variable_name}" for variable_name in variable_names}
761+
formatted_variable_names = {
762+
variable_name: f"${self._handle_column_name(variable_name)}" for variable_name in variable_names
763+
}
753764
formatted_statement = formatted_statement % formatted_variable_names
754765

755766
formatted_statement = formatted_statement.replace("%%", "%")
756767
return formatted_statement, formatted_parameters
757768

758769
def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types):
759770
declarations = "\n".join(
760-
[f"DECLARE {param_name} as {str(param_type)};" for param_name, param_type in parameters_types.items()]
771+
[
772+
f"DECLARE $`{param_name[1:]}` as {str(param_type)};"
773+
for param_name, param_type in parameters_types.items()
774+
]
761775
)
762776
return f"{declarations}\n{statement}"
763777

0 commit comments

Comments
 (0)