diff --git a/airflow-ctl/src/airflowctl/ctl/cli_config.py b/airflow-ctl/src/airflowctl/ctl/cli_config.py index 5f17c60335734..d561567ebda8c 100755 --- a/airflow-ctl/src/airflowctl/ctl/cli_config.py +++ b/airflow-ctl/src/airflowctl/ctl/cli_config.py @@ -30,7 +30,7 @@ from enum import Enum from functools import partial from pathlib import Path -from typing import Any, NamedTuple +from typing import Any, NamedTuple, cast import httpx import rich @@ -469,6 +469,11 @@ def _is_primitive_type(type_name: str) -> bool: base_type = type_name.replace(" | None", "").strip() return base_type in primitive_types + @staticmethod + def _is_optional_type(type_name: str) -> bool: + normalized = str(type_name).replace("typing.", "").strip() + return " | None" in normalized or normalized.startswith("Optional[") + @staticmethod def _python_type_from_string(type_name: str | type) -> type | Callable: """ @@ -627,34 +632,39 @@ def _get_func(args: Namespace, api_operation: dict, api_client: Client = NEW_API operation_method_object = getattr(operation_class, api_operation["name"]) # Walk through all args and create a dictionary such as args.abc -> {"abc": "value"} - method_params = {} + method_params: dict[str, Any] = {} datamodel = None datamodel_param_name = None args_dict = vars(args) for parameter in api_operation["parameters"]: for parameter_key, parameter_type in parameter.items(): if self._is_primitive_type(type_name=parameter_type): - method_params[self._sanitize_method_param_key(parameter_key)] = args_dict[ - parameter_key - ] + val = args_dict.get(parameter_key) + is_optional_primitive = self._is_optional_type(parameter_type) + if val is not None or not is_optional_primitive: + method_params[self._sanitize_method_param_key(parameter_key)] = val else: datamodel = getattr(generated_datamodels, parameter_type) for expanded_parameter in self.datamodels_extended_map[parameter_type]: - if parameter_key not in method_params: + if parameter_key not in method_params or not isinstance( + method_params[parameter_key], dict + ): method_params[parameter_key] = {} datamodel_param_name = parameter_key if expanded_parameter in self.excluded_parameters: continue if expanded_parameter in args_dict.keys(): - method_params[parameter_key][ - self._sanitize_method_param_key(expanded_parameter) - ] = args_dict[expanded_parameter] + datamodel_params = cast("dict[str, Any]", method_params[parameter_key]) + datamodel_params[self._sanitize_method_param_key(expanded_parameter)] = ( + args_dict[expanded_parameter] + ) if datamodel: if datamodel_param_name: + datamodel_params = cast("dict[str, Any]", method_params[datamodel_param_name]) # Apply datamodel-specific defaults (e.g., logical_date for TriggerDAGRunPostBody) method_params[datamodel_param_name] = self._apply_datamodel_defaults( - datamodel, method_params[datamodel_param_name] + datamodel, datamodel_params ) method_params[datamodel_param_name] = datamodel.model_validate( method_params[datamodel_param_name] diff --git a/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py b/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py index 117f874c34651..ff012b0b88a38 100644 --- a/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py +++ b/airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py @@ -20,6 +20,7 @@ import argparse from argparse import BooleanOptionalAction from textwrap import dedent +from types import SimpleNamespace import pytest @@ -289,6 +290,66 @@ def delete(self, backfill_id: str) -> ServerResponseError | None: class TestCliConfigMethods: + @staticmethod + def _run_list_command_with_dag_id( + monkeypatch, + *, + dag_id_value, + dag_id_param_type: str, + method_has_default: bool, + ): + import airflowctl.api.operations as operations + + captured: dict[str, object] = {} + dummy_operations_cls: type[object] + + if method_has_default: + + class DummyOperationsWithDefault: + def __init__(self, client): + self.client = client + + def list(self, limit: int, dag_id: str = "default-dag"): + captured["limit"] = limit + captured["dag_id"] = dag_id + return {"items": []} + + dummy_operations_cls = DummyOperationsWithDefault + + else: + + class DummyOperationsNoDefault: + def __init__(self, client): + self.client = client + + def list(self, limit: int, dag_id: str): + captured["limit"] = limit + captured["dag_id"] = dag_id + return {"items": []} + + dummy_operations_cls = DummyOperationsNoDefault + + monkeypatch.setattr(operations, "DummyOperations", dummy_operations_cls, raising=False) + monkeypatch.setattr( + "airflowctl.ctl.cli_config.AirflowConsole.print_as", + lambda self, data, output: None, + ) + + command_factory = CommandFactory() + command_factory.operations = [ + { + "name": "list", + "parameters": [{"limit": "int"}, {"dag_id": dag_id_param_type}], + "return_type": "dict", + "parent": SimpleNamespace(name="DummyOperations"), + } + ] + + command_factory._create_func_map_from_operation() + generated_func = command_factory.func_map[("list", "DummyOperations")] + generated_func(argparse.Namespace(limit=10, dag_id=dag_id_value, output="json"), api_client=object()) + return captured + def test_add_to_parser_drops_type_for_boolean_optional_action(self): """Test add_to_parser removes type for BooleanOptionalAction.""" parser = argparse.ArgumentParser() @@ -554,3 +615,36 @@ def test_apply_datamodel_defaults_other_datamodel(self): # Should return params unchanged for other datamodels assert result == params, "Params should be unchanged for non-TriggerDAGRunPostBody datamodels" + + @pytest.mark.parametrize( + ("dag_id_value", "expected_dag_id"), + [ + (None, "default-dag"), + ("manual-dag", "manual-dag"), + ], + ) + def test_create_func_map_handles_optional_primitive_params( + self, monkeypatch, dag_id_value, expected_dag_id + ): + """Test optional primitive params are skipped when None and passed when set.""" + captured = self._run_list_command_with_dag_id( + monkeypatch, + dag_id_value=dag_id_value, + dag_id_param_type="str | None", + method_has_default=True, + ) + + assert captured["limit"] == 10 + assert captured["dag_id"] == expected_dag_id + + def test_create_func_map_keeps_none_for_required_primitive_params(self, monkeypatch): + """Test required primitive params are passed even when parsed value is None.""" + captured = self._run_list_command_with_dag_id( + monkeypatch, + dag_id_value=None, + dag_id_param_type="str", + method_has_default=False, + ) + + assert captured["limit"] == 10 + assert captured["dag_id"] is None