Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions src/dstack/_internal/core/backends/base/configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ class BackendRecord(CoreModel):
This model includes backend parameters to store in the DB.
"""

# `config` stores text-encoded non-sensitive backend config parameters (e.g. json)
config: str
# `auth` stores text-encoded sensitive backend config parameters (e.g. json).
# Configurator should not encrypt/decrypt it. This is done by the caller.
"""`config` stores text-encoded non-sensitive backend config parameters (e.g. json)
"""
auth: str
"""`auth` stores text-encoded sensitive backend config parameters (e.g. json).
`Configurator` should not encrypt/decrypt it. This is done by the caller.
"""


class StoredBackendRecord(BackendRecord):
Expand All @@ -53,8 +55,8 @@ class Configurator(ABC, Generic[BackendConfigWithoutCredsT, BackendConfigWithCre
"""

TYPE: ClassVar[BackendType]
# `BACKEND_CLASS` is used to introspect backend features without initializing it.
BACKEND_CLASS: ClassVar[type[Backend]]
"""`BACKEND_CLASS` is used to introspect backend features without initializing it."""

@abstractmethod
def validate_config(self, config: BackendConfigWithCredsT, default_creds_enabled: bool):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Add BackendModel.source_config and BackendModel.source_auth

Revision ID: 1b9e2e7e7d35
Revises: ad8c50120507
Create Date: 2026-04-10 12:00:00.000000+00:00

"""

import sqlalchemy as sa
from alembic import op

import dstack._internal.server.models

# revision identifiers, used by Alembic.
revision = "1b9e2e7e7d35"
down_revision = "ad8c50120507"
branch_labels = None
depends_on = None


def upgrade() -> None:
with op.batch_alter_table("backends", schema=None) as batch_op:
batch_op.add_column(sa.Column("source_config", sa.String(length=20000), nullable=True))
batch_op.add_column(
sa.Column(
"source_auth",
dstack._internal.server.models.EncryptedString(20000),
nullable=True,
)
)


def downgrade() -> None:
with op.batch_alter_table("backends", schema=None) as batch_op:
batch_op.drop_column("source_auth")
batch_op.drop_column("source_config")
10 changes: 10 additions & 0 deletions src/dstack/_internal/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,16 @@ class BackendModel(BaseModel):

config: Mapped[str] = mapped_column(String(20000))
auth: Mapped[DecryptedString] = mapped_column(EncryptedString(20000))
source_config: Mapped[Optional[str]] = mapped_column(String(20000), nullable=True)
"""`source_config` stores the original non-sensitive backend config from user input
before configurators materialize defaults or generated values.
"""
source_auth: Mapped[Optional[DecryptedString]] = mapped_column(
EncryptedString(20000), nullable=True
)
"""`source_auth` stores the original sensitive backend config from user input
before configurators materialize defaults or generated values.
"""

gateways: Mapped[List["GatewayModel"]] = relationship(back_populates="backend")

Expand Down
77 changes: 77 additions & 0 deletions src/dstack/_internal/server/services/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import asyncio
import heapq
import json
import time
from collections.abc import Iterable, Iterator
from typing import Callable, Coroutine, Dict, List, Optional, Tuple
from uuid import UUID

from cachetools import TTLCache
from pydantic import Field, ValidationError
from sqlalchemy import delete, update
from sqlalchemy.ext.asyncio import AsyncSession
from typing_extensions import Annotated

from dstack._internal.core.backends.base.backend import Backend
from dstack._internal.core.backends.base.configurator import (
Expand All @@ -33,6 +36,7 @@
ServerClientError,
)
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.instances import (
InstanceOfferWithAvailability,
)
Expand All @@ -46,6 +50,20 @@
logger = get_logger(__name__)


class _BackendConfigWithCreds(CoreModel):
__root__: Annotated[AnyBackendConfigWithCreds, Field(..., discriminator="type")]


def serialize_source_backend_config(
config: AnyBackendConfigWithCreds,
) -> Tuple[str, Optional[str]]:
"""Split user-intent backend config into non-sensitive and sensitive JSON blobs."""
source_config_dict = config.dict()
source_auth = source_config_dict.pop("creds", None)
source_auth_json = None if source_auth is None else json.dumps(source_auth)
return json.dumps(source_config_dict), source_auth_json


async def create_backend(
session: AsyncSession,
project: ProjectModel,
Expand Down Expand Up @@ -89,6 +107,8 @@ async def update_backend(
.values(
config=backend.config,
auth=backend.auth,
source_config=backend.source_config,
source_auth=backend.source_auth,
)
)
return config
Expand All @@ -99,6 +119,9 @@ async def validate_and_create_backend_model(
configurator: Configurator,
config: AnyBackendConfigWithCreds,
) -> BackendModel:
# Configurators may mutate `config` while building the effective stored backend config,
# so capture the user-intent payload before validation/create_backend runs.
source_config, source_auth = serialize_source_backend_config(config)
await run_async(
configurator.validate_config, config, default_creds_enabled=settings.DEFAULT_CREDS_ENABLED
)
Expand All @@ -112,6 +135,8 @@ async def validate_and_create_backend_model(
type=configurator.TYPE,
config=backend_record.config,
auth=DecryptedString(plaintext=backend_record.auth),
source_config=source_config,
source_auth=None if source_auth is None else DecryptedString(plaintext=source_auth),
)


Expand All @@ -134,6 +159,16 @@ async def get_backend_config(
return None


async def get_source_backend_config(
project: ProjectModel,
backend_type: BackendType,
) -> Optional[AnyBackendConfigWithCreds]:
backend_model = await get_project_backend_model_by_type(project, backend_type)
if backend_model is None:
return None
return get_source_backend_config_from_backend_model(backend_model)


def get_backend_config_with_creds_from_backend_model(
configurator: Configurator,
backend_model: BackendModel,
Expand All @@ -152,6 +187,48 @@ def get_backend_config_without_creds_from_backend_model(
return backend_config


def get_source_backend_config_from_backend_model(
backend_model: BackendModel,
) -> Optional[AnyBackendConfigWithCreds]:
"""Reconstruct user-intent backend config from `source_config`/`source_auth`."""

if backend_model.source_config is None:
return None
try:
source_config_dict = json.loads(backend_model.source_config)
except ValueError:
logger.warning(
"Failed to parse source config for %s backend. Falling back to stored config.",
backend_model.type.value,
)
return None
if backend_model.source_auth is not None:
if not backend_model.source_auth.decrypted:
logger.warning(
"Failed to decrypt source creds for %s backend. Falling back to stored config.",
backend_model.type.value,
)
return None
try:
source_config_dict["creds"] = json.loads(
backend_model.source_auth.get_plaintext_or_error()
)
except ValueError:
logger.warning(
"Failed to parse source creds for %s backend. Falling back to stored config.",
backend_model.type.value,
)
return None
try:
return _BackendConfigWithCreds.parse_obj(source_config_dict).__root__
except ValidationError:
logger.warning(
"Failed to validate source config for %s backend. Falling back to stored config.",
backend_model.type.value,
)
return None


def get_stored_backend_record(backend_model: BackendModel) -> StoredBackendRecord:
return StoredBackendRecord(
config=backend_model.config,
Expand Down
13 changes: 10 additions & 3 deletions src/dstack/_internal/server/services/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ async def _apply_project_config(
backend_config = file_config_to_config(backend_file_config)
backend_type = BackendType(backend_config.type)
backends_to_delete.difference_update([backend_type])
backend_exists = any(backend_type == b.type for b in project.backends)
try:
current_backend_config = await backends_services.get_backend_config(
project=project,
Expand All @@ -154,9 +155,15 @@ async def _apply_project_config(
backend_type.value,
)
continue
if backend_config == current_backend_config:
continue
backend_exists = any(backend_type == b.type for b in project.backends)
if current_backend_config is not None:
current_source_backend_config = await backends_services.get_source_backend_config(
project=project,
backend_type=backend_type,
)
# current_source_backend_config may be missing for old backend records
comparable_backend_config = current_source_backend_config or current_backend_config
if backend_config == comparable_backend_config:
continue
try:
# current_backend_config may be None if backend exists
# but it's config is invalid (e.g. cannot be decrypted).
Expand Down
6 changes: 6 additions & 0 deletions src/dstack/_internal/server/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ async def create_backend(
backend_type: BackendType = BackendType.AWS,
config: Optional[Dict] = None,
auth: Optional[Dict] = None,
source_config: Optional[Dict] = None,
source_auth: Optional[Dict] = None,
) -> BackendModel:
if config is None:
config = {
Expand All @@ -239,6 +241,10 @@ async def create_backend(
type=backend_type,
config=json.dumps(config),
auth=DecryptedString(plaintext=json.dumps(auth)),
source_config=None if source_config is None else json.dumps(source_config),
source_auth=(
None if source_auth is None else DecryptedString(plaintext=json.dumps(source_auth))
),
)
session.add(backend)
await session.commit()
Expand Down
12 changes: 10 additions & 2 deletions src/tests/_internal/server/routers/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,11 @@ async def test_creates_aws_backend(self, test_db, session: AsyncSession, client:
)
assert response.status_code == 200, response.json()
res = await session.execute(select(BackendModel))
assert len(res.scalars().all()) == 1
backend = res.scalars().one()
assert backend.source_config is not None
assert backend.source_auth is not None
assert json.loads(backend.source_config)["regions"] == ["us-west-1"]
assert json.loads(backend.source_auth.get_plaintext_or_error()) == body["creds"]

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand Down Expand Up @@ -615,6 +619,10 @@ async def test_updates_backend(self, test_db, session: AsyncSession, client: Asy
assert response.status_code == 200, response.json()
await session.refresh(backend)
assert json.loads(backend.config)["regions"] == ["us-east-1"]
assert backend.source_config is not None
assert backend.source_auth is not None
assert json.loads(backend.source_config)["regions"] == ["us-east-1"]
assert json.loads(backend.source_auth.get_plaintext_or_error()) == body["creds"]

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
Expand Down Expand Up @@ -815,7 +823,7 @@ async def test_returns_config_info(self, test_db, session: AsyncSession, client:
"iam_instance_profile": None,
"tags": None,
"os_images": None,
"creds": json.loads(backend.auth.plaintext),
"creds": json.loads(backend.auth.get_plaintext_or_error()),
}


Expand Down
Loading
Loading