From c1d9fa9be6235a92f0f6d67d615fa7c6c5243686 Mon Sep 17 00:00:00 2001 From: Andrey Cheptsov Date: Fri, 12 Jun 2026 09:46:29 +0200 Subject: [PATCH] Support targeting specific run instances Co-authored-by: Federico D'Agostino --- .../reference/dstack.yml/dev-environment.md | 34 + mkdocs/docs/reference/dstack.yml/service.md | 34 + mkdocs/docs/reference/dstack.yml/task.md | 34 + .../_internal/core/compatibility/common.py | 2 + .../_internal/core/compatibility/runs.py | 2 + src/dstack/_internal/core/models/profiles.py | 71 +++ .../pipeline_tasks/jobs_submitted.py | 108 ++++ .../_internal/server/services/instances.py | 181 +++++- .../_internal/server/services/runs/plan.py | 172 ++++- .../_internal/core/models/test_profiles.py | 83 ++- .../pipeline_tasks/test_submitted_jobs.py | 603 +++++++++++++++++- .../_internal/server/routers/test_fleets.py | 3 + .../_internal/server/routers/test_runs.py | 4 + .../server/services/runs/test_plan.py | 503 ++++++++++++++- .../server/services/test_instances.py | 131 +++- 15 files changed, 1948 insertions(+), 17 deletions(-) diff --git a/mkdocs/docs/reference/dstack.yml/dev-environment.md b/mkdocs/docs/reference/dstack.yml/dev-environment.md index 594b6e65d7..b2384d679a 100644 --- a/mkdocs/docs/reference/dstack.yml/dev-environment.md +++ b/mkdocs/docs/reference/dstack.yml/dev-environment.md @@ -34,6 +34,40 @@ The `dev-environment` configuration type allows running [dev environments](../.. type: required: true +### `instances[n]` { #_instances data-toc-label="instances" } + +When `instances` is set, the run is placed only on matching existing fleet instances. + +=== "By name" + + #SCHEMA# dstack._internal.core.models.profiles.InstanceNameSelector + overrides: + show_root_heading: false + type: + required: true + +=== "By hostname" + + #SCHEMA# dstack._internal.core.models.profiles.InstanceHostnameSelector + overrides: + show_root_heading: false + type: + required: true + +=== "By fleet and instance number" + + #SCHEMA# dstack._internal.core.models.profiles.FleetInstanceSelector + overrides: + show_root_heading: false + type: + required: true + +??? info "Short syntax" + + The short syntax for instances is an instance name string. + + * `my-fleet-1`, same as `{name: my-fleet-1}` + ### `resources` #SCHEMA# dstack._internal.core.models.resources.ResourcesSpec diff --git a/mkdocs/docs/reference/dstack.yml/service.md b/mkdocs/docs/reference/dstack.yml/service.md index 5ddfe46dd1..5f3aa3bd16 100644 --- a/mkdocs/docs/reference/dstack.yml/service.md +++ b/mkdocs/docs/reference/dstack.yml/service.md @@ -114,6 +114,40 @@ The `service` configuration type allows running [services](../../concepts/servic type: required: true +### `instances[n]` { #_instances data-toc-label="instances" } + +When `instances` is set, the run is placed only on matching existing fleet instances. + +=== "By name" + + #SCHEMA# dstack._internal.core.models.profiles.InstanceNameSelector + overrides: + show_root_heading: false + type: + required: true + +=== "By hostname" + + #SCHEMA# dstack._internal.core.models.profiles.InstanceHostnameSelector + overrides: + show_root_heading: false + type: + required: true + +=== "By fleet and instance number" + + #SCHEMA# dstack._internal.core.models.profiles.FleetInstanceSelector + overrides: + show_root_heading: false + type: + required: true + +??? info "Short syntax" + + The short syntax for instances is an instance name string. + + * `my-fleet-1`, same as `{name: my-fleet-1}` + ### `resources` #SCHEMA# dstack._internal.core.models.resources.ResourcesSpec diff --git a/mkdocs/docs/reference/dstack.yml/task.md b/mkdocs/docs/reference/dstack.yml/task.md index 96d05c325d..104333c1bf 100644 --- a/mkdocs/docs/reference/dstack.yml/task.md +++ b/mkdocs/docs/reference/dstack.yml/task.md @@ -34,6 +34,40 @@ The `task` configuration type allows running [tasks](../../concepts/tasks.md). type: required: true +### `instances[n]` { #_instances data-toc-label="instances" } + +When `instances` is set, the run is placed only on matching existing fleet instances. + +=== "By name" + + #SCHEMA# dstack._internal.core.models.profiles.InstanceNameSelector + overrides: + show_root_heading: false + type: + required: true + +=== "By hostname" + + #SCHEMA# dstack._internal.core.models.profiles.InstanceHostnameSelector + overrides: + show_root_heading: false + type: + required: true + +=== "By fleet and instance number" + + #SCHEMA# dstack._internal.core.models.profiles.FleetInstanceSelector + overrides: + show_root_heading: false + type: + required: true + +??? info "Short syntax" + + The short syntax for instances is an instance name string. + + * `my-fleet-1`, same as `{name: my-fleet-1}` + ### `resources` #SCHEMA# dstack._internal.core.models.resources.ResourcesSpec diff --git a/src/dstack/_internal/core/compatibility/common.py b/src/dstack/_internal/core/compatibility/common.py index 789e2d120a..8a34b40579 100644 --- a/src/dstack/_internal/core/compatibility/common.py +++ b/src/dstack/_internal/core/compatibility/common.py @@ -10,6 +10,8 @@ def get_profile_excludes(profile: Optional[ProfileParams]) -> IncludeExcludeSetT return excludes if profile.backend_options is None: excludes.add("backend_options") + if profile.instances is None: + excludes.add("instances") return excludes diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index 3fb92d9215..4b57db1d47 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -83,6 +83,8 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: spec_excludes: IncludeExcludeDictType = {} configuration_excludes: IncludeExcludeDictType = {} profile_excludes = get_profile_excludes(run_spec.profile) + for field in get_profile_excludes(run_spec.configuration): + configuration_excludes[field] = True if run_spec.configuration.backend_options is None: configuration_excludes["backend_options"] = True diff --git a/src/dstack/_internal/core/models/profiles.py b/src/dstack/_internal/core/models/profiles.py index f1beebd5b9..7a448486df 100644 --- a/src/dstack/_internal/core/models/profiles.py +++ b/src/dstack/_internal/core/models/profiles.py @@ -234,6 +234,58 @@ def crons(self) -> List[str]: return self.cron +class InstanceNameSelector(CoreModel): + name: Annotated[str, Field(description="The fleet instance name", min_length=1)] + + +class InstanceHostnameSelector(CoreModel): + hostname: Annotated[ + str, Field(description="The fleet instance hostname or IP address", min_length=1) + ] + + +def _parse_fleet_instance_selector_fleet(v: Any) -> Any: + if isinstance(v, str): + return EntityReference.parse(v) + return v + + +class FleetInstanceSelectorConfig(CoreConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + add_extra_schema_types( + schema["properties"]["fleet"], + extra_types=[{"type": "string", "minLength": 1}], + ) + + +class FleetInstanceSelector(generate_dual_core_model(FleetInstanceSelectorConfig)): + fleet: Annotated[ + EntityReference, + Field( + description=( + "The fleet reference. For fleets owned by the current project, specify" + " the fleet name. For a fleet from another project, specify" + " `/` or an object with `project` and `name`." + ), + ), + ] + instance: Annotated[int, Field(description="The fleet instance number", ge=0)] + + _validate_fleet = validator("fleet", pre=True, allow_reuse=True)( + _parse_fleet_instance_selector_fleet + ) + + +InstanceSelector = Union[InstanceNameSelector, InstanceHostnameSelector, FleetInstanceSelector] + + +def parse_instance_selector(v: Union[InstanceSelector, str]) -> InstanceSelector: + if isinstance(v, str): + return InstanceNameSelector(name=v) + return v + + class ProfileParamsConfig(CoreConfig): @staticmethod def schema_extra(schema: Dict[str, Any]): @@ -249,6 +301,10 @@ def schema_extra(schema: Dict[str, Any]): schema["properties"]["idle_duration"], extra_types=[{"type": "string"}], ) + add_extra_schema_types( + schema["properties"]["instances"]["items"], + extra_types=[{"type": "string", "minLength": 1}], + ) class ProfileParams(CoreModel): @@ -391,6 +447,18 @@ class ProfileParams(CoreModel): ), ), ] = None + instances: Annotated[ + Optional[List[InstanceSelector]], + Field( + description=( + "The specific fleet instances to consider for reuse." + " Each value can be an instance name string, or an object with" + " `name`, `hostname`, or `fleet` and `instance`." + " When set, the run is only placed on matching existing instances." + ), + min_items=1, + ), + ] = None tags: Annotated[ Optional[Dict[str, str]], Field( @@ -416,6 +484,9 @@ class ProfileParams(CoreModel): parse_idle_duration ) _validate_fleets = validator("fleets", allow_reuse=True, each_item=True)(EntityReference.parse) + _validate_instances = validator("instances", pre=True, allow_reuse=True, each_item=True)( + parse_instance_selector + ) _validate_tags = validator("tags", pre=True, allow_reuse=True)(tags_validator) _validate_backend_options = validator("backend_options", allow_reuse=True)( validate_backend_options diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py index 2869374d55..35f613833c 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py @@ -131,6 +131,7 @@ get_instance_offers_in_fleet, get_run_candidate_fleet_models_filters, get_run_profile_and_requirements_in_fleet, + get_targeted_instance_offers, select_run_candidate_fleet_models_with_filters, ) from dstack._internal.server.services.runs.spec import ( @@ -494,6 +495,12 @@ async def _process_assignment(context: _SubmittedJobContext) -> _AssignmentResul if not isinstance(preconditions, _ProcessedPreconditions): return preconditions + if context.run.run_spec.merged_profile.instances is not None: + return await _select_targeted_instance_assignment( + context=context, + preconditions=preconditions, + ) + candidate_fleet_models = await _load_assignment_candidate_fleets(context=context) return await _select_assignment( context=context, @@ -533,6 +540,30 @@ async def _select_assignment( return _NewCapacityAssignment(fleet_id=fleet_model.id) +async def _select_targeted_instance_assignment( + context: _SubmittedJobContext, + preconditions: _ProcessedPreconditions, +) -> _AssignmentResult: + async with get_session_ctx() as session: + instance_offers = await get_targeted_instance_offers( + session=session, + project=context.project, + run_spec=context.run.run_spec, + job=context.job, + master_job_provisioning_data=preconditions.master_job_provisioning_data, + volumes=preconditions.prepared_job_volumes.volumes, + exclude_not_available=True, + fleet_id=context.run_model.fleet_id, + ) + if len(instance_offers) < _get_required_targeted_instance_offers(context): + return _NoFleetAssignment() + return _ExistingInstanceAssignment( + fleet_id=get_or_error(instance_offers[0][0].fleet_id), + master_job_provisioning_data=preconditions.master_job_provisioning_data, + volumes=preconditions.prepared_job_volumes.volumes, + ) + + async def _apply_assignment_result( item: JobSubmittedPipelineItem, context: _SubmittedJobContext, @@ -621,6 +652,28 @@ async def _apply_assignment_result( return async with AsyncExitStack() as exit_stack: + if context.run.run_spec.merged_profile.instances is not None: + current_instance_offers = await _lock_targeted_instance_offers_for_assignment( + exit_stack=exit_stack, + session=session, + context=context, + assignment=assignment, + ) + if len(current_instance_offers) < _get_required_targeted_instance_offers(context): + await _reset_job_lock_for_retry(session=session, item=item) + return + + instance_model, current_offer = current_instance_offers[0] + _assign_instance_to_job( + session=session, + job_model=job_model, + instance_model=instance_model, + offer=current_offer, + multinode=context.multinode, + ) + await _mark_job_processed(session=session, job_model=job_model) + return + fleet_model = await _lock_assignment_fleet_for_existing_instance_assignment( exit_stack=exit_stack, session=session, @@ -905,6 +958,16 @@ async def _apply_no_fleet_selection( job_model: JobModel, run: Run, ) -> None: + if run.run_spec.merged_profile.instances is not None: + logger.debug("%s: failed to use specified instances", fmt(job_model)) + await _terminate_submitted_job( + session=session, + job_model=job_model, + reason=JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY, + message="Failed to use specified instances", + ) + return + if run.run_spec.merged_profile.fleets is not None: logger.debug("%s: failed to use specified fleets", fmt(job_model)) await _terminate_submitted_job( @@ -927,6 +990,45 @@ async def _apply_no_fleet_selection( ) +async def _lock_targeted_instance_offers_for_assignment( + exit_stack: AsyncExitStack, + session: AsyncSession, + context: _SubmittedJobContext, + assignment: _ExistingInstanceAssignment, +) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]: + instance_offers = await get_targeted_instance_offers( + session=session, + project=context.project, + run_spec=context.run.run_spec, + job=context.job, + master_job_provisioning_data=assignment.master_job_provisioning_data, + volumes=assignment.volumes, + exclude_not_available=True, + fleet_id=assignment.fleet_id, + lock_instances=True, + ) + instance_ids = sorted(instance.id for instance, _ in instance_offers) + if not instance_ids or not is_db_sqlite(): + return instance_offers + + await sqlite_commit(session) + await exit_stack.enter_async_context( + get_locker(get_db().dialect_name).lock_ctx(InstanceModel.__tablename__, instance_ids) + ) + return await get_targeted_instance_offers( + session=session, + project=context.project, + run_spec=context.run.run_spec, + job=context.job, + master_job_provisioning_data=assignment.master_job_provisioning_data, + volumes=assignment.volumes, + exclude_not_available=True, + fleet_id=assignment.fleet_id, + instance_ids=instance_ids, + lock_instances=True, + ) + + async def _lock_assignment_fleet_for_existing_instance_assignment( exit_stack: AsyncExitStack, session: AsyncSession, @@ -2046,6 +2148,12 @@ def _select_jobs_to_provision(job: Job, replica_jobs: list[Job], job_model: JobM return jobs_to_provision +def _get_required_targeted_instance_offers(context: _SubmittedJobContext) -> int: + if is_multinode_job(context.job) and is_master_job(context.job): + return len(context.jobs_to_provision) + return 1 + + def _release_replica_jobs_from_master_wait( job_model: JobModel, replica_job_models: list[JobModel], diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index ad48ff1f51..913d3c9f44 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -2,12 +2,12 @@ import uuid from collections.abc import Container, Iterable from datetime import datetime -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Sequence, Union import gpuhunt from sqlalchemy import and_, exists, false, or_, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload, load_only +from sqlalchemy.orm import contains_eager, joinedload, load_only from dstack._internal.core.backends.base.offers import ( offer_to_catalog_item, @@ -16,6 +16,7 @@ from dstack._internal.core.backends.features import BACKENDS_WITH_MULTINODE_SUPPORT from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.envs import Env from dstack._internal.core.models.health import HealthCheck, HealthEvent, HealthStatus from dstack._internal.core.models.instances import ( @@ -34,6 +35,10 @@ ) from dstack._internal.core.models.profiles import ( DEFAULT_FLEET_TERMINATION_IDLE_TIME, + FleetInstanceSelector, + InstanceHostnameSelector, + InstanceNameSelector, + InstanceSelector, Profile, TerminationPolicy, ) @@ -375,6 +380,178 @@ def get_instance_ssh_private_keys(instance_model: InstanceModel) -> tuple[str, O return host_private_key, proxy_private_keys[0] +async def select_instances_by_selectors( + session: AsyncSession, + project: ProjectModel, + selectors: Sequence[InstanceSelector], + *, + fleets: Optional[Sequence[Union[EntityReference, str]]] = None, + detaching_instance_ids: Optional[Sequence[uuid.UUID]] = None, + fleet_id: Optional[uuid.UUID] = None, + instance_ids: Optional[Sequence[uuid.UUID]] = None, + lock_instances: bool = False, +) -> list[InstanceModel]: + if instance_ids is not None and len(instance_ids) == 0: + return [] + is_instance_imported_subquery = exists().where( + ImportModel.project_id == project.id, + ImportModel.export_id == ExportedFleetModel.export_id, + ExportedFleetModel.fleet_id == InstanceModel.fleet_id, + ) + filters = [ + or_( + InstanceModel.project_id == project.id, + is_instance_imported_subquery, + ), + FleetModel.deleted == False, + InstanceModel.deleted == False, + ] + if detaching_instance_ids is not None: + filters.append(InstanceModel.id.not_in(detaching_instance_ids)) + if fleet_id is not None: + filters.append(InstanceModel.fleet_id == fleet_id) + if instance_ids is not None: + filters.append(InstanceModel.id.in_(instance_ids)) + if fleets is not None: + filters.append( + or_( + *[ + _get_fleet_reference_condition(project, EntityReference.parse(fleet)) + for fleet in fleets + ] + ) + ) + selector_conditions = _get_instance_selector_conditions(project, selectors) + if selector_conditions: + filters.append(or_(*selector_conditions)) + + stmt = ( + select(InstanceModel) + .join(InstanceModel.fleet) + .join(FleetModel.project) + .where(*filters) + .options( + contains_eager(InstanceModel.fleet) + .load_only(FleetModel.id, FleetModel.name, FleetModel.project_id, FleetModel.spec) + .contains_eager(FleetModel.project) + .load_only(ProjectModel.name) + ) + ) + if lock_instances: + stmt = stmt.where(InstanceModel.lock_expires_at.is_(None)) + stmt = stmt.order_by(InstanceModel.id).with_for_update( + skip_locked=True, key_share=True, of=InstanceModel + ) + res = await session.execute(stmt) + instances = list(res.unique().scalars().all()) + return [ + instance + for instance in instances + if instance_matches_selectors(instance, selectors, project=project) + ] + + +def instance_matches_selectors( + instance: InstanceModel, + selectors: Sequence[InstanceSelector], + *, + project: ProjectModel, +) -> bool: + return any( + instance_matches_selector(instance, selector, project=project) for selector in selectors + ) + + +def instance_matches_selector( + instance: InstanceModel, + selector: InstanceSelector, + *, + project: ProjectModel, +) -> bool: + if isinstance(selector, InstanceNameSelector): + return instance.name == selector.name + if isinstance(selector, InstanceHostnameSelector): + return instance_matches_hostname_selector(instance, selector) + if isinstance(selector, FleetInstanceSelector): + return _instance_matches_fleet_instance_selector(instance, selector, project=project) + return False + + +def instance_matches_hostname_selector( + instance: InstanceModel, selector: InstanceHostnameSelector +) -> bool: + candidates = set() + jpd = get_instance_provisioning_data(instance) + if jpd is not None: + if jpd.hostname is not None: + candidates.add(jpd.hostname.lower()) + if jpd.internal_ip is not None: + candidates.add(jpd.internal_ip.lower()) + rci = get_instance_remote_connection_info(instance) + if rci is not None: + candidates.add(rci.host.lower()) + return selector.hostname.lower() in candidates + + +def _instance_matches_fleet_instance_selector( + instance: InstanceModel, + selector: FleetInstanceSelector, + *, + project: ProjectModel, +) -> bool: + fleet = instance.fleet + if fleet is None: + return False + if fleet.name != selector.fleet.name: + return False + if instance.instance_num != selector.instance: + return False + if selector.fleet.project is None: + return fleet.project_id == project.id + return fleet.project.name == selector.fleet.project + + +def _get_instance_selector_conditions( + project: ProjectModel, + selectors: Sequence[InstanceSelector], +) -> list: + conditions = [] + for selector in selectors: + if isinstance(selector, InstanceNameSelector): + conditions.append(InstanceModel.name == selector.name) + elif isinstance(selector, InstanceHostnameSelector): + conditions.append(_get_hostname_selector_condition(selector)) + elif isinstance(selector, FleetInstanceSelector): + conditions.append( + and_( + _get_fleet_reference_condition(project, selector.fleet), + InstanceModel.instance_num == selector.instance, + ) + ) + return conditions + + +def _get_fleet_reference_condition(project: ProjectModel, ref: EntityReference): + if ref.project is None: + return and_( + FleetModel.name == ref.name, + FleetModel.project_id == project.id, + ) + return and_( + FleetModel.name == ref.name, + ProjectModel.name == ref.project, + ) + + +def _get_hostname_selector_condition(selector: InstanceHostnameSelector): + # This is only a DB prefilter. `instance_matches_selector` parses these JSON columns + # and performs the exact hostname/internal IP comparison in memory. + return or_( + InstanceModel.job_provisioning_data.icontains(selector.hostname, autoescape=True), + InstanceModel.remote_connection_info.icontains(selector.hostname, autoescape=True), + ) + + def instance_matches_constraints( instance: InstanceModel, *, diff --git a/src/dstack/_internal/server/services/runs/plan.py b/src/dstack/_internal/server/services/runs/plan.py index fa43f1576b..8d272bdb2c 100644 --- a/src/dstack/_internal/server/services/runs/plan.py +++ b/src/dstack/_internal/server/services/runs/plan.py @@ -1,4 +1,5 @@ import math +import uuid from collections.abc import Hashable, Mapping from dataclasses import dataclass from enum import Enum @@ -45,11 +46,13 @@ get_pool_instances, get_shared_instances_with_offers, is_placeholder_instance, + select_instances_by_selectors, ) from dstack._internal.server.services.jobs import ( get_instances_ids_with_detaching_volumes, get_job_configured_volumes, get_jobs_from_run_spec, + is_master_job, is_multinode_job, remove_job_spec_sensitive_info, ) @@ -115,7 +118,7 @@ async def get_job_plans( job_num=0, ) - if _should_select_best_fleet_candidate(run_spec): + if _should_select_best_fleet_candidate(run_spec) and run_spec.merged_profile.instances is None: candidate_fleet_models = await _select_candidate_fleet_models( session=session, project=project, @@ -137,8 +140,17 @@ async def get_job_plans( replica_num=0, replica_group_name=replica_group_name, ) - if candidate_fleet_models is None: # `dstack offer` path - if profile.fleets is None: + if candidate_fleet_models is None: + if profile.instances is not None: + instance_offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=jobs[0], + volumes=volumes, + ) + backend_offers = [] + elif profile.fleets is None: instance_offers, backend_offers = await _get_non_fleet_offers( session=session, project=project, @@ -495,11 +507,29 @@ def get_instance_offers_in_fleet( master_job_provisioning_data: Optional[JobProvisioningData] = None, volumes: Optional[list[list[Volume]]] = None, exclude_not_available: bool = False, +) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]: + return get_instance_offers_from_instances( + instances=fleet_model.instances, + run_spec=run_spec, + job=job, + master_job_provisioning_data=master_job_provisioning_data, + volumes=volumes, + exclude_not_available=exclude_not_available, + ) + + +def get_instance_offers_from_instances( + instances: list[InstanceModel], + run_spec: RunSpec, + job: Job, + master_job_provisioning_data: Optional[JobProvisioningData] = None, + volumes: Optional[list[list[Volume]]] = None, + exclude_not_available: bool = False, ) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]: profile = run_spec.merged_profile multinode = is_multinode_job(job) nonshared_instances = filter_instances( - instances=fleet_model.instances, + instances=instances, profile=profile, requirements=job.job_spec.requirements, multinode=multinode, @@ -509,7 +539,7 @@ def get_instance_offers_in_fleet( ) instances_with_offers = _get_offers_from_instances(nonshared_instances) shared_instances_with_offers = get_shared_instances_with_offers( - instances=fleet_model.instances, + instances=instances, profile=profile, requirements=job.job_spec.requirements, multinode=multinode, @@ -522,6 +552,113 @@ def get_instance_offers_in_fleet( return instances_with_offers +async def get_targeted_instance_offers( + session: AsyncSession, + project: ProjectModel, + run_spec: RunSpec, + job: Job, + master_job_provisioning_data: Optional[JobProvisioningData] = None, + volumes: Optional[list[list[Volume]]] = None, + exclude_not_available: bool = False, + fleet_id: Optional[uuid.UUID] = None, + instance_ids: Optional[list[uuid.UUID]] = None, + lock_instances: bool = False, +) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]: + selectors = common_utils.get_or_error(run_spec.merged_profile.instances) + detaching_instance_ids = await get_instances_ids_with_detaching_volumes(session) + instances = await select_instances_by_selectors( + session=session, + project=project, + selectors=selectors, + fleets=run_spec.merged_profile.fleets, + detaching_instance_ids=detaching_instance_ids, + fleet_id=fleet_id, + instance_ids=instance_ids, + lock_instances=lock_instances, + ) + return select_targeted_instance_offers( + instances=instances, + run_spec=run_spec, + job=job, + master_job_provisioning_data=master_job_provisioning_data, + volumes=volumes, + exclude_not_available=exclude_not_available, + ) + + +def select_targeted_instance_offers( + instances: list[InstanceModel], + run_spec: RunSpec, + job: Job, + master_job_provisioning_data: Optional[JobProvisioningData] = None, + volumes: Optional[list[list[Volume]]] = None, + exclude_not_available: bool = False, +) -> list[tuple[InstanceModel, InstanceOfferWithAvailability]]: + candidates: list[_TargetedInstanceOffersCandidate] = [] + for fleet_instances in _group_instances_by_fleet(instances).values(): + fleet = common_utils.get_or_error(fleet_instances[0].fleet) + fleet_spec = get_fleet_spec(fleet) + if ( + is_multinode_job(job) + and fleet_spec.configuration.placement != InstanceGroupPlacement.CLUSTER + ): + continue + all_offers = get_instance_offers_from_instances( + instances=fleet_instances, + run_spec=run_spec, + job=job, + master_job_provisioning_data=master_job_provisioning_data, + volumes=volumes, + exclude_not_available=False, + ) + if len(all_offers) < _get_required_instance_offers(run_spec, job): + continue + available_offers = _exclude_non_available_instance_offers(all_offers) + if exclude_not_available: + all_offers = available_offers + if all_offers: + has_capacity = len(available_offers) >= _get_required_instance_offers(run_spec, job) + candidates.append( + _TargetedInstanceOffersCandidate( + lacks_capacity=not has_capacity, + available_price=_get_min_instance_or_backend_offer_price(available_offers), + selected_price=_get_min_instance_or_backend_offer_price(all_offers), + offers=all_offers, + ) + ) + if not candidates: + return [] + return min(candidates, key=lambda candidate: candidate.sort_key()).offers + + +@dataclass(frozen=True) +class _TargetedInstanceOffersCandidate: + lacks_capacity: bool + available_price: float + selected_price: float + offers: list[tuple[InstanceModel, InstanceOfferWithAvailability]] + + def sort_key(self) -> tuple[bool, float, float]: + return self.lacks_capacity, self.available_price, self.selected_price + + +def _group_instances_by_fleet( + instances: list[InstanceModel], +) -> dict[uuid.UUID, list[InstanceModel]]: + instances_by_fleet: dict[uuid.UUID, list[InstanceModel]] = {} + for instance in instances: + if instance.fleet_id is None: + continue + instances_by_fleet.setdefault(instance.fleet_id, []).append(instance) + return instances_by_fleet + + +def _get_required_instance_offers(run_spec: RunSpec, job: Job) -> int: + if is_multinode_job(job) and is_master_job(job): + return get_nodes_required_num(run_spec) + return 1 + + def _run_can_fit_into_fleet( run_spec: RunSpec, fleet_model: FleetModel, fleet_spec: FleetSpec ) -> bool: @@ -658,6 +795,16 @@ async def _get_non_fleet_offers( Returns instance and backend offers for job irrespective of fleets, i.e. all pool instances and project backends matching the spec. """ + if profile.instances is not None: + instance_offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=job, + volumes=volumes, + ) + return instance_offers, [] + instance_offers = await _get_pool_offers( session=session, project=project, @@ -693,6 +840,9 @@ async def get_backend_offers_in_run_candidate_fleets( It resolves the selected fleets from `run_spec`, requests backend offers in each fleet, merges them, and deduplicates identical backend offers across fleets. """ + if run_spec.merged_profile.instances is not None: + return [] + candidate_fleet_models = await _select_candidate_fleet_models( session=session, project=project, @@ -737,6 +887,16 @@ async def _get_offers_in_run_candidate_fleets( offers from each selected fleet, keeps existing instances as separate reusable options, and deduplicates identical backend offers across fleets. """ + if run_spec.merged_profile.instances is not None: + instance_offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=job, + volumes=volumes, + ) + return instance_offers, [] + candidate_fleet_models = await _select_candidate_fleet_models( session=session, project=project, @@ -816,7 +976,7 @@ def _get_job_plan( ) -> JobPlan: job_offers: list[InstanceOfferWithAvailability] = [] job_offers.extend(offer for _, offer in instance_offers) - if profile.creation_policy == CreationPolicy.REUSE_OR_CREATE: + if profile.creation_policy == CreationPolicy.REUSE_OR_CREATE and profile.instances is None: job_offers.extend(offer for _, offer in backend_offers) job_offers.sort(key=lambda offer: not offer.availability.is_available()) remove_job_spec_sensitive_info(job.job_spec) diff --git a/src/tests/_internal/core/models/test_profiles.py b/src/tests/_internal/core/models/test_profiles.py index 4a1caf8bf8..246435f6f4 100644 --- a/src/tests/_internal/core/models/test_profiles.py +++ b/src/tests/_internal/core/models/test_profiles.py @@ -2,7 +2,14 @@ from pydantic import ValidationError from dstack._internal.core.backends.vastai.profile_options import VastAIProfileOptions -from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.compatibility.common import get_profile_excludes +from dstack._internal.core.models.common import EntityReference +from dstack._internal.core.models.profiles import ( + FleetInstanceSelector, + InstanceHostnameSelector, + InstanceNameSelector, + Profile, +) class TestValidateProfileBackendOptions: @@ -27,3 +34,77 @@ def test_none_backend_options_is_valid(self): def test_empty_list_backend_options_is_valid(self): profile = Profile(backend_options=[]) assert profile.backend_options == [] + + +class TestProfileInstances: + def test_string_is_parsed_as_instance_name_selector(self): + profile = Profile.parse_obj({"instances": ["my-fleet-1"]}) + + assert profile.instances == [InstanceNameSelector(name="my-fleet-1")] + + @pytest.mark.parametrize( + ("value", "expected"), + [ + ({"name": "my-fleet-1"}, InstanceNameSelector(name="my-fleet-1")), + ({"hostname": "worker-1"}, InstanceHostnameSelector(hostname="worker-1")), + ( + {"fleet": "my-fleet", "instance": 3}, + FleetInstanceSelector(fleet="my-fleet", instance=3), + ), + ( + {"fleet": "other-project/my-fleet", "instance": 3}, + FleetInstanceSelector(fleet="other-project/my-fleet", instance=3), + ), + ], + ) + def test_object_selectors_are_parsed(self, value, expected): + profile = Profile.parse_obj({"instances": [value]}) + + assert profile.instances == [expected] + + def test_parses_fleet_selector_object_notation(self): + profile = Profile.parse_obj( + {"instances": [{"fleet": {"project": "main", "name": "my-fleet"}, "instance": 0}]} + ) + + assert profile.instances == [ + FleetInstanceSelector( + fleet=EntityReference(project="main", name="my-fleet"), instance=0 + ) + ] + + @pytest.mark.parametrize( + "value", + [ + "", + {"name": "my-fleet-1", "hostname": "worker-1"}, + {"name": ""}, + {"hostname": ""}, + {"fleet": "", "instance": 0}, + {"fleet": "project/name/extra", "instance": 0}, + {"fleet": "my-fleet"}, + {"fleet": "my-fleet", "instance": -1}, + {"hostname": "worker-1", "extra": "value"}, + ], + ) + def test_invalid_selector_is_rejected(self, value): + with pytest.raises(ValidationError): + Profile.parse_obj({"instances": [value]}) + + def test_empty_instances_list_is_rejected(self): + with pytest.raises(ValidationError): + Profile.parse_obj({"instances": []}) + + +class TestProfileInstancesCompatibilityExcludes: + def test_excludes_unset_instances(self): + profile = Profile() + + assert "instances" not in profile.dict(exclude=get_profile_excludes(profile)) + + def test_preserves_configured_instances(self): + profile = Profile(instances=[InstanceNameSelector(name="my-fleet-1")]) + + assert profile.dict(exclude=get_profile_excludes(profile))["instances"] == [ + {"name": "my-fleet-1"} + ] diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_submitted_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_submitted_jobs.py index fd9cf3b58a..b00ce59029 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_submitted_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_submitted_jobs.py @@ -2,7 +2,7 @@ import uuid from datetime import timedelta from typing import cast -from unittest.mock import Mock, call, patch +from unittest.mock import AsyncMock, Mock, call, patch import pytest from sqlalchemy import select @@ -11,14 +11,21 @@ from dstack._internal.core.errors import BackendError from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.common import RegistryAuth -from dstack._internal.core.models.configurations import TaskConfiguration +from dstack._internal.core.models.common import EntityReference, NetworkMode, RegistryAuth +from dstack._internal.core.models.configurations import ServiceConfiguration, TaskConfiguration from dstack._internal.core.models.envs import Env from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.placement import PlacementGroup -from dstack._internal.core.models.profiles import Profile -from dstack._internal.core.models.runs import JobStatus, JobTerminationReason +from dstack._internal.core.models.profiles import ( + FleetInstanceSelector, + InstanceHostnameSelector, + InstanceNameSelector, + InstanceSelector, + Profile, +) +from dstack._internal.core.models.resources import CPUSpec, Memory, Range, ResourcesSpec +from dstack._internal.core.models.runs import JobRuntimeData, JobStatus, JobTerminationReason from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.models.volumes import ( VolumeAttachmentData, @@ -61,6 +68,7 @@ get_instance_offer_with_availability, get_job_provisioning_data, get_placement_group_provisioning_data, + get_remote_connection_info, get_run_spec, get_ssh_fleet_configuration, get_volume_provisioning_data, @@ -1040,8 +1048,17 @@ async def test_ignores_lock_token_mismatch( assert job.lock_token is not None async def test_assigns_job_to_instance( - self, test_db, session: AsyncSession, worker: JobSubmittedWorker + self, + test_db, + session: AsyncSession, + worker: JobSubmittedWorker, + monkeypatch: pytest.MonkeyPatch, ): + get_targeted_instance_offers_mock = AsyncMock() + monkeypatch.setattr( + "dstack._internal.server.background.pipeline_tasks.jobs_submitted.get_targeted_instance_offers", + get_targeted_instance_offers_mock, + ) project = await create_project(session=session) user = await create_user(session=session) repo = await create_repo(session=session, project_id=project.id) @@ -1073,6 +1090,580 @@ async def test_assigns_job_to_instance( assert job.lock_expires_at is None assert instance.status == InstanceStatus.BUSY assert instance.busy_blocks == 1 + get_targeted_instance_offers_mock.assert_not_awaited() + + async def test_assigns_job_to_specific_instance( + self, test_db, session: AsyncSession, worker: JobSubmittedWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project, name="my-fleet") + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + name="my-fleet-0", + ) + selected = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + name="my-fleet-1", + ) + run_spec = get_run_spec( + repo_id=repo.name, + profile=Profile(instances=[InstanceNameSelector(name="my-fleet-1")]), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=fleet, + run_spec=run_spec, + ) + job = await create_job(session=session, run=run) + + await _process_job(session=session, worker=worker, job_model=job) + + job = await _get_job(session, job.id) + assert job.status == JobStatus.SUBMITTED + assert job.instance_assigned + assert job.instance is not None and job.instance.id == selected.id + assert job.fleet_id == fleet.id + + async def test_assigns_job_to_specific_hostname( + self, test_db, session: AsyncSession, worker: JobSubmittedWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project, name="my-fleet") + await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + remote_connection_info=get_remote_connection_info(host="192.168.1.10"), + ) + selected = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + remote_connection_info=get_remote_connection_info(host="192.168.1.11"), + ) + run_spec = get_run_spec( + repo_id=repo.name, + profile=Profile(instances=[InstanceHostnameSelector(hostname="192.168.1.11")]), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=fleet, + run_spec=run_spec, + ) + job = await create_job(session=session, run=run) + + await _process_job(session=session, worker=worker, job_model=job) + + job = await _get_job(session, job.id) + assert job.status == JobStatus.SUBMITTED + assert job.instance_assigned + assert job.instance is not None and job.instance.id == selected.id + assert job.fleet_id == fleet.id + + async def test_assigns_service_replicas_to_specific_shared_instance_blocks( + self, test_db, session: AsyncSession, worker: JobSubmittedWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project, name="my-fleet") + selected = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + name="shared-worker", + total_blocks=2, + busy_blocks=0, + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=ServiceConfiguration( + port=8080, + commands=["echo"], + replicas=Range[int](min=2, max=2), + resources=ResourcesSpec( + cpu=CPUSpec.parse("1"), + memory=Range[Memory](min=Memory.parse("1GB"), max=None), + gpu=None, + ), + ), + profile=Profile(instances=[InstanceNameSelector(name="shared-worker")]), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=fleet, + run_spec=run_spec, + ) + first_job = await create_job(session=session, run=run, replica_num=0) + second_job = await create_job(session=session, run=run, replica_num=1) + + await _process_job(session=session, worker=worker, job_model=first_job) + await session.refresh(selected) + assert selected.busy_blocks == 1 + + await _process_job(session=session, worker=worker, job_model=second_job) + + first_job = await _get_job(session, first_job.id) + second_job = await _get_job(session, second_job.id) + await session.refresh(selected) + assert first_job.instance is not None and first_job.instance.id == selected.id + assert second_job.instance is not None and second_job.instance.id == selected.id + assert selected.status == InstanceStatus.BUSY + assert selected.busy_blocks == 2 + + async def test_specific_instance_assignment_stays_in_run_fleet( + self, test_db, session: AsyncSession, worker: JobSubmittedWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_fleet = await create_fleet(session=session, project=project, name="run-fleet") + other_fleet = await create_fleet(session=session, project=project, name="other-fleet") + selected = await create_instance( + session=session, + project=project, + fleet=run_fleet, + status=InstanceStatus.IDLE, + name="run-fleet-0", + price=10, + ) + await create_instance( + session=session, + project=project, + fleet=other_fleet, + status=InstanceStatus.IDLE, + name="other-fleet-0", + price=1, + ) + run_spec = get_run_spec( + repo_id=repo.name, + profile=Profile( + instances=[ + InstanceNameSelector(name="run-fleet-0"), + InstanceNameSelector(name="other-fleet-0"), + ] + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=run_fleet, + run_spec=run_spec, + ) + job = await create_job(session=session, run=run) + + await _process_job(session=session, worker=worker, job_model=job) + + job = await _get_job(session, job.id) + assert job.status == JobStatus.SUBMITTED + assert job.instance_assigned + assert job.instance is not None and job.instance.id == selected.id + assert job.fleet_id == run_fleet.id + + async def test_assigns_job_to_specific_instance_in_imported_fleet( + self, test_db, session: AsyncSession, worker: JobSubmittedWorker + ): + exporter_user = await create_user( + session, name="exporter-user", global_role=GlobalRole.USER + ) + importer_user = await create_user( + session, name="importer-user", global_role=GlobalRole.USER + ) + exporter_project = await create_project( + session, name="exporter-project", owner=exporter_user + ) + importer_project = await create_project( + session, name="importer-project", owner=importer_user + ) + repo = await create_repo(session=session, project_id=importer_project.id) + local_fleet = await create_fleet( + session=session, + project=importer_project, + name="same-fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + exported_fleet = await create_fleet( + session=session, + project=exporter_project, + name="same-fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_instance( + session=session, + project=importer_project, + fleet=local_fleet, + status=InstanceStatus.IDLE, + instance_num=1, + name="local-worker", + ) + selected = await create_instance( + session=session, + project=exporter_project, + fleet=exported_fleet, + status=InstanceStatus.IDLE, + instance_num=1, + name="exported-worker", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[importer_project], + exported_fleets=[exported_fleet], + ) + selectors: list[InstanceSelector] = [ + FleetInstanceSelector( + fleet=EntityReference.parse("exporter-project/same-fleet"), + instance=1, + ) + ] + run_spec = get_run_spec( + repo_id=repo.name, + profile=Profile(instances=selectors), + ) + run = await create_run( + session=session, + project=importer_project, + repo=repo, + user=importer_user, + run_spec=run_spec, + ) + job = await create_job(session=session, run=run) + + await _process_job(session=session, worker=worker, job_model=job) + + job = await _get_job(session, job.id) + assert job.status == JobStatus.SUBMITTED + assert job.instance_assigned + assert job.instance is not None and job.instance.id == selected.id + assert job.fleet_id == exported_fleet.id + + async def test_does_not_assign_multinode_job_without_enough_specific_instances( + self, test_db, session: AsyncSession, worker: JobSubmittedWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER + fleet = await create_fleet(session=session, project=project, spec=fleet_spec) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + name="shared-worker", + backend=BackendType.AWS, + total_blocks=2, + busy_blocks=0, + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration(image="debian", nodes=2, commands=["echo"]), + profile=Profile(instances=[InstanceNameSelector(name="shared-worker")]), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=fleet, + run_spec=run_spec, + ) + master_job = await create_job( + session=session, + run=run, + job_num=0, + waiting_master_job=False, + ) + worker_job = await create_job( + session=session, + run=run, + job_num=1, + waiting_master_job=True, + ) + + await _process_job(session=session, worker=worker, job_model=master_job) + + master_job = await _get_job(session, master_job.id) + await session.refresh(worker_job) + await session.refresh(instance) + assert master_job.status == JobStatus.TERMINATING + assert ( + master_job.termination_reason + == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY + ) + assert not master_job.instance_assigned + assert worker_job.waiting_master_job + assert instance.status == InstanceStatus.IDLE + assert instance.busy_blocks == 0 + + async def test_assigns_multinode_jobs_to_specific_shared_ssh_instances( + self, test_db, session: AsyncSession, worker: JobSubmittedWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet( + session=session, + project=project, + spec=get_fleet_spec( + get_ssh_fleet_configuration( + hosts=["10.0.0.1", "10.0.0.2"], + placement=InstanceGroupPlacement.CLUSTER, + blocks=2, + ) + ), + ) + selected_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + name="worker-0", + backend=BackendType.REMOTE, + region="remote", + price=1, + total_blocks=2, + busy_blocks=0, + offer=get_instance_offer_with_availability( + backend=BackendType.REMOTE, + region="remote", + cpu_count=2, + memory_gib=4, + total_blocks=2, + ), + job_provisioning_data=get_job_provisioning_data( + backend=BackendType.REMOTE, + region="remote", + cpu_count=2, + memory_gib=4, + ), + ) + selected_worker = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + name="worker-1", + backend=BackendType.REMOTE, + region="remote", + price=2, + instance_num=1, + total_blocks=2, + busy_blocks=0, + offer=get_instance_offer_with_availability( + backend=BackendType.REMOTE, + region="remote", + cpu_count=2, + memory_gib=4, + total_blocks=2, + ), + job_provisioning_data=get_job_provisioning_data( + backend=BackendType.REMOTE, + region="remote", + cpu_count=2, + memory_gib=4, + ), + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration( + image="debian", + nodes=2, + commands=["echo"], + resources=ResourcesSpec( + cpu=CPUSpec.parse("1.."), + memory=Range[Memory](min=Memory.parse("1GB"), max=None), + gpu=None, + ), + ), + profile=Profile( + instances=[ + InstanceNameSelector(name="worker-0"), + InstanceNameSelector(name="worker-1"), + ] + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=fleet, + run_spec=run_spec, + ) + master_job = await create_job( + session=session, + run=run, + job_num=0, + waiting_master_job=False, + ) + worker_job = await create_job( + session=session, + run=run, + job_num=1, + waiting_master_job=True, + ) + + await _process_job(session=session, worker=worker, job_model=master_job) + master_job = await _get_job(session, master_job.id) + assert master_job.instance is not None and master_job.instance.id == selected_master.id + + await _process_job(session=session, worker=worker, job_model=master_job) + master_job = await _get_job(session, master_job.id) + await session.refresh(worker_job) + assert master_job.status == JobStatus.PROVISIONING + assert worker_job.waiting_master_job is False + + await _process_job(session=session, worker=worker, job_model=worker_job) + + worker_job = await _get_job(session, worker_job.id) + await session.refresh(selected_master) + await session.refresh(selected_worker) + assert worker_job.instance is not None and worker_job.instance.id == selected_worker.id + assert selected_master.busy_blocks == 2 + assert selected_worker.busy_blocks == 2 + master_runtime = JobRuntimeData.__response__.parse_raw(master_job.job_runtime_data) + worker_runtime = JobRuntimeData.__response__.parse_raw(worker_job.job_runtime_data) + assert master_runtime.network_mode == NetworkMode.HOST + assert worker_runtime.network_mode == NetworkMode.HOST + assert master_runtime.offer is not None and master_runtime.offer.blocks == 2 + assert worker_runtime.offer is not None and worker_runtime.offer.blocks == 2 + + async def test_assigns_multinode_jobs_to_specific_instances_in_same_cluster_fleet( + self, test_db, session: AsyncSession, worker: JobSubmittedWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER + fleet = await create_fleet(session=session, project=project, spec=fleet_spec) + selected_master = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + name="worker-0", + backend=BackendType.AWS, + region="eu-west-1", + price=1, + job_provisioning_data=get_job_provisioning_data(region="eu-west-1"), + ) + selected_worker = await create_instance( + session=session, + project=project, + fleet=fleet, + status=InstanceStatus.IDLE, + name="worker-1", + backend=BackendType.AWS, + region="eu-west-1", + price=2, + job_provisioning_data=get_job_provisioning_data(region="eu-west-1"), + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration(image="debian", nodes=2, commands=["echo"]), + profile=Profile( + instances=[ + InstanceNameSelector(name="worker-0"), + InstanceNameSelector(name="worker-1"), + ] + ), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + fleet=fleet, + run_spec=run_spec, + ) + master_job = await create_job( + session=session, + run=run, + job_num=0, + waiting_master_job=False, + ) + worker_job = await create_job( + session=session, + run=run, + job_num=1, + waiting_master_job=True, + ) + + await _process_job(session=session, worker=worker, job_model=master_job) + master_job = await _get_job(session, master_job.id) + assert master_job.instance is not None and master_job.instance.id == selected_master.id + assert master_job.status == JobStatus.SUBMITTED + + await _process_job(session=session, worker=worker, job_model=master_job) + master_job = await _get_job(session, master_job.id) + await session.refresh(worker_job) + assert master_job.status == JobStatus.PROVISIONING + assert worker_job.waiting_master_job is False + + await _process_job(session=session, worker=worker, job_model=worker_job) + + worker_job = await _get_job(session, worker_job.id) + await session.refresh(selected_master) + await session.refresh(selected_worker) + assert worker_job.status == JobStatus.SUBMITTED + assert worker_job.instance is not None and worker_job.instance.id == selected_worker.id + assert selected_master.busy_blocks == 1 + assert selected_worker.busy_blocks == 1 + + async def test_does_not_create_capacity_when_specific_instance_is_missing( + self, test_db, session: AsyncSession, worker: JobSubmittedWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + await create_fleet(session=session, project=project) + run_spec = get_run_spec( + repo_id=repo.name, + profile=Profile(instances=[InstanceNameSelector(name="missing-instance")]), + ) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_spec=run_spec, + ) + job = await create_job(session=session, run=run) + + await _process_job(session=session, worker=worker, job_model=job) + + job = await _get_job(session, job.id) + assert job.status == JobStatus.TERMINATING + assert job.termination_reason == JobTerminationReason.FAILED_TO_START_DUE_TO_NO_CAPACITY + res = await session.execute(select(InstanceModel)) + assert res.scalars().all() == [] async def test_assigns_job_to_imported_fleet( self, test_db, session: AsyncSession, worker: JobSubmittedWorker diff --git a/src/tests/_internal/server/routers/test_fleets.py b/src/tests/_internal/server/routers/test_fleets.py index 16ce066866..04d0145dfe 100644 --- a/src/tests/_internal/server/routers/test_fleets.py +++ b/src/tests/_internal/server/routers/test_fleets.py @@ -975,6 +975,7 @@ async def test_creates_fleet(self, test_db, session: AsyncSession, client: Async "fleets": None, "tags": None, "backend_options": None, + "instances": None, }, "autocreated": False, }, @@ -1095,6 +1096,7 @@ async def test_creates_ssh_fleet(self, test_db, session: AsyncSession, client: A "fleets": None, "tags": None, "backend_options": None, + "instances": None, }, "autocreated": False, }, @@ -1314,6 +1316,7 @@ async def test_updates_ssh_fleet(self, test_db, session: AsyncSession, client: A "fleets": None, "tags": None, "backend_options": None, + "instances": None, }, "autocreated": False, }, diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 93414c9f42..9be2294901 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -215,6 +215,7 @@ def get_dev_env_run_plan_dict( "fleets": None, "tags": None, "backend_options": None, + "instances": None, "priority": 0, }, "configuration_path": "dstack.yaml", @@ -241,6 +242,7 @@ def get_dev_env_run_plan_dict( "fleets": None, "tags": None, "backend_options": None, + "instances": None, }, "repo_code_hash": None, "repo_data": { @@ -460,6 +462,7 @@ def get_dev_env_run_dict( "fleets": None, "tags": None, "backend_options": None, + "instances": None, "priority": 0, }, "configuration_path": "dstack.yaml", @@ -486,6 +489,7 @@ def get_dev_env_run_dict( "fleets": None, "tags": None, "backend_options": None, + "instances": None, }, "repo_code_hash": None, "repo_data": { diff --git a/src/tests/_internal/server/services/runs/test_plan.py b/src/tests/_internal/server/services/runs/test_plan.py index 5836319bc1..ce586171c4 100644 --- a/src/tests/_internal/server/services/runs/test_plan.py +++ b/src/tests/_internal/server/services/runs/test_plan.py @@ -4,16 +4,35 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.configurations import TaskConfiguration +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import EntityReference +from dstack._internal.core.models.configurations import ( + DevEnvironmentConfiguration, + TaskConfiguration, +) from dstack._internal.core.models.fleets import FleetNodesSpec, InstanceGroupPlacement from dstack._internal.core.models.instances import InstanceAvailability +from dstack._internal.core.models.profiles import ( + CreationPolicy, + FleetInstanceSelector, + InstanceHostnameSelector, + InstanceNameSelector, + Profile, +) +from dstack._internal.core.models.resources import CPUSpec, Memory, Range, ResourcesSpec from dstack._internal.server.services.jobs import get_jobs_from_run_spec +from dstack._internal.server.services.projects import get_project_model_by_name +from dstack._internal.server.services.runs import get_plan from dstack._internal.server.services.runs.plan import ( _freeze_offer_identity_value, _get_backend_offer_identity, _get_backend_offers_in_fleet, + _get_job_plan, + get_backend_offers_in_run_candidate_fleets, + get_targeted_instance_offers, ) from dstack._internal.server.testing.common import ( + create_export, create_fleet, create_instance, create_project, @@ -22,6 +41,7 @@ get_fleet_spec, get_instance_offer_with_availability, get_job_provisioning_data, + get_remote_connection_info, get_run_spec, ) @@ -66,6 +86,487 @@ def test_get_backend_offer_identity_uses_full_offer_payload(self) -> None: assert _get_backend_offer_identity(offer) != _get_backend_offer_identity(different_offer) +class TestGetJobPlan: + @pytest.mark.asyncio + async def test_excludes_backend_offers_when_instances_specified(self) -> None: + run_spec = get_run_spec( + repo_id="test-repo", + configuration=TaskConfiguration(image="debian", commands=["echo"]), + ) + jobs = await get_jobs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=0) + instance_offer = get_instance_offer_with_availability() + backend_offer = get_instance_offer_with_availability() + + job_plan = _get_job_plan( + instance_offers=[(None, instance_offer)], # type: ignore[list-item] + backend_offers=[(None, backend_offer)], # type: ignore[list-item] + profile=Profile( + name="default", + creation_policy=CreationPolicy.REUSE_OR_CREATE, + instances=[InstanceNameSelector(name="my-fleet-0")], + ), + job=jobs[0], + max_offers=None, + ) + + assert job_plan.total_offers == 1 + assert job_plan.offers == [instance_offer] + + +class TestGetPlan: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_empty_dev_environment_with_fleet_does_not_use_targeted_instances( + self, + test_db, + session: AsyncSession, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + await create_fleet(session=session, project=project) + project = await get_project_model_by_name(session=session, project_name=project.name) + assert project is not None + select_instances_mock = AsyncMock() + monkeypatch.setattr( + "dstack._internal.server.services.runs.plan.select_instances_by_selectors", + select_instances_mock, + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=DevEnvironmentConfiguration(), + ) + + await get_plan( + session=session, + project=project, + user=user, + run_spec=run_spec, + max_offers=None, + ) + + select_instances_mock.assert_not_awaited() + + +class TestGetTargetedInstanceOffers: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_only_selected_instance(self, test_db, session: AsyncSession) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project) + await create_instance( + session=session, + project=project, + fleet=fleet, + name="worker-0", + ) + selected = await create_instance( + session=session, + project=project, + fleet=fleet, + name="worker-1", + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration(image="debian", commands=["echo"]), + profile=Profile(instances=[InstanceNameSelector(name="worker-1")]), + ) + jobs = await get_jobs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=0) + + offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=jobs[0], + volumes=None, + exclude_not_available=True, + ) + + assert [instance for instance, _ in offers] == [selected] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_selected_instance_by_hostname( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project) + await create_instance( + session=session, + project=project, + fleet=fleet, + name="worker-0", + remote_connection_info=get_remote_connection_info(host="192.168.1.10"), + ) + selected = await create_instance( + session=session, + project=project, + fleet=fleet, + name="worker-1", + remote_connection_info=get_remote_connection_info(host="192.168.1.11"), + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration(image="debian", commands=["echo"]), + profile=Profile(instances=[InstanceHostnameSelector(hostname="192.168.1.11")]), + ) + jobs = await get_jobs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=0) + + offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=jobs[0], + volumes=None, + exclude_not_available=True, + ) + + assert [instance for instance, _ in offers] == [selected] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_selected_instance_from_imported_fleet_reference( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user, name="importer-project") + exporter_project = await create_project( + session=session, owner=user, name="exporter-project" + ) + repo = await create_repo(session=session, project_id=project.id) + local_fleet = await create_fleet(session=session, project=project, name="same-fleet") + exported_fleet = await create_fleet( + session=session, project=exporter_project, name="same-fleet" + ) + await create_instance( + session=session, + project=project, + fleet=local_fleet, + instance_num=1, + name="local-worker", + ) + selected = await create_instance( + session=session, + project=exporter_project, + fleet=exported_fleet, + instance_num=1, + name="exported-worker", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[project], + exported_fleets=[exported_fleet], + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration(image="debian", commands=["echo"]), + profile=Profile( + instances=[ + FleetInstanceSelector( + fleet=EntityReference.parse("exporter-project/same-fleet"), + instance=1, + ) + ] + ), + ) + jobs = await get_jobs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=0) + + offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=jobs[0], + volumes=None, + exclude_not_available=True, + ) + + assert [instance for instance, _ in offers] == [selected] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_shared_block_offer_for_selected_instance( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + name="shared-worker", + total_blocks=2, + busy_blocks=1, + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration( + image="debian", + commands=["echo"], + resources=ResourcesSpec( + cpu=CPUSpec.parse("1"), + memory=Range[Memory](min=Memory.parse("1GB"), max=None), + gpu=None, + ), + ), + profile=Profile(instances=[InstanceNameSelector(name="shared-worker")]), + ) + jobs = await get_jobs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=0) + + offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=jobs[0], + volumes=None, + exclude_not_available=True, + ) + + assert [selected for selected, _ in offers] == [instance] + assert offers[0][1].blocks == 1 + assert offers[0][1].total_blocks == 2 + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_multinode_does_not_count_blocks_as_nodes( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER + fleet = await create_fleet(session=session, project=project, spec=fleet_spec) + await create_instance( + session=session, + project=project, + fleet=fleet, + name="shared-worker", + backend=BackendType.AWS, + total_blocks=2, + busy_blocks=0, + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration(image="debian", nodes=2, commands=["echo"]), + profile=Profile(instances=[InstanceNameSelector(name="shared-worker")]), + ) + jobs = await get_jobs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=0) + + offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=jobs[0], + volumes=None, + exclude_not_available=True, + ) + + assert offers == [] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_multinode_returns_full_host_offer_per_selected_shared_instance( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER + fleet = await create_fleet(session=session, project=project, spec=fleet_spec) + selected_1 = await create_instance( + session=session, + project=project, + fleet=fleet, + name="worker-0", + backend=BackendType.REMOTE, + total_blocks=2, + busy_blocks=0, + ) + selected_2 = await create_instance( + session=session, + project=project, + fleet=fleet, + name="worker-1", + backend=BackendType.REMOTE, + total_blocks=2, + busy_blocks=0, + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration( + image="debian", + nodes=2, + commands=["echo"], + resources=ResourcesSpec( + cpu=CPUSpec.parse("1.."), + memory=Range[Memory](min=Memory.parse("1GB"), max=None), + gpu=None, + ), + ), + profile=Profile( + instances=[ + InstanceNameSelector(name="worker-0"), + InstanceNameSelector(name="worker-1"), + ] + ), + ) + jobs = await get_jobs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=0) + + offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=jobs[0], + volumes=None, + exclude_not_available=True, + ) + + assert [instance for instance, _ in offers] == [selected_1, selected_2] + assert [offer.blocks for _, offer in offers] == [2, 2] + assert [offer.total_blocks for _, offer in offers] == [2, 2] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_multinode_returns_selected_instances_in_same_cluster_fleet( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER + fleet = await create_fleet(session=session, project=project, spec=fleet_spec) + selected_1 = await create_instance( + session=session, + project=project, + fleet=fleet, + name="worker-0", + backend=BackendType.AWS, + job_provisioning_data=get_job_provisioning_data(region="eu-west-1"), + ) + selected_2 = await create_instance( + session=session, + project=project, + fleet=fleet, + name="worker-1", + backend=BackendType.AWS, + job_provisioning_data=get_job_provisioning_data(region="eu-west-1"), + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration(image="debian", nodes=2, commands=["echo"]), + profile=Profile( + instances=[ + InstanceNameSelector(name="worker-0"), + InstanceNameSelector(name="worker-1"), + ] + ), + ) + jobs = await get_jobs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=0) + + offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=jobs[0], + volumes=None, + exclude_not_available=True, + ) + + assert [instance for instance, _ in offers] == [selected_1, selected_2] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_multinode_requires_selected_instances_in_one_cluster_fleet( + self, test_db, session: AsyncSession + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + fleet_spec = get_fleet_spec() + fleet_spec.configuration.placement = InstanceGroupPlacement.CLUSTER + fleet_1 = await create_fleet(session=session, project=project, spec=fleet_spec) + fleet_2 = await create_fleet(session=session, project=project, spec=fleet_spec) + await create_instance( + session=session, + project=project, + fleet=fleet_1, + name="worker-0", + backend=BackendType.AWS, + ) + await create_instance( + session=session, + project=project, + fleet=fleet_2, + name="worker-1", + backend=BackendType.AWS, + ) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration(image="debian", nodes=2, commands=["echo"]), + profile=Profile( + instances=[ + InstanceNameSelector(name="worker-0"), + InstanceNameSelector(name="worker-1"), + ] + ), + ) + jobs = await get_jobs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=0) + + offers = await get_targeted_instance_offers( + session=session, + project=project, + run_spec=run_spec, + job=jobs[0], + volumes=None, + exclude_not_available=True, + ) + + assert offers == [] + + +class TestGetBackendOffersInRunCandidateFleets: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_skips_backend_offers_when_instances_specified( + self, test_db, session: AsyncSession, monkeypatch: pytest.MonkeyPatch + ) -> None: + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec( + repo_id=repo.name, + configuration=TaskConfiguration(image="debian", commands=["echo"]), + profile=Profile(instances=[InstanceNameSelector(name="missing-instance")]), + ) + jobs = await get_jobs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=0) + select_candidate_fleet_models_mock = AsyncMock() + monkeypatch.setattr( + "dstack._internal.server.services.runs.plan._select_candidate_fleet_models", + select_candidate_fleet_models_mock, + ) + + offers = await get_backend_offers_in_run_candidate_fleets( + session=session, + project=project, + run_spec=run_spec, + job=jobs[0], + volumes=None, + ) + + assert offers == [] + select_candidate_fleet_models_mock.assert_not_awaited() + + class TestGetBackendOffersInFleet: @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) diff --git a/src/tests/_internal/server/services/test_instances.py b/src/tests/_internal/server/services/test_instances.py index aaa3e48d88..cba11c67ec 100644 --- a/src/tests/_internal/server/services/test_instances.py +++ b/src/tests/_internal/server/services/test_instances.py @@ -14,19 +14,28 @@ InstanceType, Resources, ) -from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.profiles import ( + FleetInstanceSelector, + InstanceHostnameSelector, + InstanceNameSelector, + Profile, +) from dstack._internal.core.models.runs import JobStatus from dstack._internal.server.models import InstanceModel from dstack._internal.server.schemas.runner import TaskListItem, TaskListResponse, TaskStatus from dstack._internal.server.services.runner.client import ShimClient from dstack._internal.server.testing.common import ( + create_export, + create_fleet, create_instance, create_job, create_project, create_repo, create_run, create_user, + get_job_provisioning_data, get_kubernetes_volume_configuration, + get_remote_connection_info, get_volume, get_volume_configuration, get_volume_provisioning_data, @@ -234,6 +243,126 @@ async def test_returns_volume_instances_without_region(self, test_db, session: A assert res == [kubernetes_instance] +class TestSelectInstancesBySelectors: + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_selects_by_instance_name(self, test_db, session: AsyncSession): + project = await create_project(session=session) + fleet = await create_fleet(session=session, project=project) + await create_instance(session=session, project=project, fleet=fleet, name="worker-0") + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + name="worker-1", + ) + + res = await instances_services.select_instances_by_selectors( + session=session, + project=project, + selectors=[InstanceNameSelector(name="worker-1")], + ) + + assert res == [instance] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_selects_by_cloud_hostname_and_internal_ip(self, test_db, session: AsyncSession): + project = await create_project(session=session) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + job_provisioning_data=get_job_provisioning_data( + hostname="203.0.113.8", + internal_ip="10.0.0.8", + ), + ) + + res = await instances_services.select_instances_by_selectors( + session=session, + project=project, + selectors=[InstanceHostnameSelector(hostname="10.0.0.8")], + ) + + assert res == [instance] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_selects_by_ssh_host(self, test_db, session: AsyncSession): + project = await create_project(session=session) + fleet = await create_fleet(session=session, project=project) + instance = await create_instance( + session=session, + project=project, + fleet=fleet, + remote_connection_info=get_remote_connection_info(host="192.168.1.11"), + ) + + res = await instances_services.select_instances_by_selectors( + session=session, + project=project, + selectors=[InstanceHostnameSelector(hostname="192.168.1.11")], + ) + + assert res == [instance] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + ("selector", "expected_name"), + [ + ("same-fleet", "local-worker"), + ("exporter-project/same-fleet", "exported-worker"), + ], + ) + async def test_fleet_instance_selector_respects_project_reference( + self, + test_db, + session: AsyncSession, + selector: str, + expected_name: str, + ): + user = await create_user(session=session) + project = await create_project(session=session, owner=user, name="importer-project") + exporter_project = await create_project( + session=session, owner=user, name="exporter-project" + ) + local_fleet = await create_fleet(session=session, project=project, name="same-fleet") + exported_fleet = await create_fleet( + session=session, project=exporter_project, name="same-fleet" + ) + await create_instance( + session=session, + project=project, + fleet=local_fleet, + instance_num=1, + name="local-worker", + ) + await create_instance( + session=session, + project=exporter_project, + fleet=exported_fleet, + instance_num=1, + name="exported-worker", + ) + await create_export( + session=session, + exporter_project=exporter_project, + importer_projects=[project], + exported_fleets=[exported_fleet], + ) + + res = await instances_services.select_instances_by_selectors( + session=session, + project=project, + selectors=[FleetInstanceSelector(fleet=selector, instance=1)], + ) + + assert [instance.name for instance in res] == [expected_name] + + @pytest.mark.asyncio @pytest.mark.usefixtures("image_config_mock") @pytest.mark.usefixtures("turn_off_keep_shim_tasks_setting")