Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions task-sdk/src/airflow/sdk/bases/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from airflow.sdk import TriggerRule, timezone
from airflow.sdk.bases.operator import (
BASEOPERATOR_ARGS_EXPECTED_TYPES,
BaseOperator,
coerce_resources,
coerce_timedelta,
Expand Down Expand Up @@ -498,6 +499,13 @@ def __attrs_post_init__(self):
if "self" in self.function_signature.parameters:
raise TypeError(f"@{self.decorator_name} does not support methods")
self.kwargs.setdefault("task_id", self.function.__name__)
for arg_name, expected_type in BASEOPERATOR_ARGS_EXPECTED_TYPES.items():
if arg_name in self.kwargs:
value = self.kwargs[arg_name]
if value is not None and not isinstance(value, expected_type):
raise TypeError(
f"Expected {arg_name!r} to be {expected_type}, got {type(value).__name__}: {value!r}"
)
update_wrapper(self, self.function)

def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg:
Expand Down
33 changes: 33 additions & 0 deletions task-sdk/tests/task_sdk/bases/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,39 @@ def dummy_task(a, b=1, *args, kw_required, **kwargs):

assert make_op(dummy_task, op_kwargs={"a": 1, "kw_required": "x"}) is not None

@pytest.mark.parametrize(
("kwargs", "match"),
[
pytest.param(
{"task_id": "fetch_{}".format},
"Expected 'task_id' to be .*, got builtin_function_or_method",
id="task_id_bound_method",
),
pytest.param(
{"retries": "three"},
"Expected 'retries' to be .*, got str",
id="retries_string",
),
pytest.param(
{"queue": 42},
"Expected 'queue' to be .*, got int",
id="queue_int",
),
pytest.param(
{"priority_weight": 1.5},
"Expected 'priority_weight' to be .*, got float",
id="priority_weight_float",
),
],
)
def test_wrong_arg_type_raises_type_error_at_decoration_time(self, kwargs, match):
"""Non-matching types for operator kwargs raise TypeError at decoration time."""
with pytest.raises(TypeError, match=match):

@task(**kwargs)
def my_task():
return 1


def make_op(func, op_args=None, op_kwargs=None, **kwargs):
return DummyDecoratedOperator(
Expand Down
Loading