From 15ed950d7a0835bcc5451edb1b60a808c966a277 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Mon, 27 Apr 2026 09:26:57 +0200 Subject: [PATCH] Clean up exports on project and fleet deletion - Delete exported fleets on fleet deletion - Delete exports and imports on project deletion --- .../background/pipeline_tasks/fleets.py | 8 ++- .../_internal/server/services/projects.py | 5 ++ .../background/pipeline_tasks/test_fleets.py | 19 +++++- .../_internal/server/routers/test_projects.py | 59 ++++++++++++++++++- 4 files changed, 87 insertions(+), 4 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 1e033865f..8050c552c 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -4,7 +4,7 @@ from datetime import timedelta from typing import Optional, Sequence, TypedDict -from sqlalchemy import or_, select, update +from sqlalchemy import delete, or_, select, update from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.orm import joinedload, load_only, selectinload @@ -33,6 +33,7 @@ ) from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( + ExportedFleetModel, FleetModel, InstanceModel, JobModel, @@ -489,6 +490,11 @@ async def _apply_process_result( .where(PlacementGroupModel.fleet_id == context.fleet_model.id) .values(fleet_deleted=True) ) + await session.execute( + delete(ExportedFleetModel).where( + ExportedFleetModel.fleet_id == context.fleet_model.id + ) + ) if instance_update_rows: await session.execute( update(InstanceModel), diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index 15438efbc..499d6c039 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -27,7 +27,9 @@ from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole from dstack._internal.server.models import ( + ExportModel, FleetModel, + ImportModel, MemberModel, ProjectModel, RunModel, @@ -272,6 +274,7 @@ async def delete_projects( # so there can be dangling active resources due to race conditions. await _check_project_has_active_resources(session=session, project_id=p.id) + project_ids = {p.id for p in projects} timestamp = str(int(get_current_datetime().timestamp())) updates = [] for p in projects: @@ -290,6 +293,8 @@ async def delete_projects( targets=[events.Target.from_model(p)], ) await session.execute(update(ProjectModel), updates) + await session.execute(delete(ExportModel).where(ExportModel.project_id.in_(project_ids))) + await session.execute(delete(ImportModel).where(ImportModel.project_id.in_(project_ids))) await session.commit() diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py index 1a47d6ca5..2268ad940 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_fleets.py @@ -23,9 +23,16 @@ FleetPipeline, FleetWorker, ) -from dstack._internal.server.models import EventModel, EventTargetModel, FleetModel, InstanceModel +from dstack._internal.server.models import ( + EventModel, + EventTargetModel, + ExportedFleetModel, + FleetModel, + InstanceModel, +) from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( + create_export, create_fleet, create_instance, create_placement_group, @@ -1109,7 +1116,7 @@ async def test_consolidation_attempt_increments_when_over_max_and_no_idle_instan assert instance2.status == InstanceStatus.BUSY assert fleet.consolidation_attempt == 3 - async def test_marks_placement_groups_fleet_deleted_on_fleet_delete( + async def test_deletes_related_resources_on_fleet_delete( self, test_db, session: AsyncSession, worker: FleetWorker ): project = await create_project(session) @@ -1130,6 +1137,12 @@ async def test_marks_placement_groups_fleet_deleted_on_fleet_delete( fleet=fleet, name="test-pg-2", ) + await create_export( + session=session, + exporter_project=project, + importer_projects=[], + exported_fleets=[fleet], + ) fleet.lock_token = uuid.uuid4() fleet.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) @@ -1143,6 +1156,8 @@ async def test_marks_placement_groups_fleet_deleted_on_fleet_delete( assert fleet.deleted assert placement_group1.fleet_deleted assert placement_group2.fleet_deleted + res = await session.execute(select(ExportedFleetModel)) + assert len(res.scalars().all()) == 0 async def test_consolidation_respects_retry_delay( self, test_db, session: AsyncSession, worker: FleetWorker diff --git a/src/tests/_internal/server/routers/test_projects.py b/src/tests/_internal/server/routers/test_projects.py index 3b8b4aab0..6d7bcca0e 100644 --- a/src/tests/_internal/server/routers/test_projects.py +++ b/src/tests/_internal/server/routers/test_projects.py @@ -11,7 +11,7 @@ from dstack._internal.core.models.fleets import FleetStatus from dstack._internal.core.models.runs import RunStatus from dstack._internal.core.models.users import GlobalRole, ProjectRole -from dstack._internal.server.models import MemberModel, ProjectModel +from dstack._internal.server.models import ExportModel, ImportModel, MemberModel, ProjectModel from dstack._internal.server.services.permissions import DefaultPermissions from dstack._internal.server.services.projects import add_project_member from dstack._internal.server.testing.common import ( @@ -1367,6 +1367,63 @@ async def test_errors_if_project_has_active_volumes( res = await session.execute(select(ProjectModel).where(ProjectModel.deleted.is_(False))) assert len(res.all()) == 0 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_deletes_export_models_on_project_delete( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.ADMIN) + project = await create_project(session=session, owner=user) + fleet = await create_fleet(session=session, project=project, deleted=True) + await create_export( + session=session, + exporter_project=project, + importer_projects=[], + exported_fleets=[fleet], + ) + + res = await session.execute(select(ExportModel)) + assert len(res.scalars().all()) == 1 + + response = await client.post( + "/api/projects/delete", + headers=get_auth_headers(user.token), + json={"projects_names": [project.name]}, + ) + assert response.status_code == 200 + + res = await session.execute(select(ExportModel)) + assert len(res.scalars().all()) == 0 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_deletes_import_models_on_project_delete( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.ADMIN) + exporter_project = await create_project(session=session, owner=user, name="exporter") + importer_project = await create_project(session=session, owner=user, name="importer") + fleet = await create_fleet(session=session, project=exporter_project, deleted=True) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[fleet], + ) + + res = await session.execute(select(ImportModel)) + assert len(res.scalars().all()) == 1 + + response = await client.post( + "/api/projects/delete", + headers=get_auth_headers(user.token), + json={"projects_names": [importer_project.name]}, + ) + assert response.status_code == 200 + + res = await session.execute(select(ImportModel)) + assert len(res.scalars().all()) == 0 + class TestGetProject: @pytest.mark.asyncio