Skip to content

Commit 226f277

Browse files
committed
Fix SnowflakeHook transaction support: multi-statement SQL and AUTOCOMMIT
When split_statements=False, pass num_statements=0 to cursor.execute() so Snowflake accepts multi-statement SQL blocks (BEGIN/INSERT/COMMIT). Previously this failed with "Actual statement count N did not match the desired statement count 1". Also respect AUTOCOMMIT in session_parameters instead of unconditionally overriding it with set_autocommit(conn, False). Closes: #48233 Closes: #30236
1 parent 34497e5 commit 226f277

2 files changed

Lines changed: 225 additions & 5 deletions

File tree

providers/snowflake/src/airflow/providers/snowflake/hooks/snowflake.py

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@
4242
Connection,
4343
conf,
4444
)
45+
from airflow.providers.common.sql.hooks import handlers as sql_handlers
4546
from airflow.providers.common.sql.hooks.handlers import return_single_query_results
47+
from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage
4648
from airflow.providers.common.sql.hooks.sql import DbApiHook
4749
from airflow.providers.snowflake.utils.openlineage import fix_snowflake_sqlalchemy_uri
4850
from airflow.utils import timezone
@@ -697,6 +699,63 @@ def set_autocommit(self, conn, autocommit: Any) -> None:
697699
def get_autocommit(self, conn):
698700
return getattr(conn, "autocommit_mode", False)
699701

702+
@staticmethod
703+
def _session_params_has_autocommit(session_params: Any) -> bool:
704+
"""Check if AUTOCOMMIT is present in a session_parameters dict (case-insensitive)."""
705+
if not isinstance(session_params, dict):
706+
return False
707+
return any(k.upper() == "AUTOCOMMIT" for k in session_params)
708+
709+
def _has_autocommit_session_parameter(self) -> bool:
710+
"""Check if AUTOCOMMIT is configured in session_parameters."""
711+
# Check hook-level session_parameters first (avoids connection lookup)
712+
if isinstance(self.session_parameters, dict):
713+
return self._session_params_has_autocommit(self.session_parameters)
714+
# Fall back to connection-level session_parameters using the cached
715+
# static config to avoid triggering OAuth token refresh.
716+
try:
717+
static_config = self._get_static_conn_params
718+
except Exception:
719+
self.log.debug("Could not read connection params to check AUTOCOMMIT session parameter")
720+
return False
721+
session_params = static_config.get("session_parameters") or {}
722+
return self._session_params_has_autocommit(session_params)
723+
724+
def _run_command(self, cur, sql_statement, parameters, *, num_statements=None):
725+
"""
726+
Run a statement using an already open cursor.
727+
728+
Extends the base implementation to support Snowflake's ``num_statements``
729+
parameter for multi-statement execution.
730+
731+
:param cur: The database cursor.
732+
:param sql_statement: The SQL statement to execute.
733+
:param parameters: The parameters to bind to the SQL statement.
734+
:param num_statements: Number of statements for Snowflake multi-statement
735+
execution. Set to 0 to auto-detect. None means single-statement mode.
736+
"""
737+
if self.log_sql:
738+
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
739+
740+
execute_kwargs: dict[str, Any] = {}
741+
if num_statements is not None:
742+
execute_kwargs["num_statements"] = num_statements
743+
744+
if parameters:
745+
cur.execute(sql_statement, parameters, **execute_kwargs)
746+
else:
747+
cur.execute(sql_statement, **execute_kwargs)
748+
749+
send_sql_hook_lineage(
750+
context=self,
751+
sql=sql_statement,
752+
sql_parameters=parameters,
753+
cur=cur,
754+
)
755+
756+
if (row_count := sql_handlers.get_row_count(cur)) is not None:
757+
self.log.info("Rows affected: %s", row_count)
758+
700759
@overload
701760
def run(
702761
self,
@@ -746,7 +805,13 @@ def run(
746805
:param handler: The result handler which is called with the result of
747806
each statement.
748807
:param split_statements: Whether to split a single SQL string into
749-
statements and run separately
808+
statements and run separately. When False and sql is a string,
809+
the entire SQL block is sent to Snowflake in a single execute()
810+
call with ``num_statements=0`` (auto-detect), enabling
811+
multi-statement execution (e.g., ``BEGIN; INSERT ...; COMMIT;``
812+
transaction blocks). Note that the handler only receives the
813+
first result set, and a single query ID is recorded for the
814+
entire block.
750815
:param return_last: Whether to return result for only last statement or
751816
for all after split.
752817
:param return_dictionaries: Whether to return dictionaries rather than
@@ -775,14 +840,33 @@ def run(
775840
else:
776841
raise ValueError("List of SQL statements is empty")
777842

843+
# When split_statements=False and sql is a string, the entire SQL
844+
# block is sent as one cursor.execute() call. Snowflake requires
845+
# num_statements to be set for multi-statement execution.
846+
# See: https://github.com/apache/airflow/issues/48233
847+
is_multi_statement = isinstance(sql, str) and not split_statements
848+
778849
with closing(self.get_conn()) as conn:
779-
self.set_autocommit(conn, autocommit)
850+
# Respect AUTOCOMMIT in session_parameters when autocommit is
851+
# False (the default). When autocommit=True, always override.
852+
# See: https://github.com/apache/airflow/issues/30236
853+
if autocommit or not self._has_autocommit_session_parameter():
854+
self.set_autocommit(conn, autocommit)
855+
else:
856+
# AUTOCOMMIT is set in session_parameters and was applied
857+
# during connect(). Record the mode so get_autocommit()
858+
# returns True and we skip the redundant conn.commit().
859+
conn.autocommit_mode = True
780860

781861
with self._get_cursor(conn, return_dictionaries) as cur:
782862
results = []
783863
for sql_statement in sql_list:
784-
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
785-
self._run_command(cur, sql_statement, parameters)
864+
self._run_command(
865+
cur,
866+
sql_statement,
867+
parameters,
868+
num_statements=0 if is_multi_statement else None,
869+
)
786870

787871
if handler is not None:
788872
result = self._make_common_data_structure(handler(cur))
@@ -794,7 +878,6 @@ def run(
794878
self.descriptions.append(cur.description)
795879

796880
query_id = cur.sfqid
797-
self.log.info("Rows affected: %s", cur.rowcount)
798881
self.log.info("Snowflake query id: %s", query_id)
799882
self.query_ids.append(query_id)
800883

providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,143 @@ def test_empty_sql_parameter(self):
10431043
with pytest.raises(ValueError, match="List of SQL statements is empty"):
10441044
hook.run(sql=empty_statement)
10451045

1046+
@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
1047+
def test_run_multi_statement_with_split_statements_false(self, mock_conn):
1048+
"""When split_statements=False, cursor.execute() receives num_statements=0."""
1049+
hook = SnowflakeHook()
1050+
conn = mock_conn.return_value
1051+
cur = mock.MagicMock(rowcount=0)
1052+
conn.cursor.return_value = cur
1053+
type(cur).sfqid = mock.PropertyMock(return_value="multi_query_id")
1054+
1055+
sql = "BEGIN; CREATE TABLE t(id INT); INSERT INTO t VALUES(1); COMMIT;"
1056+
hook.run(sql, split_statements=False)
1057+
1058+
# Entire SQL block sent as one execute with num_statements=0
1059+
cur.execute.assert_called_once_with(
1060+
"BEGIN; CREATE TABLE t(id INT); INSERT INTO t VALUES(1); COMMIT",
1061+
num_statements=0,
1062+
)
1063+
assert hook.query_ids == ["multi_query_id"]
1064+
1065+
@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
1066+
def test_run_split_statements_true_does_not_pass_num_statements(self, mock_conn):
1067+
"""When split_statements=True, cursor.execute() does not receive num_statements."""
1068+
hook = SnowflakeHook()
1069+
conn = mock_conn.return_value
1070+
cur = mock.MagicMock(rowcount=0)
1071+
conn.cursor.return_value = cur
1072+
type(cur).sfqid = mock.PropertyMock(side_effect=["id1", "id2"])
1073+
1074+
hook.run("SELECT 1; SELECT 2", split_statements=True)
1075+
1076+
assert cur.execute.call_count == 2
1077+
for call in cur.execute.call_args_list:
1078+
assert "num_statements" not in call.kwargs
1079+
1080+
@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
1081+
def test_run_sql_list_does_not_pass_num_statements(self, mock_conn):
1082+
"""When sql is a list, cursor.execute() does not receive num_statements."""
1083+
hook = SnowflakeHook()
1084+
conn = mock_conn.return_value
1085+
cur = mock.MagicMock(rowcount=0)
1086+
conn.cursor.return_value = cur
1087+
type(cur).sfqid = mock.PropertyMock(side_effect=["id1", "id2"])
1088+
1089+
hook.run(["SELECT 1;", "SELECT 2;"])
1090+
1091+
assert cur.execute.call_count == 2
1092+
for call in cur.execute.call_args_list:
1093+
assert "num_statements" not in call.kwargs
1094+
1095+
@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
1096+
def test_run_respects_autocommit_session_parameter(self, mock_conn):
1097+
"""When session_parameters has AUTOCOMMIT, set_autocommit is skipped and no commit."""
1098+
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
1099+
connection_kwargs["extra"]["session_parameters"] = {"AUTOCOMMIT": True}
1100+
with mock.patch.dict(
1101+
"os.environ",
1102+
AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**connection_kwargs).get_uri(),
1103+
):
1104+
hook = SnowflakeHook()
1105+
conn = mock_conn.return_value
1106+
cur = mock.MagicMock(rowcount=0)
1107+
conn.cursor.return_value = cur
1108+
type(cur).sfqid = mock.PropertyMock(return_value="qid")
1109+
1110+
hook.run("SELECT 1", autocommit=False)
1111+
1112+
# set_autocommit should NOT have been called
1113+
conn.autocommit.assert_not_called()
1114+
# No manual commit since AUTOCOMMIT session param is in effect
1115+
conn.commit.assert_not_called()
1116+
1117+
@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
1118+
def test_run_respects_autocommit_session_parameter_case_insensitive(self, mock_conn):
1119+
"""AUTOCOMMIT check is case-insensitive (Snowflake params are case-insensitive)."""
1120+
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
1121+
connection_kwargs["extra"]["session_parameters"] = {"autocommit": True}
1122+
with mock.patch.dict(
1123+
"os.environ",
1124+
AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**connection_kwargs).get_uri(),
1125+
):
1126+
hook = SnowflakeHook()
1127+
conn = mock_conn.return_value
1128+
cur = mock.MagicMock(rowcount=0)
1129+
conn.cursor.return_value = cur
1130+
type(cur).sfqid = mock.PropertyMock(return_value="qid")
1131+
1132+
hook.run("SELECT 1", autocommit=False)
1133+
1134+
conn.autocommit.assert_not_called()
1135+
conn.commit.assert_not_called()
1136+
1137+
@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
1138+
def test_run_respects_autocommit_from_hook_session_parameters(self, mock_conn):
1139+
"""AUTOCOMMIT from hook constructor session_parameters is respected."""
1140+
hook = SnowflakeHook(session_parameters={"AUTOCOMMIT": True})
1141+
conn = mock_conn.return_value
1142+
cur = mock.MagicMock(rowcount=0)
1143+
conn.cursor.return_value = cur
1144+
type(cur).sfqid = mock.PropertyMock(return_value="qid")
1145+
1146+
hook.run("SELECT 1", autocommit=False)
1147+
1148+
conn.autocommit.assert_not_called()
1149+
conn.commit.assert_not_called()
1150+
1151+
@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
1152+
def test_run_explicit_autocommit_true_overrides_session_parameter(self, mock_conn):
1153+
"""When autocommit=True is explicit, it overrides session_parameters."""
1154+
connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS)
1155+
connection_kwargs["extra"]["session_parameters"] = {"AUTOCOMMIT": False}
1156+
with mock.patch.dict(
1157+
"os.environ",
1158+
AIRFLOW_CONN_SNOWFLAKE_DEFAULT=Connection(**connection_kwargs).get_uri(),
1159+
):
1160+
hook = SnowflakeHook()
1161+
conn = mock_conn.return_value
1162+
cur = mock.MagicMock(rowcount=0)
1163+
conn.cursor.return_value = cur
1164+
type(cur).sfqid = mock.PropertyMock(return_value="qid")
1165+
1166+
hook.run("SELECT 1", autocommit=True)
1167+
1168+
conn.autocommit.assert_called_once_with(True)
1169+
1170+
@mock.patch("airflow.providers.snowflake.hooks.snowflake.SnowflakeHook.get_conn")
1171+
def test_run_default_autocommit_without_session_parameter(self, mock_conn):
1172+
"""Without AUTOCOMMIT in session_parameters, default (False) is applied."""
1173+
hook = SnowflakeHook()
1174+
conn = mock_conn.return_value
1175+
cur = mock.MagicMock(rowcount=0)
1176+
conn.cursor.return_value = cur
1177+
type(cur).sfqid = mock.PropertyMock(return_value="qid")
1178+
1179+
hook.run("SELECT 1")
1180+
1181+
conn.autocommit.assert_called_once_with(False)
1182+
10461183
def test_get_openlineage_default_schema_with_no_schema_set(self):
10471184
connection_kwargs = {
10481185
**BASE_CONNECTION_KWARGS,

0 commit comments

Comments
 (0)