Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 55 additions & 27 deletions src/dstack/_internal/server/services/runs/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ async def get_job_plans(
run_spec: RunSpec,
max_offers: Optional[int],
) -> list[JobPlan]:
"""
Returns job plans for the given run spec.

Normal run planning (`dstack apply`) selects the best fleet candidate for each planned job
and builds offers from that path. `dstack offer` without `--group-by` uses the same
`/runs/get_plan` API, but its synthetic run spec is detected by
`_should_select_best_fleet_candidate()`. In that case, planning skips
best-fleet-candidate selection and collects offers directly: global offers when no fleets
are specified, or offers from the selected fleets when `--fleet` is used.

Services are planned per replica group. Other run types are planned once and then expanded
into per-job `JobPlan` results.
"""
run_name = run_spec.run_name
if run_spec.run_name is None:
# Set/unset dummy run name to generate job names for run plan.
Expand Down Expand Up @@ -120,7 +133,7 @@ async def get_job_plans(
volumes=volumes,
exclude_not_available=False,
)
if _should_force_non_fleet_offers(run_spec):
if not _should_select_best_fleet_candidate(run_spec):
if profile.fleets is None:
instance_offers, backend_offers = await _get_non_fleet_offers(
session=session,
Expand Down Expand Up @@ -160,23 +173,7 @@ async def get_job_plans(
run_spec=run_spec,
job_num=0,
)
candidate_fleet_models = await _select_candidate_fleet_models(
session=session,
project=project,
run_model=None,
run_spec=run_spec,
)
fleet_model, instance_offers, backend_offers = await find_optimal_fleet_with_offers(
project=project,
fleet_models=candidate_fleet_models,
run_model=None,
run_spec=run_spec,
job=jobs[0],
master_job_provisioning_data=None,
volumes=volumes,
exclude_not_available=False,
)
if _should_force_non_fleet_offers(run_spec):
if not _should_select_best_fleet_candidate(run_spec):
if profile.fleets is None:
instance_offers, backend_offers = await _get_non_fleet_offers(
session=session,
Expand All @@ -194,6 +191,23 @@ async def get_job_plans(
job=jobs[0],
volumes=volumes,
)
else:
candidate_fleet_models = await _select_candidate_fleet_models(
session=session,
project=project,
run_model=None,
run_spec=run_spec,
)
fleet_model, instance_offers, backend_offers = await find_optimal_fleet_with_offers(
project=project,
fleet_models=candidate_fleet_models,
run_model=None,
run_spec=run_spec,
job=jobs[0],
master_job_provisioning_data=None,
volumes=volumes,
exclude_not_available=False,
)

for job in jobs:
job_plan = _get_job_plan(
Expand Down Expand Up @@ -724,10 +738,10 @@ async def _get_offers_in_run_candidate_fleets(
"""
Returns existing-instance and backend offers across the run's candidate fleets.

Used by plain/json `dstack offer --fleet ...`. Unlike normal `dstack apply`, it does not
choose a single best fleet. Instead, it gathers existing-instance and backend offers from
each selected fleet, keeps existing instances as separate reusable options, and deduplicates
identical backend offers across fleets.
Used by `dstack offer --fleet ...` without `--group-by`. Unlike normal `dstack apply`, it
does not choose a single best fleet. Instead, it gathers existing-instance and backend
offers from each selected fleet, keeps existing instances as separate reusable options, and
deduplicates identical backend offers across fleets.
"""
candidate_fleet_models = await _select_candidate_fleet_models(
session=session,
Expand Down Expand Up @@ -820,11 +834,25 @@ def _get_job_plan(
)


def _should_force_non_fleet_offers(run_spec: RunSpec) -> bool:
# A hack to force non-fleet offers for `dstack offer` command that uses
# get run plan API to show offers and the only way to distinguish it is commands.
# Assuming real runs will not use such commands.
return run_spec.configuration.type == "task" and run_spec.configuration.commands == [":"]
def _should_select_best_fleet_candidate(run_spec: RunSpec) -> bool:
"""
Returns ``True`` for normal run planning and ``False`` for `dstack offer` without
`--group-by`.

Both `dstack apply` and `dstack offer` without `--group-by` call `/runs/get_plan`. The
current way to recognize `dstack offer` without `--group-by` is the synthetic task spec
that the CLI sends with `type == "task"` and `commands == [":"]`.
TODO: Replace this command-shape hack with an explicit request/API signal for
`dstack offer` without `--group-by`.

When this function returns ``False``, the planner skips best-fleet-candidate selection
and goes directly to the special `dstack offer` collection path:
global offers when no fleets are specified, or offers from the selected fleets when
`--fleet` is used.

A real task with `commands == [":"]` would also match this special `dstack offer` path.
"""
return not (run_spec.configuration.type == "task" and run_spec.configuration.commands == [":"])


def _get_offers_from_instances(
Expand Down
191 changes: 191 additions & 0 deletions src/tests/_internal/server/routers/test_runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
create_user,
get_auth_headers,
get_fleet_spec,
get_instance_offer_with_availability,
get_job_provisioning_data,
get_job_runtime_data,
get_run_spec,
Expand Down Expand Up @@ -2131,6 +2132,196 @@ async def test_offer_cli_without_fleet_keeps_global_offers(
offers = response.json()["job_plans"][0]["offers"]
assert [offer["backend"] for offer in offers] == ["aws", "runpod"]

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_offer_without_fleets_uses_global_offer_collection(
self,
test_db,
session: AsyncSession,
client: AsyncClient,
) -> None:
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session,
project=project,
user=user,
project_role=ProjectRole.USER,
)
repo = await create_repo(session=session, project_id=project.id)
run_spec = get_run_spec(
repo_id=repo.name,
configuration=TaskConfiguration(
commands=[":"],
image="scratch",
user="root",
),
)
global_offer = get_instance_offer_with_availability(price=1.0)
with (
patch(
"dstack._internal.server.services.runs.plan._get_non_fleet_offers",
new=AsyncMock(return_value=([(Mock(), global_offer)], [])),
) as get_non_fleet_offers_mock,
patch(
"dstack._internal.server.services.runs.plan._get_offers_in_run_candidate_fleets",
new=AsyncMock(
side_effect=AssertionError(
"_get_offers_in_run_candidate_fleets should not be called"
)
),
) as get_offers_in_run_candidate_fleets_mock,
patch(
"dstack._internal.server.services.runs.plan.find_optimal_fleet_with_offers",
new=AsyncMock(
side_effect=AssertionError(
"find_optimal_fleet_with_offers should not be called"
)
),
) as find_optimal_fleet_with_offers_mock,
):
response = await client.post(
f"/api/project/{project.name}/runs/get_plan",
headers=get_auth_headers(user.token),
json={"run_spec": run_spec.dict()},
)

assert response.status_code == 200, response.json()
get_non_fleet_offers_mock.assert_awaited_once()
get_offers_in_run_candidate_fleets_mock.assert_not_called()
find_optimal_fleet_with_offers_mock.assert_not_called()
job_plan = response.json()["job_plans"][0]
assert job_plan["total_offers"] == 1
assert job_plan["offers"][0]["price"] == 1.0

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_offer_with_fleets_uses_selected_fleet_offer_collection(
self,
test_db,
session: AsyncSession,
client: AsyncClient,
) -> None:
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session,
project=project,
user=user,
project_role=ProjectRole.USER,
)
repo = await create_repo(session=session, project_id=project.id)
selected_fleets = ["fleet-a", "fleet-b"]
run_spec = get_run_spec(
repo_id=repo.name,
profile=Profile(name="default", fleets=selected_fleets),
configuration=TaskConfiguration(
commands=[":"],
image="scratch",
user="root",
fleets=selected_fleets,
),
)
fleet_offer = get_instance_offer_with_availability(price=2.0)
with (
patch(
"dstack._internal.server.services.runs.plan._get_non_fleet_offers",
new=AsyncMock(
side_effect=AssertionError("_get_non_fleet_offers should not be called")
),
) as get_non_fleet_offers_mock,
patch(
"dstack._internal.server.services.runs.plan._get_offers_in_run_candidate_fleets",
new=AsyncMock(return_value=([(Mock(), fleet_offer)], [])),
) as get_offers_in_run_candidate_fleets_mock,
patch(
"dstack._internal.server.services.runs.plan.find_optimal_fleet_with_offers",
new=AsyncMock(
side_effect=AssertionError(
"find_optimal_fleet_with_offers should not be called"
)
),
) as find_optimal_fleet_with_offers_mock,
):
response = await client.post(
f"/api/project/{project.name}/runs/get_plan",
headers=get_auth_headers(user.token),
json={"run_spec": run_spec.dict()},
)

assert response.status_code == 200, response.json()
get_non_fleet_offers_mock.assert_not_called()
get_offers_in_run_candidate_fleets_mock.assert_awaited_once()
find_optimal_fleet_with_offers_mock.assert_not_called()
job_plan = response.json()["job_plans"][0]
assert job_plan["total_offers"] == 1
assert job_plan["offers"][0]["price"] == 2.0

@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_regular_run_plan_uses_best_fleet_candidate_selection(
self,
test_db,
session: AsyncSession,
client: AsyncClient,
) -> None:
user = await create_user(session=session, global_role=GlobalRole.USER)
project = await create_project(session=session, owner=user)
await add_project_member(
session=session,
project=project,
user=user,
project_role=ProjectRole.USER,
)
repo = await create_repo(session=session, project_id=project.id)
run_spec = get_run_spec(
repo_id=repo.name,
configuration=TaskConfiguration(
commands=["echo ok"],
image="scratch",
user="root",
),
)
chosen_fleet_offer = get_instance_offer_with_availability(price=3.0)
with (
patch(
"dstack._internal.server.services.runs.plan._select_candidate_fleet_models",
new=AsyncMock(return_value=[Mock()]),
) as select_candidate_fleet_models_mock,
patch(
"dstack._internal.server.services.runs.plan.find_optimal_fleet_with_offers",
new=AsyncMock(return_value=(Mock(), [(Mock(), chosen_fleet_offer)], [])),
) as find_optimal_fleet_with_offers_mock,
patch(
"dstack._internal.server.services.runs.plan._get_non_fleet_offers",
new=AsyncMock(
side_effect=AssertionError("_get_non_fleet_offers should not be called")
),
) as get_non_fleet_offers_mock,
patch(
"dstack._internal.server.services.runs.plan._get_offers_in_run_candidate_fleets",
new=AsyncMock(
side_effect=AssertionError(
"_get_offers_in_run_candidate_fleets should not be called"
)
),
) as get_offers_in_run_candidate_fleets_mock,
):
response = await client.post(
f"/api/project/{project.name}/runs/get_plan",
headers=get_auth_headers(user.token),
json={"run_spec": run_spec.dict()},
)

assert response.status_code == 200, response.json()
select_candidate_fleet_models_mock.assert_awaited_once()
find_optimal_fleet_with_offers_mock.assert_awaited_once()
get_non_fleet_offers_mock.assert_not_called()
get_offers_in_run_candidate_fleets_mock.assert_not_called()
job_plan = response.json()["job_plans"][0]
assert job_plan["total_offers"] == 1
assert job_plan["offers"][0]["price"] == 3.0

@pytest.mark.parametrize(
("client_version", "expected_availability"),
[
Expand Down
Loading