Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions airflow-ctl/src/airflowctl/ctl/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from enum import Enum
from functools import partial
from pathlib import Path
from typing import Any, NamedTuple
from typing import Any, NamedTuple, cast

import httpx
import rich
Expand Down Expand Up @@ -469,6 +469,11 @@ def _is_primitive_type(type_name: str) -> bool:
base_type = type_name.replace(" | None", "").strip()
return base_type in primitive_types

@staticmethod
def _is_optional_type(type_name: str) -> bool:
normalized = str(type_name).replace("typing.", "").strip()
return " | None" in normalized or normalized.startswith("Optional[")

@staticmethod
def _python_type_from_string(type_name: str | type) -> type | Callable:
"""
Expand Down Expand Up @@ -627,34 +632,39 @@ def _get_func(args: Namespace, api_operation: dict, api_client: Client = NEW_API
operation_method_object = getattr(operation_class, api_operation["name"])

# Walk through all args and create a dictionary such as args.abc -> {"abc": "value"}
method_params = {}
method_params: dict[str, Any] = {}
datamodel = None
datamodel_param_name = None
args_dict = vars(args)
for parameter in api_operation["parameters"]:
for parameter_key, parameter_type in parameter.items():
if self._is_primitive_type(type_name=parameter_type):
method_params[self._sanitize_method_param_key(parameter_key)] = args_dict[
parameter_key
]
val = args_dict.get(parameter_key)
is_optional_primitive = self._is_optional_type(parameter_type)
if val is not None or not is_optional_primitive:
method_params[self._sanitize_method_param_key(parameter_key)] = val
else:
datamodel = getattr(generated_datamodels, parameter_type)
for expanded_parameter in self.datamodels_extended_map[parameter_type]:
if parameter_key not in method_params:
if parameter_key not in method_params or not isinstance(
method_params[parameter_key], dict
):
method_params[parameter_key] = {}
datamodel_param_name = parameter_key
if expanded_parameter in self.excluded_parameters:
continue
if expanded_parameter in args_dict.keys():
method_params[parameter_key][
self._sanitize_method_param_key(expanded_parameter)
] = args_dict[expanded_parameter]
datamodel_params = cast("dict[str, Any]", method_params[parameter_key])
datamodel_params[self._sanitize_method_param_key(expanded_parameter)] = (
args_dict[expanded_parameter]
)

if datamodel:
if datamodel_param_name:
datamodel_params = cast("dict[str, Any]", method_params[datamodel_param_name])
# Apply datamodel-specific defaults (e.g., logical_date for TriggerDAGRunPostBody)
method_params[datamodel_param_name] = self._apply_datamodel_defaults(
datamodel, method_params[datamodel_param_name]
datamodel, datamodel_params
)
method_params[datamodel_param_name] = datamodel.model_validate(
method_params[datamodel_param_name]
Expand Down
94 changes: 94 additions & 0 deletions airflow-ctl/tests/airflow_ctl/ctl/test_cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import argparse
from argparse import BooleanOptionalAction
from textwrap import dedent
from types import SimpleNamespace

import pytest

Expand Down Expand Up @@ -289,6 +290,66 @@ def delete(self, backfill_id: str) -> ServerResponseError | None:


class TestCliConfigMethods:
@staticmethod
def _run_list_command_with_dag_id(
monkeypatch,
*,
dag_id_value,
dag_id_param_type: str,
method_has_default: bool,
):
import airflowctl.api.operations as operations

captured: dict[str, object] = {}
dummy_operations_cls: type[object]

if method_has_default:

class DummyOperationsWithDefault:
def __init__(self, client):
self.client = client

def list(self, limit: int, dag_id: str = "default-dag"):
captured["limit"] = limit
captured["dag_id"] = dag_id
return {"items": []}

dummy_operations_cls = DummyOperationsWithDefault

else:

class DummyOperationsNoDefault:
def __init__(self, client):
self.client = client

def list(self, limit: int, dag_id: str):
captured["limit"] = limit
captured["dag_id"] = dag_id
return {"items": []}

dummy_operations_cls = DummyOperationsNoDefault

monkeypatch.setattr(operations, "DummyOperations", dummy_operations_cls, raising=False)
monkeypatch.setattr(
"airflowctl.ctl.cli_config.AirflowConsole.print_as",
lambda self, data, output: None,
)

command_factory = CommandFactory()
command_factory.operations = [
{
"name": "list",
"parameters": [{"limit": "int"}, {"dag_id": dag_id_param_type}],
"return_type": "dict",
"parent": SimpleNamespace(name="DummyOperations"),
}
]

command_factory._create_func_map_from_operation()
generated_func = command_factory.func_map[("list", "DummyOperations")]
generated_func(argparse.Namespace(limit=10, dag_id=dag_id_value, output="json"), api_client=object())
return captured

def test_add_to_parser_drops_type_for_boolean_optional_action(self):
"""Test add_to_parser removes type for BooleanOptionalAction."""
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -554,3 +615,36 @@ def test_apply_datamodel_defaults_other_datamodel(self):

# Should return params unchanged for other datamodels
assert result == params, "Params should be unchanged for non-TriggerDAGRunPostBody datamodels"

@pytest.mark.parametrize(
("dag_id_value", "expected_dag_id"),
[
(None, "default-dag"),
("manual-dag", "manual-dag"),
],
)
def test_create_func_map_handles_optional_primitive_params(
self, monkeypatch, dag_id_value, expected_dag_id
):
"""Test optional primitive params are skipped when None and passed when set."""
captured = self._run_list_command_with_dag_id(
monkeypatch,
dag_id_value=dag_id_value,
dag_id_param_type="str | None",
method_has_default=True,
)

assert captured["limit"] == 10
assert captured["dag_id"] == expected_dag_id

def test_create_func_map_keeps_none_for_required_primitive_params(self, monkeypatch):
"""Test required primitive params are passed even when parsed value is None."""
captured = self._run_list_command_with_dag_id(
monkeypatch,
dag_id_value=None,
dag_id_param_type="str",
method_has_default=False,
)

assert captured["limit"] == 10
assert captured["dag_id"] is None
Loading