Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from collections.abc import Iterable, Mapping
from typing import Annotated, Any

from pydantic import Field, field_validator
from pydantic import Field, field_validator, model_validator
from pydantic_core.core_schema import ValidationInfo

from airflow._shared.secrets_masker import redact, should_hide_value_for_key
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel, make_partial_model
from airflow.configuration import conf


# Response Models
Expand Down Expand Up @@ -199,5 +200,13 @@ def validate_extra(cls, v: str | None) -> str | None:
)
return v

@model_validator(mode="after")
def validate_team_name(self) -> ConnectionBody:
if self.team_name is not None and not conf.getboolean("core", "multi_team"):
raise ValueError(
"team_name cannot be set when multi_team mode is disabled. Please contact your administrator."
)
return self


ConnectionBodyPartial = make_partial_model(ConnectionBody)
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from collections.abc import Callable, Iterable
from typing import Annotated

from pydantic import BeforeValidator, Field
from pydantic import BeforeValidator, Field, model_validator

from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
from airflow.configuration import conf


def _call_function(function: Callable[[], int]) -> int:
Expand Down Expand Up @@ -83,6 +84,14 @@ class PoolPatchBody(StrictBaseModel):
include_deferred: bool | None = None
team_name: str | None = Field(max_length=50, default=None)

@model_validator(mode="after")
def validate_team_name(self) -> PoolPatchBody:
if self.team_name is not None and not conf.getboolean("core", "multi_team"):
raise ValueError(
"team_name cannot be set when multi_team mode is disabled. Please contact your administrator."
)
return self


class PoolBody(BasePool, StrictBaseModel):
"""Pool serializer for post bodies."""
Expand All @@ -91,3 +100,11 @@ class PoolBody(BasePool, StrictBaseModel):
description: str | None = None
include_deferred: bool = False
team_name: str | None = Field(max_length=50, default=None)

@model_validator(mode="after")
def validate_team_name(self) -> PoolBody:
if self.team_name is not None and not conf.getboolean("core", "multi_team"):
raise ValueError(
"team_name cannot be set when multi_team mode is disabled. Please contact your administrator."
)
return self
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from airflow._shared.secrets_masker import redact
from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel, make_partial_model
from airflow.configuration import conf
from airflow.models.base import ID_LEN
from airflow.typing_compat import Self

Expand Down Expand Up @@ -60,6 +61,14 @@ class VariableBody(StrictBaseModel):
description: str | None = Field(default=None)
team_name: str | None = Field(max_length=50, default=None)

@model_validator(mode="after")
def validate_team_name(self) -> VariableBody:
if self.team_name is not None and not conf.getboolean("core", "multi_team"):
raise ValueError(
"team_name cannot be set when multi_team mode is disabled. Please contact your administrator."
)
return self


VariableBodyPartial = make_partial_model(VariableBody)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ export const useEditConnection = (
if (requestBody.schema !== initialConnection.schema) {
updateMask.push("schema");
}
if (requestBody.team_name !== initialConnection.team_name) {
updateMask.push("team_name");
}

const teamName = requestBody.team_name === "" ? undefined : requestBody.team_name;

mutate({
connectionId: initialConnection.connection_id,
Expand All @@ -99,6 +104,7 @@ export const useEditConnection = (
extra: requestBody.extra === "{}" ? undefined : requestBody.extra,
// eslint-disable-next-line unicorn/no-null
port: requestBody.port === "" ? null : Number(requestBody.port),
team_name: teamName,
},
updateMask,
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,13 @@ export const useEditVariable = (

const parsedDescription =
editVariableRequestBody.description === "" ? undefined : editVariableRequestBody.description;
const teamName = editVariableRequestBody.team_name === "" ? undefined : editVariableRequestBody.team_name;

mutate({
requestBody: {
description: parsedDescription,
key: editVariableRequestBody.key,
team_name: editVariableRequestBody.team_name,
team_name: teamName,
value: editVariableRequestBody.value,
},
updateMask,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from tests_common.test_utils.api_fastapi import _check_last_log
from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_connections, clear_db_logs, clear_test_connections
from tests_common.test_utils.markers import skip_if_force_lowest_dependencies_marker

Expand Down Expand Up @@ -286,10 +287,15 @@ def test_post_should_respond_201(self, test_client, session, body):
assert len(connection) == 1
_check_last_log(session, dag_id=None, event="post_connection", logical_date=None)

@conf_vars({("core", "multi_team"): "True"})
def test_post_should_respond_201_with_team(self, test_client, session, testing_team):
response = test_client.post(
"/connections",
json={"connection_id": TEST_CONN_ID, "conn_type": TEST_CONN_TYPE, "team_name": testing_team.name},
json={
"connection_id": TEST_CONN_ID,
"conn_type": TEST_CONN_TYPE,
"team_name": testing_team.name,
},
)
assert response.status_code == 201
assert response.json() == {
Expand Down Expand Up @@ -338,6 +344,22 @@ def test_post_should_respond_422_for_invalid_conn_id(self, test_client, body):
]
}

@conf_vars({("core", "multi_team"): "False"})
def test_post_rejects_team_name_when_multi_team_disabled(self, test_client):
response = test_client.post(
"/connections",
json={
"connection_id": TEST_CONN_ID_2,
"conn_type": TEST_CONN_TYPE_2,
"team_name": "test_team",
},
)
assert response.status_code == 422
assert (
"team_name cannot be set when multi_team mode is disabled. Please contact your administrator."
in response.json()["detail"][0]["msg"]
)

@pytest.mark.parametrize(
"body",
[
Expand Down Expand Up @@ -602,6 +624,7 @@ def test_patch_should_respond_200(

assert response.json() == expected_result

@conf_vars({("core", "multi_team"): "True"})
def test_patch_with_team_should_respond_200(self, test_client, testing_team, session):
self.create_connection()

Expand Down Expand Up @@ -966,6 +989,23 @@ def test_patch_with_update_mask_rejects_extra_fields(self, test_client):
)
assert response.status_code == 422

@conf_vars({("core", "multi_team"): "False"})
def test_patch_rejects_team_name_when_multi_team_disabled(self, test_client):
self.create_connection()
response = test_client.patch(
f"/connections/{TEST_CONN_ID_2}",
json={
"connection_id": TEST_CONN_ID_2,
"conn_type": "new_type",
"team_name": "test_team",
},
)
assert response.status_code == 422
assert (
"team_name cannot be set when multi_team mode is disabled. Please contact your administrator."
in response.json()["detail"][0]["msg"]
)


class TestConnection(TestConnectionEndpoint):
def setup_method(self):
Expand Down Expand Up @@ -1616,6 +1656,57 @@ def test_bulk_delete_avoids_n_plus_one_queries(self, session):

assert sorted(results.success) == [TEST_CONN_ID, TEST_CONN_ID_2]

@conf_vars({("core", "multi_team"): "False"})
def test_bulk_rejects_team_name_when_multi_team_is_disabled(self, test_client):
actions = {
"actions": [
{
"action": "create",
"entities": [
{
"connection_id": "test_conn_id_1",
"conn_type": TEST_CONN_TYPE,
"description": "description",
},
{
"connection_id": "test_conn_id_2",
"conn_type": TEST_CONN_TYPE_2,
"description": "description_2",
"team_name": "test_team",
},
],
},
{
"action": "update",
"entities": [
{
"connection_id": "test_conn_id_3",
"conn_type": TEST_CONN_TYPE,
"description": "updated_description",
"team_name": "test_team",
},
{
"connection_id": "test_conn_id_4",
"conn_type": TEST_CONN_TYPE_2,
"description": "updated_description_2",
},
],
},
]
}
response = test_client.patch("/connections", json=actions)
assert response.status_code == 422
detail = response.json()["detail"]

assert all(
"team_name cannot be set when multi_team mode is disabled. Please contact your administrator."
in err["msg"]
for err in detail
), f"Unexpected errors in detail: {detail}"

expected_error_conn_ids = {err["input"]["connection_id"] for err in detail}
assert sorted(expected_error_conn_ids) == ["test_conn_id_2", "test_conn_id_3"]


class TestPostConnectionExtraBackwardCompatibility(TestConnectionEndpoint):
def test_post_should_accept_empty_string_as_extra(self, test_client, session):
Expand Down
Loading
Loading