diff --git a/frontend/src/locale/en.json b/frontend/src/locale/en.json
index 7d8c9f2465..804134f3d1 100644
--- a/frontend/src/locale/en.json
+++ b/frontend/src/locale/en.json
@@ -144,6 +144,7 @@
"region_description": "Select a region",
"default": "Default",
"default_checkbox": "Turn on default",
+ "hostname": "Hostname",
"external_ip": "External IP",
"wildcard_domain": "Wildcard domain",
"wildcard_domain_description": "Specify the wildcard domain mapped to the external IP.",
diff --git a/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx b/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx
index 79b4aac22c..f63a77fa14 100644
--- a/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx
+++ b/frontend/src/pages/Project/Gateways/Table/hooks/useColumnsDefinitions.tsx
@@ -30,13 +30,15 @@ export const useColumnsDefinitions = ({ loading, projectName, onDeleteClick, onE
{
id: 'type',
header: t('gateway.edit.backend'),
- cell: (gateway: IGateway) => gateway.backend,
+ cell: (gateway: IGateway) =>
+ gateway.replicas.length > 0 ? gateway.replicas.map((r, i) =>
{r.backend}
) : null,
},
{
id: 'region',
header: t('gateway.edit.region'),
- cell: (gateway: IGateway) => gateway.region,
+ cell: (gateway: IGateway) =>
+ gateway.replicas.length > 0 ? gateway.replicas.map((r, i) => {r.region}
) : null,
},
{
@@ -46,9 +48,13 @@ export const useColumnsDefinitions = ({ loading, projectName, onDeleteClick, onE
},
{
- id: 'external_ip',
- header: t('gateway.edit.external_ip'),
- cell: (gateway: IGateway) => gateway.ip_address,
+ id: 'hostname',
+ header: t('gateway.edit.hostname'),
+ cell: (gateway: IGateway) => {
+ if (gateway.hostname) return gateway.hostname;
+ if (gateway.replicas.length > 0) return gateway.replicas.map((r, i) => {r.hostname}
);
+ return null;
+ },
},
{
diff --git a/frontend/src/types/gateway.d.ts b/frontend/src/types/gateway.d.ts
index 4ef2eeeb54..1442cf4d62 100644
--- a/frontend/src/types/gateway.d.ts
+++ b/frontend/src/types/gateway.d.ts
@@ -1,3 +1,9 @@
+declare interface IGatewayReplica {
+ hostname: string,
+ backend: string,
+ region: string,
+}
+
declare interface IGateway {
backend: string,
name: string,
@@ -5,8 +11,10 @@ declare interface IGateway {
ip_address: string,
instance_id: string,
region:string
+ hostname?: string,
wildcard_domain?: string
default: boolean
+ replicas: IGatewayReplica[],
created_at?: number,
}
diff --git a/mkdocs/docs/concepts/gateways.md b/mkdocs/docs/concepts/gateways.md
index bd71187964..b71a23d7b6 100644
--- a/mkdocs/docs/concepts/gateways.md
+++ b/mkdocs/docs/concepts/gateways.md
@@ -182,6 +182,50 @@ domain: example.com
+### Replicas
+
+A gateway can have multiple replicas for improved availability.
+
+
+
+```yaml
+type: gateway
+name: example-gateway
+
+backend: aws
+region: eu-west-1
+
+domain: example.com
+
+certificate: null
+replicas: 2
+```
+
+
+
+To balance requests between gateway replicas, add DNS records for each replica or set up a load balancer outside of `dstack`. Replica hostnames are displayed in `dstack` CLI and UI.
+
+
+
+```shell
+$ dstack gateway list
+ NAME BACKEND HOSTNAME DOMAIN DEFAULT STATUS
+ example-gateway example.com ✓ running
+ replica=0 aws (eu-west-1) 34.244.128.46
+ replica=1 aws (eu-west-1) 18.201.201.174
+```
+
+
+
+!!! warning "Experimental"
+ Replicated gateways are an experimental feature and currently have limitations:
+
+ - Changing the number of replicas or redeploying replicas is not supported.
+ - HTTPS is not supported. Use an external load balancer for TLS termination.
+ - An unavailable gateway replica prevents any new services or service replicas from being added.
+ - All replicas are bound to the same backend and region.
+ - At most 3 replicas are allowed per gateway.
+
!!! info "Reference"
For all gateway configuration options, refer to the [reference](../reference/dstack.yml/gateway.md).
diff --git a/src/dstack/_internal/cli/services/configurators/gateway.py b/src/dstack/_internal/cli/services/configurators/gateway.py
index 0b6993e18b..9f8e6cd0d1 100644
--- a/src/dstack/_internal/cli/services/configurators/gateway.py
+++ b/src/dstack/_internal/cli/services/configurators/gateway.py
@@ -236,6 +236,10 @@ def th(s: str) -> str:
configuration_table.add_row(th("Region"), plan.spec.configuration.region)
configuration_table.add_row(th("Domain"), domain)
+ if plan.spec.configuration.replicas is not None:
+ assert isinstance(plan.spec.configuration.replicas, int)
+ configuration_table.add_row(th("Replicas"), str(plan.spec.configuration.replicas))
+
console.print(configuration_table)
console.print()
diff --git a/src/dstack/_internal/cli/utils/gateway.py b/src/dstack/_internal/cli/utils/gateway.py
index 5605326211..0d873a9a5d 100644
--- a/src/dstack/_internal/cli/utils/gateway.py
+++ b/src/dstack/_internal/cli/utils/gateway.py
@@ -95,15 +95,39 @@ def get_gateways_table(
# Ignore errors in case future server versions introduce more interpolation variables
exception_type=None,
)
- row = {
+
+ gateway_row = {
"NAME": name,
- "BACKEND": format_backend(gateway.configuration.backend, gateway.configuration.region),
- "HOSTNAME": gateway.hostname,
"DOMAIN": domain,
"DEFAULT": "✓" if gateway.default else "",
"STATUS": gateway.status,
"CREATED": format_date(gateway.created_at),
"ERROR": gateway.status_message,
}
- add_row_from_dict(table, row)
+ if gateway.hostname is not None:
+ gateway_row["HOSTNAME"] = gateway.hostname
+ if len(gateway.replicas) == 0:
+ # replicas not yet created, or it's a pre-0.20.25 server without replica support
+ gateway_row["BACKEND"] = format_backend(
+ gateway.configuration.backend, gateway.configuration.region
+ )
+ gateway_row["HOSTNAME"] = gateway_row.get("HOSTNAME", gateway.ip_address)
+ if len(gateway.replicas) == 1:
+ # compact display for single-replica gateway
+ gateway_row["BACKEND"] = format_backend(
+ gateway.replicas[0].backend, gateway.replicas[0].region
+ )
+ gateway_row["HOSTNAME"] = gateway_row.get("HOSTNAME", gateway.replicas[0].hostname)
+ add_row_from_dict(table, gateway_row)
+
+ if len(gateway.replicas) > 1:
+ for replica in gateway.replicas:
+ replica_row = {
+ "NAME": f" replica={replica.replica_num}",
+ "BACKEND": format_backend(replica.backend, replica.region),
+ "HOSTNAME": replica.hostname,
+ "CREATED": format_date(replica.created_at),
+ }
+ add_row_from_dict(table, replica_row, style="secondary")
+
return table
diff --git a/src/dstack/_internal/core/compatibility/gateways.py b/src/dstack/_internal/core/compatibility/gateways.py
index a2fc6101e6..0a89e86113 100644
--- a/src/dstack/_internal/core/compatibility/gateways.py
+++ b/src/dstack/_internal/core/compatibility/gateways.py
@@ -41,5 +41,7 @@ def _get_gateway_configuration_excludes(
if configuration.router is None:
configuration_excludes["router"] = True
+ if configuration.replicas is None:
+ configuration_excludes["replicas"] = True
return configuration_excludes
diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py
index 6f92b449b7..74b3f4e835 100644
--- a/src/dstack/_internal/core/models/gateways.py
+++ b/src/dstack/_internal/core/models/gateways.py
@@ -11,6 +11,8 @@
from dstack._internal.core.models.routers import AnyGatewayRouterConfig
from dstack._internal.utils.tags import tags_validator
+GATEWAY_REPLICAS_DEFAULT = 1
+
class GatewayStatus(str, Enum):
SUBMITTED = "submitted"
@@ -90,6 +92,13 @@ class GatewayConfiguration(CoreModel):
" Set to `null` to disable. Defaults to `type: lets-encrypt`"
),
] = LetsEncryptGatewayCertificate()
+ replicas: Annotated[
+ Optional[int],
+ Field(
+ description=f"The number of gateway replicas. Defaults to `{GATEWAY_REPLICAS_DEFAULT}`",
+ ge=1,
+ ),
+ ] = None
tags: Annotated[
Optional[Dict[str, str]],
Field(
@@ -109,6 +118,14 @@ class GatewaySpec(CoreModel):
configuration_path: Optional[str] = None
+class GatewayReplica(CoreModel):
+ hostname: str
+ replica_num: int
+ backend: BackendType
+ region: str
+ created_at: datetime.datetime
+
+
class Gateway(CoreModel):
# TODO(0.21): Make `id` required.
id: Optional[uuid.UUID] = None
@@ -121,14 +138,13 @@ class Gateway(CoreModel):
status: GatewayStatus
status_message: Optional[str]
hostname: Optional[str]
- """`hostname` is the IP address or hostname the user should set up the domain for.
- Could be the same as `ip_address` but also different, for example a gateway behind ALB.
+ """Hostname of the load balancer.
+ Unset if there is no load balancer, in which case users are expected to point the gateway's
+ wildcard domain name to `replicas[i].hostname`.
"""
- ip_address: Optional[str]
- """`ip_address` is the IP address of the gateway instance."""
- instance_id: Optional[str]
wildcard_domain: Optional[str]
default: bool
+ replicas: list[GatewayReplica] = []
backend: Optional[BackendType] = None
"""`backend` duplicates a configuration field on the top level for backward compatibility
with 0.19.x clients that expect it to be required.
@@ -139,6 +155,10 @@ class Gateway(CoreModel):
with 0.19.x clients that expect it to be required.
Remove after 0.21.
"""
+ ip_address: Optional[str] = None
+ """Deprecated in favor of `replicas[i].hostname`, only set for pre-0.20.25 clients."""
+ instance_id: Optional[str] = None
+ """Deprecated, unused, kept for pre-0.20.25 clients."""
class GatewayPlan(CoreModel):
diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py
index 1f8f0c64f5..05393eb1ae 100644
--- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py
+++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py
@@ -5,12 +5,12 @@
from typing import Optional, Sequence, TypedDict
from sqlalchemy import delete, or_, select, update
-from sqlalchemy.orm import joinedload, load_only
+from sqlalchemy.orm import joinedload, load_only, selectinload
from dstack._internal.core.backends.base.compute import ComputeWithGatewaySupport
from dstack._internal.core.errors import BackendError, BackendNotAvailable
from dstack._internal.core.models.backends.base import BackendType
-from dstack._internal.core.models.gateways import GatewayStatus
+from dstack._internal.core.models.gateways import GATEWAY_REPLICAS_DEFAULT, GatewayStatus
from dstack._internal.server.background.pipeline_tasks.base import (
Fetcher,
Heartbeater,
@@ -34,7 +34,10 @@
from dstack._internal.server.services import backends as backends_services
from dstack._internal.server.services import events
from dstack._internal.server.services import gateways as gateways_services
-from dstack._internal.server.services.gateways import emit_gateway_status_change_event
+from dstack._internal.server.services.gateways import (
+ emit_gateway_status_change_event,
+ get_gateway_compute_models,
+)
from dstack._internal.server.services.gateways.pool import gateway_connections_pool
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.logging import fmt
@@ -239,11 +242,8 @@ async def _process_submitted_item(item: GatewayPipelineItem):
set_processed_update_map_fields(update_map)
set_unlock_update_map_fields(update_map)
async with get_session_ctx() as session:
- gateway_compute_model = result.gateway_compute_model
- if gateway_compute_model is not None:
+ for gateway_compute_model in result.gateway_compute_models:
session.add(gateway_compute_model)
- await session.flush()
- update_map["gateway_compute_id"] = gateway_compute_model.id
now = get_current_datetime()
resolve_now_placeholders(update_map, now=now)
res = await session.execute(
@@ -258,7 +258,7 @@ async def _process_submitted_item(item: GatewayPipelineItem):
updated_ids = list(res.scalars().all())
if len(updated_ids) == 0:
log_lock_token_changed_after_processing(logger, item)
- # TODO: Clean up gateway_compute_model.
+ # TODO: Clean up gateway_compute_models.
return
emit_gateway_status_change_event(
session=session,
@@ -272,7 +272,6 @@ async def _process_submitted_item(item: GatewayPipelineItem):
class _GatewayUpdateMap(ItemUpdateMap, total=False):
status: GatewayStatus
status_message: str
- gateway_compute_id: uuid.UUID
class _GatewayComputeUpdateMap(TypedDict, total=False):
@@ -283,7 +282,7 @@ class _GatewayComputeUpdateMap(TypedDict, total=False):
@dataclass
class _SubmittedResult:
update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap)
- gateway_compute_model: Optional[GatewayComputeModel] = None
+ gateway_compute_models: list[GatewayComputeModel] = field(default_factory=list)
async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedResult:
@@ -303,16 +302,28 @@ async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedR
"status_message": "Backend not available",
}
)
+ replicas = (
+ configuration.replicas if configuration.replicas is not None else GATEWAY_REPLICAS_DEFAULT
+ )
+ gateway_compute_models = []
try:
- gateway_compute_model = await gateways_services.create_gateway_compute(
- backend_compute=backend.compute(),
- project_name=gateway_model.project.name,
- configuration=configuration,
- backend_id=backend_model.id,
- )
+ for replica_num in range(replicas):
+ logger.debug(
+ "%s replica %d: creating gateway compute", fmt(gateway_model), replica_num
+ )
+ gateway_compute_model = await gateways_services.create_gateway_compute(
+ backend_compute=backend.compute(),
+ project_name=gateway_model.project.name,
+ configuration=configuration,
+ replica_num=replica_num,
+ gateway_id=gateway_model.id,
+ backend_id=backend_model.id,
+ )
+ logger.info("%s replica %d: gateway compute created", fmt(gateway_model), replica_num)
+ gateway_compute_models.append(gateway_compute_model)
return _SubmittedResult(
update_map={"status": GatewayStatus.PROVISIONING},
- gateway_compute_model=gateway_compute_model,
+ gateway_compute_models=gateway_compute_models,
)
except BackendError as e:
status_message = f"Backend error: {repr(e)}"
@@ -322,7 +333,8 @@ async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedR
update_map={
"status": GatewayStatus.FAILED,
"status_message": status_message,
- }
+ },
+ gateway_compute_models=gateway_compute_models,
)
except Exception as e:
logger.exception("%s: got exception when creating gateway compute", fmt(gateway_model))
@@ -330,7 +342,8 @@ async def _process_submitted_gateway(gateway_model: GatewayModel) -> _SubmittedR
update_map={
"status": GatewayStatus.FAILED,
"status_message": f"Unexpected error: {repr(e)}",
- }
+ },
+ gateway_compute_models=gateway_compute_models,
)
@@ -343,6 +356,7 @@ async def _process_provisioning_item(item: GatewayPipelineItem):
GatewayModel.lock_token == item.lock_token,
)
.options(joinedload(GatewayModel.gateway_compute))
+ .options(selectinload(GatewayModel.gateway_computes))
)
gateway_model = res.unique().scalar_one_or_none()
if gateway_model is None:
@@ -377,34 +391,39 @@ async def _process_provisioning_item(item: GatewayPipelineItem):
new_status=gateway_update_map.get("status", gateway_model.status),
status_message=gateway_update_map.get("status_message", gateway_model.status_message),
)
- if result.gateway_compute_update_map:
+ if result.all_computes_update_map:
res = await session.execute(
update(GatewayComputeModel)
- .where(GatewayComputeModel.id == gateway_model.gateway_compute_id)
- .values(**result.gateway_compute_update_map)
+ .where(
+ or_(
+ GatewayComputeModel.gateway_id == gateway_model.id,
+ GatewayComputeModel.id == gateway_model.gateway_compute_id,
+ )
+ )
+ .values(**result.all_computes_update_map)
.returning(GatewayComputeModel.id)
)
updated_ids = list(res.scalars().all())
- if len(updated_ids) == 0:
+ if len(updated_ids) < len(get_gateway_compute_models(gateway_model)):
logger.error(
- "Failed to update compute model %s for gateway %s."
+ "Failed to update compute models for gateway %s."
" This is unexpected and may happen only if the compute model was manually deleted.",
gateway_model.id,
- item.id,
)
@dataclass
class _ProvisioningResult:
gateway_update_map: _GatewayUpdateMap = field(default_factory=_GatewayUpdateMap)
- gateway_compute_update_map: _GatewayComputeUpdateMap = field(
+ all_computes_update_map: _GatewayComputeUpdateMap = field(
default_factory=_GatewayComputeUpdateMap
)
async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _ProvisioningResult:
+ gateway_computes = get_gateway_compute_models(gateway_model)
# Provisioning gateways must have compute.
- assert gateway_model.gateway_compute is not None
+ assert len(gateway_computes) > 0
# FIXME: problems caused by blocking on connect_to_gateway_with_retry and configure_gateway:
# - cannot delete the gateway before it is provisioned because the DB model is locked
@@ -413,32 +432,58 @@ async def _process_provisioning_gateway(gateway_model: GatewayModel) -> _Provisi
# Easy to fix by doing only one connection/configuration attempt per processing iteration. The
# main challenge is applying the same provisioning model to the dstack Sky gateway to avoid
# maintaining a different model for Sky.
- connection = await gateways_services.connect_to_gateway_with_retry(
- gateway_model.gateway_compute
+
+ errors = await asyncio.gather(
+ *(_connect_and_configure_gateway_replica(gateway_model, gc) for gc in gateway_computes)
)
- if connection is None:
+ if any(errors):
return _ProvisioningResult(
gateway_update_map={
"status": GatewayStatus.FAILED,
- "status_message": "Failed to connect to gateway",
+ "status_message": next(e for e in errors if e),
},
- gateway_compute_update_map={"active": False},
+ all_computes_update_map={"active": False},
+ )
+
+ return _ProvisioningResult(
+ gateway_update_map={"status": GatewayStatus.RUNNING},
+ )
+
+
+async def _connect_and_configure_gateway_replica(
+ gateway_model: GatewayModel,
+ gateway_compute: GatewayComputeModel,
+) -> Optional[str]:
+ """Returns an error message on failure, None on success."""
+ logger.debug(
+ "%s replica %d: connecting to gateway compute",
+ fmt(gateway_model),
+ gateway_compute.replica_num,
+ )
+ connection = await gateways_services.connect_to_gateway_with_retry(gateway_compute)
+ if connection is None:
+ logger.warning(
+ "%s replica %d: failed to connect to gateway compute",
+ fmt(gateway_model),
+ gateway_compute.replica_num,
)
+ return "Failed to connect to gateway"
try:
await gateways_services.configure_gateway(connection)
except Exception:
- logger.exception("%s: failed to configure gateway", fmt(gateway_model))
- await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
- return _ProvisioningResult(
- gateway_update_map={
- "status": GatewayStatus.FAILED,
- "status_message": "Failed to configure gateway",
- },
- gateway_compute_update_map={"active": False},
+ logger.exception(
+ "%s replica %d: failed to configure gateway",
+ fmt(gateway_model),
+ gateway_compute.replica_num,
)
- return _ProvisioningResult(
- gateway_update_map={"status": GatewayStatus.RUNNING},
+ await gateway_connections_pool.remove(gateway_compute.ip_address)
+ return "Failed to configure gateway"
+ logger.info(
+ "%s replica %d: gateway compute connected and configured",
+ fmt(gateway_model),
+ gateway_compute.replica_num,
)
+ return None
async def _process_to_be_deleted_item(item: GatewayPipelineItem):
@@ -451,6 +496,7 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem):
)
.options(joinedload(GatewayModel.project).joinedload(ProjectModel.backends))
.options(joinedload(GatewayModel.gateway_compute))
+ .options(selectinload(GatewayModel.gateway_computes))
.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
)
gateway_model = res.unique().scalar_one_or_none()
@@ -460,6 +506,27 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem):
result = await _process_to_be_deleted_gateway(gateway_model)
async with get_session_ctx() as session:
+ if result.all_computes_update_map:
+ res = await session.execute(
+ update(GatewayComputeModel)
+ .where(
+ or_(
+ GatewayComputeModel.gateway_id == gateway_model.id,
+ GatewayComputeModel.id == gateway_model.gateway_compute_id,
+ )
+ )
+ .values(**result.all_computes_update_map)
+ .returning(GatewayComputeModel.id)
+ )
+ updated_ids = list(res.scalars().all())
+ if len(updated_ids) < len(get_gateway_compute_models(gateway_model)):
+ logger.error(
+ "Failed to update compute models for gateway %s."
+ " This is unexpected and may happen only if the compute model was manually deleted.",
+ gateway_model.id,
+ )
+ return
+
if result.delete_gateway:
res = await session.execute(
delete(GatewayModel)
@@ -503,28 +570,11 @@ async def _process_to_be_deleted_item(item: GatewayPipelineItem):
log_lock_token_changed_after_processing(logger, item)
return
- if result.gateway_compute_update_map:
- res = await session.execute(
- update(GatewayComputeModel)
- .where(GatewayComputeModel.id == gateway_model.gateway_compute_id)
- .values(**result.gateway_compute_update_map)
- .returning(GatewayComputeModel.id)
- )
- updated_ids = list(res.scalars().all())
- if len(updated_ids) == 0:
- logger.error(
- "Failed to update compute model %s for gateway %s."
- " This is unexpected and may happen only if the compute model was manually deleted.",
- gateway_model.id,
- item.id,
- )
- return
-
@dataclass
class _ProcessToBeDeletedResult:
delete_gateway: bool
- gateway_compute_update_map: _GatewayComputeUpdateMap = field(
+ all_computes_update_map: _GatewayComputeUpdateMap = field(
default_factory=_GatewayComputeUpdateMap
)
@@ -536,27 +586,39 @@ async def _process_to_be_deleted_gateway(gateway_model: GatewayModel) -> _Proces
)
compute = backend.compute()
assert isinstance(compute, ComputeWithGatewaySupport)
- gateway_compute_configuration = gateways_services.get_gateway_compute_configuration(
- gateway_model
- )
- if gateway_model.gateway_compute is not None and gateway_compute_configuration is not None:
- logger.info("Deleting gateway compute for %s...", gateway_model.name)
+
+ for gateway_compute in get_gateway_compute_models(gateway_model):
+ gateway_compute_configuration = gateways_services.get_gateway_compute_configuration(
+ gateway_compute=gateway_compute,
+ gateway_model=gateway_model,
+ )
+ logger.debug(
+ "%s replica %d: terminating gateway compute",
+ fmt(gateway_model),
+ gateway_compute.replica_num,
+ )
try:
await run_async(
compute.terminate_gateway,
- gateway_model.gateway_compute.instance_id,
+ gateway_compute.instance_id,
gateway_compute_configuration,
- gateway_model.gateway_compute.backend_data,
+ gateway_compute.backend_data,
)
except Exception:
logger.exception(
- "Error when deleting gateway compute for %s",
- gateway_model.name,
+ "%s replica %d: error when terminating gateway compute",
+ fmt(gateway_model),
+ gateway_compute.replica_num,
)
return _ProcessToBeDeletedResult(delete_gateway=False)
- logger.info("Deleted gateway compute for %s", gateway_model.name)
- result = _ProcessToBeDeletedResult(delete_gateway=True)
- if gateway_model.gateway_compute is not None:
- await gateway_connections_pool.remove(gateway_model.gateway_compute.ip_address)
- result.gateway_compute_update_map = {"active": False, "deleted": True}
- return result
+ logger.info(
+ "%s replica %d: gateway compute terminated",
+ fmt(gateway_model),
+ gateway_compute.replica_num,
+ )
+ await gateway_connections_pool.remove(gateway_compute.ip_address)
+
+ return _ProcessToBeDeletedResult(
+ delete_gateway=True,
+ all_computes_update_map={"active": False, "deleted": True},
+ )
diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py
index 014f84c604..61599172b5 100644
--- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py
+++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py
@@ -76,7 +76,7 @@
get_instance_specific_mounts,
resolve_provisioning_image,
)
-from dstack._internal.server.services.gateways import get_or_add_gateway_connection
+from dstack._internal.server.services.gateways import get_or_add_gateway_connections
from dstack._internal.server.services.instances import (
get_instance_remote_connection_info,
get_instance_ssh_private_keys,
@@ -1185,7 +1185,7 @@ async def _register_service_replica(
return None
async with get_session_ctx() as session:
- gateway_model, conn = await get_or_add_gateway_connection(
+ gateway_model, connections = await get_or_add_gateway_connections(
session, context.run_model.gateway_id
)
gateway_target = events.Target.from_model(gateway_model)
@@ -1197,35 +1197,40 @@ async def _register_service_replica(
# so we must update job_submission with the result value.
job_submission = context.job_submission.copy(deep=True)
job_submission.job_runtime_data = _get_result_job_runtime_data(context.job_model, result)
- try:
- logger.debug(
- "%s: registering replica for service %s", fmt(context.job_model), context.run.id.hex
- )
- async with conn.client() as gateway_client:
- await gateway_client.register_replica(
- run=context.run,
- job_spec=job_spec,
- job_submission=job_submission,
- instance_project_ssh_private_key=instance_project_ssh_private_key,
- ssh_head_proxy=ssh_head_proxy,
- ssh_head_proxy_private_key=ssh_head_proxy_private_key,
- )
- except (httpx.RequestError, SSHError) as e:
- logger.debug("Gateway request failed", exc_info=True)
- raise GatewayError(repr(e))
- except GatewayError as e:
- if "already exists in service" in e.msg:
- logger.warning(
- (
- "%s: could not register replica in gateway: %s."
- " NOTE: if you just updated dstack from pre-0.19.25 to 0.19.25+,"
- " expect to see this warning once for every running service replica"
- ),
+ for conn in connections:
+ try:
+ logger.debug(
+ "%s: registering replica for service %s on gateway replica %s",
fmt(context.job_model),
- e.msg,
+ context.run.id.hex,
+ conn.ip_address,
)
- else:
- raise
+ async with conn.client() as gateway_client:
+ await gateway_client.register_replica(
+ run=context.run,
+ job_spec=job_spec,
+ job_submission=job_submission,
+ instance_project_ssh_private_key=instance_project_ssh_private_key,
+ ssh_head_proxy=ssh_head_proxy,
+ ssh_head_proxy_private_key=ssh_head_proxy_private_key,
+ )
+ except (httpx.RequestError, SSHError) as e:
+ logger.debug("Gateway request failed", exc_info=True)
+ raise GatewayError(repr(e))
+ except GatewayError as e:
+ if "already exists in service" in e.msg:
+ logger.warning(
+ (
+ "%s: could not register replica in gateway %s: %s."
+ " NOTE: if you just updated dstack from pre-0.19.25 to 0.19.25+,"
+ " expect to see this warning once for every running service replica"
+ ),
+ fmt(context.job_model),
+ conn.ip_address,
+ e.msg,
+ )
+ else:
+ raise
return gateway_target
diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py
index 3ae30c3ef2..adedf9bb4b 100644
--- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py
+++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py
@@ -51,7 +51,7 @@
)
from dstack._internal.server.services import backends as backends_services
from dstack._internal.server.services import events
-from dstack._internal.server.services.gateways import get_or_add_gateway_connection
+from dstack._internal.server.services.gateways import get_or_add_gateway_connections
from dstack._internal.server.services.instances import (
emit_instance_status_change_event,
get_instance_ssh_private_keys,
@@ -795,25 +795,36 @@ async def _unregister_replica(
run_model = job_model.run
if run_model.gateway_id is not None:
async with get_session_ctx() as session:
- gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
- gateway_target = events.Target.from_model(gateway)
- try:
- logger.debug(
- "%s: unregistering replica from service %s", fmt(job_model), job_model.run_id.hex
+ gateway, connections = await get_or_add_gateway_connections(
+ session, run_model.gateway_id
)
- async with conn.client() as client:
- await client.unregister_replica(
- project=run_model.project.name,
- run_name=run_model.run_name,
- job_id=job_model.id,
+ gateway_target = events.Target.from_model(gateway)
+ for conn in connections:
+ try:
+ logger.debug(
+ "%s: unregistering replica from service %s on gateway replica %s",
+ fmt(job_model),
+ job_model.run_id.hex,
+ conn.ip_address,
+ )
+ async with conn.client() as client:
+ await client.unregister_replica(
+ project=run_model.project.name,
+ run_name=run_model.run_name,
+ job_id=job_model.id,
+ )
+ except GatewayError as e:
+ logger.warning(
+ "%s: unregistering replica from service on gateway replica %s: %s",
+ fmt(job_model),
+ conn.ip_address,
+ e,
)
- except GatewayError as e:
- logger.warning("%s: unregistering replica from service: %s", fmt(job_model), e)
- except (httpx.RequestError, SSHError) as e:
- logger.debug("Gateway request failed", exc_info=True)
- # FIXME: Unhandled exception raised.
- # Handle and retry unregister with timeout.
- raise GatewayError(repr(e))
+ except (httpx.RequestError, SSHError) as e:
+ logger.debug("Gateway request failed", exc_info=True)
+ # FIXME: Unhandled exception raised.
+ # Handle and retry unregister with timeout.
+ raise GatewayError(repr(e))
return gateway_target
diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py
index c26df9d4d3..071af9fbd3 100644
--- a/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py
+++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/__init__.py
@@ -28,7 +28,7 @@
from dstack._internal.server.db import get_db, get_session_ctx
from dstack._internal.server.models import InstanceModel, JobModel, ProjectModel, RunModel
from dstack._internal.server.services import events
-from dstack._internal.server.services.gateways import get_or_add_gateway_connection
+from dstack._internal.server.services.gateways import get_combined_gateway_stats
from dstack._internal.server.services.jobs import emit_job_status_change_event
from dstack._internal.server.services.locking import get_locker
from dstack._internal.server.services.pipelines import PipelineHinterProtocol
@@ -313,8 +313,9 @@ async def _load_pending_context(
gateway_stats = None
if run_spec.configuration.type == "service" and run_model.gateway_id is not None:
- _, conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
- gateway_stats = await conn.get_stats(run_model.project.name, run_model.run_name)
+ gateway_stats = await get_combined_gateway_stats(
+ session, run_model.gateway_id, run_model.project.name, run_model.run_name
+ )
return pending.PendingContext(
run_model=run_model,
@@ -494,8 +495,9 @@ async def _load_active_context(
gateway_stats = None
if run_spec.configuration.type == "service" and run_model.gateway_id is not None:
- _, conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
- gateway_stats = await conn.get_stats(run_model.project.name, run_model.run_name)
+ gateway_stats = await get_combined_gateway_stats(
+ session, run_model.gateway_id, run_model.project.name, run_model.run_name
+ )
return active.ActiveContext(
run_model=run_model,
diff --git a/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py
index eece7dfa7c..c9a75e3c71 100644
--- a/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py
+++ b/src/dstack/_internal/server/background/pipeline_tasks/runs/terminating.py
@@ -16,7 +16,7 @@
from dstack._internal.server.background.pipeline_tasks.base import ItemUpdateMap
from dstack._internal.server.db import get_session_ctx
from dstack._internal.server.services import events
-from dstack._internal.server.services.gateways import get_or_add_gateway_connection
+from dstack._internal.server.services.gateways import get_or_add_gateway_connections
from dstack._internal.server.services.logging import fmt
from dstack._internal.server.services.runs import _get_next_triggered_at, get_run_spec
from dstack._internal.utils.common import get_or_error
@@ -148,24 +148,37 @@ async def _unregister_service(run_model: models.RunModel) -> Optional[ServiceUnr
return None
async with get_session_ctx() as session:
- gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id)
+ gateway, connections = await get_or_add_gateway_connections(session, run_model.gateway_id)
gateway_target = events.Target.from_model(gateway)
- try:
- logger.debug("%s: unregistering service", fmt(run_model))
- async with conn.client() as client:
- await client.unregister_service(
- project=run_model.project.name,
- run_name=run_model.run_name,
+ gateway_errors = []
+ for conn in connections:
+ try:
+ logger.debug(
+ "%s: unregistering service on gateway replica %s", fmt(run_model), conn.ip_address
+ )
+ async with conn.client() as client:
+ await client.unregister_service(
+ project=run_model.project.name,
+ run_name=run_model.run_name,
+ )
+ except GatewayError as e:
+ # Ignore if the service is not registered on this replica.
+ logger.warning(
+ "%s: unregistering service on gateway replica %s: %s",
+ fmt(run_model),
+ conn.ip_address,
+ e,
)
+ gateway_errors.append(str(e))
+ except (httpx.RequestError, SSHError) as e:
+ logger.debug("Gateway request failed", exc_info=True)
+ raise GatewayError(repr(e))
+
+ if gateway_errors:
+ event_message = f"Gateway error when unregistering service: {'; '.join(gateway_errors)}"
+ else:
event_message = "Service unregistered from gateway"
- except GatewayError as e:
- # Ignore if the service is not registered.
- logger.warning("%s: unregistering service: %s", fmt(run_model), e)
- event_message = f"Gateway error when unregistering service: {e}"
- except (httpx.RequestError, SSHError) as e:
- logger.debug("Gateway request failed", exc_info=True)
- raise GatewayError(repr(e))
return ServiceUnregistration(
event_message=event_message,
gateway_target=gateway_target,
diff --git a/src/dstack/_internal/server/compatibility/gateways.py b/src/dstack/_internal/server/compatibility/gateways.py
new file mode 100644
index 0000000000..3e410b5a9c
--- /dev/null
+++ b/src/dstack/_internal/server/compatibility/gateways.py
@@ -0,0 +1,15 @@
+from typing import Optional
+
+from packaging.version import Version
+
+from dstack._internal.core.models.gateways import Gateway
+
+
+def patch_gateway(gateway: Gateway, client_version: Optional[Version]) -> None:
+ if client_version is None:
+ return
+ if client_version < Version("0.20.25") and len(gateway.replicas) < 2:
+ gateway.instance_id = ""
+ gateway.ip_address = gateway.replicas[0].hostname if gateway.replicas else ""
+ if gateway.hostname is None:
+ gateway.hostname = gateway.ip_address
diff --git a/src/dstack/_internal/server/migrations/versions/2026/06_01_1911_b7609b94ea4d_add_gatewaycomputemodel_gateway_id.py b/src/dstack/_internal/server/migrations/versions/2026/06_01_1911_b7609b94ea4d_add_gatewaycomputemodel_gateway_id.py
new file mode 100644
index 0000000000..2729699af1
--- /dev/null
+++ b/src/dstack/_internal/server/migrations/versions/2026/06_01_1911_b7609b94ea4d_add_gatewaycomputemodel_gateway_id.py
@@ -0,0 +1,52 @@
+"""Add GatewayComputeModel.gateway_id
+
+Revision ID: b7609b94ea4d
+Revises: 201cb7ccd0d3
+Create Date: 2026-06-01 19:11:30.641417+00:00
+
+"""
+
+import sqlalchemy as sa
+import sqlalchemy_utils
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "b7609b94ea4d"
+down_revision = "201cb7ccd0d3"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table("gateway_computes", schema=None) as batch_op:
+ batch_op.add_column(
+ sa.Column("replica_num", sa.Integer(), server_default="0", nullable=False)
+ )
+ batch_op.add_column(
+ sa.Column(
+ "gateway_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True
+ )
+ )
+ batch_op.create_foreign_key(
+ batch_op.f("fk_gateway_computes_gateway_id_gateways"),
+ "gateways",
+ ["gateway_id"],
+ ["id"],
+ ondelete="SET NULL",
+ use_alter=True,
+ )
+
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table("gateway_computes", schema=None) as batch_op:
+ batch_op.drop_constraint(
+ batch_op.f("fk_gateway_computes_gateway_id_gateways"), type_="foreignkey"
+ )
+ batch_op.drop_column("gateway_id")
+ batch_op.drop_column("replica_num")
+
+ # ### end Alembic commands ###
diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py
index d433244ea3..8d6f3c512c 100644
--- a/src/dstack/_internal/server/models.py
+++ b/src/dstack/_internal/server/models.py
@@ -629,7 +629,21 @@ class GatewayModel(PipelineModelMixin, BaseModel):
gateway_compute_id: Mapped[Optional[uuid.UUID]] = mapped_column(
ForeignKey("gateway_computes.id", ondelete="CASCADE")
)
- gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship()
+ gateway_compute: Mapped[Optional["GatewayComputeModel"]] = relationship(
+ foreign_keys=[gateway_compute_id]
+ )
+ """
+ Relationship with gateway computes for pre-0.20.25 gateways.
+ Use `get_gateway_compute_models()` for version-agnostic gateway compute retrieval.
+ """
+ gateway_computes: Mapped[List["GatewayComputeModel"]] = relationship(
+ back_populates="gateway",
+ foreign_keys="GatewayComputeModel.gateway_id",
+ )
+ """
+ Relationship with gateway computes for 0.20.25+ gateways.
+ Use `get_gateway_compute_models()` for version-agnostic gateway compute retrieval.
+ """
runs: Mapped[List["RunModel"]] = relationship(back_populates="gateway")
@@ -639,15 +653,26 @@ class GatewayModel(PipelineModelMixin, BaseModel):
class GatewayComputeModel(BaseModel):
+ """A single gateway replica.
+ **TODO**: consider renaming to `GatewayReplicaModel`.
+ """
+
__tablename__ = "gateway_computes"
id: Mapped[uuid.UUID] = mapped_column(
UUIDType(binary=False), primary_key=True, default=uuid.uuid4
)
created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime)
+ replica_num: Mapped[int] = mapped_column(Integer, server_default="0")
instance_id: Mapped[str] = mapped_column(String(100))
ip_address: Mapped[str] = mapped_column(String(100))
+ """Gateway replica IP address or domain name (e.g., k8s can use domain names).
+ **TODO**: rename.
+ """
hostname: Mapped[Optional[str]] = mapped_column(String(100))
+ """Hostname of the gateway's load balancer.
+ **TODO**: move to `GatewayModel`.
+ """
configuration: Mapped[Optional[str]] = mapped_column(Text)
"""`configuration` is optional for compatibility with pre-0.18.2 gateways.
Use `get_gateway_compute_configuration` to construct `configuration` for old gateways.
@@ -655,6 +680,22 @@ class GatewayComputeModel(BaseModel):
backend_data: Mapped[Optional[str]] = mapped_column(Text)
region: Mapped[str] = mapped_column(String(100))
+ gateway_id: Mapped[Optional[uuid.UUID]] = mapped_column(
+ ForeignKey(
+ "gateways.id",
+ ondelete="SET NULL",
+ use_alter=True,
+ )
+ )
+ gateway: Mapped[Optional["GatewayModel"]] = relationship(
+ back_populates="gateway_computes",
+ foreign_keys=[gateway_id],
+ )
+ """
+ Gateway. Can be None for pre-0.20.25 gateways, which use GatewayModel.gateway_compute_id to
+ establish the relationship.
+ """
+
backend_id: Mapped[Optional[uuid.UUID]] = mapped_column(
ForeignKey("backends.id", ondelete="CASCADE")
)
diff --git a/src/dstack/_internal/server/routers/gateways.py b/src/dstack/_internal/server/routers/gateways.py
index eee99077c1..6b9a6718dd 100644
--- a/src/dstack/_internal/server/routers/gateways.py
+++ b/src/dstack/_internal/server/routers/gateways.py
@@ -1,6 +1,7 @@
from typing import List, Optional, Tuple
from fastapi import APIRouter, Depends
+from packaging.version import Version
from sqlalchemy.ext.asyncio import AsyncSession
import dstack._internal.core.models.gateways as models
@@ -8,6 +9,7 @@
import dstack._internal.server.services.gateways as gateways
from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.core.models.common import EntityReference
+from dstack._internal.server.compatibility.gateways import patch_gateway
from dstack._internal.server.db import get_session
from dstack._internal.server.deps import Project
from dstack._internal.server.models import ProjectModel, UserModel
@@ -21,6 +23,7 @@
from dstack._internal.server.utils.routers import (
CustomORJSONResponse,
get_base_api_additional_responses,
+ get_client_version,
)
router = APIRouter(
@@ -35,17 +38,19 @@ async def list_gateways(
body: Optional[schemas.ListGatewaysRequest] = None,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectMemberOrPublicAccess()),
+ client_version: Optional[Version] = Depends(get_client_version),
):
_, project = user_project
if body is None:
body = schemas.ListGatewaysRequest()
- return CustomORJSONResponse(
- await gateways.list_project_gateways(
- session=session,
- project=project,
- include_imported=body.include_imported,
- )
+ gateway_list = await gateways.list_project_gateways(
+ session=session,
+ project=project,
+ include_imported=body.include_imported,
)
+ for gateway in gateway_list:
+ patch_gateway(gateway, client_version)
+ return CustomORJSONResponse(gateway_list)
@router.post("/get", summary="Get gateway", response_model=models.Gateway)
@@ -54,6 +59,7 @@ async def get_gateway(
session: AsyncSession = Depends(get_session),
user: UserModel = Depends(Authenticated()),
project: ProjectModel = Depends(Project()),
+ client_version: Optional[Version] = Depends(get_client_version),
):
await check_can_access_gateway(
session=session, user=user, gateway_project=project, gateway_name=body.name
@@ -61,6 +67,7 @@ async def get_gateway(
gateway = await gateways.get_gateway_by_name(session=session, project=project, name=body.name)
if gateway is None:
raise ResourceNotExistsError()
+ patch_gateway(gateway, client_version)
return CustomORJSONResponse(gateway)
@@ -70,17 +77,18 @@ async def create_gateway(
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
pipeline_hinter: PipelineHinterProtocol = Depends(get_pipeline_hinter),
+ client_version: Optional[Version] = Depends(get_client_version),
):
user, project = user_project
- return CustomORJSONResponse(
- await gateways.create_gateway(
- session=session,
- user=user,
- project=project,
- configuration=body.configuration,
- pipeline_hinter=pipeline_hinter,
- )
+ gateway = await gateways.create_gateway(
+ session=session,
+ user=user,
+ project=project,
+ configuration=body.configuration,
+ pipeline_hinter=pipeline_hinter,
)
+ patch_gateway(gateway, client_version)
+ return CustomORJSONResponse(gateway)
@router.post("/delete", summary="Delete gateways")
@@ -118,14 +126,15 @@ async def set_gateway_wildcard_domain(
body: schemas.SetWildcardDomainRequest,
session: AsyncSession = Depends(get_session),
user_project: Tuple[UserModel, ProjectModel] = Depends(ProjectAdmin()),
+ client_version: Optional[Version] = Depends(get_client_version),
):
user, project = user_project
- return CustomORJSONResponse(
- await gateways.set_gateway_wildcard_domain(
- session=session,
- project=project,
- name=body.name,
- wildcard_domain=body.wildcard_domain,
- user=user,
- )
+ gateway = await gateways.set_gateway_wildcard_domain(
+ session=session,
+ project=project,
+ name=body.name,
+ wildcard_domain=body.wildcard_domain,
+ user=user,
)
+ patch_gateway(gateway, client_version)
+ return CustomORJSONResponse(gateway)
diff --git a/src/dstack/_internal/server/services/gateways/__init__.py b/src/dstack/_internal/server/services/gateways/__init__.py
index 287117e2ed..bfd05cecf9 100644
--- a/src/dstack/_internal/server/services/gateways/__init__.py
+++ b/src/dstack/_internal/server/services/gateways/__init__.py
@@ -10,7 +10,7 @@
import httpx
from sqlalchemy import exists, func, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
-from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import joinedload, selectinload
import dstack._internal.utils.random_names as random_names
from dstack._internal.core.backends.base.compute import (
@@ -32,15 +32,19 @@
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import EntityReference
from dstack._internal.core.models.gateways import (
+ GATEWAY_REPLICAS_DEFAULT,
AnyGatewayRouterConfig,
Gateway,
GatewayComputeConfiguration,
GatewayConfiguration,
+ GatewayReplica,
GatewaySpec,
GatewayStatus,
LetsEncryptGatewayCertificate,
)
from dstack._internal.core.services import validate_dstack_resource_name
+from dstack._internal.proxy.gateway.const import SERVICE_SCALING_WINDOWS
+from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats, Stat
from dstack._internal.server import settings
from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite
from dstack._internal.server.models import (
@@ -130,6 +134,10 @@ def get_gateway_status_change_message(
GATEWAY_CONNECT_DELAY = 10
GATEWAY_CONFIGURE_ATTEMPTS = 50
GATEWAY_CONFIGURE_DELAY = 3
+# Artificial limit to avoid doing too many per-replica operations (gateway replica provisioning,
+# service registration, etc) in a single pipeline tick. Can be lifted once the implementation is
+# more mature.
+GATEWAY_MAX_REPLICAS = 3 # documented in gateways.md, keep in sync
async def list_project_gateways(
@@ -169,6 +177,8 @@ async def create_gateway_compute(
project_name: str,
backend_compute: Compute,
configuration: GatewayConfiguration,
+ replica_num: int,
+ gateway_id: Optional[uuid.UUID] = None,
backend_id: Optional[uuid.UUID] = None,
) -> GatewayComputeModel:
assert isinstance(backend_compute, ComputeWithGatewaySupport)
@@ -180,7 +190,7 @@ async def create_gateway_compute(
compute_configuration = GatewayComputeConfiguration(
project_name=project_name,
- instance_name=configuration.name,
+ instance_name=f"{configuration.name}-{replica_num}",
backend=configuration.backend,
region=configuration.region,
instance_type=configuration.instance_type,
@@ -197,7 +207,9 @@ async def create_gateway_compute(
)
return GatewayComputeModel(
+ gateway_id=gateway_id,
backend_id=backend_id,
+ replica_num=replica_num,
region=gpd.region,
ip_address=gpd.ip_address,
instance_id=gpd.instance_id,
@@ -467,6 +479,7 @@ async def list_project_gateway_models(
stmt = stmt.where(GatewayModel.project_id == project.id)
if load_gateway_compute:
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))
+ stmt = stmt.options(selectinload(GatewayModel.gateway_computes))
if load_backend_type:
stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
res = await session.execute(stmt)
@@ -495,6 +508,7 @@ async def get_project_gateway_model_by_reference(
)
if load_gateway_compute:
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))
+ stmt = stmt.options(selectinload(GatewayModel.gateway_computes))
if load_backend_type:
stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
res = await session.execute(stmt)
@@ -529,6 +543,7 @@ async def get_project_gateway_model_by_name_for_update(
select(GatewayModel)
.where(GatewayModel.id.in_([gateway_id]), *filters)
.options(joinedload(GatewayModel.gateway_compute))
+ .options(selectinload(GatewayModel.gateway_computes))
.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
.with_for_update(key_share=True, of=GatewayModel)
)
@@ -555,6 +570,7 @@ async def get_project_default_gateway_model(
)
if load_gateway_compute:
stmt = stmt.options(joinedload(GatewayModel.gateway_compute))
+ stmt = stmt.options(selectinload(GatewayModel.gateway_computes))
if load_backend_type:
stmt = stmt.options(joinedload(GatewayModel.backend).load_only(BackendModel.type))
res = await session.execute(stmt)
@@ -571,30 +587,73 @@ async def generate_gateway_name(session: AsyncSession, project: ProjectModel) ->
# TODO: Connect to gateway outside session
-async def get_or_add_gateway_connection(
+async def get_or_add_gateway_connections(
session: AsyncSession, gateway_id: uuid.UUID
-) -> tuple[GatewayModel, GatewayConnection]:
- gateway = await session.get(
- GatewayModel,
- gateway_id,
- options=[joinedload(GatewayModel.gateway_compute)],
- populate_existing=True,
+) -> tuple[GatewayModel, List[GatewayConnection]]:
+ res = await session.execute(
+ select(GatewayModel)
+ .where(GatewayModel.id == gateway_id)
+ .options(joinedload(GatewayModel.gateway_compute))
+ .options(selectinload(GatewayModel.gateway_computes))
)
+ gateway = res.scalar_one_or_none()
if gateway is None:
raise GatewayError("Gateway not found")
- if gateway.gateway_compute is None:
+ computes = get_gateway_compute_models(gateway)
+ if not computes:
raise GatewayError("Gateway compute not found")
+ connections: List[GatewayConnection] = []
+ for compute in computes:
+ try:
+ conn = await gateway_connections_pool.get_or_add(
+ hostname=compute.ip_address,
+ id_rsa=compute.ssh_private_key,
+ )
+ connections.append(conn)
+ except Exception as e:
+ logger.warning("Failed to connect to gateway %s: %s", compute.ip_address, e)
+ raise GatewayError("Failed to connect to gateway")
+ return gateway, connections
+
+
+async def get_combined_gateway_stats(
+ session: AsyncSession,
+ gateway_id: uuid.UUID,
+ project_name: str,
+ run_name: str,
+) -> Optional[PerWindowStats]:
+ """
+ Return stats for *run_name* aggregated across all replicas of *gateway_id*.
+ """
try:
- conn = await gateway_connections_pool.get_or_add(
- hostname=gateway.gateway_compute.ip_address,
- id_rsa=gateway.gateway_compute.ssh_private_key,
- )
- except Exception as e:
- logger.warning(
- "Failed to connect to gateway %s: %s", gateway.gateway_compute.ip_address, e
+ _, connections = await get_or_add_gateway_connections(session, gateway_id)
+ except GatewayError:
+ return None
+ per_replica: list[PerWindowStats] = []
+ for conn in connections:
+ stats = await conn.get_stats(project_name, run_name)
+ if stats is None: # Stats not fetched yet
+ # TODO: find a way to make service scaling decisions even if some gateway replicas are
+ # unavailable for fetching stats.
+ return None
+ per_replica.append(stats)
+ return _merge_per_window_stats(per_replica) if per_replica else None
+
+
+def _merge_per_window_stats(stats_per_gateway_replica: list[PerWindowStats]) -> PerWindowStats:
+ merged: PerWindowStats = {}
+ for window in SERVICE_SCALING_WINDOWS:
+ total_requests = 0
+ total_time_of_all_requests = 0.0
+ for gateway_replica_stats in stats_per_gateway_replica:
+ stat = gateway_replica_stats[window]
+ total_requests += stat.requests
+ total_time_of_all_requests += stat.requests * stat.request_time
+ merged[window] = Stat(
+ requests=total_requests,
+ request_time=(total_time_of_all_requests / total_requests if total_requests else 0.0),
)
- raise GatewayError("Failed to connect to gateway")
- return gateway, conn
+ return merged
async def init_gateways(session: AsyncSession):
@@ -732,6 +791,14 @@ async def configure_gateway(
logger.info("Gateway %s configured", connection.ip_address)
+def get_gateway_compute_models(gateway_model: GatewayModel) -> List[GatewayComputeModel]:
+ if gateway_model.gateway_computes: # 0.20.25+ gateway
+ return list(gateway_model.gateway_computes)
+ if gateway_model.gateway_compute is not None: # pre-0.20.25 gateway
+ return [gateway_model.gateway_compute]
+ return []
+
+
def get_gateway_configuration(gateway_model: GatewayModel) -> GatewayConfiguration:
if gateway_model.configuration is not None:
return GatewayConfiguration.__response__.parse_raw(gateway_model.configuration)
@@ -746,22 +813,19 @@ def get_gateway_configuration(gateway_model: GatewayModel) -> GatewayConfigurati
def get_gateway_compute_configuration(
+ gateway_compute: GatewayComputeModel,
gateway_model: GatewayModel,
-) -> Optional[GatewayComputeConfiguration]:
- if gateway_model.gateway_compute is None:
- return None
- if gateway_model.gateway_compute.configuration is not None:
- return GatewayComputeConfiguration.__response__.parse_raw(
- gateway_model.gateway_compute.configuration
- )
+) -> GatewayComputeConfiguration:
+ if gateway_compute.configuration is not None:
+ return GatewayComputeConfiguration.__response__.parse_raw(gateway_compute.configuration)
# Handle gateways created before GatewayComputeConfiguration was introduced
return GatewayComputeConfiguration(
project_name=gateway_model.project.name,
- instance_name=gateway_model.gateway_compute.instance_id,
+ instance_name=gateway_compute.instance_id,
backend=gateway_model.backend.type,
- region=gateway_model.gateway_compute.region,
+ region=gateway_compute.region,
public_ip=True,
- ssh_key_pub=gateway_model.gateway_compute.ssh_public_key,
+ ssh_key_pub=gateway_compute.ssh_public_key,
certificate=LetsEncryptGatewayCertificate(),
)
@@ -775,28 +839,34 @@ def gateway_model_to_gateway(
default_gateway_id: ID of the default gateway in the project where `gateway_model` is being
viewed. Can be different from `gateway_model.project` if the gateway is imported.
"""
- ip_address = ""
- instance_id = ""
- hostname = ""
- if gateway_model.gateway_compute is not None:
- ip_address = gateway_model.gateway_compute.ip_address
- instance_id = gateway_model.gateway_compute.instance_id
- hostname = gateway_model.gateway_compute.hostname
- if hostname is None:
- hostname = ip_address
backend_type = gateway_model.backend.type
if gateway_model.backend.type == BackendType.DSTACK:
backend_type = BackendType.AWS
is_default = default_gateway_id == gateway_model.id
configuration = get_gateway_configuration(gateway_model)
configuration.default = is_default
+
+ compute_models = sorted(get_gateway_compute_models(gateway_model), key=lambda c: c.replica_num)
+ gateway_hostname = None
+ replicas = []
+ for compute in compute_models:
+ compute_configuration = get_gateway_compute_configuration(compute, gateway_model)
+ replicas.append(
+ GatewayReplica(
+ hostname=compute.ip_address,
+ replica_num=compute.replica_num,
+ backend=compute_configuration.backend,
+ region=compute_configuration.region,
+ created_at=compute.created_at,
+ )
+ )
+ gateway_hostname = compute.hostname
+
return Gateway(
id=gateway_model.id,
name=gateway_model.name,
project_name=gateway_model.project.name,
- ip_address=ip_address,
- instance_id=instance_id,
- hostname=hostname,
+ hostname=gateway_hostname,
backend=backend_type,
region=gateway_model.region,
wildcard_domain=gateway_model.wildcard_domain,
@@ -805,6 +875,7 @@ def gateway_model_to_gateway(
status=gateway_model.status,
status_message=gateway_model.status_message,
configuration=configuration,
+ replicas=replicas,
)
@@ -838,6 +909,15 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration):
f" {[b.value for b in BACKENDS_WITH_PRIVATE_GATEWAY_SUPPORT]}."
)
+ replicas = (
+ configuration.replicas if configuration.replicas is not None else GATEWAY_REPLICAS_DEFAULT
+ )
+
+ if replicas > GATEWAY_MAX_REPLICAS:
+ raise ServerClientError(
+ f"Cannot provision {replicas} gateway replicas. This server allows at most {GATEWAY_MAX_REPLICAS}"
+ )
+
if configuration.certificate is not None:
if configuration.certificate.type == "lets-encrypt" and not configuration.public_ip:
raise ServerClientError(
@@ -845,3 +925,13 @@ def _validate_gateway_configuration(configuration: GatewayConfiguration):
)
if configuration.certificate.type == "acm" and configuration.backend != BackendType.AWS:
raise ServerClientError("acm certificate type is supported for aws backend only")
+ if replicas > 1:
+ raise ServerClientError(
+ "Replicated gateways do not support certificates."
+ " Set either `certificate: null` or `replicas: 1` in the gateway configuration"
+ )
+
+ if configuration.router is not None and replicas > 1:
+ raise ServerClientError(
+ "The deprecated `router` property is not supported for multi-replica gateways"
+ )
diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py
index 273054e74f..b637683af7 100644
--- a/src/dstack/_internal/server/services/services/__init__.py
+++ b/src/dstack/_internal/server/services/services/__init__.py
@@ -32,8 +32,9 @@
from dstack._internal.server.models import GatewayModel, RunModel
from dstack._internal.server.services import events
from dstack._internal.server.services.gateways import (
+ get_gateway_compute_models,
get_gateway_configuration,
- get_or_add_gateway_connection,
+ get_or_add_gateway_connections,
get_project_default_gateway_model,
get_project_gateway_model_by_reference,
)
@@ -100,7 +101,7 @@ async def _register_service_in_gateway(
) -> ServiceSpec:
assert run_spec.configuration.type == "service"
- if gateway.gateway_compute is None:
+ if not get_gateway_compute_models(gateway):
raise ServerClientError("Gateway has no instance associated with it")
if gateway.status != GatewayStatus.RUNNING:
@@ -178,50 +179,51 @@ async def _register_service_in_gateway(
domain = service_spec.get_domain()
assert domain is not None
- _, conn = await get_or_add_gateway_connection(session, gateway.id)
- try:
- logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url)
- async with conn.client() as client:
- do_register = partial(
- client.register_service,
- project=run_model.project.name,
- run_name=run_model.run_name,
- domain=domain,
- service_https=configure_service_https,
- gateway_https=gateway_https,
- auth=run_spec.configuration.auth,
- client_max_body_size=settings.DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE,
- options=service_spec.options,
- rate_limits=run_spec.configuration.rate_limits,
- ssh_private_key=run_model.project.ssh_private_key,
- has_router_replica=has_replica_group_router,
- router=router,
- )
- try:
- await do_register()
- except GatewayError as e:
- if e.msg == SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format(
- ref=f"{run_model.project.name}/{run_model.run_name}"
- ):
- # Happens if there was a communication issue with the gateway when last unregistering
- logger.warning(
- "Service %s/%s is dangling on gateway %s, unregistering and re-registering",
- run_model.project.name,
- run_model.run_name,
- gateway.name,
- )
- await client.unregister_service(
- project=run_model.project.name,
- run_name=run_model.run_name,
- )
+ _, connections = await get_or_add_gateway_connections(session, gateway.id)
+ for conn in connections:
+ try:
+ logger.debug("%s: registering service as %s", fmt(run_model), service_spec.url)
+ async with conn.client() as client:
+ do_register = partial(
+ client.register_service,
+ project=run_model.project.name,
+ run_name=run_model.run_name,
+ domain=domain,
+ service_https=configure_service_https,
+ gateway_https=gateway_https,
+ auth=run_spec.configuration.auth,
+ client_max_body_size=settings.DEFAULT_SERVICE_CLIENT_MAX_BODY_SIZE,
+ options=service_spec.options,
+ rate_limits=run_spec.configuration.rate_limits,
+ ssh_private_key=run_model.project.ssh_private_key,
+ has_router_replica=has_replica_group_router,
+ router=router,
+ )
+ try:
await do_register()
- else:
- raise
- except SSHError:
- raise ServerClientError("Gateway tunnel is not working")
- except httpx.RequestError as e:
- logger.debug("Gateway request failed", exc_info=True)
- raise GatewayError(f"Gateway is not working: {e!r}")
+ except GatewayError as e:
+ if e.msg == SERVICE_ALREADY_REGISTERED_ERROR_TEMPLATE.format(
+ ref=f"{run_model.project.name}/{run_model.run_name}"
+ ):
+ # Happens if there was a communication issue with the gateway when last (un)registering
+ logger.warning(
+ "Service %s/%s is dangling on gateway replica %s, unregistering and re-registering",
+ run_model.project.name,
+ run_model.run_name,
+ conn.ip_address,
+ )
+ await client.unregister_service(
+ project=run_model.project.name,
+ run_name=run_model.run_name,
+ )
+ await do_register()
+ else:
+ raise
+ except SSHError:
+ raise ServerClientError("Gateway tunnel is not working")
+ except httpx.RequestError as e:
+ logger.debug("Gateway request failed", exc_info=True)
+ raise GatewayError(f"Gateway is not working: {e!r}")
events.emit(
session,
diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py
index 6c8b7233f6..2c0a66be5a 100644
--- a/src/dstack/_internal/server/testing/common.py
+++ b/src/dstack/_internal/server/testing/common.py
@@ -638,7 +638,6 @@ async def create_gateway(
name: str = "test_gateway",
region: str = "us",
wildcard_domain: Optional[str] = None,
- gateway_compute_id: Optional[UUID] = None,
status: Optional[GatewayStatus] = GatewayStatus.SUBMITTED,
last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc),
forbid_new_services: bool = False,
@@ -649,7 +648,6 @@ async def create_gateway(
name=name,
region=region,
wildcard_domain=wildcard_domain,
- gateway_compute_id=gateway_compute_id,
status=status,
last_processed_at=last_processed_at,
forbid_new_services=forbid_new_services,
@@ -661,6 +659,7 @@ async def create_gateway(
async def create_gateway_compute(
session: AsyncSession,
+ gateway_id: Optional[UUID] = None,
backend_id: Optional[UUID] = None,
ip_address: Optional[str] = "1.1.1.1",
region: str = "us",
@@ -669,6 +668,7 @@ async def create_gateway_compute(
ssh_public_key: str = "",
) -> GatewayComputeModel:
gateway_compute = GatewayComputeModel(
+ gateway_id=gateway_id,
backend_id=backend_id,
ip_address=ip_address,
region=region,
diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py
index d1113c90d1..2759d8a236 100644
--- a/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py
+++ b/src/tests/_internal/server/background/pipeline_tasks/test_gateways.py
@@ -6,10 +6,15 @@
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
-from sqlalchemy.orm import joinedload
+from sqlalchemy.orm import selectinload
from dstack._internal.core.errors import BackendError
-from dstack._internal.core.models.gateways import GatewayProvisioningData, GatewayStatus
+from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.gateways import (
+ GatewayConfiguration,
+ GatewayProvisioningData,
+ GatewayStatus,
+)
from dstack._internal.server.background.pipeline_tasks.gateways import (
GatewayFetcher,
GatewayPipeline,
@@ -257,12 +262,12 @@ async def test_submitted_to_provisioning(
res = await session.execute(
select(GatewayModel)
.where(GatewayModel.id == gateway.id)
- .options(joinedload(GatewayModel.gateway_compute))
+ .options(selectinload(GatewayModel.gateway_computes))
)
gateway = res.unique().scalar_one()
assert gateway.status == GatewayStatus.PROVISIONING
- assert gateway.gateway_compute is not None
- assert gateway.gateway_compute.ip_address == "2.2.2.2"
+ assert len(gateway.gateway_computes) > 0
+ assert gateway.gateway_computes[0].ip_address == "2.2.2.2"
events = await list_events(session)
assert len(events) == 1
assert events[0].message == "Gateway status changed SUBMITTED -> PROVISIONING"
@@ -300,23 +305,129 @@ async def test_marks_gateway_as_failed_if_gateway_creation_errors(
assert len(events) == 1
assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)"
+ async def test_submitted_creates_multiple_computes_for_multi_replica(
+ self, test_db, session: AsyncSession, worker: GatewayWorker
+ ):
+ project = await create_project(session=session)
+ backend = await create_backend(session=session, project_id=project.id)
+ gateway = await create_gateway(
+ session=session,
+ project_id=project.id,
+ backend_id=backend.id,
+ status=GatewayStatus.SUBMITTED,
+ )
+ config = GatewayConfiguration(
+ name=gateway.name,
+ backend=BackendType.AWS,
+ region=gateway.region,
+ replicas=2,
+ )
+ gateway.configuration = config.json()
+ gateway.lock_token = uuid.uuid4()
+ gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
+ await session.commit()
+
+ with patch(
+ "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error"
+ ) as m:
+ aws = Mock()
+ m.return_value = (backend, aws)
+ aws.compute.return_value = Mock(spec=ComputeMockSpec)
+ aws.compute.return_value.create_gateway.side_effect = [
+ GatewayProvisioningData(instance_id="i-aaa", ip_address="2.2.2.2", region="us"),
+ GatewayProvisioningData(instance_id="i-bbb", ip_address="3.3.3.3", region="us"),
+ ]
+ await worker.process(_gateway_to_pipeline_item(gateway))
+ assert aws.compute.return_value.create_gateway.call_count == 2
+
+ await session.refresh(gateway)
+ res = await session.execute(
+ select(GatewayModel)
+ .where(GatewayModel.id == gateway.id)
+ .options(selectinload(GatewayModel.gateway_computes))
+ )
+ gateway = res.unique().scalar_one()
+ assert gateway.status == GatewayStatus.PROVISIONING
+ computes = sorted(gateway.gateway_computes, key=lambda c: c.replica_num)
+ assert len(computes) == 2
+ assert computes[0].ip_address == "2.2.2.2"
+ assert computes[0].replica_num == 0
+ assert computes[1].ip_address == "3.3.3.3"
+ assert computes[1].replica_num == 1
+
+ async def test_marks_gateway_as_failed_if_second_replica_creation_errors(
+ self, test_db, session: AsyncSession, worker: GatewayWorker
+ ):
+ project = await create_project(session=session)
+ backend = await create_backend(session=session, project_id=project.id)
+ gateway = await create_gateway(
+ session=session,
+ project_id=project.id,
+ backend_id=backend.id,
+ status=GatewayStatus.SUBMITTED,
+ )
+ config = GatewayConfiguration(
+ name=gateway.name,
+ backend=BackendType.AWS,
+ region=gateway.region,
+ replicas=2,
+ )
+ gateway.configuration = config.json()
+ gateway.lock_token = uuid.uuid4()
+ gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
+ await session.commit()
+
+ with patch(
+ "dstack._internal.server.services.backends.get_project_backend_with_model_by_type_or_error"
+ ) as m:
+ aws = Mock()
+ m.return_value = (backend, aws)
+ aws.compute.return_value = Mock(spec=ComputeMockSpec)
+ aws.compute.return_value.create_gateway.side_effect = [
+ GatewayProvisioningData(instance_id="i-aaa", ip_address="2.2.2.2", region="us"),
+ BackendError("Some error"),
+ ]
+ await worker.process(_gateway_to_pipeline_item(gateway))
+ assert aws.compute.return_value.create_gateway.call_count == 2
+
+ await session.refresh(gateway)
+ res = await session.execute(
+ select(GatewayModel)
+ .where(GatewayModel.id == gateway.id)
+ .options(selectinload(GatewayModel.gateway_computes))
+ )
+ gateway = res.unique().scalar_one()
+ assert gateway.status == GatewayStatus.FAILED
+ assert gateway.status_message == "Some error"
+ # The first replica's compute is saved even though the second failed
+ assert len(gateway.gateway_computes) == 1
+ assert gateway.gateway_computes[0].ip_address == "2.2.2.2"
+ assert gateway.gateway_computes[0].replica_num == 0
+ events = await list_events(session)
+ assert len(events) == 1
+ assert events[0].message == "Gateway status changed SUBMITTED -> FAILED (Some error)"
+
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
class TestGatewayWorkerProvisioning:
+ @pytest.mark.parametrize("legacy_compute", [False, True])
async def test_provisioning_to_running(
- self, test_db, session: AsyncSession, worker: GatewayWorker
+ self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool
):
project = await create_project(session=session)
backend = await create_backend(session=session, project_id=project.id)
- gateway_compute = await create_gateway_compute(session)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.PROVISIONING,
)
+ if legacy_compute:
+ gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
+ gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style
+ else:
+ await create_gateway_compute(session, gateway_id=gateway.id)
gateway.lock_token = uuid.uuid4()
gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
await session.commit()
@@ -335,19 +446,57 @@ async def test_provisioning_to_running(
assert len(events) == 1
assert events[0].message == "Gateway status changed PROVISIONING -> RUNNING"
- async def test_marks_gateway_as_failed_if_fails_to_connect(
+ async def test_provisioning_to_running_with_multiple_replicas(
self, test_db, session: AsyncSession, worker: GatewayWorker
):
project = await create_project(session=session)
backend = await create_backend(session=session, project_id=project.id)
- gateway_compute = await create_gateway_compute(session)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.PROVISIONING,
)
+ await create_gateway_compute(session, gateway_id=gateway.id, ip_address="1.1.1.1")
+ compute1 = await create_gateway_compute(
+ session, gateway_id=gateway.id, ip_address="2.2.2.2"
+ )
+ compute1.replica_num = 1
+ gateway.lock_token = uuid.uuid4()
+ gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
+ await session.commit()
+
+ with patch(
+ "dstack._internal.server.services.gateways.gateway_connections_pool.get_or_add"
+ ) as pool_add:
+ pool_add.return_value = MagicMock()
+ pool_add.return_value.client.return_value = MagicMock(AsyncContextManager())
+ await worker.process(_gateway_to_pipeline_item(gateway))
+ assert pool_add.call_count == 2
+
+ await session.refresh(gateway)
+ assert gateway.status == GatewayStatus.RUNNING
+ events = await list_events(session)
+ assert len(events) == 1
+ assert events[0].message == "Gateway status changed PROVISIONING -> RUNNING"
+
+ @pytest.mark.parametrize("legacy_compute", [False, True])
+ async def test_marks_gateway_as_failed_if_fails_to_connect(
+ self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool
+ ):
+ project = await create_project(session=session)
+ backend = await create_backend(session=session, project_id=project.id)
+ gateway = await create_gateway(
+ session=session,
+ project_id=project.id,
+ backend_id=backend.id,
+ status=GatewayStatus.PROVISIONING,
+ )
+ if legacy_compute:
+ gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
+ gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style
+ else:
+ gateway_compute = await create_gateway_compute(session, gateway_id=gateway.id)
gateway.lock_token = uuid.uuid4()
gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
await session.commit()
@@ -360,8 +509,55 @@ async def test_marks_gateway_as_failed_if_fails_to_connect(
connect_to_gateway_with_retry_mock.assert_called_once()
await session.refresh(gateway)
+ await session.refresh(gateway_compute)
assert gateway.status == GatewayStatus.FAILED
assert gateway.status_message == "Failed to connect to gateway"
+ assert gateway_compute.active is False
+ events = await list_events(session)
+ assert len(events) == 1
+ assert (
+ events[0].message
+ == "Gateway status changed PROVISIONING -> FAILED (Failed to connect to gateway)"
+ )
+
+ async def test_marks_gateway_as_failed_if_any_replica_fails_to_connect(
+ self, test_db, session: AsyncSession, worker: GatewayWorker
+ ):
+ project = await create_project(session=session)
+ backend = await create_backend(session=session, project_id=project.id)
+ gateway = await create_gateway(
+ session=session,
+ project_id=project.id,
+ backend_id=backend.id,
+ status=GatewayStatus.PROVISIONING,
+ )
+ compute0 = await create_gateway_compute(
+ session, gateway_id=gateway.id, ip_address="1.1.1.1"
+ )
+ compute1 = await create_gateway_compute(
+ session, gateway_id=gateway.id, ip_address="2.2.2.2"
+ )
+ compute1.replica_num = 1
+ gateway.lock_token = uuid.uuid4()
+ gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
+ await session.commit()
+
+ with patch(
+ "dstack._internal.server.services.gateways.connect_to_gateway_with_retry"
+ ) as connect_mock:
+ connect_mock.return_value = None
+ await worker.process(_gateway_to_pipeline_item(gateway))
+ assert connect_mock.call_count == 2
+
+ await session.refresh(gateway)
+ assert gateway.status == GatewayStatus.FAILED
+ assert gateway.status_message == "Failed to connect to gateway"
+
+ await session.refresh(compute0)
+ await session.refresh(compute1)
+ assert compute0.active is False
+ assert compute1.active is False
+
events = await list_events(session)
assert len(events) == 1
assert (
@@ -373,19 +569,25 @@ async def test_marks_gateway_as_failed_if_fails_to_connect(
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
class TestGatewayWorkerDeleted:
+ @pytest.mark.parametrize("legacy_compute", [False, True])
async def test_deletes_gateway_and_marks_compute_deleted(
- self, test_db, session: AsyncSession, worker: GatewayWorker
+ self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool
):
project = await create_project(session=session)
backend = await create_backend(session=session, project_id=project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
)
+ if legacy_compute:
+ gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
+ gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style
+ else:
+ gateway_compute = await create_gateway_compute(
+ session=session, backend_id=backend.id, gateway_id=gateway.id
+ )
gateway.lock_token = uuid.uuid4()
gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
gateway.to_be_deleted = True
@@ -418,19 +620,25 @@ async def test_deletes_gateway_and_marks_compute_deleted(
assert len(events) == 1
assert events[0].message == "Gateway deleted"
+ @pytest.mark.parametrize("legacy_compute", [False, True])
async def test_keeps_gateway_if_terminate_fails(
- self, test_db, session: AsyncSession, worker: GatewayWorker
+ self, test_db, session: AsyncSession, worker: GatewayWorker, legacy_compute: bool
):
project = await create_project(session=session)
backend = await create_backend(session=session, project_id=project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
)
+ if legacy_compute:
+ gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
+ gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style
+ else:
+ gateway_compute = await create_gateway_compute(
+ session=session, backend_id=backend.id, gateway_id=gateway.id
+ )
gateway.lock_token = uuid.uuid4()
gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
gateway.lock_owner = "GatewayPipeline"
@@ -470,3 +678,112 @@ async def test_keeps_gateway_if_terminate_fails(
assert gateway_compute.deleted is False
events = await list_events(session)
assert len(events) == 0
+
+ async def test_deletes_gateway_with_multiple_replicas(
+ self, test_db, session: AsyncSession, worker: GatewayWorker
+ ):
+ project = await create_project(session=session)
+ backend = await create_backend(session=session, project_id=project.id)
+ gateway = await create_gateway(
+ session=session,
+ project_id=project.id,
+ backend_id=backend.id,
+ status=GatewayStatus.RUNNING,
+ )
+ compute0 = await create_gateway_compute(
+ session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="1.1.1.1"
+ )
+ compute1 = await create_gateway_compute(
+ session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="2.2.2.2"
+ )
+ compute1.replica_num = 1
+ gateway.lock_token = uuid.uuid4()
+ gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
+ gateway.to_be_deleted = True
+ await session.commit()
+
+ with (
+ patch(
+ "dstack._internal.server.services.backends.get_project_backend_by_type_or_error"
+ ) as get_backend_mock,
+ patch(
+ "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove"
+ ) as remove_connection_mock,
+ ):
+ backend_mock = Mock()
+ backend_mock.compute.return_value = Mock(spec=ComputeMockSpec)
+ get_backend_mock.return_value = backend_mock
+
+ await worker.process(_gateway_to_pipeline_item(gateway))
+
+ assert backend_mock.compute.return_value.terminate_gateway.call_count == 2
+ assert remove_connection_mock.call_count == 2
+
+ await session.refresh(compute0)
+ await session.refresh(compute1)
+ res = await session.execute(select(GatewayModel.id).where(GatewayModel.id == gateway.id))
+ assert res.scalar_one_or_none() is None
+ assert compute0.active is False
+ assert compute0.deleted is True
+ assert compute1.active is False
+ assert compute1.deleted is True
+ events = await list_events(session)
+ assert len(events) == 1
+ assert events[0].message == "Gateway deleted"
+
+ async def test_keeps_gateway_if_second_replica_terminate_fails(
+ self, test_db, session: AsyncSession, worker: GatewayWorker
+ ):
+ project = await create_project(session=session)
+ backend = await create_backend(session=session, project_id=project.id)
+ gateway = await create_gateway(
+ session=session,
+ project_id=project.id,
+ backend_id=backend.id,
+ status=GatewayStatus.RUNNING,
+ )
+ compute0 = await create_gateway_compute(
+ session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="1.1.1.1"
+ )
+ compute1 = await create_gateway_compute(
+ session=session, backend_id=backend.id, gateway_id=gateway.id, ip_address="2.2.2.2"
+ )
+ compute1.replica_num = 1
+ gateway.lock_token = uuid.uuid4()
+ gateway.lock_expires_at = datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc)
+ gateway.lock_owner = "GatewayPipeline"
+ gateway.to_be_deleted = True
+ original_last_processed_at = gateway.last_processed_at
+ await session.commit()
+
+ with (
+ patch(
+ "dstack._internal.server.services.backends.get_project_backend_by_type_or_error"
+ ) as get_backend_mock,
+ patch(
+ "dstack._internal.server.background.pipeline_tasks.gateways.gateway_connections_pool.remove"
+ ) as remove_connection_mock,
+ ):
+ backend_mock = Mock()
+ backend_mock.compute.return_value = Mock(spec=ComputeMockSpec)
+ backend_mock.compute.return_value.terminate_gateway.side_effect = [
+ None,
+ BackendError("Terminate failed"),
+ ]
+ get_backend_mock.return_value = backend_mock
+
+ await worker.process(_gateway_to_pipeline_item(gateway))
+
+ assert backend_mock.compute.return_value.terminate_gateway.call_count == 2
+ remove_connection_mock.assert_called_once_with(compute0.ip_address)
+
+ await session.refresh(gateway)
+ await session.refresh(compute0)
+ await session.refresh(compute1)
+ assert gateway.to_be_deleted is True
+ assert gateway.last_processed_at > original_last_processed_at
+ assert gateway.lock_token is None
+ assert gateway.lock_expires_at is None
+ assert gateway.lock_owner is None
+ assert compute0.deleted is False
+ assert compute1.deleted is False
diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py
index e308b89ce8..85aa00e0b6 100644
--- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py
+++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py
@@ -1892,19 +1892,19 @@ async def test_registers_service_replica_in_gateway(
project = await create_project(session=session, owner=user)
repo = await create_repo(session=session, project_id=project.id)
backend = await create_backend(session=session, project_id=project.id)
- gateway_compute = await create_gateway_compute(
- session=session,
- backend_id=backend.id,
- )
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
name="test-gateway",
wildcard_domain="example.com",
)
+ await create_gateway_compute(
+ session=session,
+ backend_id=backend.id,
+ gateway_id=gateway.id,
+ )
run = await create_run(
session=session,
project=project,
@@ -1985,19 +1985,19 @@ async def test_registers_service_replica_in_gateway_when_running_on_imported_ins
)
repo = await create_repo(session=session, project_id=importer_project.id)
backend = await create_backend(session=session, project_id=importer_project.id)
- gateway_compute = await create_gateway_compute(
- session=session,
- backend_id=backend.id,
- )
gateway = await create_gateway(
session=session,
project_id=importer_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
name="test-gateway",
wildcard_domain="example.com",
)
+ await create_gateway_compute(
+ session=session,
+ backend_id=backend.id,
+ gateway_id=gateway.id,
+ )
run = await create_run(
session=session,
project=importer_project,
diff --git a/src/tests/_internal/server/compatibility/__init__.py b/src/tests/_internal/server/compatibility/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/tests/_internal/server/compatibility/test_gateways.py b/src/tests/_internal/server/compatibility/test_gateways.py
new file mode 100644
index 0000000000..4bbd2a80d1
--- /dev/null
+++ b/src/tests/_internal/server/compatibility/test_gateways.py
@@ -0,0 +1,88 @@
+import uuid
+from datetime import datetime, timezone
+
+from packaging.version import Version
+
+from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.gateways import (
+ Gateway,
+ GatewayConfiguration,
+ GatewayReplica,
+ GatewayStatus,
+)
+from dstack._internal.server.compatibility.gateways import patch_gateway
+from dstack._internal.utils.common import get_current_datetime
+
+_CREATED_AT = datetime(2025, 1, 1, tzinfo=timezone.utc)
+_CONFIG = GatewayConfiguration(name="gw", backend=BackendType.AWS, region="us")
+
+
+def _make_gateway_replica(hostname: str = "1.2.3.4") -> GatewayReplica:
+ return GatewayReplica(
+ hostname=hostname,
+ replica_num=0,
+ backend=BackendType.AWS,
+ region="us",
+ created_at=get_current_datetime(),
+ )
+
+
+def _make_gateway(replicas=None, hostname=None) -> Gateway:
+ return Gateway(
+ id=uuid.uuid4(),
+ name="test",
+ project_name="proj",
+ backend=BackendType.AWS,
+ region="us",
+ created_at=_CREATED_AT,
+ status=GatewayStatus.RUNNING,
+ status_message=None,
+ hostname=hostname,
+ wildcard_domain=None,
+ default=False,
+ replicas=replicas or [],
+ configuration=_CONFIG,
+ )
+
+
+class TestPatchGateway:
+ def test_none_version_is_noop(self):
+ replica = _make_gateway_replica("1.2.3.4")
+ gw = _make_gateway(replicas=[replica])
+ patch_gateway(gw, None)
+ assert gw.ip_address is None
+ assert gw.instance_id is None
+ assert gw.hostname is None
+
+ def test_new_version_is_noop(self):
+ replica = _make_gateway_replica("1.2.3.4")
+ gw = _make_gateway(replicas=[replica])
+ patch_gateway(gw, Version("0.20.25"))
+ assert gw.ip_address is None
+ assert gw.instance_id is None
+
+ def test_old_version_fills_hostname_from_replica(self):
+ replica = _make_gateway_replica("1.2.3.4")
+ gw = _make_gateway(replicas=[replica], hostname=None)
+ patch_gateway(gw, Version("0.20.24"))
+ assert gw.hostname == "1.2.3.4"
+
+ def test_old_version_keeps_existing_hostname(self):
+ replica = _make_gateway_replica("1.2.3.4")
+ gw = _make_gateway(replicas=[replica], hostname="lb.example.com")
+ patch_gateway(gw, Version("0.20.24"))
+ assert gw.hostname == "lb.example.com"
+
+ def test_old_version_no_replicas_sets_empty_strings(self):
+ gw = _make_gateway(replicas=[])
+ patch_gateway(gw, Version("0.20.24"))
+ assert gw.ip_address == ""
+ assert gw.instance_id == ""
+ assert gw.hostname == ""
+
+ def test_old_version_multi_replica_is_noop(self):
+ replicas = [_make_gateway_replica("1.2.3.4"), _make_gateway_replica("5.6.7.8")]
+ gw = _make_gateway(replicas=replicas)
+ patch_gateway(gw, Version("0.20.24"))
+ assert gw.ip_address is None
+ assert gw.instance_id is None
diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py
index 5cc6bdd715..075d1f6d4a 100644
--- a/src/tests/_internal/server/routers/test_gateways.py
+++ b/src/tests/_internal/server/routers/test_gateways.py
@@ -1,3 +1,4 @@
+from typing import Any
from unittest.mock import patch
import pytest
@@ -29,23 +30,29 @@ async def test_returns_40x_if_not_authenticated(self, client: AsyncClient):
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
- async def test_list(self, test_db, session: AsyncSession, client: AsyncClient):
+ @pytest.mark.parametrize("legacy_compute", [False, True])
+ async def test_list(
+ self, test_db, session: AsyncSession, client: AsyncClient, legacy_compute: bool
+ ):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
backend = await create_backend(session=session, project_id=project.id)
- gateway_compute = await create_gateway_compute(
- session=session,
- backend_id=backend.id,
- )
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
)
+ if legacy_compute:
+ gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
+ gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style
+ else:
+ gateway_compute = await create_gateway_compute(
+ session=session, backend_id=backend.id, gateway_id=gateway.id
+ )
+ await session.commit()
response = await client.post(
f"/api/project/{project.name}/gateways/list",
headers=get_auth_headers(user.token),
@@ -60,9 +67,18 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient):
"default": False,
"status": "submitted",
"status_message": None,
- "instance_id": gateway_compute.instance_id,
- "ip_address": gateway_compute.ip_address,
- "hostname": gateway_compute.ip_address,
+ "replicas": [
+ {
+ "hostname": gateway_compute.ip_address,
+ "replica_num": 0,
+ "backend": backend.type.value,
+ "region": "us",
+ "created_at": response.json()[0]["replicas"][0]["created_at"],
+ }
+ ],
+ "instance_id": None,
+ "ip_address": None,
+ "hostname": None,
"name": gateway.name,
"region": gateway.region,
"wildcard_domain": gateway.wildcard_domain,
@@ -78,29 +94,36 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient):
"public_ip": True,
"certificate": {"type": "lets-encrypt"},
"tags": None,
+ "replicas": None,
},
}
]
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
- async def test_get(self, test_db, session: AsyncSession, client: AsyncClient):
+ @pytest.mark.parametrize("legacy_compute", [False, True])
+ async def test_get(
+ self, test_db, session: AsyncSession, client: AsyncClient, legacy_compute: bool
+ ):
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session)
await add_project_member(
session=session, project=project, user=user, project_role=ProjectRole.USER
)
backend = await create_backend(session, project.id)
- gateway_compute = await create_gateway_compute(
- session=session,
- backend_id=backend.id,
- )
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
)
+ if legacy_compute:
+ gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
+ gateway.gateway_compute_id = gateway_compute.id # pre-0.20.25 relationship style
+ else:
+ gateway_compute = await create_gateway_compute(
+ session=session, backend_id=backend.id, gateway_id=gateway.id
+ )
+ await session.commit()
response = await client.post(
f"/api/project/{project.name}/gateways/get",
json={"name": gateway.name},
@@ -115,9 +138,18 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient):
"default": False,
"status": "submitted",
"status_message": None,
- "instance_id": gateway_compute.instance_id,
- "ip_address": gateway_compute.ip_address,
- "hostname": gateway_compute.ip_address,
+ "replicas": [
+ {
+ "hostname": gateway_compute.ip_address,
+ "replica_num": 0,
+ "backend": backend.type.value,
+ "region": "us",
+ "created_at": response.json()["replicas"][0]["created_at"],
+ }
+ ],
+ "instance_id": None,
+ "ip_address": None,
+ "hostname": None,
"name": gateway.name,
"region": gateway.region,
"wildcard_domain": gateway.wildcard_domain,
@@ -133,26 +165,60 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient):
"public_ip": True,
"certificate": {"type": "lets-encrypt"},
"tags": None,
+ "replicas": None,
},
}
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
- async def test_list_non_member_public_project(
+ async def test_list_legacy_client_populates_compat_fields(
self, test_db, session: AsyncSession, client: AsyncClient
):
+ """Old clients (< 0.20.25) get ip_address/instance_id/hostname back-filled."""
user = await create_user(session, global_role=GlobalRole.USER)
- project = await create_project(session, is_public=True)
+ project = await create_project(session)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.USER
+ )
backend = await create_backend(session=session, project_id=project.id)
+ gateway = await create_gateway(
+ session=session,
+ project_id=project.id,
+ backend_id=backend.id,
+ )
gateway_compute = await create_gateway_compute(
session=session,
backend_id=backend.id,
+ gateway_id=gateway.id,
)
+ response = await client.post(
+ f"/api/project/{project.name}/gateways/list",
+ headers={**get_auth_headers(user.token), "x-api-version": "0.20.24"},
+ )
+ assert response.status_code == 200
+ assert len(response.json()) == 1
+ gw = response.json()[0]
+ assert gw["ip_address"] == gateway_compute.ip_address
+ assert gw["instance_id"] == ""
+ assert gw["hostname"] == gateway_compute.ip_address
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_list_non_member_public_project(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session, global_role=GlobalRole.USER)
+ project = await create_project(session, is_public=True)
+ backend = await create_backend(session=session, project_id=project.id)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
+ )
+ await create_gateway_compute(
+ session=session,
+ backend_id=backend.id,
+ gateway_id=gateway.id,
)
response = await client.post(
f"/api/project/{project.name}/gateways/list",
@@ -170,15 +236,15 @@ async def test_get_non_member_public_project(
user = await create_user(session, global_role=GlobalRole.USER)
project = await create_project(session, is_public=True)
backend = await create_backend(session, project.id)
- gateway_compute = await create_gateway_compute(
+ gateway = await create_gateway(
session=session,
+ project_id=project.id,
backend_id=backend.id,
)
- gateway = await create_gateway(
+ await create_gateway_compute(
session=session,
- project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
+ gateway_id=gateway.id,
)
response = await client.post(
f"/api/project/{project.name}/gateways/get",
@@ -222,14 +288,13 @@ async def test_list_returns_imported_gateway_with_include_imported(
project_role=ProjectRole.ADMIN,
)
backend = await create_backend(session=session, project_id=exporter_project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
name="exported-gateway",
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
await create_export(
session=session,
exporter_project=exporter_project,
@@ -266,14 +331,13 @@ async def test_list_not_returns_imported_gateway_without_include_imported(
project_role=ProjectRole.ADMIN,
)
backend = await create_backend(session=session, project_id=exporter_project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
name="exported-gateway",
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
await create_export(
session=session,
exporter_project=exporter_project,
@@ -308,14 +372,13 @@ async def test_get_returns_imported_gateway(
project_role=ProjectRole.ADMIN,
)
backend = await create_backend(session=session, project_id=exporter_project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
name="exported-gateway",
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
await create_export(
session=session,
exporter_project=exporter_project,
@@ -356,14 +419,13 @@ async def test_get_returns_403_on_foreign_gateway_if_not_imported(
project_role=ProjectRole.USER,
)
backend = await create_backend(session=session, project_id=exporter_project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
name="exported-gateway",
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
await create_export(
session=session,
exporter_project=exporter_project,
@@ -426,9 +488,10 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn
"region": "us",
"status": "submitted",
"status_message": None,
- "instance_id": "",
- "ip_address": "",
- "hostname": "",
+ "replicas": [],
+ "instance_id": None,
+ "ip_address": None,
+ "hostname": None,
"wildcard_domain": None,
"default": True,
"created_at": response.json()["created_at"],
@@ -444,11 +507,43 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn
"public_ip": True,
"certificate": {"type": "lets-encrypt"},
"tags": None,
+ "replicas": None,
},
}
events = await list_events(session)
assert events[0].message == "Gateway created. Status: SUBMITTED"
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ async def test_create_multi_replica_gateway(
+ self, test_db, session: AsyncSession, client: AsyncClient
+ ):
+ user = await create_user(session, global_role=GlobalRole.USER)
+ project = await create_project(session)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.ADMIN
+ )
+ await create_backend(session, project.id, backend_type=BackendType.AWS)
+ response = await client.post(
+ f"/api/project/{project.name}/gateways/create",
+ json={
+ "configuration": {
+ "type": "gateway",
+ "name": "test",
+ "backend": "aws",
+ "region": "us",
+ "replicas": 2,
+ "certificate": None,
+ },
+ },
+ headers=get_auth_headers(user.token),
+ )
+ assert response.status_code == 200
+ assert response.json()["configuration"]["replicas"] == 2
+ assert response.json()["replicas"] == [] # populated later by pipelines
+ events = await list_events(session)
+ assert events[0].message == "Gateway created. Status: SUBMITTED"
+
@pytest.mark.asyncio
@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
async def test_create_gateway_without_name(
@@ -484,9 +579,10 @@ async def test_create_gateway_without_name(
"region": "us",
"status": "submitted",
"status_message": None,
- "instance_id": "",
- "ip_address": "",
- "hostname": "",
+ "replicas": [],
+ "instance_id": None,
+ "ip_address": None,
+ "hostname": None,
"wildcard_domain": None,
"default": True,
"created_at": response.json()["created_at"],
@@ -502,6 +598,7 @@ async def test_create_gateway_without_name(
"public_ip": True,
"certificate": {"type": "lets-encrypt"},
"tags": None,
+ "replicas": None,
},
}
events = await list_events(session)
@@ -583,6 +680,100 @@ async def test_create_gateway_with_invalid_domain_interpolation(
)
assert response.status_code == 400
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+ @pytest.mark.parametrize(
+ "configuration, expected_error",
+ [
+ pytest.param(
+ {
+ "type": "gateway",
+ "name": "test",
+ "backend": "aws",
+ "region": "us",
+ "domain": "${{ run.unknown_variable }}.example.com",
+ },
+ "Cannot interpolate gateway domain name: Failed to interpolate due to missing vars: ['run.unknown_variable']",
+ id="invalid-domain-interpolation",
+ ),
+ pytest.param(
+ {
+ "type": "gateway",
+ "name": "test",
+ "backend": "aws",
+ "region": "us",
+ "certificate": {
+ "type": "acm",
+ "arn": "arn:aws:acm:us-east-1:123456789:certificate/abc",
+ },
+ "replicas": 2,
+ },
+ "Replicated gateways do not support certificates."
+ " Set either `certificate: null` or `replicas: 1` in the gateway configuration",
+ id="multi-replica-with-acm-cert",
+ ),
+ pytest.param(
+ {
+ "type": "gateway",
+ "name": "test",
+ "backend": "aws",
+ "region": "us",
+ "certificate": {"type": "lets-encrypt"},
+ "replicas": 2,
+ },
+ "Replicated gateways do not support certificates."
+ " Set either `certificate: null` or `replicas: 1` in the gateway configuration",
+ id="multi-replica-with-letsencrypt-cert",
+ ),
+ pytest.param(
+ {
+ "type": "gateway",
+ "name": "test",
+ "backend": "aws",
+ "region": "us",
+ "certificate": None,
+ "router": {"type": "sglang"},
+ "replicas": 2,
+ },
+ "The deprecated `router` property is not supported for multi-replica gateways",
+ id="multi-replica-with-router",
+ ),
+ pytest.param(
+ {
+ "type": "gateway",
+ "name": "test",
+ "backend": "aws",
+ "region": "us",
+ "certificate": None,
+ "replicas": 4,
+ },
+ "Cannot provision 4 gateway replicas. This server allows at most 3",
+ id="replicas-exceed-max",
+ ),
+ ],
+ )
+ async def test_invalid_configuration_rejected(
+ self,
+ test_db,
+ session: AsyncSession,
+ client: AsyncClient,
+ configuration: dict[str, Any],
+ expected_error: str,
+ ):
+ user = await create_user(session, global_role=GlobalRole.USER)
+ project = await create_project(session)
+ await add_project_member(
+ session=session, project=project, user=user, project_role=ProjectRole.ADMIN
+ )
+ await create_backend(session, project.id, backend_type=BackendType.AWS)
+ response = await client.post(
+ f"/api/project/{project.name}/gateways/create",
+ json={"configuration": configuration},
+ headers=get_auth_headers(user.token),
+ )
+ assert response.status_code == 400
+ assert response.json()["detail"][0]["msg"] == expected_error
+
class TestDefaultGateway:
@pytest.mark.asyncio
@@ -613,17 +804,17 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client:
session=session, project=project, user=user, project_role=ProjectRole.ADMIN
)
backend = await create_backend(session, project.id)
- gateway_compute = await create_gateway_compute(
- session=session,
- backend_id=backend.id,
- )
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
name="first_gateway",
)
+ gateway_compute = await create_gateway_compute(
+ session=session,
+ backend_id=backend.id,
+ gateway_id=gateway.id,
+ )
response = await client.post(
f"/api/project/{project.name}/gateways/set_default",
json={"name": gateway.name},
@@ -645,9 +836,18 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client:
"default": True,
"status": "submitted",
"status_message": None,
- "instance_id": gateway_compute.instance_id,
- "ip_address": gateway_compute.ip_address,
- "hostname": gateway_compute.ip_address,
+ "replicas": [
+ {
+ "hostname": gateway_compute.ip_address,
+ "replica_num": 0,
+ "backend": backend.type.value,
+ "region": "us",
+ "created_at": response.json()["replicas"][0]["created_at"],
+ }
+ ],
+ "instance_id": None,
+ "ip_address": None,
+ "hostname": None,
"name": gateway.name,
"region": gateway.region,
"wildcard_domain": gateway.wildcard_domain,
@@ -663,23 +863,24 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client:
"public_ip": True,
"certificate": {"type": "lets-encrypt"},
"tags": None,
+ "replicas": None,
},
}
events = await list_events(session)
assert len(events) == 1
assert events[0].message == "Gateway set as project default"
- second_gateway_compute = await create_gateway_compute(
- session=session,
- backend_id=backend.id,
- )
second_gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=second_gateway_compute.id,
name="second_gateway",
)
+ await create_gateway_compute(
+ session=session,
+ backend_id=backend.id,
+ gateway_id=second_gateway.id,
+ )
await clear_events(session)
response = await client.post(
f"/api/project/{project.name}/gateways/set_default",
@@ -775,14 +976,13 @@ async def test_set_imported_gateway_as_default(
project_role=ProjectRole.ADMIN,
)
backend = await create_backend(session=session, project_id=exporter_project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
name="exported-gateway",
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
await create_export(
session=session,
exporter_project=exporter_project,
@@ -820,14 +1020,13 @@ async def test_cannot_set_non_imported_foreign_gateway_as_default(
project_role=ProjectRole.ADMIN,
)
backend = await create_backend(session=session, project_id=exporter_project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
name="exported-gateway",
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
await create_export(
session=session,
exporter_project=exporter_project,
@@ -872,27 +1071,27 @@ async def test_marks_gateways_to_be_deleted(
)
backend_aws = await create_backend(session, project.id)
backend_gcp = await create_backend(session, project.id, backend_type=BackendType.GCP)
- gateway_compute_aws = await create_gateway_compute(
- session=session,
- backend_id=backend_aws.id,
- )
gateway_aws = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend_aws.id,
name="gateway-aws",
- gateway_compute_id=gateway_compute_aws.id,
)
- gateway_compute_gcp = await create_gateway_compute(
+ gateway_compute_aws = await create_gateway_compute(
session=session,
- backend_id=backend_gcp.id,
+ backend_id=backend_aws.id,
+ gateway_id=gateway_aws.id,
)
gateway_gcp = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend_gcp.id,
name="gateway-gcp",
- gateway_compute_id=gateway_compute_gcp.id,
+ )
+ gateway_compute_gcp = await create_gateway_compute(
+ session=session,
+ backend_id=backend_gcp.id,
+ gateway_id=gateway_gcp.id,
)
response = await client.post(
f"/api/project/{project.name}/gateways/delete",
@@ -991,17 +1190,17 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client:
session=session, project=project, user=user, project_role=ProjectRole.ADMIN
)
backend = await create_backend(session, project.id)
- gateway_compute = await create_gateway_compute(
- session=session,
- backend_id=backend.id,
- )
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
wildcard_domain="old.example",
)
+ gateway_compute = await create_gateway_compute(
+ session=session,
+ backend_id=backend.id,
+ gateway_id=gateway.id,
+ )
response = await client.post(
f"/api/project/{project.name}/gateways/set_wildcard_domain",
json={"name": gateway.name, "wildcard_domain": "new.example"},
@@ -1016,9 +1215,18 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client:
"status": "submitted",
"status_message": None,
"default": False,
- "instance_id": gateway_compute.instance_id,
- "ip_address": gateway_compute.ip_address,
- "hostname": gateway_compute.ip_address,
+ "replicas": [
+ {
+ "hostname": gateway_compute.ip_address,
+ "replica_num": 0,
+ "backend": backend.type.value,
+ "region": "us",
+ "created_at": response.json()["replicas"][0]["created_at"],
+ }
+ ],
+ "instance_id": None,
+ "ip_address": None,
+ "hostname": None,
"name": gateway.name,
"region": gateway.region,
"wildcard_domain": "new.example",
@@ -1034,6 +1242,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client:
"public_ip": True,
"certificate": {"type": "lets-encrypt"},
"tags": None,
+ "replicas": None,
},
}
events = await list_events(session)
diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py
index 93414c9f42..01bc19f5eb 100644
--- a/src/tests/_internal/server/routers/test_runs.py
+++ b/src/tests/_internal/server/routers/test_runs.py
@@ -3748,19 +3748,19 @@ async def test_submit_to_correct_proxy(
repo = await create_repo(session=session, project_id=project.id)
backend = await create_backend(session=session, project_id=project.id)
for gateway_name, is_default in existing_gateways:
- gateway_compute = await create_gateway_compute(
- session=session,
- backend_id=backend.id,
- )
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
name=gateway_name,
wildcard_domain=f"{gateway_name}.example",
)
+ await create_gateway_compute(
+ session=session,
+ backend_id=backend.id,
+ gateway_id=gateway.id,
+ )
if is_default:
project.default_gateway_id = gateway.id
await session.commit()
@@ -3844,16 +3844,15 @@ async def test_submit_to_foreign_gateway_only_if_imported(
session=session, owner=exporter_user, name="exporter-project"
)
backend = await create_backend(session=session, project_id=exporter_project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
name="exported-gateway",
wildcard_domain="exported-gateway.example",
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
importer_user = await create_user(
session=session, global_role=GlobalRole.USER, name="importer_user"
@@ -3929,14 +3928,13 @@ async def test_not_submits_to_default_gateway_if_not_imported(
user = await create_user(session=session, global_role=GlobalRole.USER)
gateway_project = await create_project(session=session, owner=user, name="gateway-project")
backend = await create_backend(session=session, project_id=gateway_project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=gateway_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
service_project = await create_project(session=session, owner=user, name="service-project")
# The project's default_gateway_id may point to the gateway (e.g., if the gateway was
@@ -3982,16 +3980,15 @@ async def test_interpolates_project_name_in_imported_gateway_domain(
session=session, owner=exporter_user, name="exporter-project"
)
backend = await create_backend(session=session, project_id=exporter_project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
name="exported-gateway",
wildcard_domain="${{ run.project_name }}.example.com",
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
importer_user = await create_user(
session=session, global_role=GlobalRole.USER, name="importer_user"
@@ -4041,16 +4038,15 @@ async def test_returns_error_if_imported_gateway_domain_has_unknown_variable(
session=session, owner=exporter_user, name="exporter-project"
)
backend = await create_backend(session=session, project_id=exporter_project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=exporter_project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
name="exported-gateway",
wildcard_domain="${{ run.unknown_variable }}.example.com",
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
importer_user = await create_user(
session=session, global_role=GlobalRole.USER, name="importer_user"
@@ -4108,15 +4104,14 @@ async def test_unregister_dangling_service(
)
repo = await create_repo(session=session, project_id=project.id)
backend = await create_backend(session=session, project_id=project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
wildcard_domain="example.com",
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
project.default_gateway_id = gateway.id
await session.commit()
@@ -4158,16 +4153,15 @@ async def test_return_error_if_default_gateway_forbids_new_services(
)
repo = await create_repo(session=session, project_id=project.id)
backend = await create_backend(session=session, project_id=project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
wildcard_domain="example.com",
forbid_new_services=True,
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
project.default_gateway_id = gateway.id
await session.commit()
@@ -4196,17 +4190,16 @@ async def test_return_error_if_explicitly_specified_gateway_forbids_new_services
)
repo = await create_repo(session=session, project_id=project.id)
backend = await create_backend(session=session, project_id=project.id)
- gateway_compute = await create_gateway_compute(session=session, backend_id=backend.id)
- await create_gateway(
+ gateway = await create_gateway(
session=session,
project_id=project.id,
backend_id=backend.id,
- gateway_compute_id=gateway_compute.id,
status=GatewayStatus.RUNNING,
name="restricted-gateway",
wildcard_domain="example.com",
forbid_new_services=True,
)
+ await create_gateway_compute(session=session, backend_id=backend.id, gateway_id=gateway.id)
response = await client.post(
"/api/project/test-project/runs/submit",
diff --git a/src/tests/_internal/server/services/gateways/__init__.py b/src/tests/_internal/server/services/gateways/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/tests/_internal/server/services/gateways/test_gateways.py b/src/tests/_internal/server/services/gateways/test_gateways.py
new file mode 100644
index 0000000000..aaf8fe6d52
--- /dev/null
+++ b/src/tests/_internal/server/services/gateways/test_gateways.py
@@ -0,0 +1,88 @@
+import pytest
+from sqlalchemy.ext.asyncio import AsyncSession
+
+from dstack._internal.proxy.gateway.const import SERVICE_SCALING_WINDOWS
+from dstack._internal.proxy.gateway.schemas.stats import Stat
+from dstack._internal.server.services.gateways import (
+ _merge_per_window_stats,
+ get_gateway_compute_models,
+)
+from dstack._internal.server.testing.common import (
+ create_backend,
+ create_gateway,
+ create_gateway_compute,
+ create_project,
+)
+
+
+class TestMergePerWindowStats:
+ def test_empty_returns_zero_stats(self):
+ result = _merge_per_window_stats([])
+ for window in SERVICE_SCALING_WINDOWS:
+ assert result[window].requests == 0
+ assert result[window].request_time == 0.0
+
+ def test_single_replica_returns_same_values(self):
+ stats = {w: Stat(requests=10, request_time=0.5) for w in SERVICE_SCALING_WINDOWS}
+ result = _merge_per_window_stats([stats])
+ for window in SERVICE_SCALING_WINDOWS:
+ assert result[window].requests == 10
+ assert result[window].request_time == pytest.approx(0.5)
+
+ def test_multiple_replicas_sums_requests_and_averages_time(self):
+ stats_a = {w: Stat(requests=10, request_time=1.0) for w in SERVICE_SCALING_WINDOWS}
+ stats_b = {w: Stat(requests=30, request_time=3.0) for w in SERVICE_SCALING_WINDOWS}
+ result = _merge_per_window_stats([stats_a, stats_b])
+ for window in SERVICE_SCALING_WINDOWS:
+ assert result[window].requests == 40
+ assert result[window].request_time == pytest.approx(2.5) # (10*1 + 30*3) / 40
+
+ def test_zero_requests_across_all_replicas_returns_zero_time(self):
+ stats_a = {w: Stat(requests=0, request_time=0.0) for w in SERVICE_SCALING_WINDOWS}
+ stats_b = {w: Stat(requests=0, request_time=0.0) for w in SERVICE_SCALING_WINDOWS}
+ result = _merge_per_window_stats([stats_a, stats_b])
+ for window in SERVICE_SCALING_WINDOWS:
+ assert result[window].requests == 0
+ assert result[window].request_time == 0.0
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True)
+class TestGetGatewayComputeModels:
+ async def test_new_style_returns_gateway_computes(self, test_db, session: AsyncSession):
+ project = await create_project(session=session)
+ backend = await create_backend(session=session, project_id=project.id)
+ gateway = await create_gateway(
+ session=session, project_id=project.id, backend_id=backend.id
+ )
+ compute = await create_gateway_compute(
+ session=session, gateway_id=gateway.id, backend_id=backend.id
+ )
+ await session.refresh(gateway, ["gateway_computes", "gateway_compute"])
+ result = get_gateway_compute_models(gateway)
+ assert len(result) == 1
+ assert result[0].id == compute.id
+
+ async def test_old_style_returns_single_compute(self, test_db, session: AsyncSession):
+ project = await create_project(session=session)
+ backend = await create_backend(session=session, project_id=project.id)
+ compute = await create_gateway_compute(session=session, backend_id=backend.id)
+ gateway = await create_gateway(
+ session=session, project_id=project.id, backend_id=backend.id
+ )
+ gateway.gateway_compute_id = compute.id
+ await session.commit()
+ await session.refresh(gateway, ["gateway_computes", "gateway_compute"])
+ result = get_gateway_compute_models(gateway)
+ assert len(result) == 1
+ assert result[0].id == compute.id
+
+ async def test_no_computes_returns_empty(self, test_db, session: AsyncSession):
+ project = await create_project(session=session)
+ backend = await create_backend(session=session, project_id=project.id)
+ gateway = await create_gateway(
+ session=session, project_id=project.id, backend_id=backend.id
+ )
+ await session.refresh(gateway, ["gateway_computes", "gateway_compute"])
+ result = get_gateway_compute_models(gateway)
+ assert result == []