Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import TYPE_CHECKING, Any

from airflow.providers.common.compat.notifier import BaseNotifier
from airflow.providers.common.compat.sdk import conf
from airflow.providers.smtp.hooks.smtp import SmtpHook
from airflow.providers.smtp.version_compat import AIRFLOW_V_3_1_PLUS

Expand Down Expand Up @@ -80,7 +81,7 @@ def __init__(
mime_subtype: str = "mixed",
mime_charset: str = "utf-8",
custom_headers: dict[str, Any] | None = None,
smtp_conn_id: str = SmtpHook.default_conn_name,
smtp_conn_id: str = conf.get("email", "email_conn_id", fallback=SmtpHook.default_conn_name),
auth_type: str = "basic",
*,
template: str | None = None,
Expand Down
34 changes: 34 additions & 0 deletions providers/smtp/tests/unit/smtp/notifications/test_smtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,40 @@ def test_notifier_with_nondefault_connection_extra(
**DEFAULT_EMAIL_PARAMS,
)

@mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook")
def test_notifier_with_custom_smtp_conn_id(self, mock_smtphook_hook, create_dag_without_db):
"""Test that a custom smtp_conn_id is correctly passed to SmtpHook."""
custom_conn_id = "my_custom_smtp"
notifier = SmtpNotifier(**NOTIFIER_DEFAULT_PARAMS, smtp_conn_id=custom_conn_id)
notifier({"dag": create_dag_without_db(TEST_DAG_ID)})
mock_smtphook_hook.assert_called_once_with(smtp_conn_id=custom_conn_id, auth_type="basic")
mock_smtphook_hook.return_value.__enter__().send_email_smtp.assert_called_once_with(
**NOTIFIER_DEFAULT_PARAMS,
smtp_conn_id=custom_conn_id,
**DEFAULT_EMAIL_PARAMS,
)

def test_notifier_default_smtp_conn_id_from_config(self):
"""Test that smtp_conn_id defaults to email.email_conn_id from config."""
import importlib

import airflow.providers.smtp.notifications.smtp as smtp_mod

from tests_common.test_utils.config import conf_vars

with conf_vars({("email", "email_conn_id"): "config_smtp_conn"}):
importlib.reload(smtp_mod)
try:
notifier = smtp_mod.SmtpNotifier(
to=TEST_RECEIVER,
from_email=TEST_SENDER,
subject=TEST_SUBJECT,
html_content=TEST_BODY,
)
assert notifier.smtp_conn_id == "config_smtp_conn"
finally:
importlib.reload(smtp_mod)

@mock.patch("airflow.providers.smtp.notifications.smtp.SmtpHook")
def test_notifier_oauth2_passes_auth_type(self, mock_smtphook_hook, create_dag_without_db):
notifier = SmtpNotifier(**NOTIFIER_DEFAULT_PARAMS, auth_type=SMTP_AUTH_TYPE)
Expand Down
2 changes: 2 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,7 @@ def _send_error_email_notification(
) -> None:
"""Send email notification for task errors using SmtpNotifier."""
try:
from airflow.providers.smtp.hooks.smtp import SmtpHook
from airflow.providers.smtp.notifications.smtp import SmtpNotifier
except ImportError:
log.error(
Expand Down Expand Up @@ -1624,6 +1625,7 @@ def _send_error_email_notification(
subject=subject,
html_content=html_content,
from_email=conf.get("email", "from_email", fallback="airflow@airflow"),
smtp_conn_id=conf.get("email", "email_conn_id", fallback=SmtpHook.default_conn_name),
)
notifier(email_context)
except Exception:
Expand Down
41 changes: 41 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3456,6 +3456,47 @@ def execute(self, context):
)
assert kwargs["from_email"] == self.FROM

@pytest.mark.parametrize(
("email_conn_id", "expected_conn_id"),
[
pytest.param("custom_smtp_conn", "custom_smtp_conn", id="custom-conn-id"),
pytest.param(None, "smtp_default", id="default-conn-id"),
],
)
def test_email_smtp_conn_id(
self, email_conn_id, expected_conn_id, create_runtime_ti, mock_supervisor_comms
):
"""Test that smtp_conn_id is passed to SmtpNotifier from email config."""
from airflow.sdk.exceptions import AirflowFailException
from airflow.sdk.execution_time.task_runner import finalize, run

class FailingOperator(BaseOperator):
def execute(self, context):
raise AirflowFailException("Task failed for conn_id test")

task = FailingOperator(
task_id="conn_id_test_task",
email="test@example.com",
email_on_failure=True,
)

runtime_ti = create_runtime_ti(task=task)
context = runtime_ti.get_template_context()
log = mock.MagicMock()

conf_overrides = {("email", "from_email"): self.FROM}
if email_conn_id is not None:
conf_overrides[("email", "email_conn_id")] = email_conn_id

with conf_vars(conf_overrides):
with mock.patch("airflow.providers.smtp.notifications.smtp.SmtpNotifier") as mock_smtp_notifier:
state, _, error = run(runtime_ti, context, log)
finalize(runtime_ti, state, context, log, error)

mock_smtp_notifier.assert_called_once()
kwargs = mock_smtp_notifier.call_args.kwargs
assert kwargs["smtp_conn_id"] == expected_conn_id

@pytest.mark.enable_redact
def test_rendered_templates_mask_secrets(self, create_runtime_ti, mock_supervisor_comms):
"""Test that secrets registered with mask_secret() are redacted in rendered template fields."""
Expand Down
Loading