diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py index a44edb5b91b43..6b4154bb9b525 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/connections.py @@ -21,11 +21,12 @@ from collections.abc import Iterable, Mapping from typing import Annotated, Any -from pydantic import Field, field_validator +from pydantic import Field, field_validator, model_validator from pydantic_core.core_schema import ValidationInfo from airflow._shared.secrets_masker import redact, should_hide_value_for_key from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel, make_partial_model +from airflow.configuration import conf # Response Models @@ -199,5 +200,13 @@ def validate_extra(cls, v: str | None) -> str | None: ) return v + @model_validator(mode="after") + def validate_team_name(self) -> ConnectionBody: + if self.team_name is not None and not conf.getboolean("core", "multi_team"): + raise ValueError( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + ) + return self + ConnectionBodyPartial = make_partial_model(ConnectionBody) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/pools.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/pools.py index a661308d4434c..55c7c3ad35e56 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/pools.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/pools.py @@ -20,9 +20,10 @@ from collections.abc import Callable, Iterable from typing import Annotated -from pydantic import BeforeValidator, Field +from pydantic import BeforeValidator, Field, model_validator from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel +from airflow.configuration import conf def _call_function(function: Callable[[], int]) -> int: @@ -83,6 +84,14 @@ class PoolPatchBody(StrictBaseModel): include_deferred: bool | None = None team_name: str | None = Field(max_length=50, default=None) + @model_validator(mode="after") + def validate_team_name(self) -> PoolPatchBody: + if self.team_name is not None and not conf.getboolean("core", "multi_team"): + raise ValueError( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + ) + return self + class PoolBody(BasePool, StrictBaseModel): """Pool serializer for post bodies.""" @@ -91,3 +100,11 @@ class PoolBody(BasePool, StrictBaseModel): description: str | None = None include_deferred: bool = False team_name: str | None = Field(max_length=50, default=None) + + @model_validator(mode="after") + def validate_team_name(self) -> PoolBody: + if self.team_name is not None and not conf.getboolean("core", "multi_team"): + raise ValueError( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + ) + return self diff --git a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/variables.py b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/variables.py index 9dc7969b69b6a..001d8c70e95ad 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/datamodels/variables.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/datamodels/variables.py @@ -24,6 +24,7 @@ from airflow._shared.secrets_masker import redact from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel, make_partial_model +from airflow.configuration import conf from airflow.models.base import ID_LEN from airflow.typing_compat import Self @@ -60,6 +61,14 @@ class VariableBody(StrictBaseModel): description: str | None = Field(default=None) team_name: str | None = Field(max_length=50, default=None) + @model_validator(mode="after") + def validate_team_name(self) -> VariableBody: + if self.team_name is not None and not conf.getboolean("core", "multi_team"): + raise ValueError( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + ) + return self + VariableBodyPartial = make_partial_model(VariableBody) diff --git a/airflow-core/src/airflow/ui/src/queries/useEditConnection.tsx b/airflow-core/src/airflow/ui/src/queries/useEditConnection.tsx index 54f5cef48c210..0b7b9cefd0af4 100644 --- a/airflow-core/src/airflow/ui/src/queries/useEditConnection.tsx +++ b/airflow-core/src/airflow/ui/src/queries/useEditConnection.tsx @@ -89,6 +89,11 @@ export const useEditConnection = ( if (requestBody.schema !== initialConnection.schema) { updateMask.push("schema"); } + if (requestBody.team_name !== initialConnection.team_name) { + updateMask.push("team_name"); + } + + const teamName = requestBody.team_name === "" ? undefined : requestBody.team_name; mutate({ connectionId: initialConnection.connection_id, @@ -99,6 +104,7 @@ export const useEditConnection = ( extra: requestBody.extra === "{}" ? undefined : requestBody.extra, // eslint-disable-next-line unicorn/no-null port: requestBody.port === "" ? null : Number(requestBody.port), + team_name: teamName, }, updateMask, }); diff --git a/airflow-core/src/airflow/ui/src/queries/useEditVariable.ts b/airflow-core/src/airflow/ui/src/queries/useEditVariable.ts index 73a0fcab7dca7..ae33993058181 100644 --- a/airflow-core/src/airflow/ui/src/queries/useEditVariable.ts +++ b/airflow-core/src/airflow/ui/src/queries/useEditVariable.ts @@ -78,12 +78,13 @@ export const useEditVariable = ( const parsedDescription = editVariableRequestBody.description === "" ? undefined : editVariableRequestBody.description; + const teamName = editVariableRequestBody.team_name === "" ? undefined : editVariableRequestBody.team_name; mutate({ requestBody: { description: parsedDescription, key: editVariableRequestBody.key, - team_name: editVariableRequestBody.team_name, + team_name: teamName, value: editVariableRequestBody.value, }, updateMask, diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py index 1f63247cfa9ab..f631f197651cc 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py @@ -34,6 +34,7 @@ from tests_common.test_utils.api_fastapi import _check_last_log from tests_common.test_utils.asserts import assert_queries_count +from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import clear_db_connections, clear_db_logs, clear_test_connections from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker @@ -286,10 +287,15 @@ def test_post_should_respond_201(self, test_client, session, body): assert len(connection) == 1 _check_last_log(session, dag_id=None, event="post_connection", logical_date=None) + @conf_vars({("core", "multi_team"): "True"}) def test_post_should_respond_201_with_team(self, test_client, session, testing_team): response = test_client.post( "/connections", - json={"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, "team_name": testing_team.name}, + json={ + "connection_id": TEST_CONN_ID, + "conn_type": TEST_CONN_TYPE, + "team_name": testing_team.name, + }, ) assert response.status_code == 201 assert response.json() == { @@ -338,6 +344,22 @@ def test_post_should_respond_422_for_invalid_conn_id(self, test_client, body): ] } + @conf_vars({("core", "multi_team"): "False"}) + def test_post_rejects_team_name_when_multi_team_disabled(self, test_client): + response = test_client.post( + "/connections", + json={ + "connection_id": TEST_CONN_ID_2, + "conn_type": TEST_CONN_TYPE_2, + "team_name": "test_team", + }, + ) + assert response.status_code == 422 + assert ( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + in response.json()["detail"][0]["msg"] + ) + @pytest.mark.parametrize( "body", [ @@ -602,6 +624,7 @@ def test_patch_should_respond_200( assert response.json() == expected_result + @conf_vars({("core", "multi_team"): "True"}) def test_patch_with_team_should_respond_200(self, test_client, testing_team, session): self.create_connection() @@ -966,6 +989,23 @@ def test_patch_with_update_mask_rejects_extra_fields(self, test_client): ) assert response.status_code == 422 + @conf_vars({("core", "multi_team"): "False"}) + def test_patch_rejects_team_name_when_multi_team_disabled(self, test_client): + self.create_connection() + response = test_client.patch( + f"/connections/{TEST_CONN_ID_2}", + json={ + "connection_id": TEST_CONN_ID_2, + "conn_type": "new_type", + "team_name": "test_team", + }, + ) + assert response.status_code == 422 + assert ( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + in response.json()["detail"][0]["msg"] + ) + class TestConnection(TestConnectionEndpoint): def setup_method(self): @@ -1616,6 +1656,57 @@ def test_bulk_delete_avoids_n_plus_one_queries(self, session): assert sorted(results.success) == [TEST_CONN_ID, TEST_CONN_ID_2] + @conf_vars({("core", "multi_team"): "False"}) + def test_bulk_rejects_team_name_when_multi_team_is_disabled(self, test_client): + actions = { + "actions": [ + { + "action": "create", + "entities": [ + { + "connection_id": "test_conn_id_1", + "conn_type": TEST_CONN_TYPE, + "description": "description", + }, + { + "connection_id": "test_conn_id_2", + "conn_type": TEST_CONN_TYPE_2, + "description": "description_2", + "team_name": "test_team", + }, + ], + }, + { + "action": "update", + "entities": [ + { + "connection_id": "test_conn_id_3", + "conn_type": TEST_CONN_TYPE, + "description": "updated_description", + "team_name": "test_team", + }, + { + "connection_id": "test_conn_id_4", + "conn_type": TEST_CONN_TYPE_2, + "description": "updated_description_2", + }, + ], + }, + ] + } + response = test_client.patch("/connections", json=actions) + assert response.status_code == 422 + detail = response.json()["detail"] + + assert all( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + in err["msg"] + for err in detail + ), f"Unexpected errors in detail: {detail}" + + expected_error_conn_ids = {err["input"]["connection_id"] for err in detail} + assert sorted(expected_error_conn_ids) == ["test_conn_id_2", "test_conn_id_3"] + class TestPostConnectionExtraBackwardCompatibility(TestConnectionEndpoint): def test_post_should_accept_empty_string_as_extra(self, test_client, session): diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py index e2ef872073530..08aa24cd76721 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py @@ -25,6 +25,7 @@ from airflow.models.team import Team from airflow.utils.session import provide_session +from tests_common.test_utils.config import conf_vars from tests_common.test_utils.db import clear_db_pools, clear_db_teams from tests_common.test_utils.logs import check_last_log @@ -416,8 +417,27 @@ def test_patch_pool3_should_respond_200(self, test_client, session): assert response.json() == expected_response check_last_log(session, dag_id=None, event="patch_pool", logical_date=None) + @conf_vars({("core", "multi_team"): "False"}) + def test_patch_pool_rejects_team_name_when_multi_team_disabled(self, test_client): + self.create_pools() + response = test_client.patch( + f"/pools/{POOL2_NAME}", + json={ + "name": POOL2_NAME, + "slots": POOL2_SLOT, + "include_deferred": POOL2_INCLUDE_DEFERRED, + "team_name": "test_team", + }, + ) + assert response.status_code == 422 + assert ( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + in response.json()["detail"][0]["msg"] + ) + class TestPostPool(TestPoolsEndpoint): + @conf_vars({("core", "multi_team"): "True"}) @pytest.mark.parametrize( ("body", "expected_status_code", "expected_response"), [ @@ -483,9 +503,10 @@ class TestPostPool(TestPoolsEndpoint): def test_should_respond_200(self, test_client, session, body, expected_status_code, expected_response): self.create_pools() n_pools = session.scalar(select(func.count()).select_from(Pool)) + response = test_client.post("/pools", json=body) - assert response.status_code == expected_status_code + assert response.status_code == expected_status_code assert response.json() == expected_response assert session.scalar(select(func.count()).select_from(Pool)) == n_pools + 1 check_last_log(session, dag_id=None, event="post_pool", logical_date=None) @@ -523,6 +544,22 @@ def test_post_pool_rejects_infinity_string(self, test_client, session): ) assert response.status_code == 422 + @conf_vars({("core", "multi_team"): "False"}) + def test_post_pool_rejects_team_name_when_multi_team_disabled(self, test_client): + response = test_client.post( + "/pools", + json={ + "name": "bad_team_pool", + "slots": 1, + "team_name": "test_team", + }, + ) + assert response.status_code == 422 + assert ( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + in response.json()["detail"][0]["msg"] + ) + def test_should_respond_401(self, unauthenticated_test_client): response = unauthenticated_test_client.post("/pools", json={}) assert response.status_code == 401 @@ -590,6 +627,7 @@ def test_should_response_409( class TestBulkPools(TestPoolsEndpoint): + @conf_vars({("core", "multi_team"): "True"}) @pytest.mark.enable_redact @pytest.mark.parametrize( ("actions", "expected_results"), @@ -1045,7 +1083,9 @@ class TestBulkPools(TestPoolsEndpoint): ) def test_bulk_pools(self, test_client, actions, expected_results, session): self.create_pools() + response = test_client.patch("/pools", json=actions) + response_data = response.json() for key, value in expected_results.items(): assert response_data[key] == value @@ -1105,3 +1145,54 @@ def test_should_respond_403(self, unauthorized_test_client): }, ) assert response.status_code == 403 + + @conf_vars({("core", "multi_team"): "False"}) + def test_bulk_rejects_team_name_when_multi_team_is_disabled(self, test_client): + actions = { + "actions": [ + { + "action": "create", + "entities": [ + { + "name": "pool_1", + "slots": 1, + "description": "description", + }, + { + "name": "pool_2", + "slots": 2, + "description": "description_2", + "team_name": "test_team", + }, + ], + }, + { + "action": "update", + "entities": [ + { + "name": "pool_3", + "slots": 3, + "description": "updated_description", + "team_name": "test_team", + }, + { + "name": "pool_4", + "slots": 4, + "description": "updated_description_2", + }, + ], + }, + ] + } + response = test_client.patch("/pools", json=actions) + assert response.status_code == 422 + detail = response.json()["detail"] + + assert all( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + in err["msg"] + for err in detail + ), f"Unexpected errors in detail: {detail}" + + expected_error_names = {err["input"]["name"] for err in detail} + assert sorted(expected_error_names) == ["pool_2", "pool_3"] diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py index 508718796755a..045b78acc5636 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_variables.py @@ -521,6 +521,21 @@ def test_patch_should_respond_404(self, test_client): body = response.json() assert f"The Variable with key: `{TEST_VARIABLE_KEY}` was not found" == body["detail"] + @conf_vars({("core", "multi_team"): "False"}) + def test_patch_rejects_team_name_when_multi_team_disabled(self, test_client): + body = { + "key": TEST_VARIABLE_KEY, + "value": "The new value", + "description": "The new description", + "team_name": "test_team", + } + response = test_client.patch(f"/variables/{TEST_VARIABLE_KEY}", json=body) + assert response.status_code == 422 + assert ( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + in response.json()["detail"][0]["msg"] + ) + @pytest.mark.enable_redact def test_patch_with_update_mask_description_only(self, test_client, session): """PATCH with update_mask=['description'] should only update description, keeping value unchanged.""" @@ -686,6 +701,21 @@ def test_post_should_respond_422_when_key_too_large(self, test_client): ] } + @conf_vars({("core", "multi_team"): "False"}) + def test_post_rejects_team_name_when_multi_team_disabled(self, test_client): + body = { + "key": "new variable key", + "value": "new variable value", + "description": "new variable description", + "team_name": "test_team", + } + response = test_client.post("/variables", json=body) + assert response.status_code == 422 + assert ( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + in response.json()["detail"][0]["msg"] + ) + @pytest.mark.parametrize( "body", [ @@ -1366,3 +1396,55 @@ def test_bulk_variables_should_respond_403(self, unauthorized_test_client): }, ) assert response.status_code == 403 + + @conf_vars({("core", "multi_team"): "False"}) + def test_bulk_rejects_team_name_when_multi_team_is_disabled(self, test_client): + actions = { + "actions": [ + { + "action": "create", + "entities": [ + { + "key": "var_1", + "value": "value_1", + "description": "description", + }, + { + "key": "var_2", + "value": "value_2", + "description": "description_2", + "team_name": "test_team", + }, + ], + }, + { + "action": "update", + "entities": [ + { + "key": "var_3", + "value": "value_3", + "description": "updated_description", + "team_name": "test_team", + }, + { + "key": "var_4", + "value": "value_4", + "description": "updated_description_2", + }, + ], + }, + ] + } + response = test_client.patch("/variables", json=actions) + + assert response.status_code == 422 + detail = response.json()["detail"] + + assert all( + "team_name cannot be set when multi_team mode is disabled. Please contact your administrator." + in err["msg"] + for err in detail + ), f"Unexpected errors in detail: {detail}" + + expected_error_keys = {err["input"]["key"] for err in detail} + assert sorted(expected_error_keys) == ["var_2", "var_3"]