Skip to content

Commit 6267ca9

Browse files
authored
Merge pull request #23 from ilchuk96/add-upsert-support Add upsert support
2 parents dc7381a + 839dcac commit 6267ca9

File tree

5 files changed

+157
-3
lines changed

5 files changed

+157
-3
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"sqlalchemy.dialects": [
4141
"yql.ydb=ydb_sqlalchemy.sqlalchemy:YqlDialect",
4242
"ydb=ydb_sqlalchemy.sqlalchemy:YqlDialect",
43+
"yql=ydb_sqlalchemy.sqlalchemy:YqlDialect",
4344
]
4445
},
4546
)

test/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
registry.register("yql.ydb", "ydb_sqlalchemy.sqlalchemy", "YqlDialect")
55
registry.register("ydb", "ydb_sqlalchemy.sqlalchemy", "YqlDialect")
6+
registry.register("yql", "ydb_sqlalchemy.sqlalchemy", "YqlDialect")
67
pytest.register_assert_rewrite("sqlalchemy.testing.assertions")
78

89
from sqlalchemy.testing.plugin.pytestplugin import * # noqa: E402, F401, F403

test/test_core.py

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33
from typing import NamedTuple
44

55
import pytest
6+
67
import sqlalchemy as sa
7-
import ydb
8-
from sqlalchemy import Table, Column, Integer, Unicode
8+
from sqlalchemy import Table, Column, Integer, Unicode, String
99
from sqlalchemy.testing.fixtures import TestBase, TablesTest, config
10+
11+
import ydb
1012
from ydb._grpc.v4.protos import ydb_common_pb2
1113

1214
from ydb_sqlalchemy import dbapi, IsolationLevel
1315
from ydb_sqlalchemy.sqlalchemy import types
16+
from ydb_sqlalchemy import sqlalchemy as ydb_sa
1417

1518

1619
def clear_sql(stm):
@@ -539,3 +542,136 @@ def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
539542
engine1.dispose()
540543
engine2.dispose()
541544
assert not ydb_driver._stopped
545+
546+
547+
class TestUpsert(TablesTest):
548+
@classmethod
549+
def define_tables(cls, metadata):
550+
Table(
551+
"test_upsert",
552+
metadata,
553+
Column("id", Integer, primary_key=True),
554+
Column("val", Integer),
555+
)
556+
557+
def test_string(self, connection):
558+
tb = self.tables.test_upsert
559+
stm = ydb_sa.upsert(tb).values(id=0, val=5)
560+
561+
assert str(stm) == "UPSERT INTO test_upsert (id, val) VALUES (?, ?)"
562+
563+
def test_upsert_new_id(self, connection):
564+
tb = self.tables.test_upsert
565+
stm = ydb_sa.upsert(tb).values(id=0, val=1)
566+
connection.execute(stm)
567+
row = connection.execute(sa.select(tb)).fetchall()
568+
assert row == [(0, 1)]
569+
570+
stm = ydb_sa.upsert(tb).values(id=1, val=2)
571+
connection.execute(stm)
572+
row = connection.execute(sa.select(tb)).fetchall()
573+
assert row == [(0, 1), (1, 2)]
574+
575+
def test_upsert_existing_id(self, connection):
576+
tb = self.tables.test_upsert
577+
stm = ydb_sa.upsert(tb).values(id=0, val=5)
578+
connection.execute(stm)
579+
row = connection.execute(sa.select(tb)).fetchall()
580+
581+
assert row == [(0, 5)]
582+
583+
stm = ydb_sa.upsert(tb).values(id=0, val=6)
584+
connection.execute(stm)
585+
row = connection.execute(sa.select(tb)).fetchall()
586+
587+
assert row == [(0, 6)]
588+
589+
def test_upsert_several_diff_id(self, connection):
590+
tb = self.tables.test_upsert
591+
stm = ydb_sa.upsert(tb).values(
592+
[
593+
{"id": 0, "val": 4},
594+
{"id": 1, "val": 5},
595+
{"id": 2, "val": 6},
596+
]
597+
)
598+
connection.execute(stm)
599+
row = connection.execute(sa.select(tb)).fetchall()
600+
601+
assert row == [(0, 4), (1, 5), (2, 6)]
602+
603+
def test_upsert_several_same_id(self, connection):
604+
tb = self.tables.test_upsert
605+
stm = ydb_sa.upsert(tb).values(
606+
[
607+
{"id": 0, "val": 4},
608+
{"id": 0, "val": 5},
609+
{"id": 0, "val": 6},
610+
]
611+
)
612+
connection.execute(stm)
613+
row = connection.execute(sa.select(tb)).fetchall()
614+
615+
assert row == [(0, 6)]
616+
617+
def test_upsert_from_select(self, connection, metadata):
618+
table_to_select_from = Table(
619+
"table_to_select_from",
620+
metadata,
621+
Column("id", Integer, primary_key=True),
622+
Column("val", Integer),
623+
)
624+
table_to_select_from.create(connection)
625+
stm = sa.insert(table_to_select_from).values(
626+
[
627+
{"id": 100, "val": 0},
628+
{"id": 110, "val": 1},
629+
{"id": 120, "val": 2},
630+
{"id": 130, "val": 3},
631+
]
632+
)
633+
connection.execute(stm)
634+
635+
tb = self.tables.test_upsert
636+
select_stm = sa.select(table_to_select_from.c.id, table_to_select_from.c.val).where(
637+
table_to_select_from.c.id > 115,
638+
)
639+
upsert_stm = ydb_sa.upsert(tb).from_select(["id", "val"], select_stm)
640+
connection.execute(upsert_stm)
641+
row = connection.execute(sa.select(tb)).fetchall()
642+
643+
assert row == [(120, 2), (130, 3)]
644+
645+
646+
class TestUpsertDoesNotReplaceInsert(TablesTest):
647+
@classmethod
648+
def define_tables(cls, metadata):
649+
Table(
650+
"test_upsert_does_not_replace_insert",
651+
metadata,
652+
Column("id", Integer, primary_key=True),
653+
Column("VALUE_TO_INSERT", String),
654+
)
655+
656+
def test_string(self, connection):
657+
tb = self.tables.test_upsert_does_not_replace_insert
658+
659+
stm = ydb_sa.upsert(tb).values(id=0, VALUE_TO_INSERT="5")
660+
661+
assert str(stm) == "UPSERT INTO test_upsert_does_not_replace_insert (id, `VALUE_TO_INSERT`) VALUES (?, ?)"
662+
663+
def test_insert_in_name(self, connection):
664+
tb = self.tables.test_upsert_does_not_replace_insert
665+
stm = ydb_sa.upsert(tb).values(id=1, VALUE_TO_INSERT="5")
666+
connection.execute(stm)
667+
row = connection.execute(sa.select(tb).where(tb.c.id == 1)).fetchone()
668+
669+
assert row == (1, "5")
670+
671+
def test_insert_in_name_and_field(self, connection):
672+
tb = self.tables.test_upsert_does_not_replace_insert
673+
stm = ydb_sa.upsert(tb).values(id=2, VALUE_TO_INSERT="INSERT is my favourite operation")
674+
connection.execute(stm)
675+
row = connection.execute(sa.select(tb).where(tb.c.id == 2)).fetchone()
676+
677+
assert row == (2, "INSERT is my favourite operation")

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import ydb
77
import ydb_sqlalchemy.dbapi as dbapi
88
from ydb_sqlalchemy.dbapi.constants import YDB_KEYWORDS
9+
from ydb_sqlalchemy.sqlalchemy.dml import Upsert
910

1011
import sqlalchemy as sa
1112
from sqlalchemy.exc import CompileError, NoSuchTableError
@@ -341,6 +342,9 @@ def get_bind_types(
341342

342343
return parameter_types
343344

345+
def visit_upsert(self, insert_stmt, visited_bindparam=None, **kw):
346+
return self.visit_insert(insert_stmt, visited_bindparam, **kw).replace("INSERT", "UPSERT", 1)
347+
344348

345349
class YqlDDLCompiler(DDLCompiler):
346350
def post_create_table(self, table: sa.Table) -> str:
@@ -379,7 +383,7 @@ def _render_table_partitioning_settings(self, ydb_opts: Dict[str, Any]) -> List[
379383

380384

381385
def upsert(table):
382-
return sa.sql.Insert(table)
386+
return Upsert(table)
383387

384388

385389
COLUMN_TYPES = {

ydb_sqlalchemy/sqlalchemy/dml.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import sqlalchemy as sa
2+
3+
4+
class Upsert(sa.sql.Insert):
5+
__visit_name__ = "upsert"
6+
_propagate_attrs = {"compile_state_plugin": "yql"}
7+
stringify_dialect = "yql"
8+
9+
10+
@sa.sql.base.CompileState.plugin_for("yql", "upsert")
11+
class UpsertDMLState(sa.sql.dml.InsertDMLState):
12+
pass

0 commit comments

Comments
 (0)