Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class TriggerDAGRunPayload(StrictBaseModel):
"""Schema for Trigger DAG Run API request."""

logical_date: UtcDateTime | None = None
run_after: UtcDateTime | None = None
conf: dict = Field(default_factory=dict)
reset_dag_run: bool = False
partition_key: str | None = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def trigger_dag_run(
run_type=DagRunType.OPERATOR_TRIGGERED,
conf=payload.conf,
logical_date=payload.logical_date,
run_after=payload.run_after,
triggered_by=DagRunTriggeredByType.OPERATOR,
replace_microseconds=False,
partition_key=payload.partition_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
AddDagRunDetailEndpoint,
AddNoteField,
AddPartitionKeyField,
AddRunAfterField,
MakeDagRunStartDateNullable,
ModifyDeferredTaskKwargsToJsonValue,
MovePreviousRunEndpoint,
Expand All @@ -50,6 +51,7 @@
ModifyDeferredTaskKwargsToJsonValue,
RemoveUpstreamMapIndexesField,
AddNoteField,
AddRunAfterField,
AddDagEndpoint,
),
Version("2025-11-05", AddTriggeringUserNameField),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ class AddNoteField(VersionChange):

description = __doc__

instructions_to_migrate_to_previous_version = (schema(DagRun).field("note").didnt_exist,)
instructions_to_migrate_to_previous_version = (
schema(DagRun).field("note").didnt_exist,
schema(TriggerDAGRunPayload).field("note").didnt_exist,
)

@convert_response_to_previous_version_for(TIRunContext) # type: ignore[arg-type]
def remove_note_field(response: ResponseInfo) -> None: # type: ignore[misc]
Expand All @@ -184,3 +187,13 @@ class AddDagEndpoint(VersionChange):
description = __doc__

instructions_to_migrate_to_previous_version = (endpoint("/dags/{dag_id}", ["GET"]).didnt_exist,)


class AddRunAfterField(VersionChange):
"""Add run_after parameter to TriggerDAGRunPayload Model."""

description = __doc__

instructions_to_migrate_to_previous_version = (
schema(TriggerDAGRunPayload).field("run_after").didnt_exist,
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import time
from collections.abc import Sequence
from json import JSONDecodeError
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast, overload

from sqlalchemy import select
from sqlalchemy.orm.exc import NoResultFound
Expand All @@ -43,7 +43,7 @@
)
from airflow.providers.standard.triggers.external_task import DagStateTrigger
from airflow.providers.standard.utils.openlineage import safe_inject_openlineage_properties_into_dagrun_conf
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator
from airflow.providers.standard.version_compat import AIRFLOW_V_3_0_PLUS, BaseOperator, is_arg_set
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType

Expand Down Expand Up @@ -127,6 +127,7 @@ class TriggerDagRunOperator(BaseOperator):
If not provided, a run ID will be automatically generated.
:param conf: Configuration for the DAG run (templated).
:param logical_date: Logical date for the triggered DAG (templated).
:param run_after: The date before which the triggered DAG should not run.
:param reset_dag_run: Whether clear existing DAG run if already exists.
This is useful when backfill or rerun an existing DAG run.
This only resets (not recreates) the DAG run.
Expand Down Expand Up @@ -162,6 +163,11 @@ class TriggerDagRunOperator(BaseOperator):
"wait_for_completion",
"skip_when_already_exists",
)

attributes_not_supported_in_airflow_2 = {
"run_after": NOTSET,
"note": None,
}
template_fields_renderers = {"conf": "py"}
ui_color = "#ffefeb"
operator_extra_links = [TriggerDagRunLink()]
Expand All @@ -173,6 +179,7 @@ def __init__(
trigger_run_id: str | None = None,
conf: dict | None = None,
logical_date: str | datetime.datetime | None | ArgNotSet = NOTSET,
run_after: str | datetime.datetime | None | ArgNotSet = NOTSET,
reset_dag_run: bool = False,
wait_for_completion: bool = False,
poke_interval: int = 60,
Expand Down Expand Up @@ -205,27 +212,29 @@ def __init__(
self.openlineage_inject_parent_info = openlineage_inject_parent_info
self.note = note
self.deferrable = deferrable
logical_date = _validate_datetime_param("logical_date", logical_date)
run_after = _validate_datetime_param("run_after", run_after)
self.logical_date = logical_date
if logical_date is NOTSET:
self.logical_date = NOTSET
elif logical_date is None or isinstance(logical_date, (str, datetime.datetime)):
self.logical_date = logical_date
else:
raise TypeError(
f"Expected str, datetime.datetime, or None for parameter 'logical_date'. Got {type(logical_date).__name__}"
)

self.run_after = run_after
if fail_when_dag_is_paused and AIRFLOW_V_3_0_PLUS:
raise NotImplementedError("Setting `fail_when_dag_is_paused` not yet supported for Airflow 3.x")

def execute(self, context: Context):
if self.logical_date is NOTSET:
# If no logical_date is provided we will set utcnow()
parsed_logical_date = timezone.utcnow()
elif self.logical_date is None or isinstance(self.logical_date, datetime.datetime):
parsed_logical_date = self.logical_date # type: ignore
elif isinstance(self.logical_date, str):
parsed_logical_date = timezone.parse(self.logical_date)
if self.run_after is not NOTSET:
parsed_logical_date = None
else:
# If no logical_date is provided we will set utcnow()
parsed_logical_date = timezone.utcnow()
else:
logical_date = cast("str | datetime.datetime | None", self.logical_date)
parsed_logical_date = _parse_datetime_param(logical_date)

if self.run_after is NOTSET:
parsed_run_after = parsed_logical_date
else:
run_after = cast("str | datetime.datetime | None", self.run_after)
parsed_run_after = _parse_datetime_param(run_after)

try:
if self.conf and isinstance(self.conf, str):
Expand All @@ -247,7 +256,7 @@ def execute(self, context: Context):
run_id = DagRun.generate_run_id(
run_type=DagRunType.MANUAL,
logical_date=parsed_logical_date,
run_after=parsed_logical_date or timezone.utcnow(),
run_after=parsed_run_after or timezone.utcnow(),
)
else:
run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_logical_date or timezone.utcnow()) # type: ignore[misc,call-arg]
Expand All @@ -267,14 +276,17 @@ def execute(self, context: Context):

if AIRFLOW_V_3_0_PLUS:
self._trigger_dag_af_3(
context=context, run_id=self.trigger_run_id, parsed_logical_date=parsed_logical_date
context=context,
run_id=self.trigger_run_id,
parsed_logical_date=parsed_logical_date,
parsed_run_after=parsed_run_after if self.run_after is not NOTSET else None,
)
else:
self._trigger_dag_af_2(
context=context, run_id=self.trigger_run_id, parsed_logical_date=parsed_logical_date
)

def _trigger_dag_af_3(self, context, run_id, parsed_logical_date):
def _trigger_dag_af_3(self, context, run_id, parsed_logical_date, parsed_run_after=None):
from airflow.providers.common.compat.sdk import DagRunTriggerException

kwargs_accepted = dict(
Expand All @@ -291,16 +303,28 @@ def _trigger_dag_af_3(self, context, run_id, parsed_logical_date):
deferrable=self.deferrable,
)

if self.note and "note" in inspect.signature(DagRunTriggerException.__init__).parameters:
parameters = inspect.signature(DagRunTriggerException.__init__).parameters
if self.note and "note" in parameters:
kwargs_accepted["note"] = self.note

if parsed_run_after and "run_after" in parameters:
kwargs_accepted["run_after"] = parsed_run_after

raise DagRunTriggerException(**kwargs_accepted)

def _trigger_dag_af_2(self, context, run_id, parsed_logical_date):
try:
if self.note:
self.log.warning("Parameter 'note' is not supported in Airflow 2.x and will be ignored.")

unsupported_parameters = []
for attr, default_value in self.attributes_not_supported_in_airflow_2.items():
value = getattr(self, attr, default_value)
if value is not default_value:
unsupported_parameters.append(attr)

if unsupported_parameters:
self.log.warning(
"The following parameters are not supported in Airflow 2.x and will be ignored: %s",
", ".join(unsupported_parameters),
)
dag_run = trigger_dag(
dag_id=self.trigger_dag_id,
run_id=run_id,
Expand Down Expand Up @@ -453,3 +477,42 @@ def _trigger_dag_run_af_2_execute_complete(
f"{self.trigger_dag_id} return {state} which is not in {self.failed_states}"
f" or {self.allowed_states}"
)


@overload
def _validate_datetime_param(name: str, value: ArgNotSet) -> ArgNotSet: ...
@overload
def _validate_datetime_param(name: str, value: None) -> None: ...
@overload
def _validate_datetime_param(name: str, value: str) -> str: ...
@overload
def _validate_datetime_param(name: str, value: datetime.datetime) -> datetime.datetime: ...


def _validate_datetime_param(
name: str,
value: str | datetime.datetime | None | ArgNotSet,
) -> str | datetime.datetime | None | ArgNotSet:
if not is_arg_set(value):
return NOTSET
if value is None or isinstance(value, (str, datetime.datetime)):
return value
raise TypeError(
f"Expected str, datetime.datetime, or None for parameter '{name}'. Got {type(value).__name__}"
)


@overload
def _parse_datetime_param(value: None) -> None: ...
@overload
def _parse_datetime_param(value: datetime.datetime) -> datetime.datetime: ...
@overload
def _parse_datetime_param(value: str) -> datetime.datetime: ...


def _parse_datetime_param(
value: str | datetime.datetime | None,
) -> datetime.datetime | None:
if value is None or isinstance(value, datetime.datetime):
return value
return timezone.parse(value)
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,24 @@ def get_base_airflow_version_tuple() -> tuple[int, int, int]:
# This is needed for DecoratedOperator compatibility
if AIRFLOW_V_3_1_PLUS:
from airflow.sdk import BaseOperator
from airflow.sdk.definitions._internal.types import ArgNotSet
from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet
else:
from airflow.models.baseoperator import BaseOperator # type: ignore[no-redef]
from airflow.utils.types import ArgNotSet # type: ignore[attr-defined,no-redef]
from airflow.utils.types import NOTSET, ArgNotSet # type: ignore[attr-defined,no-redef]

try:
from airflow.sdk.definitions._internal.types import is_arg_set
except ImportError:

def is_arg_set(value): # type: ignore[misc,no-redef]
return value is not NOTSET


__all__ = [
"AIRFLOW_V_3_0_PLUS",
"AIRFLOW_V_3_1_PLUS",
"AIRFLOW_V_3_2_PLUS",
"ArgNotSet",
"BaseOperator",
"is_arg_set",
]
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,44 @@ def teardown_method(self):
session.execute(delete(DagBundleModel))
session.commit()

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
def test_trigger_dagrun_with_run_after(self):
"""
Test TriggerDagRunOperator.

We only verify that the operator runs and raises correct exception. The actual execution logic
after the exception is in Task SDK code.
"""
with time_machine.travel("2025-02-18T08:04:46Z", tick=False):
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
conf={"foo": "bar"},
run_after=timezone.datetime(2025, 2, 19, 12, 0, 0),
)

# Ensure correct exception is raised
with pytest.raises(DagRunTriggerException) as exc_info:
task.execute(context={})

assert exc_info.value.trigger_dag_id == TRIGGERED_DAG_ID
assert exc_info.value.conf == {"foo": "bar"}
assert exc_info.value.logical_date is None
assert exc_info.value.reset_dag_run is False
assert exc_info.value.skip_when_already_exists is False
assert exc_info.value.wait_for_completion is False
assert exc_info.value.allowed_states == [DagRunState.SUCCESS]
assert exc_info.value.failed_states == [DagRunState.FAILED]
if getattr(exc_info, "note", None) is not None:
assert exc_info.value.note == "Test note"

expected_run_id = DagRun.generate_run_id(
run_type=DagRunType.MANUAL, run_after=task.run_after
).rsplit("_", 1)[0]
# rsplit because last few characters are random.
assert exc_info.value.dag_run_id.rsplit("_", 1)[0] == expected_run_id
assert task.trigger_run_id.rsplit("_", 1)[0] == expected_run_id # run_id is saved as attribute

@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is different for Airflow 2 & 3")
def test_trigger_dagrun(self):
"""
Expand Down Expand Up @@ -557,7 +595,7 @@ def test_trigger_dagrun(self, dag_maker, mock_supervisor_comms):
with time_machine.travel("2025-02-18T08:04:46Z", tick=False):
with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, note="Test note"
task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, note="Test note", run_after=None
)
mock_warning = mock.patch.object(task.log, "warning").start()
dag_maker.sync_dagbag_to_db()
Expand All @@ -566,12 +604,31 @@ def test_trigger_dagrun(self, dag_maker, mock_supervisor_comms):
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

dagrun = dag_maker.session.scalar(select(DagRun).where(DagRun.dag_id == TRIGGERED_DAG_ID))
unsupported_params = ["run_after", "note"]
assert mock_warning.mock_calls == [
mock.call("Parameter 'note' is not supported in Airflow 2.x and will be ignored.")
mock.call(
"The following parameters are not supported in Airflow 2.x and will be ignored: %s",
", ".join(unsupported_params),
)
]
assert dagrun.run_type == DagRunType.MANUAL
assert dagrun.run_id == DagRun.generate_run_id(DagRunType.MANUAL, dagrun.logical_date)

def test_trigger_dagrun_does_not_warn_for_default_unsupported_params(
self, dag_maker, mock_supervisor_comms
):
"""Test TriggerDagRunOperator does not warn for unsupported params when they are not provided."""
with time_machine.travel("2025-02-18T08:04:46Z", tick=False):
with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID)
mock_warning = mock.patch.object(task.log, "warning").start()
dag_maker.sync_dagbag_to_db()
parse_and_sync_to_db(self.f_name)
dag_maker.create_dagrun()
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

mock_warning.assert_not_called()

def test_explicitly_provided_trigger_run_id_is_saved_as_attr(self, dag_maker, session):
with dag_maker(TEST_DAG_ID, default_args={"start_date": DEFAULT_DATE}, serialized=True):
task = TriggerDagRunOperator(
Expand Down
7 changes: 6 additions & 1 deletion task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,17 @@ def trigger(
run_id: str,
conf: dict | None = None,
logical_date: datetime | None = None,
run_after: datetime | None = None,
reset_dag_run: bool = False,
note: str | None = None,
) -> OKResponse | ErrorResponse:
"""Trigger a Dag run via the API server."""
body = TriggerDAGRunPayload(
logical_date=logical_date, conf=conf or {}, reset_dag_run=reset_dag_run, note=note
logical_date=logical_date,
conf=conf or {},
reset_dag_run=reset_dag_run,
note=note,
run_after=run_after,
)

try:
Expand Down
1 change: 1 addition & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ class TriggerDAGRunPayload(BaseModel):
extra="forbid",
)
logical_date: Annotated[AwareDatetime | None, Field(title="Logical Date")] = None
run_after: Annotated[AwareDatetime | None, Field(title="Run After")] = None
conf: Annotated[dict[str, Any] | None, Field(title="Conf")] = None
reset_dag_run: Annotated[bool | None, Field(title="Reset Dag Run")] = False
partition_key: Annotated[str | None, Field(title="Partition Key")] = None
Expand Down
Loading
Loading