Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@
Connection,
conf,
)
from airflow.providers.common.sql.hooks import handlers as sql_handlers
from airflow.providers.common.sql.hooks.handlers import return_single_query_results
from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri
from airflow.utils import timezone
Expand Down Expand Up @@ -697,6 +699,62 @@ def set_autocommit(self, conn, autocommit: Any) -> None:
def get_autocommit(self, conn):
return getattr(conn, "autocommit_mode", False)

@staticmethod
def _session_params_has_autocommit(session_params: Any) -> bool:
"""Check if AUTOCOMMIT is present in a session_parameters dict (case-insensitive)."""
if not isinstance(session_params, dict):
return False
return any(k.upper() == "AUTOCOMMIT" for k in session_params)

def _has_autocommit_session_parameter(self) -> bool:
"""Check if AUTOCOMMIT is configured in session_parameters."""
# Check hook-level session_parameters first (avoids connection lookup)
if isinstance(self.session_parameters, dict):
return self._session_params_has_autocommit(self.session_parameters)
# Fall back to connection-level session_parameters using the cached
# static config to avoid triggering OAuth token refresh.
try:
static_config = self._get_static_conn_params
except Exception:
self.log.debug("Could not read connection params to check AUTOCOMMIT session parameter")
return False
return self._session_params_has_autocommit(static_config.get("session_parameters"))

def _run_command(self, cur, sql_statement, parameters, *, num_statements=None):
"""
Run a statement using an already open cursor.

Extends the base implementation to support Snowflake's ``num_statements``
parameter for multi-statement execution.

:param cur: The database cursor.
:param sql_statement: The SQL statement to execute.
:param parameters: The parameters to bind to the SQL statement.
:param num_statements: Number of statements for Snowflake multi-statement
execution. Set to 0 to auto-detect. None means single-statement mode.
"""
if self.log_sql:
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)

execute_kwargs: dict[str, Any] = {}
if num_statements is not None:
execute_kwargs["num_statements"] = num_statements

if parameters:
cur.execute(sql_statement, parameters, **execute_kwargs)
else:
cur.execute(sql_statement, **execute_kwargs)

send_sql_hook_lineage(
context=self,
sql=sql_statement,
sql_parameters=parameters,
cur=cur,
)

if (row_count := sql_handlers.get_row_count(cur)) is not None:
self.log.info("Rows affected: %s", row_count)

@overload
def run(
self,
Expand Down Expand Up @@ -746,7 +804,13 @@ def run(
:param handler: The result handler which is called with the result of
each statement.
:param split_statements: Whether to split a single SQL string into
statements and run separately
statements and run separately. When False and sql is a string,
the entire SQL block is sent to Snowflake in a single execute()
call with ``num_statements=0`` (auto-detect), enabling
multi-statement execution (e.g., ``BEGIN; INSERT ...; COMMIT;``
transaction blocks). Note that the handler only receives the
first result set, and a single query ID is recorded for the
entire block.
:param return_last: Whether to return result for only last statement or
for all after split.
:param return_dictionaries: Whether to return dictionaries rather than
Expand Down Expand Up @@ -775,14 +839,33 @@ def run(
else:
raise ValueError("List of SQL statements is empty")

# When split_statements=False and sql is a string, the entire SQL
# block is sent as one cursor.execute() call. Snowflake requires
# num_statements to be set for multi-statement execution.
# See: https://github.com/apache/airflow/issues/48233
is_multi_statement = isinstance(sql, str) and not split_statements

with closing(self.get_conn()) as conn:
self.set_autocommit(conn, autocommit)
# Respect AUTOCOMMIT in session_parameters when autocommit is
# False (the default). When autocommit=True, always override.
# See: https://github.com/apache/airflow/issues/30236
if autocommit or not self._has_autocommit_session_parameter():
self.set_autocommit(conn, autocommit)
else:
# AUTOCOMMIT is set in session_parameters and was applied
# during connect(). Record the mode so get_autocommit()
# returns True and we skip the redundant conn.commit().
setattr(conn, "autocommit_mode", True)

with self._get_cursor(conn, return_dictionaries) as cur:
results = []
for sql_statement in sql_list:
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
self._run_command(cur, sql_statement, parameters)
self._run_command(
cur,
sql_statement,
parameters,
num_statements=0 if is_multi_statement else None,
)

if handler is not None:
result = self._make_common_data_structure(handler(cur))
Expand All @@ -794,7 +877,6 @@ def run(
self.descriptions.append(cur.description)

query_id = cur.sfqid
self.log.info("Rows affected: %s", cur.rowcount)
self.log.info("Snowflake query id: %s", query_id)
self.query_ids.append(query_id)

Expand Down
137 changes: 137 additions & 0 deletions providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,143 @@ def test_empty_sql_parameter(self):
with pytest.raises(ValueError, match="List of SQL statements is empty"):
hook.run(sql=empty_statement)

@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
def test_run_multi_statement_with_split_statements_false(self, mock_conn):
"""When split_statements=False, cursor.execute() receives num_statements=0."""
hook = SnowflakeHook()
conn = mock_conn.return_value
cur = mock.MagicMock(rowcount=0)
conn.cursor.return_value = cur
type(cur).sfqid = mock.PropertyMock(return_value="multi_query_id")

sql = "BEGIN; CREATE TABLE t(id INT); INSERT INTO t VALUES(1); COMMIT;"
hook.run(sql, split_statements=False)

# Entire SQL block sent as one execute with num_statements=0
cur.execute.assert_called_once_with(
"BEGIN; CREATE TABLE t(id INT); INSERT INTO t VALUES(1); COMMIT",
num_statements=0,
)
assert hook.query_ids == ["multi_query_id"]

@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
def test_run_split_statements_true_does_not_pass_num_statements(self, mock_conn):
"""When split_statements=True, cursor.execute() does not receive num_statements."""
hook = SnowflakeHook()
conn = mock_conn.return_value
cur = mock.MagicMock(rowcount=0)
conn.cursor.return_value = cur
type(cur).sfqid = mock.PropertyMock(side_effect=["id1", "id2"])

hook.run("SELECT 1; SELECT 2", split_statements=True)

assert cur.execute.call_count == 2
for call in cur.execute.call_args_list:
assert "num_statements" not in call.kwargs

@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
def test_run_sql_list_does_not_pass_num_statements(self, mock_conn):
"""When sql is a list, cursor.execute() does not receive num_statements."""
hook = SnowflakeHook()
conn = mock_conn.return_value
cur = mock.MagicMock(rowcount=0)
conn.cursor.return_value = cur
type(cur).sfqid = mock.PropertyMock(side_effect=["id1", "id2"])

hook.run(["SELECT 1;", "SELECT 2;"])

assert cur.execute.call_count == 2
for call in cur.execute.call_args_list:
assert "num_statements" not in call.kwargs

@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
def test_run_respects_autocommit_session_parameter(self, mock_conn):
"""When session_parameters has AUTOCOMMIT, set_autocommit is skipped and no commit."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["session_parameters"] = {"AUTOCOMMIT": True}
with mock.patch.dict(
"os.environ",
AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**connection_kwargs).get_uri(),
):
hook = SnowflakeHook()
conn = mock_conn.return_value
cur = mock.MagicMock(rowcount=0)
conn.cursor.return_value = cur
type(cur).sfqid = mock.PropertyMock(return_value="qid")

hook.run("SELECT 1", autocommit=False)

# set_autocommit should NOT have been called
conn.autocommit.assert_not_called()
# No manual commit since AUTOCOMMIT session param is in effect
conn.commit.assert_not_called()

@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
def test_run_respects_autocommit_session_parameter_case_insensitive(self, mock_conn):
"""AUTOCOMMIT check is case-insensitive (Snowflake params are case-insensitive)."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["session_parameters"] = {"autocommit": True}
with mock.patch.dict(
"os.environ",
AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**connection_kwargs).get_uri(),
):
hook = SnowflakeHook()
conn = mock_conn.return_value
cur = mock.MagicMock(rowcount=0)
conn.cursor.return_value = cur
type(cur).sfqid = mock.PropertyMock(return_value="qid")

hook.run("SELECT 1", autocommit=False)

conn.autocommit.assert_not_called()
conn.commit.assert_not_called()

@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
def test_run_respects_autocommit_from_hook_session_parameters(self, mock_conn):
"""AUTOCOMMIT from hook constructor session_parameters is respected."""
hook = SnowflakeHook(session_parameters={"AUTOCOMMIT": True})
conn = mock_conn.return_value
cur = mock.MagicMock(rowcount=0)
conn.cursor.return_value = cur
type(cur).sfqid = mock.PropertyMock(return_value="qid")

hook.run("SELECT 1", autocommit=False)

conn.autocommit.assert_not_called()
conn.commit.assert_not_called()

@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
def test_run_explicit_autocommit_true_overrides_session_parameter(self, mock_conn):
"""When autocommit=True is explicit, it overrides session_parameters."""
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
connection_kwargs["extra"]["session_parameters"] = {"AUTOCOMMIT": False}
with mock.patch.dict(
"os.environ",
AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**connection_kwargs).get_uri(),
):
hook = SnowflakeHook()
conn = mock_conn.return_value
cur = mock.MagicMock(rowcount=0)
conn.cursor.return_value = cur
type(cur).sfqid = mock.PropertyMock(return_value="qid")

hook.run("SELECT 1", autocommit=True)

conn.autocommit.assert_called_once_with(True)

@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
def test_run_default_autocommit_without_session_parameter(self, mock_conn):
"""Without AUTOCOMMIT in session_parameters, default (False) is applied."""
hook = SnowflakeHook()
conn = mock_conn.return_value
cur = mock.MagicMock(rowcount=0)
conn.cursor.return_value = cur
type(cur).sfqid = mock.PropertyMock(return_value="qid")

hook.run("SELECT 1")

conn.autocommit.assert_called_once_with(False)

def test_get_openlineage_default_schema_with_no_schema_set(self):
connection_kwargs = {
**BASE_CONNECTION_KWARGS,
Expand Down
Loading