diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json index 7d8c9f2465..804134f3d1 100644 --- a/frontend/src/locale/en.json +++ b/frontend/src/locale/en.json @@ -144,6 +144,7 @@ "region_description": "Select a region", "default": "Default", "default_checkbox": "Turn on default", + "hostname": "Hostname", "external_ip": "External IP", "wildcard_domain": "Wildcard domain", "wildcard_domain_description": "Specify the wildcard domain mapped to the external IP.", diff --git a/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx b/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx index 79b4aac22c..f63a77fa14 100644 --- a/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx +++ b/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx @@ -30,13 +30,15 @@ export const useColumnsDefinitions = ({ loading, projectName, onDeleteClick, onE { id: 'type', header: t('gateway.edit.backend'), - cell: (gateway: IGateway) => gateway.backend, + cell: (gateway: IGateway) => + gateway.replicas.length > 0 ? gateway.replicas.map((r, i) =>
{r.backend}
) : null, }, { id: 'region', header: t('gateway.edit.region'), - cell: (gateway: IGateway) => gateway.region, + cell: (gateway: IGateway) => + gateway.replicas.length > 0 ? gateway.replicas.map((r, i) =>
{r.region}
) : null, }, { @@ -46,9 +48,13 @@ export const useColumnsDefinitions = ({ loading, projectName, onDeleteClick, onE }, { - id: 'external_ip', - header: t('gateway.edit.external_ip'), - cell: (gateway: IGateway) => gateway.ip_address, + id: 'hostname', + header: t('gateway.edit.hostname'), + cell: (gateway: IGateway) => { + if (gateway.hostname) return gateway.hostname; + if (gateway.replicas.length > 0) return gateway.replicas.map((r, i) =>
{r.hostname}
); + return null; + }, }, { diff --git a/frontend/src/types/gateway.d.ts b/frontend/src/types/gateway.d.ts index 4ef2eeeb54..1442cf4d62 100644 --- a/frontend/src/types/gateway.d.ts +++ b/frontend/src/types/gateway.d.ts @@ -1,3 +1,9 @@ +declare interface IGatewayReplica { + hostname: string, + backend: string, + region: string, +} + declare interface IGateway { backend: string, name: string, @@ -5,8 +11,10 @@ declare interface IGateway { ip_address: string, instance_id: string, region:string + hostname?: string, wildcard_domain?: string default: boolean + replicas: IGatewayReplica[], created_at?: number, } diff --git a/mkdocs/docs/concepts/gateways.md b/mkdocs/docs/concepts/gateways.md index bd71187964..b71a23d7b6 100644 --- a/mkdocs/docs/concepts/gateways.md +++ b/mkdocs/docs/concepts/gateways.md @@ -182,6 +182,50 @@ domain: example.com +### Replicas + +A gateway can have multiple replicas for improved availability. + +
+ +```yaml +type: gateway +name: example-gateway + +backend: aws +region: eu-west-1 + +domain: example.com + +certificate: null +replicas: 2 +``` + +
+ +To balance requests between gateway replicas, add DNS records for each replica or set up a load balancer outside of `dstack`. Replica hostnames are displayed in `dstack` CLI and UI. + +
+ +```shell +$ dstack gateway list + NAME BACKEND HOSTNAME DOMAIN DEFAULT STATUS + example-gateway example.com ✓ running + replica=0 aws (eu-west-1) 34.244.128.46 + replica=1 aws (eu-west-1) 18.201.201.174 +``` + +
+ +!!! warning "Experimental" + Replicated gateways are an experimental feature and currently have limitations: + + - Changing the number of replicas or redeploying replicas is not supported. + - HTTPS is not supported. Use an external load balancer for TLS termination. + - An unavailable gateway replica prevents any new services or service replicas from being added. + - All replicas are bound to the same backend and region. + - At most 3 replicas are allowed per gateway. + !!! info "Reference" For all gateway configuration options, refer to the [reference](../reference/dstack.yml/gateway.md). diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py index 0b6993e18b..9f8e6cd0d1 100644 --- a/src/dstack/_internal/cli/services/configurators/gateway.py +++ b/src/dstack/_internal/cli/services/configurators/gateway.py @@ -236,6 +236,10 @@ def th(s: str) -> str: configuration_table.add_row(th("Region"), plan.spec.configuration.region) configuration_table.add_row(th("Domain"), domain) + if plan.spec.configuration.replicas is not None: + assert isinstance(plan.spec.configuration.replicas, int) + configuration_table.add_row(th("Replicas"), str(plan.spec.configuration.replicas)) + console.print(configuration_table) console.print() diff --git a/src/dstack/_internal/cli/utils/gateway.py b/src/dstack/_internal/cli/utils/gateway.py index 5605326211..0d873a9a5d 100644 --- a/src/dstack/_internal/cli/utils/gateway.py +++ b/src/dstack/_internal/cli/utils/gateway.py @@ -95,15 +95,39 @@ def get_gateways_table( # Ignore errors in case future server versions introduce more interpolation variables exception_type=None, ) - row = { + + gateway_row = { "NAME": name, - "BACKEND": format_backend(gateway.configuration.backend, gateway.configuration.region), - "HOSTNAME": gateway.hostname, "DOMAIN": domain, "DEFAULT": "✓" if gateway.default else "", "STATUS": gateway.status, "CREATED": format_date(gateway.created_at), "ERROR": gateway.status_message, } - add_row_from_dict(table, row) + if gateway.hostname is not None: + gateway_row["HOSTNAME"] = gateway.hostname + if len(gateway.replicas) == 0: + # replicas not yet created, or it's a pre-0.20.25 server without replica support + gateway_row["BACKEND"] = format_backend( + gateway.configuration.backend, gateway.configuration.region + ) + gateway_row["HOSTNAME"] = gateway_row.get("HOSTNAME", gateway.ip_address) + if len(gateway.replicas) == 1: + # compact display for single-replica gateway + gateway_row["BACKEND"] = format_backend( + gateway.replicas[0].backend, gateway.replicas[0].region + ) + gateway_row["HOSTNAME"] = gateway_row.get("HOSTNAME", gateway.replicas[0].hostname) + add_row_from_dict(table, gateway_row) + + if len(gateway.replicas) > 1: + for replica in gateway.replicas: + replica_row = { + "NAME": f" replica={replica.replica_num}", + "BACKEND": format_backend(replica.backend, replica.region), + "HOSTNAME": replica.hostname, + "CREATED": format_date(replica.created_at), + } + add_row_from_dict(table, replica_row, style="secondary") + return table diff --git a/src/dstack/_internal/core/compatibility/gateways.py b/src/dstack/_internal/core/compatibility/gateways.py index a2fc6101e6..0a89e86113 100644 --- a/src/dstack/_internal/core/compatibility/gateways.py +++ b/src/dstack/_internal/core/compatibility/gateways.py @@ -41,5 +41,7 @@ def _get_gateway_configuration_excludes( if configuration.router is None: configuration_excludes["router"] = True + if configuration.replicas is None: + configuration_excludes["replicas"] = True return configuration_excludes diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 6f92b449b7..74b3f4e835 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -11,6 +11,8 @@ from dstack._internal.core.models.routers import AnyGatewayRouterConfig from dstack._internal.utils.tags import tags_validator +GATEWAY_REPLICAS_DEFAULT = 1 + class GatewayStatus(str, Enum): SUBMITTED = "submitted" @@ -90,6 +92,13 @@ class GatewayConfiguration(CoreModel): " Set to `null` to disable. Defaults to `type: lets-encrypt`" ), ] = LetsEncryptGatewayCertificate() + replicas: Annotated[ + Optional[int], + Field( + description=f"The number of gateway replicas. Defaults to `{GATEWAY_REPLICAS_DEFAULT}`", + ge=1, + ), + ] = None tags: Annotated[ Optional[Dict[str, str]], Field( @@ -109,6 +118,14 @@ class GatewaySpec(CoreModel): configuration_path: Optional[str] = None +class GatewayReplica(CoreModel): + hostname: str + replica_num: int + backend: BackendType + region: str + created_at: datetime.datetime + + class Gateway(CoreModel): # TODO(0.21): Make `id` required. id: Optional[uuid.UUID] = None @@ -121,14 +138,13 @@ class Gateway(CoreModel): status: GatewayStatus status_message: Optional[str] hostname: Optional[str] - """`hostname` is the IP address or hostname the user should set up the domain for. - Could be the same as `ip_address` but also different, for example a gateway behind ALB. + """Hostname of the load balancer. + Unset if there is no load balancer, in which case users are expected to point the gateway's + wildcard domain name to `replicas[i].hostname`. """ - ip_address: Optional[str] - """`ip_address` is the IP address of the gateway instance.""" - instance_id: Optional[str] wildcard_domain: Optional[str] default: bool + replicas: list[GatewayReplica] = [] backend: Optional[BackendType] = None """`backend` duplicates a configuration field on the top level for backward compatibility with 0.19.x clients that expect it to be required. @@ -139,6 +155,10 @@ class Gateway(CoreModel): with 0.19.x clients that expect it to be required. Remove after 0.21. """ + ip_address: Optional[str] = None + """Deprecated in favor of `replicas[i].hostname`, only set for pre-0.20.25 clients.""" + instance_id: Optional[str] = None + """Deprecated, unused, kept for pre-0.20.25 clients.""" class GatewayPlan(CoreModel): diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index 1f8f0c64f5..05393eb1ae 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -5,12 +5,12 @@ from typing import Optional, Sequence, TypedDict from sqlalchemy import delete, or_, select, update -from sqlalchemy.orm import joinedload, load_only +from sqlalchemy.orm import joinedload, load_only, selectinload from dstack._internal.core.backends.base.compute import ComputeWithGatewaySupport from dstack._internal.core.errors import BackendError, BackendNotAvailable from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.gateways import GatewayStatus +from dstack._internal.core.models.gateways import GATEWAY_REPLICAS_DEFAULT, GatewayStatus from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, @@ -34,7 +34,10 @@ from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services import events from dstack._internal.server.services import gateways as gateways_services -from dstack._internal.server.services.gateways import emit_gateway_status_change_event +from dstack._internal.server.services.gateways import ( + emit_gateway_status_change_event, + get_gateway_compute_models, +) from dstack._internal.server.services.gateways.pool import gateway_connections_pool from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt @@ -239,11 +242,8 @@ async def _process_submitted_item(item: GatewayPipelineItem): set_processed_update_map_fields(update_map) set_unlock_update_map_fields(update_map) async with get_session_ctx() as session: - gateway_compute_model = result.gateway_compute_model - if gateway_compute_model is not None: + for gateway_compute_model in result.gateway_compute_models: session.add(gateway_compute_model) - await session.flush() - update_map["gateway_compute_id"] = gateway_compute_model.id now = get_current_datetime() resolve_now_placeholders(update_map, now=now) res = await session.execute( @@ -258,7 +258,7 @@ async def _process_submitted_item(item: GatewayPipelineItem): updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: log_lock_token_changed_after_processing(logger, item) - # TODO: Clean up gateway_compute_model. + # TODO: Clean up gateway_compute_models. return emit_gateway_status_change_event( session=session, @@ -272,7 +272,6 @@ async def _process_submitted_item(item: GatewayPipelineItem): class _GatewayUpdateMap(ItemUpdateMap, total=False): status: GatewayStatus status_message: str - gateway_compute_id: uuid.UUID class _GatewayComputeUpdateMap(TypedDict, total=False): @@ -283,7 +282,7 @@ class _GatewayComputeUpdateMap(TypedDict, total=False): @dataclass class _SubmittedResult: update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap) - gateway_compute_model: Optional[GatewayComputeModel] = None + gateway_compute_models: list[GatewayComputeModel] = field(default_factory=list) async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedResult: @@ -303,16 +302,28 @@ async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedR "status_message": "Backend not available", } ) + replicas = ( + configuration.replicas if configuration.replicas is not None else GATEWAY_REPLICAS_DEFAULT + ) + gateway_compute_models = [] try: - gateway_compute_model = await gateways_services.create_gateway_compute( - backend_compute=backend.compute(), - project_name=gateway_model.project.name, - configuration=configuration, - backend_id=backend_model.id, - ) + for replica_num in range(replicas): + logger.debug( + "%s replica %d: creating gateway compute", fmt(gateway_model), replica_num + ) + gateway_compute_model = await gateways_services.create_gateway_compute( + backend_compute=backend.compute(), + project_name=gateway_model.project.name, + configuration=configuration, + replica_num=replica_num, + gateway_id=gateway_model.id, + backend_id=backend_model.id, + ) + logger.info("%s replica %d: gateway compute created", fmt(gateway_model), replica_num) + gateway_compute_models.append(gateway_compute_model) return _SubmittedResult( update_map={"status": GatewayStatus.PROVISIONING}, - gateway_compute_model=gateway_compute_model, + gateway_compute_models=gateway_compute_models, ) except BackendError as e: status_message = f"Backend error: {repr(e)}" @@ -322,7 +333,8 @@ async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedR update_map={ "status": GatewayStatus.FAILED, "status_message": status_message, - } + }, + gateway_compute_models=gateway_compute_models, ) except Exception as e: logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model)) @@ -330,7 +342,8 @@ async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedR update_map={ "status": GatewayStatus.FAILED, "status_message": f"Unexpected error: {repr(e)}", - } + }, + gateway_compute_models=gateway_compute_models, ) @@ -343,6 +356,7 @@ async def _process_provisioning_item(item: GatewayPipelineItem): GatewayModel.lock_token == item.lock_token, ) .options(joinedload(GatewayModel.gateway_compute)) + .options(selectinload(GatewayModel.gateway_computes)) ) gateway_model = res.unique().scalar_one_or_none() if gateway_model is None: @@ -377,34 +391,39 @@ async def _process_provisioning_item(item: GatewayPipelineItem): new_status=gateway_update_map.get("status", gateway_model.status), status_message=gateway_update_map.get("status_message", gateway_model.status_message), ) - if result.gateway_compute_update_map: + if result.all_computes_update_map: res = await session.execute( update(GatewayComputeModel) - .where(GatewayComputeModel.id == gateway_model.gateway_compute_id) - .values(**result.gateway_compute_update_map) + .where( + or_( + GatewayComputeModel.gateway_id == gateway_model.id, + GatewayComputeModel.id == gateway_model.gateway_compute_id, + ) + ) + .values(**result.all_computes_update_map) .returning(GatewayComputeModel.id) ) updated_ids = list(res.scalars().all()) - if len(updated_ids) == 0: + if len(updated_ids) < len(get_gateway_compute_models(gateway_model)): logger.error( - "Failed to update compute model %s for gateway %s." + "Failed to update compute models for gateway %s." " This is unexpected and may happen only if the compute model was manually deleted.", gateway_model.id, - item.id, ) @dataclass class _ProvisioningResult: gateway_update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap) - gateway_compute_update_map: _GatewayComputeUpdateMap = field( + all_computes_update_map: _GatewayComputeUpdateMap = field( default_factory=_GatewayComputeUpdateMap ) async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _ProvisioningResult: + gateway_computes = get_gateway_compute_models(gateway_model) # Provisioning gateways must have compute. - assert gateway_model.gateway_compute is not None + assert len(gateway_computes) > 0 # FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway: # - cannot delete the gateway before it is provisioned because the DB model is locked @@ -413,32 +432,58 @@ async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _Provisi # Easy to fix by doing only one connection/configuration attempt per processing iteration. The # main challenge is applying the same provisioning model to the dstack Sky gateway to avoid # maintaining a different model for Sky. - connection = await gateways_services.connect_to_gateway_with_retry( - gateway_model.gateway_compute + + errors = await asyncio.gather( + *(_connect_and_configure_gateway_replica(gateway_model, gc) for gc in gateway_computes) ) - if connection is None: + if any(errors): return _ProvisioningResult( gateway_update_map={ "status": GatewayStatus.FAILED, - "status_message": "Failed to connect to gateway", + "status_message": next(e for e in errors if e), }, - gateway_compute_update_map={"active": False}, + all_computes_update_map={"active": False}, + ) + + return _ProvisioningResult( + gateway_update_map={"status": GatewayStatus.RUNNING}, + ) + + +async def _connect_and_configure_gateway_replica( + gateway_model: GatewayModel, + gateway_compute: GatewayComputeModel, +) -> Optional[str]: + """Returns an error message on failure, None on success.""" + logger.debug( + "%s replica %d: connecting to gateway compute", + fmt(gateway_model), + gateway_compute.replica_num, + ) + connection = await gateways_services.connect_to_gateway_with_retry(gateway_compute) + if connection is None: + logger.warning( + "%s replica %d: failed to connect to gateway compute", + fmt(gateway_model), + gateway_compute.replica_num, ) + return "Failed to connect to gateway" try: await gateways_services.configure_gateway(connection) except Exception: - logger.exception("%s: failed to configure gateway", fmt(gateway_model)) - await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) - return _ProvisioningResult( - gateway_update_map={ - "status": GatewayStatus.FAILED, - "status_message": "Failed to configure gateway", - }, - gateway_compute_update_map={"active": False}, + logger.exception( + "%s replica %d: failed to configure gateway", + fmt(gateway_model), + gateway_compute.replica_num, ) - return _ProvisioningResult( - gateway_update_map={"status": GatewayStatus.RUNNING}, + await gateway_connections_pool.remove(gateway_compute.ip_address) + return "Failed to configure gateway" + logger.info( + "%s replica %d: gateway compute connected and configured", + fmt(gateway_model), + gateway_compute.replica_num, ) + return None async def _process_to_be_deleted_item(item: GatewayPipelineItem): @@ -451,6 +496,7 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): ) .options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends)) .options(joinedload(GatewayModel.gateway_compute)) + .options(selectinload(GatewayModel.gateway_computes)) .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) ) gateway_model = res.unique().scalar_one_or_none() @@ -460,6 +506,27 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): result = await _process_to_be_deleted_gateway(gateway_model) async with get_session_ctx() as session: + if result.all_computes_update_map: + res = await session.execute( + update(GatewayComputeModel) + .where( + or_( + GatewayComputeModel.gateway_id == gateway_model.id, + GatewayComputeModel.id == gateway_model.gateway_compute_id, + ) + ) + .values(**result.all_computes_update_map) + .returning(GatewayComputeModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) < len(get_gateway_compute_models(gateway_model)): + logger.error( + "Failed to update compute models for gateway %s." + " This is unexpected and may happen only if the compute model was manually deleted.", + gateway_model.id, + ) + return + if result.delete_gateway: res = await session.execute( delete(GatewayModel) @@ -503,28 +570,11 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem): log_lock_token_changed_after_processing(logger, item) return - if result.gateway_compute_update_map: - res = await session.execute( - update(GatewayComputeModel) - .where(GatewayComputeModel.id == gateway_model.gateway_compute_id) - .values(**result.gateway_compute_update_map) - .returning(GatewayComputeModel.id) - ) - updated_ids = list(res.scalars().all()) - if len(updated_ids) == 0: - logger.error( - "Failed to update compute model %s for gateway %s." - " This is unexpected and may happen only if the compute model was manually deleted.", - gateway_model.id, - item.id, - ) - return - @dataclass class _ProcessToBeDeletedResult: delete_gateway: bool - gateway_compute_update_map: _GatewayComputeUpdateMap = field( + all_computes_update_map: _GatewayComputeUpdateMap = field( default_factory=_GatewayComputeUpdateMap ) @@ -536,27 +586,39 @@ async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _Proces ) compute = backend.compute() assert isinstance(compute, ComputeWithGatewaySupport) - gateway_compute_configuration = gateways_services.get_gateway_compute_configuration( - gateway_model - ) - if gateway_model.gateway_compute is not None and gateway_compute_configuration is not None: - logger.info("Deleting gateway compute for %s...", gateway_model.name) + + for gateway_compute in get_gateway_compute_models(gateway_model): + gateway_compute_configuration = gateways_services.get_gateway_compute_configuration( + gateway_compute=gateway_compute, + gateway_model=gateway_model, + ) + logger.debug( + "%s replica %d: terminating gateway compute", + fmt(gateway_model), + gateway_compute.replica_num, + ) try: await run_async( compute.terminate_gateway, - gateway_model.gateway_compute.instance_id, + gateway_compute.instance_id, gateway_compute_configuration, - gateway_model.gateway_compute.backend_data, + gateway_compute.backend_data, ) except Exception: logger.exception( - "Error when deleting gateway compute for %s", - gateway_model.name, + "%s replica %d: error when terminating gateway compute", + fmt(gateway_model), + gateway_compute.replica_num, ) return _ProcessToBeDeletedResult(delete_gateway=False) - logger.info("Deleted gateway compute for %s", gateway_model.name) - result = _ProcessToBeDeletedResult(delete_gateway=True) - if gateway_model.gateway_compute is not None: - await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address) - result.gateway_compute_update_map = {"active": False, "deleted": True} - return result + logger.info( + "%s replica %d: gateway compute terminated", + fmt(gateway_model), + gateway_compute.replica_num, + ) + await gateway_connections_pool.remove(gateway_compute.ip_address) + + return _ProcessToBeDeletedResult( + delete_gateway=True, + all_computes_update_map={"active": False, "deleted": True}, + ) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 014f84c604..61599172b5 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -76,7 +76,7 @@ get_instance_specific_mounts, resolve_provisioning_image, ) -from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.gateways import get_or_add_gateway_connections from dstack._internal.server.services.instances import ( get_instance_remote_connection_info, get_instance_ssh_private_keys, @@ -1185,7 +1185,7 @@ async def _register_service_replica( return None async with get_session_ctx() as session: - gateway_model, conn = await get_or_add_gateway_connection( + gateway_model, connections = await get_or_add_gateway_connections( session, context.run_model.gateway_id ) gateway_target = events.Target.from_model(gateway_model) @@ -1197,35 +1197,40 @@ async def _register_service_replica( # so we must update job_submission with the result value. job_submission = context.job_submission.copy(deep=True) job_submission.job_runtime_data = _get_result_job_runtime_data(context.job_model, result) - try: - logger.debug( - "%s: registering replica for service %s", fmt(context.job_model), context.run.id.hex - ) - async with conn.client() as gateway_client: - await gateway_client.register_replica( - run=context.run, - job_spec=job_spec, - job_submission=job_submission, - instance_project_ssh_private_key=instance_project_ssh_private_key, - ssh_head_proxy=ssh_head_proxy, - ssh_head_proxy_private_key=ssh_head_proxy_private_key, - ) - except (httpx.RequestError, SSHError) as e: - logger.debug("Gateway request failed", exc_info=True) - raise GatewayError(repr(e)) - except GatewayError as e: - if "already exists in service" in e.msg: - logger.warning( - ( - "%s: could not register replica in gateway: %s." - " NOTE: if you just updated dstack from pre-0.19.25 to 0.19.25+," - " expect to see this warning once for every running service replica" - ), + for conn in connections: + try: + logger.debug( + "%s: registering replica for service %s on gateway replica %s", fmt(context.job_model), - e.msg, + context.run.id.hex, + conn.ip_address, ) - else: - raise + async with conn.client() as gateway_client: + await gateway_client.register_replica( + run=context.run, + job_spec=job_spec, + job_submission=job_submission, + instance_project_ssh_private_key=instance_project_ssh_private_key, + ssh_head_proxy=ssh_head_proxy, + ssh_head_proxy_private_key=ssh_head_proxy_private_key, + ) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + except GatewayError as e: + if "already exists in service" in e.msg: + logger.warning( + ( + "%s: could not register replica in gateway %s: %s." + " NOTE: if you just updated dstack from pre-0.19.25 to 0.19.25+," + " expect to see this warning once for every running service replica" + ), + fmt(context.job_model), + conn.ip_address, + e.msg, + ) + else: + raise return gateway_target diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index 3ae30c3ef2..adedf9bb4b 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -51,7 +51,7 @@ ) from dstack._internal.server.services import backends as backends_services from dstack._internal.server.services import events -from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.gateways import get_or_add_gateway_connections from dstack._internal.server.services.instances import ( emit_instance_status_change_event, get_instance_ssh_private_keys, @@ -795,25 +795,36 @@ async def _unregister_replica( run_model = job_model.run if run_model.gateway_id is not None: async with get_session_ctx() as session: - gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) - gateway_target = events.Target.from_model(gateway) - try: - logger.debug( - "%s: unregistering replica from service %s", fmt(job_model), job_model.run_id.hex + gateway, connections = await get_or_add_gateway_connections( + session, run_model.gateway_id ) - async with conn.client() as client: - await client.unregister_replica( - project=run_model.project.name, - run_name=run_model.run_name, - job_id=job_model.id, + gateway_target = events.Target.from_model(gateway) + for conn in connections: + try: + logger.debug( + "%s: unregistering replica from service %s on gateway replica %s", + fmt(job_model), + job_model.run_id.hex, + conn.ip_address, + ) + async with conn.client() as client: + await client.unregister_replica( + project=run_model.project.name, + run_name=run_model.run_name, + job_id=job_model.id, + ) + except GatewayError as e: + logger.warning( + "%s: unregistering replica from service on gateway replica %s: %s", + fmt(job_model), + conn.ip_address, + e, ) - except GatewayError as e: - logger.warning("%s: unregistering replica from service: %s", fmt(job_model), e) - except (httpx.RequestError, SSHError) as e: - logger.debug("Gateway request failed", exc_info=True) - # FIXME: Unhandled exception raised. - # Handle and retry unregister with timeout. - raise GatewayError(repr(e)) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + # FIXME: Unhandled exception raised. + # Handle and retry unregister with timeout. + raise GatewayError(repr(e)) return gateway_target diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py index c26df9d4d3..071af9fbd3 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py @@ -28,7 +28,7 @@ from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import InstanceModel, JobModel, ProjectModel, RunModel from dstack._internal.server.services import events -from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.gateways import get_combined_gateway_stats from dstack._internal.server.services.jobs import emit_job_status_change_event from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.pipelines import PipelineHinterProtocol @@ -313,8 +313,9 @@ async def _load_pending_context( gateway_stats = None if run_spec.configuration.type == "service" and run_model.gateway_id is not None: - _, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) - gateway_stats = await conn.get_stats(run_model.project.name, run_model.run_name) + gateway_stats = await get_combined_gateway_stats( + session, run_model.gateway_id, run_model.project.name, run_model.run_name + ) return pending.PendingContext( run_model=run_model, @@ -494,8 +495,9 @@ async def _load_active_context( gateway_stats = None if run_spec.configuration.type == "service" and run_model.gateway_id is not None: - _, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) - gateway_stats = await conn.get_stats(run_model.project.name, run_model.run_name) + gateway_stats = await get_combined_gateway_stats( + session, run_model.gateway_id, run_model.project.name, run_model.run_name + ) return active.ActiveContext( run_model=run_model, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py index eece7dfa7c..c9a75e3c71 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py @@ -16,7 +16,7 @@ from dstack._internal.server.background.pipeline_tasks.base import ItemUpdateMap from dstack._internal.server.db import get_session_ctx from dstack._internal.server.services import events -from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.gateways import get_or_add_gateway_connections from dstack._internal.server.services.logging import fmt from dstack._internal.server.services.runs import _get_next_triggered_at, get_run_spec from dstack._internal.utils.common import get_or_error @@ -148,24 +148,37 @@ async def _unregister_service(run_model: models.RunModel) -> Optional[ServiceUnr return None async with get_session_ctx() as session: - gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + gateway, connections = await get_or_add_gateway_connections(session, run_model.gateway_id) gateway_target = events.Target.from_model(gateway) - try: - logger.debug("%s: unregistering service", fmt(run_model)) - async with conn.client() as client: - await client.unregister_service( - project=run_model.project.name, - run_name=run_model.run_name, + gateway_errors = [] + for conn in connections: + try: + logger.debug( + "%s: unregistering service on gateway replica %s", fmt(run_model), conn.ip_address + ) + async with conn.client() as client: + await client.unregister_service( + project=run_model.project.name, + run_name=run_model.run_name, + ) + except GatewayError as e: + # Ignore if the service is not registered on this replica. + logger.warning( + "%s: unregistering service on gateway replica %s: %s", + fmt(run_model), + conn.ip_address, + e, ) + gateway_errors.append(str(e)) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + + if gateway_errors: + event_message = f"Gateway error when unregistering service: {'; '.join(gateway_errors)}" + else: event_message = "Service unregistered from gateway" - except GatewayError as e: - # Ignore if the service is not registered. - logger.warning("%s: unregistering service: %s", fmt(run_model), e) - event_message = f"Gateway error when unregistering service: {e}" - except (httpx.RequestError, SSHError) as e: - logger.debug("Gateway request failed", exc_info=True) - raise GatewayError(repr(e)) return ServiceUnregistration( event_message=event_message, gateway_target=gateway_target, diff --git a/src/dstack/_internal/server/compatibility/gateways.py b/src/dstack/_internal/server/compatibility/gateways.py new file mode 100644 index 0000000000..3e410b5a9c --- /dev/null +++ b/src/dstack/_internal/server/compatibility/gateways.py @@ -0,0 +1,15 @@ +from typing import Optional + +from packaging.version import Version + +from dstack._internal.core.models.gateways import Gateway + + +def patch_gateway(gateway: Gateway, client_version: Optional[Version]) -> None: + if client_version is None: + return + if client_version < Version("0.20.25") and len(gateway.replicas) < 2: + gateway.instance_id = "" + gateway.ip_address = gateway.replicas[0].hostname if gateway.replicas else "" + if gateway.hostname is None: + gateway.hostname = gateway.ip_address diff --git a/src/dstack/_internal/server/migrations/versions/2026/06_01_1911_b7609b94ea4d_add_gatewaycomputemodel_gateway_id.py b/src/dstack/_internal/server/migrations/versions/2026/06_01_1911_b7609b94ea4d_add_gatewaycomputemodel_gateway_id.py new file mode 100644 index 0000000000..2729699af1 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/06_01_1911_b7609b94ea4d_add_gatewaycomputemodel_gateway_id.py @@ -0,0 +1,52 @@ +"""Add GatewayComputeModel.gateway_id + +Revision ID: b7609b94ea4d +Revises: 201cb7ccd0d3 +Create Date: 2026-06-01 19:11:30.641417+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +# revision identifiers, used by Alembic. +revision = "b7609b94ea4d" +down_revision = "201cb7ccd0d3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("gateway_computes", schema=None) as batch_op: + batch_op.add_column( + sa.Column("replica_num", sa.Integer(), server_default="0", nullable=False) + ) + batch_op.add_column( + sa.Column( + "gateway_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.create_foreign_key( + batch_op.f("fk_gateway_computes_gateway_id_gateways"), + "gateways", + ["gateway_id"], + ["id"], + ondelete="SET NULL", + use_alter=True, + ) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("gateway_computes", schema=None) as batch_op: + batch_op.drop_constraint( + batch_op.f("fk_gateway_computes_gateway_id_gateways"), type_="foreignkey" + ) + batch_op.drop_column("gateway_id") + batch_op.drop_column("replica_num") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index d433244ea3..8d6f3c512c 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -629,7 +629,21 @@ class GatewayModel(PipelineModelMixin, BaseModel): gateway_compute_id: Mapped[Optional[uuid.UUID]] = mapped_column( ForeignKey("gateway_computes.id", ondelete="CASCADE") ) - gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship() + gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship( + foreign_keys=[gateway_compute_id] + ) + """ + Relationship with gateway computes for pre-0.20.25 gateways. + Use `get_gateway_compute_models()` for version-agnostic gateway compute retrieval. + """ + gateway_computes: Mapped[List["GatewayComputeModel"]] = relationship( + back_populates="gateway", + foreign_keys="GatewayComputeModel.gateway_id", + ) + """ + Relationship with gateway computes for 0.20.25+ gateways. + Use `get_gateway_compute_models()` for version-agnostic gateway compute retrieval. + """ runs: Mapped[List["RunModel"]] = relationship(back_populates="gateway") @@ -639,15 +653,26 @@ class GatewayModel(PipelineModelMixin, BaseModel): class GatewayComputeModel(BaseModel): + """A single gateway replica. + **TODO**: consider renaming to `GatewayReplicaModel`. + """ + __tablename__ = "gateway_computes" id: Mapped[uuid.UUID] = mapped_column( UUIDType(binary=False), primary_key=True, default=uuid.uuid4 ) created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + replica_num: Mapped[int] = mapped_column(Integer, server_default="0") instance_id: Mapped[str] = mapped_column(String(100)) ip_address: Mapped[str] = mapped_column(String(100)) + """Gateway replica IP address or domain name (e.g., k8s can use domain names). + **TODO**: rename. + """ hostname: Mapped[Optional[str]] = mapped_column(String(100)) + """Hostname of the gateway's load balancer. + **TODO**: move to `GatewayModel`. + """ configuration: Mapped[Optional[str]] = mapped_column(Text) """`configuration` is optional for compatibility with pre-0.18.2 gateways. Use `get_gateway_compute_configuration` to construct `configuration` for old gateways. @@ -655,6 +680,22 @@ class GatewayComputeModel(BaseModel): backend_data: Mapped[Optional[str]] = mapped_column(Text) region: Mapped[str] = mapped_column(String(100)) + gateway_id: Mapped[Optional[uuid.UUID]] = mapped_column( + ForeignKey( + "gateways.id", + ondelete="SET NULL", + use_alter=True, + ) + ) + gateway: Mapped[Optional["GatewayModel"]] = relationship( + back_populates="gateway_computes", + foreign_keys=[gateway_id], + ) + """ + Gateway. Can be None for pre-0.20.25 gateways, which use GatewayModel.gateway_compute_id to + establish the relationship. + """ + backend_id: Mapped[Optional[uuid.UUID]] = mapped_column( ForeignKey("backends.id", ondelete="CASCADE") ) diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py index eee99077c1..6b9a6718dd 100644 --- a/src/dstack/_internal/server/routers/gateways.py +++ b/src/dstack/_internal/server/routers/gateways.py @@ -1,6 +1,7 @@ from typing import List, Optional, Tuple from fastapi import APIRouter, Depends +from packaging.version import Version from sqlalchemy.ext.asyncio import AsyncSession import dstack._internal.core.models.gateways as models @@ -8,6 +9,7 @@ import dstack._internal.server.services.gateways as gateways from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.common import EntityReference +from dstack._internal.server.compatibility.gateways import patch_gateway from dstack._internal.server.db import get_session from dstack._internal.server.deps import Project from dstack._internal.server.models import ProjectModel, UserModel @@ -21,6 +23,7 @@ from dstack._internal.server.utils.routers import ( CustomORJSONResponse, get_base_api_additional_responses, + get_client_version, ) router = APIRouter( @@ -35,17 +38,19 @@ async def list_gateways( body: Optional[schemas.ListGatewaysRequest] = None, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()), + client_version: Optional[Version] = Depends(get_client_version), ): _, project = user_project if body is None: body = schemas.ListGatewaysRequest() - return CustomORJSONResponse( - await gateways.list_project_gateways( - session=session, - project=project, - include_imported=body.include_imported, - ) + gateway_list = await gateways.list_project_gateways( + session=session, + project=project, + include_imported=body.include_imported, ) + for gateway in gateway_list: + patch_gateway(gateway, client_version) + return CustomORJSONResponse(gateway_list) @router.post("/get", summary="Get gateway", response_model=models.Gateway) @@ -54,6 +59,7 @@ async def get_gateway( session: AsyncSession = Depends(get_session), user: UserModel = Depends(Authenticated()), project: ProjectModel = Depends(Project()), + client_version: Optional[Version] = Depends(get_client_version), ): await check_can_access_gateway( session=session, user=user, gateway_project=project, gateway_name=body.name @@ -61,6 +67,7 @@ async def get_gateway( gateway = await gateways.get_gateway_by_name(session=session, project=project, name=body.name) if gateway is None: raise ResourceNotExistsError() + patch_gateway(gateway, client_version) return CustomORJSONResponse(gateway) @@ -70,17 +77,18 @@ async def create_gateway( session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter), + client_version: Optional[Version] = Depends(get_client_version), ): user, project = user_project - return CustomORJSONResponse( - await gateways.create_gateway( - session=session, - user=user, - project=project, - configuration=body.configuration, - pipeline_hinter=pipeline_hinter, - ) + gateway = await gateways.create_gateway( + session=session, + user=user, + project=project, + configuration=body.configuration, + pipeline_hinter=pipeline_hinter, ) + patch_gateway(gateway, client_version) + return CustomORJSONResponse(gateway) @router.post("/delete", summary="Delete gateways") @@ -118,14 +126,15 @@ async def set_gateway_wildcard_domain( body: schemas.SetWildcardDomainRequest, session: AsyncSession = Depends(get_session), user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()), + client_version: Optional[Version] = Depends(get_client_version), ): user, project = user_project - return CustomORJSONResponse( - await gateways.set_gateway_wildcard_domain( - session=session, - project=project, - name=body.name, - wildcard_domain=body.wildcard_domain, - user=user, - ) + gateway = await gateways.set_gateway_wildcard_domain( + session=session, + project=project, + name=body.name, + wildcard_domain=body.wildcard_domain, + user=user, ) + patch_gateway(gateway, client_version) + return CustomORJSONResponse(gateway) diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py index 287117e2ed..bfd05cecf9 100644 --- a/src/dstack/_internal/server/services/gateways/__init__.py +++ b/src/dstack/_internal/server/services/gateways/__init__.py @@ -10,7 +10,7 @@ import httpx from sqlalchemy import exists, func, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, selectinload import dstack._internal.utils.random_names as random_names from dstack._internal.core.backends.base.compute import ( @@ -32,15 +32,19 @@ from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import EntityReference from dstack._internal.core.models.gateways import ( + GATEWAY_REPLICAS_DEFAULT, AnyGatewayRouterConfig, Gateway, GatewayComputeConfiguration, GatewayConfiguration, + GatewayReplica, GatewaySpec, GatewayStatus, LetsEncryptGatewayCertificate, ) from dstack._internal.core.services import validate_dstack_resource_name +from dstack._internal.proxy.gateway.const import SERVICE_SCALING_WINDOWS +from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats, Stat from dstack._internal.server import settings from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite from dstack._internal.server.models import ( @@ -130,6 +134,10 @@ def get_gateway_status_change_message( GATEWAY_CONNECT_DELAY = 10 GATEWAY_CONFIGURE_ATTEMPTS = 50 GATEWAY_CONFIGURE_DELAY = 3 +# Artificial limit to avoid doing too many per-replica operations (gateway replica provisioning, +# service registration, etc) in a single pipeline tick. Can be lifted once the implementation is +# more mature. +GATEWAY_MAX_REPLICAS = 3 # documented in gateways.md, keep in sync async def list_project_gateways( @@ -169,6 +177,8 @@ async def create_gateway_compute( project_name: str, backend_compute: Compute, configuration: GatewayConfiguration, + replica_num: int, + gateway_id: Optional[uuid.UUID] = None, backend_id: Optional[uuid.UUID] = None, ) -> GatewayComputeModel: assert isinstance(backend_compute, ComputeWithGatewaySupport) @@ -180,7 +190,7 @@ async def create_gateway_compute( compute_configuration = GatewayComputeConfiguration( project_name=project_name, - instance_name=configuration.name, + instance_name=f"{configuration.name}-{replica_num}", backend=configuration.backend, region=configuration.region, instance_type=configuration.instance_type, @@ -197,7 +207,9 @@ async def create_gateway_compute( ) return GatewayComputeModel( + gateway_id=gateway_id, backend_id=backend_id, + replica_num=replica_num, region=gpd.region, ip_address=gpd.ip_address, instance_id=gpd.instance_id, @@ -467,6 +479,7 @@ async def list_project_gateway_models( stmt = stmt.where(GatewayModel.project_id == project.id) if load_gateway_compute: stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) + stmt = stmt.options(selectinload(GatewayModel.gateway_computes)) if load_backend_type: stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) res = await session.execute(stmt) @@ -495,6 +508,7 @@ async def get_project_gateway_model_by_reference( ) if load_gateway_compute: stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) + stmt = stmt.options(selectinload(GatewayModel.gateway_computes)) if load_backend_type: stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) res = await session.execute(stmt) @@ -529,6 +543,7 @@ async def get_project_gateway_model_by_name_for_update( select(GatewayModel) .where(GatewayModel.id.in_([gateway_id]), *filters) .options(joinedload(GatewayModel.gateway_compute)) + .options(selectinload(GatewayModel.gateway_computes)) .options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) .with_for_update(key_share=True, of=GatewayModel) ) @@ -555,6 +570,7 @@ async def get_project_default_gateway_model( ) if load_gateway_compute: stmt = stmt.options(joinedload(GatewayModel.gateway_compute)) + stmt = stmt.options(selectinload(GatewayModel.gateway_computes)) if load_backend_type: stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type)) res = await session.execute(stmt) @@ -571,30 +587,73 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) -> # TODO: Connect to gateway outside session -async def get_or_add_gateway_connection( +async def get_or_add_gateway_connections( session: AsyncSession, gateway_id: uuid.UUID -) -> tuple[GatewayModel, GatewayConnection]: - gateway = await session.get( - GatewayModel, - gateway_id, - options=[joinedload(GatewayModel.gateway_compute)], - populate_existing=True, +) -> tuple[GatewayModel, List[GatewayConnection]]: + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id == gateway_id) + .options(joinedload(GatewayModel.gateway_compute)) + .options(selectinload(GatewayModel.gateway_computes)) ) + gateway = res.scalar_one_or_none() if gateway is None: raise GatewayError("Gateway not found") - if gateway.gateway_compute is None: + computes = get_gateway_compute_models(gateway) + if not computes: raise GatewayError("Gateway compute not found") + connections: List[GatewayConnection] = [] + for compute in computes: + try: + conn = await gateway_connections_pool.get_or_add( + hostname=compute.ip_address, + id_rsa=compute.ssh_private_key, + ) + connections.append(conn) + except Exception as e: + logger.warning("Failed to connect to gateway %s: %s", compute.ip_address, e) + raise GatewayError("Failed to connect to gateway") + return gateway, connections + + +async def get_combined_gateway_stats( + session: AsyncSession, + gateway_id: uuid.UUID, + project_name: str, + run_name: str, +) -> Optional[PerWindowStats]: + """ + Return stats for *run_name* aggregated across all replicas of *gateway_id*. + """ try: - conn = await gateway_connections_pool.get_or_add( - hostname=gateway.gateway_compute.ip_address, - id_rsa=gateway.gateway_compute.ssh_private_key, - ) - except Exception as e: - logger.warning( - "Failed to connect to gateway %s: %s", gateway.gateway_compute.ip_address, e + _, connections = await get_or_add_gateway_connections(session, gateway_id) + except GatewayError: + return None + per_replica: list[PerWindowStats] = [] + for conn in connections: + stats = await conn.get_stats(project_name, run_name) + if stats is None: # Stats not fetched yet + # TODO: find a way to make service scaling decisions even if some gateway replicas are + # unavailable for fetching stats. + return None + per_replica.append(stats) + return _merge_per_window_stats(per_replica) if per_replica else None + + +def _merge_per_window_stats(stats_per_gateway_replica: list[PerWindowStats]) -> PerWindowStats: + merged: PerWindowStats = {} + for window in SERVICE_SCALING_WINDOWS: + total_requests = 0 + total_time_of_all_requests = 0.0 + for gateway_replica_stats in stats_per_gateway_replica: + stat = gateway_replica_stats[window] + total_requests += stat.requests + total_time_of_all_requests += stat.requests * stat.request_time + merged[window] = Stat( + requests=total_requests, + request_time=(total_time_of_all_requests / total_requests if total_requests else 0.0), ) - raise GatewayError("Failed to connect to gateway") - return gateway, conn + return merged async def init_gateways(session: AsyncSession): @@ -732,6 +791,14 @@ async def configure_gateway( logger.info("Gateway %s configured", connection.ip_address) +def get_gateway_compute_models(gateway_model: GatewayModel) -> List[GatewayComputeModel]: + if gateway_model.gateway_computes: # 0.20.25+ gateway + return list(gateway_model.gateway_computes) + if gateway_model.gateway_compute is not None: # pre-0.20.25 gateway + return [gateway_model.gateway_compute] + return [] + + def get_gateway_configuration(gateway_model: GatewayModel) -> GatewayConfiguration: if gateway_model.configuration is not None: return GatewayConfiguration.__response__.parse_raw(gateway_model.configuration) @@ -746,22 +813,19 @@ def get_gateway_configuration(gateway_model: GatewayModel) -> GatewayConfigurati def get_gateway_compute_configuration( + gateway_compute: GatewayComputeModel, gateway_model: GatewayModel, -) -> Optional[GatewayComputeConfiguration]: - if gateway_model.gateway_compute is None: - return None - if gateway_model.gateway_compute.configuration is not None: - return GatewayComputeConfiguration.__response__.parse_raw( - gateway_model.gateway_compute.configuration - ) +) -> GatewayComputeConfiguration: + if gateway_compute.configuration is not None: + return GatewayComputeConfiguration.__response__.parse_raw(gateway_compute.configuration) # Handle gateways created before GatewayComputeConfiguration was introduced return GatewayComputeConfiguration( project_name=gateway_model.project.name, - instance_name=gateway_model.gateway_compute.instance_id, + instance_name=gateway_compute.instance_id, backend=gateway_model.backend.type, - region=gateway_model.gateway_compute.region, + region=gateway_compute.region, public_ip=True, - ssh_key_pub=gateway_model.gateway_compute.ssh_public_key, + ssh_key_pub=gateway_compute.ssh_public_key, certificate=LetsEncryptGatewayCertificate(), ) @@ -775,28 +839,34 @@ def gateway_model_to_gateway( default_gateway_id: ID of the default gateway in the project where `gateway_model` is being viewed. Can be different from `gateway_model.project` if the gateway is imported. """ - ip_address = "" - instance_id = "" - hostname = "" - if gateway_model.gateway_compute is not None: - ip_address = gateway_model.gateway_compute.ip_address - instance_id = gateway_model.gateway_compute.instance_id - hostname = gateway_model.gateway_compute.hostname - if hostname is None: - hostname = ip_address backend_type = gateway_model.backend.type if gateway_model.backend.type == BackendType.DSTACK: backend_type = BackendType.AWS is_default = default_gateway_id == gateway_model.id configuration = get_gateway_configuration(gateway_model) configuration.default = is_default + + compute_models = sorted(get_gateway_compute_models(gateway_model), key=lambda c: c.replica_num) + gateway_hostname = None + replicas = [] + for compute in compute_models: + compute_configuration = get_gateway_compute_configuration(compute, gateway_model) + replicas.append( + GatewayReplica( + hostname=compute.ip_address, + replica_num=compute.replica_num, + backend=compute_configuration.backend, + region=compute_configuration.region, + created_at=compute.created_at, + ) + ) + gateway_hostname = compute.hostname + return Gateway( id=gateway_model.id, name=gateway_model.name, project_name=gateway_model.project.name, - ip_address=ip_address, - instance_id=instance_id, - hostname=hostname, + hostname=gateway_hostname, backend=backend_type, region=gateway_model.region, wildcard_domain=gateway_model.wildcard_domain, @@ -805,6 +875,7 @@ def gateway_model_to_gateway( status=gateway_model.status, status_message=gateway_model.status_message, configuration=configuration, + replicas=replicas, ) @@ -838,6 +909,15 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration): f" {[b.value for b in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT]}." ) + replicas = ( + configuration.replicas if configuration.replicas is not None else GATEWAY_REPLICAS_DEFAULT + ) + + if replicas > GATEWAY_MAX_REPLICAS: + raise ServerClientError( + f"Cannot provision {replicas} gateway replicas. This server allows at most {GATEWAY_MAX_REPLICAS}" + ) + if configuration.certificate is not None: if configuration.certificate.type == "lets-encrypt" and not configuration.public_ip: raise ServerClientError( @@ -845,3 +925,13 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration): ) if configuration.certificate.type == "acm" and configuration.backend != BackendType.AWS: raise ServerClientError("acm certificate type is supported for aws backend only") + if replicas > 1: + raise ServerClientError( + "Replicated gateways do not support certificates." + " Set either `certificate: null` or `replicas: 1` in the gateway configuration" + ) + + if configuration.router is not None and replicas > 1: + raise ServerClientError( + "The deprecated `router` property is not supported for multi-replica gateways" + ) diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 273054e74f..b637683af7 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -32,8 +32,9 @@ from dstack._internal.server.models import GatewayModel, RunModel from dstack._internal.server.services import events from dstack._internal.server.services.gateways import ( + get_gateway_compute_models, get_gateway_configuration, - get_or_add_gateway_connection, + get_or_add_gateway_connections, get_project_default_gateway_model, get_project_gateway_model_by_reference, ) @@ -100,7 +101,7 @@ async def _register_service_in_gateway( ) -> ServiceSpec: assert run_spec.configuration.type == "service" - if gateway.gateway_compute is None: + if not get_gateway_compute_models(gateway): raise ServerClientError("Gateway has no instance associated with it") if gateway.status != GatewayStatus.RUNNING: @@ -178,50 +179,51 @@ async def _register_service_in_gateway( domain = service_spec.get_domain() assert domain is not None - _, conn = await get_or_add_gateway_connection(session, gateway.id) - try: - logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url) - async with conn.client() as client: - do_register = partial( - client.register_service, - project=run_model.project.name, - run_name=run_model.run_name, - domain=domain, - service_https=configure_service_https, - gateway_https=gateway_https, - auth=run_spec.configuration.auth, - client_max_body_size=settings.DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE, - options=service_spec.options, - rate_limits=run_spec.configuration.rate_limits, - ssh_private_key=run_model.project.ssh_private_key, - has_router_replica=has_replica_group_router, - router=router, - ) - try: - await do_register() - except GatewayError as e: - if e.msg == SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format( - ref=f"{run_model.project.name}/{run_model.run_name}" - ): - # Happens if there was a communication issue with the gateway when last unregistering - logger.warning( - "Service %s/%s is dangling on gateway %s, unregistering and re-registering", - run_model.project.name, - run_model.run_name, - gateway.name, - ) - await client.unregister_service( - project=run_model.project.name, - run_name=run_model.run_name, - ) + _, connections = await get_or_add_gateway_connections(session, gateway.id) + for conn in connections: + try: + logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url) + async with conn.client() as client: + do_register = partial( + client.register_service, + project=run_model.project.name, + run_name=run_model.run_name, + domain=domain, + service_https=configure_service_https, + gateway_https=gateway_https, + auth=run_spec.configuration.auth, + client_max_body_size=settings.DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE, + options=service_spec.options, + rate_limits=run_spec.configuration.rate_limits, + ssh_private_key=run_model.project.ssh_private_key, + has_router_replica=has_replica_group_router, + router=router, + ) + try: await do_register() - else: - raise - except SSHError: - raise ServerClientError("Gateway tunnel is not working") - except httpx.RequestError as e: - logger.debug("Gateway request failed", exc_info=True) - raise GatewayError(f"Gateway is not working: {e!r}") + except GatewayError as e: + if e.msg == SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format( + ref=f"{run_model.project.name}/{run_model.run_name}" + ): + # Happens if there was a communication issue with the gateway when last (un)registering + logger.warning( + "Service %s/%s is dangling on gateway replica %s, unregistering and re-registering", + run_model.project.name, + run_model.run_name, + conn.ip_address, + ) + await client.unregister_service( + project=run_model.project.name, + run_name=run_model.run_name, + ) + await do_register() + else: + raise + except SSHError: + raise ServerClientError("Gateway tunnel is not working") + except httpx.RequestError as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(f"Gateway is not working: {e!r}") events.emit( session, diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 6c8b7233f6..2c0a66be5a 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -638,7 +638,6 @@ async def create_gateway( name: str = "test_gateway", region: str = "us", wildcard_domain: Optional[str] = None, - gateway_compute_id: Optional[UUID] = None, status: Optional[GatewayStatus] = GatewayStatus.SUBMITTED, last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), forbid_new_services: bool = False, @@ -649,7 +648,6 @@ async def create_gateway( name=name, region=region, wildcard_domain=wildcard_domain, - gateway_compute_id=gateway_compute_id, status=status, last_processed_at=last_processed_at, forbid_new_services=forbid_new_services, @@ -661,6 +659,7 @@ async def create_gateway( async def create_gateway_compute( session: AsyncSession, + gateway_id: Optional[UUID] = None, backend_id: Optional[UUID] = None, ip_address: Optional[str] = "1.1.1.1", region: str = "us", @@ -669,6 +668,7 @@ async def create_gateway_compute( ssh_public_key: str = "", ) -> GatewayComputeModel: gateway_compute = GatewayComputeModel( + gateway_id=gateway_id, backend_id=backend_id, ip_address=ip_address, region=region, diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py index d1113c90d1..2759d8a236 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py @@ -6,10 +6,15 @@ import pytest from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import selectinload from dstack._internal.core.errors import BackendError -from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.gateways import ( + GatewayConfiguration, + GatewayProvisioningData, + GatewayStatus, +) from dstack._internal.server.background.pipeline_tasks.gateways import ( GatewayFetcher, GatewayPipeline, @@ -257,12 +262,12 @@ async def test_submitted_to_provisioning( res = await session.execute( select(GatewayModel) .where(GatewayModel.id == gateway.id) - .options(joinedload(GatewayModel.gateway_compute)) + .options(selectinload(GatewayModel.gateway_computes)) ) gateway = res.unique().scalar_one() assert gateway.status == GatewayStatus.PROVISIONING - assert gateway.gateway_compute is not None - assert gateway.gateway_compute.ip_address == "2.2.2.2" + assert len(gateway.gateway_computes) > 0 + assert gateway.gateway_computes[0].ip_address == "2.2.2.2" events = await list_events(session) assert len(events) == 1 assert events[0].message == "Gateway status changed SUBMITTED -> PROVISIONING" @@ -300,23 +305,129 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors( assert len(events) == 1 assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)" + async def test_submitted_creates_multiple_computes_for_multi_replica( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.SUBMITTED, + ) + config = GatewayConfiguration( + name=gateway.name, + backend=BackendType.AWS, + region=gateway.region, + replicas=2, + ) + gateway.configuration = config.json() + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" + ) as m: + aws = Mock() + m.return_value = (backend, aws) + aws.compute.return_value = Mock(spec=ComputeMockSpec) + aws.compute.return_value.create_gateway.side_effect = [ + GatewayProvisioningData(instance_id="i-aaa", ip_address="2.2.2.2", region="us"), + GatewayProvisioningData(instance_id="i-bbb", ip_address="3.3.3.3", region="us"), + ] + await worker.process(_gateway_to_pipeline_item(gateway)) + assert aws.compute.return_value.create_gateway.call_count == 2 + + await session.refresh(gateway) + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id == gateway.id) + .options(selectinload(GatewayModel.gateway_computes)) + ) + gateway = res.unique().scalar_one() + assert gateway.status == GatewayStatus.PROVISIONING + computes = sorted(gateway.gateway_computes, key=lambda c: c.replica_num) + assert len(computes) == 2 + assert computes[0].ip_address == "2.2.2.2" + assert computes[0].replica_num == 0 + assert computes[1].ip_address == "3.3.3.3" + assert computes[1].replica_num == 1 + + async def test_marks_gateway_as_failed_if_second_replica_creation_errors( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.SUBMITTED, + ) + config = GatewayConfiguration( + name=gateway.name, + backend=BackendType.AWS, + region=gateway.region, + replicas=2, + ) + gateway.configuration = config.json() + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error" + ) as m: + aws = Mock() + m.return_value = (backend, aws) + aws.compute.return_value = Mock(spec=ComputeMockSpec) + aws.compute.return_value.create_gateway.side_effect = [ + GatewayProvisioningData(instance_id="i-aaa", ip_address="2.2.2.2", region="us"), + BackendError("Some error"), + ] + await worker.process(_gateway_to_pipeline_item(gateway)) + assert aws.compute.return_value.create_gateway.call_count == 2 + + await session.refresh(gateway) + res = await session.execute( + select(GatewayModel) + .where(GatewayModel.id == gateway.id) + .options(selectinload(GatewayModel.gateway_computes)) + ) + gateway = res.unique().scalar_one() + assert gateway.status == GatewayStatus.FAILED + assert gateway.status_message == "Some error" + # The first replica's compute is saved even though the second failed + assert len(gateway.gateway_computes) == 1 + assert gateway.gateway_computes[0].ip_address == "2.2.2.2" + assert gateway.gateway_computes[0].replica_num == 0 + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)" + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestGatewayWorkerProvisioning: + @pytest.mark.parametrize("legacy_compute", [False, True]) async def test_provisioning_to_running( - self, test_db, session: AsyncSession, worker: GatewayWorker + self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.PROVISIONING, ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + await create_gateway_compute(session, gateway_id=gateway.id) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() @@ -335,19 +446,57 @@ async def test_provisioning_to_running( assert len(events) == 1 assert events[0].message == "Gateway status changed PROVISIONING -> RUNNING" - async def test_marks_gateway_as_failed_if_fails_to_connect( + async def test_provisioning_to_running_with_multiple_replicas( self, test_db, session: AsyncSession, worker: GatewayWorker ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.PROVISIONING, ) + await create_gateway_compute(session, gateway_id=gateway.id, ip_address="1.1.1.1") + compute1 = await create_gateway_compute( + session, gateway_id=gateway.id, ip_address="2.2.2.2" + ) + compute1.replica_num = 1 + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add" + ) as pool_add: + pool_add.return_value = MagicMock() + pool_add.return_value.client.return_value = MagicMock(AsyncContextManager()) + await worker.process(_gateway_to_pipeline_item(gateway)) + assert pool_add.call_count == 2 + + await session.refresh(gateway) + assert gateway.status == GatewayStatus.RUNNING + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway status changed PROVISIONING -> RUNNING" + + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_marks_gateway_as_failed_if_fails_to_connect( + self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + gateway_compute = await create_gateway_compute(session, gateway_id=gateway.id) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) await session.commit() @@ -360,8 +509,55 @@ async def test_marks_gateway_as_failed_if_fails_to_connect( connect_to_gateway_with_retry_mock.assert_called_once() await session.refresh(gateway) + await session.refresh(gateway_compute) assert gateway.status == GatewayStatus.FAILED assert gateway.status_message == "Failed to connect to gateway" + assert gateway_compute.active is False + events = await list_events(session) + assert len(events) == 1 + assert ( + events[0].message + == "Gateway status changed PROVISIONING -> FAILED (Failed to connect to gateway)" + ) + + async def test_marks_gateway_as_failed_if_any_replica_fails_to_connect( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.PROVISIONING, + ) + compute0 = await create_gateway_compute( + session, gateway_id=gateway.id, ip_address="1.1.1.1" + ) + compute1 = await create_gateway_compute( + session, gateway_id=gateway.id, ip_address="2.2.2.2" + ) + compute1.replica_num = 1 + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + await session.commit() + + with patch( + "dstack._internal.server.services.gateways.connect_to_gateway_with_retry" + ) as connect_mock: + connect_mock.return_value = None + await worker.process(_gateway_to_pipeline_item(gateway)) + assert connect_mock.call_count == 2 + + await session.refresh(gateway) + assert gateway.status == GatewayStatus.FAILED + assert gateway.status_message == "Failed to connect to gateway" + + await session.refresh(compute0) + await session.refresh(compute1) + assert compute0.active is False + assert compute1.active is False + events = await list_events(session) assert len(events) == 1 assert ( @@ -373,19 +569,25 @@ async def test_marks_gateway_as_failed_if_fails_to_connect( @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) class TestGatewayWorkerDeleted: + @pytest.mark.parametrize("legacy_compute", [False, True]) async def test_deletes_gateway_and_marks_compute_deleted( - self, test_db, session: AsyncSession, worker: GatewayWorker + self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + gateway_compute = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id + ) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) gateway.to_be_deleted = True @@ -418,19 +620,25 @@ async def test_deletes_gateway_and_marks_compute_deleted( assert len(events) == 1 assert events[0].message == "Gateway deleted" + @pytest.mark.parametrize("legacy_compute", [False, True]) async def test_keeps_gateway_if_terminate_fails( - self, test_db, session: AsyncSession, worker: GatewayWorker + self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool ): project = await create_project(session=session) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + gateway_compute = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id + ) gateway.lock_token = uuid.uuid4() gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) gateway.lock_owner = "GatewayPipeline" @@ -470,3 +678,112 @@ async def test_keeps_gateway_if_terminate_fails( assert gateway_compute.deleted is False events = await list_events(session) assert len(events) == 0 + + async def test_deletes_gateway_with_multiple_replicas( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.RUNNING, + ) + compute0 = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="1.1.1.1" + ) + compute1 = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="2.2.2.2" + ) + compute1.replica_num = 1 + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + gateway.to_be_deleted = True + await session.commit() + + with ( + patch( + "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" + ) as get_backend_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove" + ) as remove_connection_mock, + ): + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + get_backend_mock.return_value = backend_mock + + await worker.process(_gateway_to_pipeline_item(gateway)) + + assert backend_mock.compute.return_value.terminate_gateway.call_count == 2 + assert remove_connection_mock.call_count == 2 + + await session.refresh(compute0) + await session.refresh(compute1) + res = await session.execute(select(GatewayModel.id).where(GatewayModel.id == gateway.id)) + assert res.scalar_one_or_none() is None + assert compute0.active is False + assert compute0.deleted is True + assert compute1.active is False + assert compute1.deleted is True + events = await list_events(session) + assert len(events) == 1 + assert events[0].message == "Gateway deleted" + + async def test_keeps_gateway_if_second_replica_terminate_fails( + self, test_db, session: AsyncSession, worker: GatewayWorker + ): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + status=GatewayStatus.RUNNING, + ) + compute0 = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="1.1.1.1" + ) + compute1 = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="2.2.2.2" + ) + compute1.replica_num = 1 + gateway.lock_token = uuid.uuid4() + gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc) + gateway.lock_owner = "GatewayPipeline" + gateway.to_be_deleted = True + original_last_processed_at = gateway.last_processed_at + await session.commit() + + with ( + patch( + "dstack._internal.server.services.backends.get_project_backend_by_type_or_error" + ) as get_backend_mock, + patch( + "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove" + ) as remove_connection_mock, + ): + backend_mock = Mock() + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.terminate_gateway.side_effect = [ + None, + BackendError("Terminate failed"), + ] + get_backend_mock.return_value = backend_mock + + await worker.process(_gateway_to_pipeline_item(gateway)) + + assert backend_mock.compute.return_value.terminate_gateway.call_count == 2 + remove_connection_mock.assert_called_once_with(compute0.ip_address) + + await session.refresh(gateway) + await session.refresh(compute0) + await session.refresh(compute1) + assert gateway.to_be_deleted is True + assert gateway.last_processed_at > original_last_processed_at + assert gateway.lock_token is None + assert gateway.lock_expires_at is None + assert gateway.lock_owner is None + assert compute0.deleted is False + assert compute1.deleted is False diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index e308b89ce8..85aa00e0b6 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -1892,19 +1892,19 @@ async def test_registers_service_replica_in_gateway( project = await create_project(session=session, owner=user) repo = await create_repo(session=session, project_id=project.id) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="test-gateway", wildcard_domain="example.com", ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ) run = await create_run( session=session, project=project, @@ -1985,19 +1985,19 @@ async def test_registers_service_replica_in_gateway_when_running_on_imported_ins ) repo = await create_repo(session=session, project_id=importer_project.id) backend = await create_backend(session=session, project_id=importer_project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=importer_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="test-gateway", wildcard_domain="example.com", ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ) run = await create_run( session=session, project=importer_project, diff --git a/src/tests/_internal/server/compatibility/__init__.py b/src/tests/_internal/server/compatibility/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/server/compatibility/test_gateways.py b/src/tests/_internal/server/compatibility/test_gateways.py new file mode 100644 index 0000000000..4bbd2a80d1 --- /dev/null +++ b/src/tests/_internal/server/compatibility/test_gateways.py @@ -0,0 +1,88 @@ +import uuid +from datetime import datetime, timezone + +from packaging.version import Version + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.gateways import ( + Gateway, + GatewayConfiguration, + GatewayReplica, + GatewayStatus, +) +from dstack._internal.server.compatibility.gateways import patch_gateway +from dstack._internal.utils.common import get_current_datetime + +_CREATED_AT = datetime(2025, 1, 1, tzinfo=timezone.utc) +_CONFIG = GatewayConfiguration(name="gw", backend=BackendType.AWS, region="us") + + +def _make_gateway_replica(hostname: str = "1.2.3.4") -> GatewayReplica: + return GatewayReplica( + hostname=hostname, + replica_num=0, + backend=BackendType.AWS, + region="us", + created_at=get_current_datetime(), + ) + + +def _make_gateway(replicas=None, hostname=None) -> Gateway: + return Gateway( + id=uuid.uuid4(), + name="test", + project_name="proj", + backend=BackendType.AWS, + region="us", + created_at=_CREATED_AT, + status=GatewayStatus.RUNNING, + status_message=None, + hostname=hostname, + wildcard_domain=None, + default=False, + replicas=replicas or [], + configuration=_CONFIG, + ) + + +class TestPatchGateway: + def test_none_version_is_noop(self): + replica = _make_gateway_replica("1.2.3.4") + gw = _make_gateway(replicas=[replica]) + patch_gateway(gw, None) + assert gw.ip_address is None + assert gw.instance_id is None + assert gw.hostname is None + + def test_new_version_is_noop(self): + replica = _make_gateway_replica("1.2.3.4") + gw = _make_gateway(replicas=[replica]) + patch_gateway(gw, Version("0.20.25")) + assert gw.ip_address is None + assert gw.instance_id is None + + def test_old_version_fills_hostname_from_replica(self): + replica = _make_gateway_replica("1.2.3.4") + gw = _make_gateway(replicas=[replica], hostname=None) + patch_gateway(gw, Version("0.20.24")) + assert gw.hostname == "1.2.3.4" + + def test_old_version_keeps_existing_hostname(self): + replica = _make_gateway_replica("1.2.3.4") + gw = _make_gateway(replicas=[replica], hostname="lb.example.com") + patch_gateway(gw, Version("0.20.24")) + assert gw.hostname == "lb.example.com" + + def test_old_version_no_replicas_sets_empty_strings(self): + gw = _make_gateway(replicas=[]) + patch_gateway(gw, Version("0.20.24")) + assert gw.ip_address == "" + assert gw.instance_id == "" + assert gw.hostname == "" + + def test_old_version_multi_replica_is_noop(self): + replicas = [_make_gateway_replica("1.2.3.4"), _make_gateway_replica("5.6.7.8")] + gw = _make_gateway(replicas=replicas) + patch_gateway(gw, Version("0.20.24")) + assert gw.ip_address is None + assert gw.instance_id is None diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 5cc6bdd715..075d1f6d4a 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -1,3 +1,4 @@ +from typing import Any from unittest.mock import patch import pytest @@ -29,23 +30,29 @@ async def test_returns_40x_if_not_authenticated(self, client: AsyncClient): @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_list( + self, test_db, session: AsyncSession, client: AsyncClient, legacy_compute: bool + ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) await add_project_member( session=session, project=project, user=user, project_role=ProjectRole.USER ) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + gateway_compute = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id + ) + await session.commit() response = await client.post( f"/api/project/{project.name}/gateways/list", headers=get_auth_headers(user.token), @@ -60,9 +67,18 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): "default": False, "status": "submitted", "status_message": None, - "instance_id": gateway_compute.instance_id, - "ip_address": gateway_compute.ip_address, - "hostname": gateway_compute.ip_address, + "replicas": [ + { + "hostname": gateway_compute.ip_address, + "replica_num": 0, + "backend": backend.type.value, + "region": "us", + "created_at": response.json()[0]["replicas"][0]["created_at"], + } + ], + "instance_id": None, + "ip_address": None, + "hostname": None, "name": gateway.name, "region": gateway.region, "wildcard_domain": gateway.wildcard_domain, @@ -78,29 +94,36 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } ] @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): + @pytest.mark.parametrize("legacy_compute", [False, True]) + async def test_get( + self, test_db, session: AsyncSession, client: AsyncClient, legacy_compute: bool + ): user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session) await add_project_member( session=session, project=project, user=user, project_role=ProjectRole.USER ) backend = await create_backend(session, project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, ) + if legacy_compute: + gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style + else: + gateway_compute = await create_gateway_compute( + session=session, backend_id=backend.id, gateway_id=gateway.id + ) + await session.commit() response = await client.post( f"/api/project/{project.name}/gateways/get", json={"name": gateway.name}, @@ -115,9 +138,18 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): "default": False, "status": "submitted", "status_message": None, - "instance_id": gateway_compute.instance_id, - "ip_address": gateway_compute.ip_address, - "hostname": gateway_compute.ip_address, + "replicas": [ + { + "hostname": gateway_compute.ip_address, + "replica_num": 0, + "backend": backend.type.value, + "region": "us", + "created_at": response.json()["replicas"][0]["created_at"], + } + ], + "instance_id": None, + "ip_address": None, + "hostname": None, "name": gateway.name, "region": gateway.region, "wildcard_domain": gateway.wildcard_domain, @@ -133,26 +165,60 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_list_non_member_public_project( + async def test_list_legacy_client_populates_compat_fields( self, test_db, session: AsyncSession, client: AsyncClient ): + """Old clients (< 0.20.25) get ip_address/instance_id/hostname back-filled.""" user = await create_user(session, global_role=GlobalRole.USER) - project = await create_project(session, is_public=True) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, + project_id=project.id, + backend_id=backend.id, + ) gateway_compute = await create_gateway_compute( session=session, backend_id=backend.id, + gateway_id=gateway.id, ) + response = await client.post( + f"/api/project/{project.name}/gateways/list", + headers={**get_auth_headers(user.token), "x-api-version": "0.20.24"}, + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + gw = response.json()[0] + assert gw["ip_address"] == gateway_compute.ip_address + assert gw["instance_id"] == "" + assert gw["hostname"] == gateway_compute.ip_address + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_list_non_member_public_project( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session, is_public=True) + backend = await create_backend(session=session, project_id=project.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, + ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, ) response = await client.post( f"/api/project/{project.name}/gateways/list", @@ -170,15 +236,15 @@ async def test_get_non_member_public_project( user = await create_user(session, global_role=GlobalRole.USER) project = await create_project(session, is_public=True) backend = await create_backend(session, project.id) - gateway_compute = await create_gateway_compute( + gateway = await create_gateway( session=session, + project_id=project.id, backend_id=backend.id, ) - gateway = await create_gateway( + await create_gateway_compute( session=session, - project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, + gateway_id=gateway.id, ) response = await client.post( f"/api/project/{project.name}/gateways/get", @@ -222,14 +288,13 @@ async def test_list_returns_imported_gateway_with_include_imported( project_role=ProjectRole.ADMIN, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -266,14 +331,13 @@ async def test_list_not_returns_imported_gateway_without_include_imported( project_role=ProjectRole.ADMIN, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -308,14 +372,13 @@ async def test_get_returns_imported_gateway( project_role=ProjectRole.ADMIN, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -356,14 +419,13 @@ async def test_get_returns_403_on_foreign_gateway_if_not_imported( project_role=ProjectRole.USER, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -426,9 +488,10 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn "region": "us", "status": "submitted", "status_message": None, - "instance_id": "", - "ip_address": "", - "hostname": "", + "replicas": [], + "instance_id": None, + "ip_address": None, + "hostname": None, "wildcard_domain": None, "default": True, "created_at": response.json()["created_at"], @@ -444,11 +507,43 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } events = await list_events(session) assert events[0].message == "Gateway created. Status: SUBMITTED" + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_create_multi_replica_gateway( + self, test_db, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + await create_backend(session, project.id, backend_type=BackendType.AWS) + response = await client.post( + f"/api/project/{project.name}/gateways/create", + json={ + "configuration": { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "replicas": 2, + "certificate": None, + }, + }, + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + assert response.json()["configuration"]["replicas"] == 2 + assert response.json()["replicas"] == [] # populated later by pipelines + events = await list_events(session) + assert events[0].message == "Gateway created. Status: SUBMITTED" + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_create_gateway_without_name( @@ -484,9 +579,10 @@ async def test_create_gateway_without_name( "region": "us", "status": "submitted", "status_message": None, - "instance_id": "", - "ip_address": "", - "hostname": "", + "replicas": [], + "instance_id": None, + "ip_address": None, + "hostname": None, "wildcard_domain": None, "default": True, "created_at": response.json()["created_at"], @@ -502,6 +598,7 @@ async def test_create_gateway_without_name( "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } events = await list_events(session) @@ -583,6 +680,100 @@ async def test_create_gateway_with_invalid_domain_interpolation( ) assert response.status_code == 400 + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + @pytest.mark.parametrize( + "configuration, expected_error", + [ + pytest.param( + { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "domain": "${{ run.unknown_variable }}.example.com", + }, + "Cannot interpolate gateway domain name: Failed to interpolate due to missing vars: ['run.unknown_variable']", + id="invalid-domain-interpolation", + ), + pytest.param( + { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "certificate": { + "type": "acm", + "arn": "arn:aws:acm:us-east-1:123456789:certificate/abc", + }, + "replicas": 2, + }, + "Replicated gateways do not support certificates." + " Set either `certificate: null` or `replicas: 1` in the gateway configuration", + id="multi-replica-with-acm-cert", + ), + pytest.param( + { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "certificate": {"type": "lets-encrypt"}, + "replicas": 2, + }, + "Replicated gateways do not support certificates." + " Set either `certificate: null` or `replicas: 1` in the gateway configuration", + id="multi-replica-with-letsencrypt-cert", + ), + pytest.param( + { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "certificate": None, + "router": {"type": "sglang"}, + "replicas": 2, + }, + "The deprecated `router` property is not supported for multi-replica gateways", + id="multi-replica-with-router", + ), + pytest.param( + { + "type": "gateway", + "name": "test", + "backend": "aws", + "region": "us", + "certificate": None, + "replicas": 4, + }, + "Cannot provision 4 gateway replicas. This server allows at most 3", + id="replicas-exceed-max", + ), + ], + ) + async def test_invalid_configuration_rejected( + self, + test_db, + session: AsyncSession, + client: AsyncClient, + configuration: dict[str, Any], + expected_error: str, + ): + user = await create_user(session, global_role=GlobalRole.USER) + project = await create_project(session) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + await create_backend(session, project.id, backend_type=BackendType.AWS) + response = await client.post( + f"/api/project/{project.name}/gateways/create", + json={"configuration": configuration}, + headers=get_auth_headers(user.token), + ) + assert response.status_code == 400 + assert response.json()["detail"][0]["msg"] == expected_error + class TestDefaultGateway: @pytest.mark.asyncio @@ -613,17 +804,17 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: session=session, project=project, user=user, project_role=ProjectRole.ADMIN ) backend = await create_backend(session, project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="first_gateway", ) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ) response = await client.post( f"/api/project/{project.name}/gateways/set_default", json={"name": gateway.name}, @@ -645,9 +836,18 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "default": True, "status": "submitted", "status_message": None, - "instance_id": gateway_compute.instance_id, - "ip_address": gateway_compute.ip_address, - "hostname": gateway_compute.ip_address, + "replicas": [ + { + "hostname": gateway_compute.ip_address, + "replica_num": 0, + "backend": backend.type.value, + "region": "us", + "created_at": response.json()["replicas"][0]["created_at"], + } + ], + "instance_id": None, + "ip_address": None, + "hostname": None, "name": gateway.name, "region": gateway.region, "wildcard_domain": gateway.wildcard_domain, @@ -663,23 +863,24 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } events = await list_events(session) assert len(events) == 1 assert events[0].message == "Gateway set as project default" - second_gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) second_gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=second_gateway_compute.id, name="second_gateway", ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=second_gateway.id, + ) await clear_events(session) response = await client.post( f"/api/project/{project.name}/gateways/set_default", @@ -775,14 +976,13 @@ async def test_set_imported_gateway_as_default( project_role=ProjectRole.ADMIN, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -820,14 +1020,13 @@ async def test_cannot_set_non_imported_foreign_gateway_as_default( project_role=ProjectRole.ADMIN, ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, name="exported-gateway", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) await create_export( session=session, exporter_project=exporter_project, @@ -872,27 +1071,27 @@ async def test_marks_gateways_to_be_deleted( ) backend_aws = await create_backend(session, project.id) backend_gcp = await create_backend(session, project.id, backend_type=BackendType.GCP) - gateway_compute_aws = await create_gateway_compute( - session=session, - backend_id=backend_aws.id, - ) gateway_aws = await create_gateway( session=session, project_id=project.id, backend_id=backend_aws.id, name="gateway-aws", - gateway_compute_id=gateway_compute_aws.id, ) - gateway_compute_gcp = await create_gateway_compute( + gateway_compute_aws = await create_gateway_compute( session=session, - backend_id=backend_gcp.id, + backend_id=backend_aws.id, + gateway_id=gateway_aws.id, ) gateway_gcp = await create_gateway( session=session, project_id=project.id, backend_id=backend_gcp.id, name="gateway-gcp", - gateway_compute_id=gateway_compute_gcp.id, + ) + gateway_compute_gcp = await create_gateway_compute( + session=session, + backend_id=backend_gcp.id, + gateway_id=gateway_gcp.id, ) response = await client.post( f"/api/project/{project.name}/gateways/delete", @@ -991,17 +1190,17 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: session=session, project=project, user=user, project_role=ProjectRole.ADMIN ) backend = await create_backend(session, project.id) - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, wildcard_domain="old.example", ) + gateway_compute = await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ) response = await client.post( f"/api/project/{project.name}/gateways/set_wildcard_domain", json={"name": gateway.name, "wildcard_domain": "new.example"}, @@ -1016,9 +1215,18 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "status": "submitted", "status_message": None, "default": False, - "instance_id": gateway_compute.instance_id, - "ip_address": gateway_compute.ip_address, - "hostname": gateway_compute.ip_address, + "replicas": [ + { + "hostname": gateway_compute.ip_address, + "replica_num": 0, + "backend": backend.type.value, + "region": "us", + "created_at": response.json()["replicas"][0]["created_at"], + } + ], + "instance_id": None, + "ip_address": None, + "hostname": None, "name": gateway.name, "region": gateway.region, "wildcard_domain": "new.example", @@ -1034,6 +1242,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "public_ip": True, "certificate": {"type": "lets-encrypt"}, "tags": None, + "replicas": None, }, } events = await list_events(session) diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 93414c9f42..01bc19f5eb 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -3748,19 +3748,19 @@ async def test_submit_to_correct_proxy( repo = await create_repo(session=session, project_id=project.id) backend = await create_backend(session=session, project_id=project.id) for gateway_name, is_default in existing_gateways: - gateway_compute = await create_gateway_compute( - session=session, - backend_id=backend.id, - ) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name=gateway_name, wildcard_domain=f"{gateway_name}.example", ) + await create_gateway_compute( + session=session, + backend_id=backend.id, + gateway_id=gateway.id, + ) if is_default: project.default_gateway_id = gateway.id await session.commit() @@ -3844,16 +3844,15 @@ async def test_submit_to_foreign_gateway_only_if_imported( session=session, owner=exporter_user, name="exporter-project" ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="exported-gateway", wildcard_domain="exported-gateway.example", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) importer_user = await create_user( session=session, global_role=GlobalRole.USER, name="importer_user" @@ -3929,14 +3928,13 @@ async def test_not_submits_to_default_gateway_if_not_imported( user = await create_user(session=session, global_role=GlobalRole.USER) gateway_project = await create_project(session=session, owner=user, name="gateway-project") backend = await create_backend(session=session, project_id=gateway_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=gateway_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) service_project = await create_project(session=session, owner=user, name="service-project") # The project's default_gateway_id may point to the gateway (e.g., if the gateway was @@ -3982,16 +3980,15 @@ async def test_interpolates_project_name_in_imported_gateway_domain( session=session, owner=exporter_user, name="exporter-project" ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="exported-gateway", wildcard_domain="${{ run.project_name }}.example.com", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) importer_user = await create_user( session=session, global_role=GlobalRole.USER, name="importer_user" @@ -4041,16 +4038,15 @@ async def test_returns_error_if_imported_gateway_domain_has_unknown_variable( session=session, owner=exporter_user, name="exporter-project" ) backend = await create_backend(session=session, project_id=exporter_project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=exporter_project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="exported-gateway", wildcard_domain="${{ run.unknown_variable }}.example.com", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) importer_user = await create_user( session=session, global_role=GlobalRole.USER, name="importer_user" @@ -4108,15 +4104,14 @@ async def test_unregister_dangling_service( ) repo = await create_repo(session=session, project_id=project.id) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, wildcard_domain="example.com", ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) project.default_gateway_id = gateway.id await session.commit() @@ -4158,16 +4153,15 @@ async def test_return_error_if_default_gateway_forbids_new_services( ) repo = await create_repo(session=session, project_id=project.id) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, wildcard_domain="example.com", forbid_new_services=True, ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) project.default_gateway_id = gateway.id await session.commit() @@ -4196,17 +4190,16 @@ async def test_return_error_if_explicitly_specified_gateway_forbids_new_services ) repo = await create_repo(session=session, project_id=project.id) backend = await create_backend(session=session, project_id=project.id) - gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id) - await create_gateway( + gateway = await create_gateway( session=session, project_id=project.id, backend_id=backend.id, - gateway_compute_id=gateway_compute.id, status=GatewayStatus.RUNNING, name="restricted-gateway", wildcard_domain="example.com", forbid_new_services=True, ) + await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id) response = await client.post( "/api/project/test-project/runs/submit", diff --git a/src/tests/_internal/server/services/gateways/__init__.py b/src/tests/_internal/server/services/gateways/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/tests/_internal/server/services/gateways/test_gateways.py b/src/tests/_internal/server/services/gateways/test_gateways.py new file mode 100644 index 0000000000..aaf8fe6d52 --- /dev/null +++ b/src/tests/_internal/server/services/gateways/test_gateways.py @@ -0,0 +1,88 @@ +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.proxy.gateway.const import SERVICE_SCALING_WINDOWS +from dstack._internal.proxy.gateway.schemas.stats import Stat +from dstack._internal.server.services.gateways import ( + _merge_per_window_stats, + get_gateway_compute_models, +) +from dstack._internal.server.testing.common import ( + create_backend, + create_gateway, + create_gateway_compute, + create_project, +) + + +class TestMergePerWindowStats: + def test_empty_returns_zero_stats(self): + result = _merge_per_window_stats([]) + for window in SERVICE_SCALING_WINDOWS: + assert result[window].requests == 0 + assert result[window].request_time == 0.0 + + def test_single_replica_returns_same_values(self): + stats = {w: Stat(requests=10, request_time=0.5) for w in SERVICE_SCALING_WINDOWS} + result = _merge_per_window_stats([stats]) + for window in SERVICE_SCALING_WINDOWS: + assert result[window].requests == 10 + assert result[window].request_time == pytest.approx(0.5) + + def test_multiple_replicas_sums_requests_and_averages_time(self): + stats_a = {w: Stat(requests=10, request_time=1.0) for w in SERVICE_SCALING_WINDOWS} + stats_b = {w: Stat(requests=30, request_time=3.0) for w in SERVICE_SCALING_WINDOWS} + result = _merge_per_window_stats([stats_a, stats_b]) + for window in SERVICE_SCALING_WINDOWS: + assert result[window].requests == 40 + assert result[window].request_time == pytest.approx(2.5) # (10*1 + 30*3) / 40 + + def test_zero_requests_across_all_replicas_returns_zero_time(self): + stats_a = {w: Stat(requests=0, request_time=0.0) for w in SERVICE_SCALING_WINDOWS} + stats_b = {w: Stat(requests=0, request_time=0.0) for w in SERVICE_SCALING_WINDOWS} + result = _merge_per_window_stats([stats_a, stats_b]) + for window in SERVICE_SCALING_WINDOWS: + assert result[window].requests == 0 + assert result[window].request_time == 0.0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +class TestGetGatewayComputeModels: + async def test_new_style_returns_gateway_computes(self, test_db, session: AsyncSession): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, project_id=project.id, backend_id=backend.id + ) + compute = await create_gateway_compute( + session=session, gateway_id=gateway.id, backend_id=backend.id + ) + await session.refresh(gateway, ["gateway_computes", "gateway_compute"]) + result = get_gateway_compute_models(gateway) + assert len(result) == 1 + assert result[0].id == compute.id + + async def test_old_style_returns_single_compute(self, test_db, session: AsyncSession): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + compute = await create_gateway_compute(session=session, backend_id=backend.id) + gateway = await create_gateway( + session=session, project_id=project.id, backend_id=backend.id + ) + gateway.gateway_compute_id = compute.id + await session.commit() + await session.refresh(gateway, ["gateway_computes", "gateway_compute"]) + result = get_gateway_compute_models(gateway) + assert len(result) == 1 + assert result[0].id == compute.id + + async def test_no_computes_returns_empty(self, test_db, session: AsyncSession): + project = await create_project(session=session) + backend = await create_backend(session=session, project_id=project.id) + gateway = await create_gateway( + session=session, project_id=project.id, backend_id=backend.id + ) + await session.refresh(gateway, ["gateway_computes", "gateway_compute"]) + result = get_gateway_compute_models(gateway) + assert result == []