From 226f2771d79167db77fd80c0a7a299bdbc06f6a2 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 11 Apr 2026 01:56:26 +0100 Subject: [PATCH 1/3] Fix SnowflakeHook transaction support: multi-statement SQL and AUTOCOMMIT When split_statements=False, pass num_statements=0 to cursor.execute() so Snowflake accepts multi-statement SQL blocks (BEGIN/INSERT/COMMIT). Previously this failed with "Actual statement count N did not match the desired statement count 1". Also respect AUTOCOMMIT in session_parameters instead of unconditionally overriding it with set_autocommit(conn, False). Closes: #48233 Closes: #30236 --- .../providers/snowflake/hooks/snowflake.py | 93 +++++++++++- .../unit/snowflake/hooks/test_snowflake.py | 137 ++++++++++++++++++ 2 files changed, 225 insertions(+), 5 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index 3a9205fd8f519..a374dd24b9213 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -42,7 +42,9 @@ Connection, conf, ) +from airflow.providers.common.sql.hooks import handlers as sql_handlers from airflow.providers.common.sql.hooks.handlers import return_single_query_results +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri from airflow.utils import timezone @@ -697,6 +699,63 @@ def set_autocommit(self, conn, autocommit: Any) -> None: def get_autocommit(self, conn): return getattr(conn, "autocommit_mode", False) + @staticmethod + def _session_params_has_autocommit(session_params: Any) -> bool: + """Check if AUTOCOMMIT is present in a session_parameters dict (case-insensitive).""" + if not isinstance(session_params, dict): + return False + return any(k.upper() == "AUTOCOMMIT" for k in session_params) + + def _has_autocommit_session_parameter(self) -> bool: + """Check if AUTOCOMMIT is configured in session_parameters.""" + # Check hook-level session_parameters first (avoids connection lookup) + if isinstance(self.session_parameters, dict): + return self._session_params_has_autocommit(self.session_parameters) + # Fall back to connection-level session_parameters using the cached + # static config to avoid triggering OAuth token refresh. + try: + static_config = self._get_static_conn_params + except Exception: + self.log.debug("Could not read connection params to check AUTOCOMMIT session parameter") + return False + session_params = static_config.get("session_parameters") or {} + return self._session_params_has_autocommit(session_params) + + def _run_command(self, cur, sql_statement, parameters, *, num_statements=None): + """ + Run a statement using an already open cursor. + + Extends the base implementation to support Snowflake's ``num_statements`` + parameter for multi-statement execution. + + :param cur: The database cursor. + :param sql_statement: The SQL statement to execute. + :param parameters: The parameters to bind to the SQL statement. + :param num_statements: Number of statements for Snowflake multi-statement + execution. Set to 0 to auto-detect. None means single-statement mode. + """ + if self.log_sql: + self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) + + execute_kwargs: dict[str, Any] = {} + if num_statements is not None: + execute_kwargs["num_statements"] = num_statements + + if parameters: + cur.execute(sql_statement, parameters, **execute_kwargs) + else: + cur.execute(sql_statement, **execute_kwargs) + + send_sql_hook_lineage( + context=self, + sql=sql_statement, + sql_parameters=parameters, + cur=cur, + ) + + if (row_count := sql_handlers.get_row_count(cur)) is not None: + self.log.info("Rows affected: %s", row_count) + @overload def run( self, @@ -746,7 +805,13 @@ def run( :param handler: The result handler which is called with the result of each statement. :param split_statements: Whether to split a single SQL string into - statements and run separately + statements and run separately. When False and sql is a string, + the entire SQL block is sent to Snowflake in a single execute() + call with ``num_statements=0`` (auto-detect), enabling + multi-statement execution (e.g., ``BEGIN; INSERT ...; COMMIT;`` + transaction blocks). Note that the handler only receives the + first result set, and a single query ID is recorded for the + entire block. :param return_last: Whether to return result for only last statement or for all after split. :param return_dictionaries: Whether to return dictionaries rather than @@ -775,14 +840,33 @@ def run( else: raise ValueError("List of SQL statements is empty") + # When split_statements=False and sql is a string, the entire SQL + # block is sent as one cursor.execute() call. Snowflake requires + # num_statements to be set for multi-statement execution. + # See: https://github.com/apache/airflow/issues/48233 + is_multi_statement = isinstance(sql, str) and not split_statements + with closing(self.get_conn()) as conn: - self.set_autocommit(conn, autocommit) + # Respect AUTOCOMMIT in session_parameters when autocommit is + # False (the default). When autocommit=True, always override. + # See: https://github.com/apache/airflow/issues/30236 + if autocommit or not self._has_autocommit_session_parameter(): + self.set_autocommit(conn, autocommit) + else: + # AUTOCOMMIT is set in session_parameters and was applied + # during connect(). Record the mode so get_autocommit() + # returns True and we skip the redundant conn.commit(). + conn.autocommit_mode = True with self._get_cursor(conn, return_dictionaries) as cur: results = [] for sql_statement in sql_list: - self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters) - self._run_command(cur, sql_statement, parameters) + self._run_command( + cur, + sql_statement, + parameters, + num_statements=0 if is_multi_statement else None, + ) if handler is not None: result = self._make_common_data_structure(handler(cur)) @@ -794,7 +878,6 @@ def run( self.descriptions.append(cur.description) query_id = cur.sfqid - self.log.info("Rows affected: %s", cur.rowcount) self.log.info("Snowflake query id: %s", query_id) self.query_ids.append(query_id) diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index 901ac925c0700..5e59a282b3767 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -1043,6 +1043,143 @@ def test_empty_sql_parameter(self): with pytest.raises(ValueError, match="List of SQL statements is empty"): hook.run(sql=empty_statement) + @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn") + def test_run_multi_statement_with_split_statements_false(self, mock_conn): + """When split_statements=False, cursor.execute() receives num_statements=0.""" + hook = SnowflakeHook() + conn = mock_conn.return_value + cur = mock.MagicMock(rowcount=0) + conn.cursor.return_value = cur + type(cur).sfqid = mock.PropertyMock(return_value="multi_query_id") + + sql = "BEGIN; CREATE TABLE t(id INT); INSERT INTO t VALUES(1); COMMIT;" + hook.run(sql, split_statements=False) + + # Entire SQL block sent as one execute with num_statements=0 + cur.execute.assert_called_once_with( + "BEGIN; CREATE TABLE t(id INT); INSERT INTO t VALUES(1); COMMIT", + num_statements=0, + ) + assert hook.query_ids == ["multi_query_id"] + + @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn") + def test_run_split_statements_true_does_not_pass_num_statements(self, mock_conn): + """When split_statements=True, cursor.execute() does not receive num_statements.""" + hook = SnowflakeHook() + conn = mock_conn.return_value + cur = mock.MagicMock(rowcount=0) + conn.cursor.return_value = cur + type(cur).sfqid = mock.PropertyMock(side_effect=["id1", "id2"]) + + hook.run("SELECT 1; SELECT 2", split_statements=True) + + assert cur.execute.call_count == 2 + for call in cur.execute.call_args_list: + assert "num_statements" not in call.kwargs + + @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn") + def test_run_sql_list_does_not_pass_num_statements(self, mock_conn): + """When sql is a list, cursor.execute() does not receive num_statements.""" + hook = SnowflakeHook() + conn = mock_conn.return_value + cur = mock.MagicMock(rowcount=0) + conn.cursor.return_value = cur + type(cur).sfqid = mock.PropertyMock(side_effect=["id1", "id2"]) + + hook.run(["SELECT 1;", "SELECT 2;"]) + + assert cur.execute.call_count == 2 + for call in cur.execute.call_args_list: + assert "num_statements" not in call.kwargs + + @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn") + def test_run_respects_autocommit_session_parameter(self, mock_conn): + """When session_parameters has AUTOCOMMIT, set_autocommit is skipped and no commit.""" + connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) + connection_kwargs["extra"]["session_parameters"] = {"AUTOCOMMIT": True} + with mock.patch.dict( + "os.environ", + AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**connection_kwargs).get_uri(), + ): + hook = SnowflakeHook() + conn = mock_conn.return_value + cur = mock.MagicMock(rowcount=0) + conn.cursor.return_value = cur + type(cur).sfqid = mock.PropertyMock(return_value="qid") + + hook.run("SELECT 1", autocommit=False) + + # set_autocommit should NOT have been called + conn.autocommit.assert_not_called() + # No manual commit since AUTOCOMMIT session param is in effect + conn.commit.assert_not_called() + + @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn") + def test_run_respects_autocommit_session_parameter_case_insensitive(self, mock_conn): + """AUTOCOMMIT check is case-insensitive (Snowflake params are case-insensitive).""" + connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) + connection_kwargs["extra"]["session_parameters"] = {"autocommit": True} + with mock.patch.dict( + "os.environ", + AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**connection_kwargs).get_uri(), + ): + hook = SnowflakeHook() + conn = mock_conn.return_value + cur = mock.MagicMock(rowcount=0) + conn.cursor.return_value = cur + type(cur).sfqid = mock.PropertyMock(return_value="qid") + + hook.run("SELECT 1", autocommit=False) + + conn.autocommit.assert_not_called() + conn.commit.assert_not_called() + + @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn") + def test_run_respects_autocommit_from_hook_session_parameters(self, mock_conn): + """AUTOCOMMIT from hook constructor session_parameters is respected.""" + hook = SnowflakeHook(session_parameters={"AUTOCOMMIT": True}) + conn = mock_conn.return_value + cur = mock.MagicMock(rowcount=0) + conn.cursor.return_value = cur + type(cur).sfqid = mock.PropertyMock(return_value="qid") + + hook.run("SELECT 1", autocommit=False) + + conn.autocommit.assert_not_called() + conn.commit.assert_not_called() + + @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn") + def test_run_explicit_autocommit_true_overrides_session_parameter(self, mock_conn): + """When autocommit=True is explicit, it overrides session_parameters.""" + connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) + connection_kwargs["extra"]["session_parameters"] = {"AUTOCOMMIT": False} + with mock.patch.dict( + "os.environ", + AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**connection_kwargs).get_uri(), + ): + hook = SnowflakeHook() + conn = mock_conn.return_value + cur = mock.MagicMock(rowcount=0) + conn.cursor.return_value = cur + type(cur).sfqid = mock.PropertyMock(return_value="qid") + + hook.run("SELECT 1", autocommit=True) + + conn.autocommit.assert_called_once_with(True) + + @mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn") + def test_run_default_autocommit_without_session_parameter(self, mock_conn): + """Without AUTOCOMMIT in session_parameters, default (False) is applied.""" + hook = SnowflakeHook() + conn = mock_conn.return_value + cur = mock.MagicMock(rowcount=0) + conn.cursor.return_value = cur + type(cur).sfqid = mock.PropertyMock(return_value="qid") + + hook.run("SELECT 1") + + conn.autocommit.assert_called_once_with(False) + def test_get_openlineage_default_schema_with_no_schema_set(self): connection_kwargs = { **BASE_CONNECTION_KWARGS, From e50badcf26d101cc1bb29fdc680cb858b094f089 Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 11 Apr 2026 03:38:12 +0100 Subject: [PATCH 2/3] Fix mypy errors in Snowflake hook --- .../src/airflow/providers/snowflake/hooks/snowflake.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index a374dd24b9213..fbfa594647880 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -718,7 +718,7 @@ def _has_autocommit_session_parameter(self) -> bool: except Exception: self.log.debug("Could not read connection params to check AUTOCOMMIT session parameter") return False - session_params = static_config.get("session_parameters") or {} + session_params: dict = static_config.get("session_parameters") or {} return self._session_params_has_autocommit(session_params) def _run_command(self, cur, sql_statement, parameters, *, num_statements=None): @@ -856,7 +856,7 @@ def run( # AUTOCOMMIT is set in session_parameters and was applied # during connect(). Record the mode so get_autocommit() # returns True and we skip the redundant conn.commit(). - conn.autocommit_mode = True + setattr(conn, "autocommit_mode", True) with self._get_cursor(conn, return_dictionaries) as cur: results = [] From cd0ffec853860c8eb422203a2855f1ef0b5d37fc Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Sat, 11 Apr 2026 04:36:24 +0100 Subject: [PATCH 3/3] Fix mypy error: pass session_parameters directly to type-checking helper _session_params_has_autocommit already handles non-dict inputs, so the intermediate variable with a dict annotation is unnecessary and causes a mypy assignment error. --- .../src/airflow/providers/snowflake/hooks/snowflake.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py index fbfa594647880..c69f243558133 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py @@ -718,8 +718,7 @@ def _has_autocommit_session_parameter(self) -> bool: except Exception: self.log.debug("Could not read connection params to check AUTOCOMMIT session parameter") return False - session_params: dict = static_config.get("session_parameters") or {} - return self._session_params_has_autocommit(session_params) + return self._session_params_has_autocommit(static_config.get("session_parameters")) def _run_command(self, cur, sql_statement, parameters, *, num_statements=None): """