Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from datetime import timedelta
from typing import Optional, Sequence, TypedDict

from sqlalchemy import or_, select, update
from sqlalchemy import delete, or_, select, update
from sqlalchemy.ext.asyncio.session import AsyncSession
from sqlalchemy.orm import joinedload, load_only, selectinload

Expand Down Expand Up @@ -33,6 +33,7 @@
)
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import (
ExportedFleetModel,
FleetModel,
InstanceModel,
JobModel,
Expand Down Expand Up @@ -489,6 +490,11 @@ async def _apply_process_result(
.where(PlacementGroupModel.fleet_id == context.fleet_model.id)
.values(fleet_deleted=True)
)
await session.execute(
delete(ExportedFleetModel).where(
ExportedFleetModel.fleet_id == context.fleet_model.id
)
)
if instance_update_rows:
await session.execute(
update(InstanceModel),
Expand Down
5 changes: 5 additions & 0 deletions src/dstack/_internal/server/services/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from dstack._internal.core.models.runs import RunStatus
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server.models import (
ExportModel,
FleetModel,
ImportModel,
MemberModel,
ProjectModel,
RunModel,
Expand Down Expand Up @@ -272,6 +274,7 @@ async def delete_projects(
# so there can be dangling active resources due to race conditions.
await _check_project_has_active_resources(session=session, project_id=p.id)

project_ids = {p.id for p in projects}
timestamp = str(int(get_current_datetime().timestamp()))
updates = []
for p in projects:
Expand All @@ -290,6 +293,8 @@ async def delete_projects(
targets=[events.Target.from_model(p)],
)
await session.execute(update(ProjectModel), updates)
await session.execute(delete(ExportModel).where(ExportModel.project_id.in_(project_ids)))
await session.execute(delete(ImportModel).where(ImportModel.project_id.in_(project_ids)))
await session.commit()


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,16 @@
FleetPipeline,
FleetWorker,
)
from dstack._internal.server.models import EventModel, EventTargetModel, FleetModel, InstanceModel
from dstack._internal.server.models import (
EventModel,
EventTargetModel,
ExportedFleetModel,
FleetModel,
InstanceModel,
)
from dstack._internal.server.services.projects import add_project_member
from dstack._internal.server.testing.common import (
create_export,
create_fleet,
create_instance,
create_placement_group,
Expand Down Expand Up @@ -1109,7 +1116,7 @@ async def test_consolidation_attempt_increments_when_over_max_and_no_idle_instan
assert instance2.status == InstanceStatus.BUSY
assert fleet.consolidation_attempt == 3

async def test_marks_placement_groups_fleet_deleted_on_fleet_delete(
async def test_deletes_related_resources_on_fleet_delete(
self, test_db, session: AsyncSession, worker: FleetWorker
):
project = await create_project(session)
Expand All @@ -1130,6 +1137,12 @@ async def test_marks_placement_groups_fleet_deleted_on_fleet_delete(
fleet=fleet,
name="test-pg-2",
)
await create_export(
session=session,
exporter_project=project,
importer_projects=[],
exported_fleets=[fleet],
)

fleet.lock_token = uuid.uuid4()
fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
Expand All @@ -1143,6 +1156,8 @@ async def test_marks_placement_groups_fleet_deleted_on_fleet_delete(
assert fleet.deleted
assert placement_group1.fleet_deleted
assert placement_group2.fleet_deleted
res = await session.execute(select(ExportedFleetModel))
assert len(res.scalars().all()) == 0

async def test_consolidation_respects_retry_delay(
self, test_db, session: AsyncSession, worker: FleetWorker
Expand Down
59 changes: 58 additions & 1 deletion src/tests/_internal/server/routers/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from dstack._internal.core.models.fleets import FleetStatus
from dstack._internal.core.models.runs import RunStatus
from dstack._internal.core.models.users import GlobalRole, ProjectRole
from dstack._internal.server.models import MemberModel, ProjectModel
from dstack._internal.server.models import ExportModel, ImportModel, MemberModel, ProjectModel
from dstack._internal.server.services.permissions import DefaultPermissions
from dstack._internal.server.services.projects import add_project_member
from dstack._internal.server.testing.common import (
Expand Down Expand Up @@ -1367,6 +1367,63 @@ async def test_errors_if_project_has_active_volumes(
res = await session.execute(select(ProjectModel).where(ProjectModel.deleted.is_(False)))
assert len(res.all()) == 0

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_deletes_export_models_on_project_delete(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.ADMIN)
project = await create_project(session=session, owner=user)
fleet = await create_fleet(session=session, project=project, deleted=True)
await create_export(
session=session,
exporter_project=project,
importer_projects=[],
exported_fleets=[fleet],
)

res = await session.execute(select(ExportModel))
assert len(res.scalars().all()) == 1

response = await client.post(
"/api/projects/delete",
headers=get_auth_headers(user.token),
json={"projects_names": [project.name]},
)
assert response.status_code == 200

res = await session.execute(select(ExportModel))
assert len(res.scalars().all()) == 0

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_deletes_import_models_on_project_delete(
self, test_db, session: AsyncSession, client: AsyncClient
):
user = await create_user(session=session, global_role=GlobalRole.ADMIN)
exporter_project = await create_project(session=session, owner=user, name="exporter")
importer_project = await create_project(session=session, owner=user, name="importer")
fleet = await create_fleet(session=session, project=exporter_project, deleted=True)
await create_export(
session=session,
exporter_project=exporter_project,
importer_projects=[importer_project],
exported_fleets=[fleet],
)

res = await session.execute(select(ImportModel))
assert len(res.scalars().all()) == 1

response = await client.post(
"/api/projects/delete",
headers=get_auth_headers(user.token),
json={"projects_names": [importer_project.name]},
)
assert response.status_code == 200

res = await session.execute(select(ImportModel))
assert len(res.scalars().all()) == 0


class TestGetProject:
@pytest.mark.asyncio
Expand Down
Loading