Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 44 additions & 47 deletions diracx-db/src/diracx/db/sql/sandbox_metadata/db.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
from __future__ import annotations

import logging
from contextlib import asynccontextmanager
from functools import partial
from typing import Any, AsyncGenerator
from typing import Any

from sqlalchemy import (
BigInteger,
Column,
Executable,
MetaData,
Table,
and_,
delete,
exists,
Expand Down Expand Up @@ -40,14 +34,6 @@
class SandboxMetadataDB(BaseSQLDB):
metadata = SandboxMetadataDBBase.metadata

# Temporary table to store the sandboxes to delete, see `select_and_delete_expired`
_temp_table = Table(
"sb_to_delete",
MetaData(),
Column("SBId", BigInteger, primary_key=True),
prefixes=["TEMPORARY"],
)

async def get_owner_id(self, user: UserInfo) -> int | None:
"""Get the id of the owner from the database."""
stmt = select(SBOwners.OwnerID).where(
Expand Down Expand Up @@ -221,18 +207,20 @@ async def unassign_sandboxes_to_jobs(self, jobs_ids: list[int]) -> None:
)
await self.conn.execute(unassign_stmt)

@asynccontextmanager
async def delete_unused_sandboxes(
self, *, limit: int | None = None
) -> AsyncGenerator[AsyncGenerator[str, None], None]:
"""Get the sandbox PFNs to delete.
async def select_sandboxes_for_deletion(
self, *, batch_size: int = 500
) -> tuple[list[int], list[str]]:
"""Select and lock a batch of sandboxes for deletion.

The result of this function can be used as an async context manager
to yield the PFNs of the sandboxes to delete. The context manager
will automatically remove the sandboxes from the database upon exit.
Uses FOR UPDATE SKIP LOCKED on MySQL to allow concurrent workers to
process different sandboxes in parallel without conflicts.

Args:
limit: If not None, the maximum number of sandboxes to delete.
batch_size: Maximum number of sandboxes to select.

Returns:
Tuple of (sb_ids, pfns) for the selected sandboxes.
On MySQL, the rows remain locked until the transaction commits/rollbacks.

"""
conditions = [
Expand All @@ -247,32 +235,41 @@ async def delete_unused_sandboxes(
# Sandboxes which are not on S3 will be handled by legacy DIRAC
condition = and_(SandBoxes.SEPFN.like("/S3/%"), or_(*conditions))

# Copy the in-flight rows to a temporary table
await self.conn.run_sync(partial(self._temp_table.create, checkfirst=True))
select_stmt = select(SandBoxes.SBId).where(condition)
if limit:
select_stmt = select_stmt.limit(limit)
insert_stmt = insert(self._temp_table).from_select(["SBId"], select_stmt)
await self.conn.execute(insert_stmt)
select_stmt = (
select(SandBoxes.SBId, SandBoxes.SEPFN).where(condition).limit(batch_size)
)

try:
# Select the sandbox PFNs from the temporary table and yield them
select_stmt = select(SandBoxes.SEPFN).join(
self._temp_table, self._temp_table.c.SBId == SandBoxes.SBId
# FOR UPDATE SKIP LOCKED is only supported on MySQL
# SQLite is used for testing and doesn't support row locking
if self.conn.dialect.name == "mysql":
select_stmt = select_stmt.with_for_update(skip_locked=True)
elif self.conn.dialect.name != "sqlite":
raise NotImplementedError(
f"Unsupported database dialect: {self.conn.dialect.name}"
)

async def yield_pfns() -> AsyncGenerator[str, None]:
async for row in await self.conn.stream(select_stmt):
yield row.SEPFN
result = await self.conn.execute(select_stmt)
rows = result.all()

yield yield_pfns()
sb_ids = [row.SBId for row in rows]
pfns = [row.SEPFN for row in rows]

# Delete the sandboxes from the main table
delete_stmt = delete(SandBoxes).where(
SandBoxes.SBId.in_(select(self._temp_table.c.SBId))
)
result = await self.conn.execute(delete_stmt)
logger.info("Deleted %d expired/unassigned sandboxes", result.rowcount)
return sb_ids, pfns

async def delete_sandboxes(self, sb_ids: list[int]) -> int:
"""Delete sandboxes by their IDs.

Args:
sb_ids: List of sandbox IDs to delete.

Returns:
Number of rows deleted.

"""
if not sb_ids:
return 0

finally:
await self.conn.run_sync(partial(self._temp_table.drop, checkfirst=True))
delete_stmt = delete(SandBoxes).where(SandBoxes.SBId.in_(sb_ids))
result = await self.conn.execute(delete_stmt)
logger.info("Deleted %d expired/unassigned sandboxes", result.rowcount)
return result.rowcount
93 changes: 63 additions & 30 deletions diracx-logic/src/diracx/logic/jobs/sandboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
s3_object_exists,
)
from diracx.core.settings import SandboxStoreSettings
from diracx.core.utils import batched_async
from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB

if TYPE_CHECKING:
Expand Down Expand Up @@ -202,34 +201,68 @@ async def clean_sandboxes(
sandbox_metadata_db: SandboxMetadataDB,
settings: SandboxStoreSettings,
*,
limit: int = 10_000,
max_concurrent_batches: int = 10,
batch_size: int = 500,
max_workers: int = 10,
) -> int:
"""Delete sandboxes that are not assigned to any job."""
semaphore = asyncio.Semaphore(max_concurrent_batches)
n_deleted = 0
async with (
sandbox_metadata_db.delete_unused_sandboxes(limit=limit) as generator,
asyncio.TaskGroup() as tg,
):
async for batch in batched_async(generator, 500):
objects: list[S3Object] = [{"Key": pfn_to_key(pfn)} for pfn in batch]
if logger.isEnabledFor(logging.DEBUG):
for pfn in batch:
logger.debug("Deleting sandbox %s from S3", pfn)
tg.create_task(delete_batch_and_log(settings, objects, semaphore))
n_deleted += len(objects)
return n_deleted


async def delete_batch_and_log(
settings: SandboxStoreSettings,
objects: list[S3Object],
semaphore: asyncio.Semaphore,
) -> None:
"""Helper function to delete a batch of objects and log the result."""
async with semaphore:
await s3_bulk_delete_with_retry(
settings.s3_client, settings.bucket_name, objects
"""Delete sandboxes that are not assigned to any job.

Uses SELECT FOR UPDATE SKIP LOCKED to allow multiple workers to run
in parallel without conflicts. Each batch:
1. Selects and locks rows
2. Deletes from S3
3. Deletes from DB

Args:
sandbox_metadata_db: Database connection (not yet entered).
settings: Sandbox store settings with S3 client.
batch_size: Number of sandboxes to process per batch.
max_workers: Maximum number of concurrent workers processing batches.

Returns:
Total number of sandboxes deleted.

"""
# Check if parallel workers are supported
async with sandbox_metadata_db:
dialect = sandbox_metadata_db.conn.dialect.name
if max_workers > 1 and dialect == "sqlite":
raise NotImplementedError(
"SQLite does not support parallel workers (no SKIP LOCKED support)"
)
logger.info("Deleted %d sandboxes from %s", len(objects), settings.bucket_name)

async def worker() -> int:
"""Process batches until no more work is available."""
worker_deleted = 0
while True:
async with sandbox_metadata_db:
# Select and lock a batch of sandboxes
sb_ids, pfns = await sandbox_metadata_db.select_sandboxes_for_deletion(
batch_size=batch_size
)

if not pfns:
break

# Delete from S3 first (while rows are locked)
objects: list[S3Object] = [{"Key": pfn_to_key(pfn)} for pfn in pfns]
if logger.isEnabledFor(logging.DEBUG):
for pfn in pfns:
logger.debug("Deleting sandbox %s from S3", pfn)

await s3_bulk_delete_with_retry(
settings.s3_client, settings.bucket_name, objects
)
logger.info(
"Deleted %d sandboxes from %s", len(objects), settings.bucket_name
)

# Then delete from DB
await sandbox_metadata_db.delete_sandboxes(sb_ids)
worker_deleted += len(sb_ids)

return worker_deleted

async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(worker()) for _ in range(max_workers)]

return sum(task.result() for task in tasks)
6 changes: 2 additions & 4 deletions diracx-logic/tests/jobs/test_sandboxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ async def test_upload_and_clean(
assert response.content == data

# There should be no sandboxes to remove
async with sandbox_metadata_db:
await clean_sandboxes(sandbox_metadata_db, sandbox_settings)
await clean_sandboxes(sandbox_metadata_db, sandbox_settings, max_workers=1)

# Try to download the sandbox
async with sandbox_metadata_db:
Expand All @@ -139,8 +138,7 @@ async def test_upload_and_clean(
)

# Now the sandbox should be removed
async with sandbox_metadata_db:
await clean_sandboxes(sandbox_metadata_db, sandbox_settings)
await clean_sandboxes(sandbox_metadata_db, sandbox_settings, max_workers=1)

# Check that the sandbox was actually removed from the bucket
with pytest.raises(botocore.exceptions.ClientError, match="Not Found"):
Expand Down
Loading