diff --git a/task-sdk/src/airflow/sdk/bases/decorator.py b/task-sdk/src/airflow/sdk/bases/decorator.py index 2e98f826c7ea1..8634fbe99647c 100644 --- a/task-sdk/src/airflow/sdk/bases/decorator.py +++ b/task-sdk/src/airflow/sdk/bases/decorator.py @@ -30,6 +30,7 @@ from airflow.sdk import TriggerRule, timezone from airflow.sdk.bases.operator import ( + BASEOPERATOR_ARGS_EXPECTED_TYPES, BaseOperator, coerce_resources, coerce_timedelta, @@ -498,6 +499,13 @@ def __attrs_post_init__(self): if "self" in self.function_signature.parameters: raise TypeError(f"@{self.decorator_name} does not support methods") self.kwargs.setdefault("task_id", self.function.__name__) + for arg_name, expected_type in BASEOPERATOR_ARGS_EXPECTED_TYPES.items(): + if arg_name in self.kwargs: + value = self.kwargs[arg_name] + if value is not None and not isinstance(value, expected_type): + raise TypeError( + f"Expected {arg_name!r} to be {expected_type}, got {type(value).__name__}: {value!r}" + ) update_wrapper(self, self.function) def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg: diff --git a/task-sdk/tests/task_sdk/bases/test_decorator.py b/task-sdk/tests/task_sdk/bases/test_decorator.py index 6021f6f0aad41..860eb6e7b3312 100644 --- a/task-sdk/tests/task_sdk/bases/test_decorator.py +++ b/task-sdk/tests/task_sdk/bases/test_decorator.py @@ -197,6 +197,39 @@ def dummy_task(a, b=1, *args, kw_required, **kwargs): assert make_op(dummy_task, op_kwargs={"a": 1, "kw_required": "x"}) is not None + @pytest.mark.parametrize( + ("kwargs", "match"), + [ + pytest.param( + {"task_id": "fetch_{}".format}, + "Expected 'task_id' to be .*, got builtin_function_or_method", + id="task_id_bound_method", + ), + pytest.param( + {"retries": "three"}, + "Expected 'retries' to be .*, got str", + id="retries_string", + ), + pytest.param( + {"queue": 42}, + "Expected 'queue' to be .*, got int", + id="queue_int", + ), + pytest.param( + {"priority_weight": 1.5}, + "Expected 'priority_weight' to be .*, got float", + id="priority_weight_float", + ), + ], + ) + def test_wrong_arg_type_raises_type_error_at_decoration_time(self, kwargs, match): + """Non-matching types for operator kwargs raise TypeError at decoration time.""" + with pytest.raises(TypeError, match=match): + + @task(**kwargs) + def my_task(): + return 1 + def make_op(func, op_args=None, op_kwargs=None, **kwargs): return DummyDecoratedOperator(