Skip to content

Commit 3b58cae

Browse files
BihanBihan  Rana
andauthored
Add SGLang Router Support (#3267)
* Add SGLang Router Support * Rename router_config to router * Rename sglang_workers.jinja2 to router_workers.jinja2 * Resolve SGLang API expose issue * Resolve model_id based single router to multi router Test sglang router per service implementation Test sglang router per service implementation Test sglang router per service implementation Test sglang router per service implementation Test sglang router per service implementation Test sglang router per service implementation * Resolve gateway installation command * Resolved Minor Review Comments * Resolve comments and kubernetes sglang router intregration Test gateway package update Test gateway package update Test gateway package update Test gateway package update Test gateway package update Resolve rate limits and location issue Resolve rate limits and location issue Resolve rate limits and location issue Resolve all major comments Resolve all major comments Resolve kubernetes gateway issue with sglang intregration * Resolve additional comments * Pinned sglang-router to 0.2.1 * Fix linting error --------- Co-authored-by: Bihan Rana <bihan@Bihans-MacBook-Pro.local>
1 parent 5bb7e6f commit 3b58cae

File tree

22 files changed

+804
-27
lines changed

22 files changed

+804
-27
lines changed

gateway/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ dependencies = [
1414
"dstack[gateway] @ https://github.com/dstackai/dstack/archive/refs/heads/master.tar.gz",
1515
]
1616

17+
[project.optional-dependencies]
18+
sglang = ["sglang-router==0.2.1"]
19+
1720
[tool.setuptools.package-data]
1821
"dstack.gateway" = [
1922
"resources/systemd/*",

src/dstack/_internal/core/backends/aws/compute.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,9 @@ def create_gateway(
460460
image_id=aws_resources.get_gateway_image_id(ec2_client),
461461
instance_type="t3.micro",
462462
iam_instance_profile=None,
463-
user_data=get_gateway_user_data(configuration.ssh_key_pub),
463+
user_data=get_gateway_user_data(
464+
configuration.ssh_key_pub, router=configuration.router
465+
),
464466
tags=tags,
465467
security_group_id=security_group_id,
466468
spot=False,

src/dstack/_internal/core/backends/azure/compute.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,9 @@ def create_gateway(
277277
image_reference=_get_gateway_image_ref(),
278278
vm_size="Standard_B1ms",
279279
instance_name=instance_name,
280-
user_data=get_gateway_user_data(configuration.ssh_key_pub),
280+
user_data=get_gateway_user_data(
281+
configuration.ssh_key_pub, router=configuration.router
282+
),
281283
ssh_pub_keys=[configuration.ssh_key_pub],
282284
spot=False,
283285
disk_size=30,

src/dstack/_internal/core/backends/base/compute.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
SSHKey,
3939
)
4040
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
41+
from dstack._internal.core.models.routers import AnyRouterConfig
4142
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
4243
from dstack._internal.core.models.volumes import (
4344
Volume,
@@ -881,7 +882,7 @@ def get_run_shim_script(
881882
]
882883

883884

884-
def get_gateway_user_data(authorized_key: str) -> str:
885+
def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str:
885886
return get_cloud_config(
886887
package_update=True,
887888
packages=[
@@ -897,7 +898,7 @@ def get_gateway_user_data(authorized_key: str) -> str:
897898
"s/# server_names_hash_bucket_size 64;/server_names_hash_bucket_size 128;/",
898899
"/etc/nginx/nginx.conf",
899900
],
900-
["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands())],
901+
["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands(router))],
901902
],
902903
ssh_authorized_keys=[authorized_key],
903904
)
@@ -1018,24 +1019,29 @@ def get_latest_runner_build() -> Optional[str]:
10181019
return None
10191020

10201021

1021-
def get_dstack_gateway_wheel(build: str) -> str:
1022+
def get_dstack_gateway_wheel(build: str, router: Optional[AnyRouterConfig] = None) -> str:
10221023
channel = "release" if settings.DSTACK_RELEASE else "stgn"
10231024
base_url = f"https://dstack-gateway-downloads.s3.amazonaws.com/{channel}"
10241025
if build == "latest":
10251026
r = requests.get(f"{base_url}/latest-version", timeout=5)
10261027
r.raise_for_status()
10271028
build = r.text.strip()
10281029
logger.debug("Found the latest gateway build: %s", build)
1029-
return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl"
1030+
wheel = f"{base_url}/dstack_gateway-{build}-py3-none-any.whl"
1031+
# Build package spec with extras if router is specified
1032+
if router:
1033+
return f"dstack-gateway[{router.type}] @ {wheel}"
1034+
return f"dstack-gateway @ {wheel}"
10301035

10311036

1032-
def get_dstack_gateway_commands() -> List[str]:
1037+
def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]:
10331038
build = get_dstack_runner_version()
1039+
gateway_package = get_dstack_gateway_wheel(build, router)
10341040
return [
10351041
"mkdir -p /home/ubuntu/dstack",
10361042
"python3 -m venv /home/ubuntu/dstack/blue",
10371043
"python3 -m venv /home/ubuntu/dstack/green",
1038-
f"/home/ubuntu/dstack/blue/bin/pip install {get_dstack_gateway_wheel(build)}",
1044+
f"/home/ubuntu/dstack/blue/bin/pip install '{gateway_package}'",
10391045
"sudo /home/ubuntu/dstack/blue/bin/python -m dstack.gateway.systemd install --run",
10401046
]
10411047

src/dstack/_internal/core/backends/gcp/compute.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,9 @@ def create_gateway(
599599
machine_type="e2-medium",
600600
accelerators=[],
601601
spot=False,
602-
user_data=get_gateway_user_data(configuration.ssh_key_pub),
602+
user_data=get_gateway_user_data(
603+
configuration.ssh_key_pub, router=configuration.router
604+
),
603605
authorized_keys=[configuration.ssh_key_pub],
604606
labels=labels,
605607
tags=[gcp_resources.DSTACK_GATEWAY_TAG],

src/dstack/_internal/core/backends/kubernetes/compute.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import shlex
12
import subprocess
23
import tempfile
34
import threading
45
import time
56
from enum import Enum
6-
from typing import Optional
7+
from typing import List, Optional
78

89
from gpuhunt import KNOWN_AMD_GPUS, KNOWN_NVIDIA_GPUS, AcceleratorVendor
910
from kubernetes import client
@@ -51,6 +52,7 @@
5152
)
5253
from dstack._internal.core.models.placement import PlacementGroup
5354
from dstack._internal.core.models.resources import CPUSpec, GPUSpec, Memory
55+
from dstack._internal.core.models.routers import AnyRouterConfig
5456
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
5557
from dstack._internal.core.models.volumes import Volume
5658
from dstack._internal.utils.common import get_or_error, parse_memory
@@ -371,7 +373,9 @@ def create_gateway(
371373
# Consider deploying an NLB. It seems it requires some extra configuration on the cluster:
372374
# https://docs.aws.amazon.com/eks/latest/userguide/network-load-balancing.html
373375
instance_name = generate_unique_gateway_instance_name(configuration)
374-
commands = _get_gateway_commands(authorized_keys=[configuration.ssh_key_pub])
376+
commands = _get_gateway_commands(
377+
authorized_keys=[configuration.ssh_key_pub], router=configuration.router
378+
)
375379
pod = client.V1Pod(
376380
metadata=client.V1ObjectMeta(
377381
name=instance_name,
@@ -983,9 +987,13 @@ def _add_authorized_key_to_jump_pod(
983987
)
984988

985989

986-
def _get_gateway_commands(authorized_keys: list[str]) -> list[str]:
990+
def _get_gateway_commands(
991+
authorized_keys: List[str], router: Optional[AnyRouterConfig] = None
992+
) -> List[str]:
987993
authorized_keys_content = "\n".join(authorized_keys).strip()
988-
gateway_commands = " && ".join(get_dstack_gateway_commands())
994+
gateway_commands = " && ".join(get_dstack_gateway_commands(router=router))
995+
quoted_gateway_commands = shlex.quote(gateway_commands)
996+
989997
commands = [
990998
# install packages
991999
"apt-get update && apt-get install -y sudo wget openssh-server nginx python3.10-venv libaugeas0",
@@ -1013,7 +1021,7 @@ def _get_gateway_commands(authorized_keys: list[str]) -> list[str]:
10131021
# start sshd
10141022
"/usr/sbin/sshd -p 22 -o PermitUserEnvironment=yes",
10151023
# run gateway
1016-
f"su ubuntu -c '{gateway_commands}'",
1024+
f"su ubuntu -c {quoted_gateway_commands}",
10171025
"sleep infinity",
10181026
]
10191027
return commands

src/dstack/_internal/core/models/gateways.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from dstack._internal.core.models.backends.base import BackendType
99
from dstack._internal.core.models.common import CoreModel
10+
from dstack._internal.core.models.routers import AnyRouterConfig
1011
from dstack._internal.utils.tags import tags_validator
1112

1213

@@ -50,6 +51,10 @@ class GatewayConfiguration(CoreModel):
5051
default: Annotated[bool, Field(description="Make the gateway default")] = False
5152
backend: Annotated[BackendType, Field(description="The gateway backend")]
5253
region: Annotated[str, Field(description="The gateway region")]
54+
router: Annotated[
55+
Optional[AnyRouterConfig],
56+
Field(description="The router configuration"),
57+
] = None
5358
domain: Annotated[
5459
Optional[str], Field(description="The gateway domain, e.g. `example.com`")
5560
] = None
@@ -113,6 +118,7 @@ class GatewayComputeConfiguration(CoreModel):
113118
ssh_key_pub: str
114119
certificate: Optional[AnyGatewayCertificate] = None
115120
tags: Optional[Dict[str, str]] = None
121+
router: Optional[AnyRouterConfig] = None
116122

117123

118124
class GatewayProvisioningData(CoreModel):
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from enum import Enum
2+
from typing import Literal
3+
4+
from dstack._internal.core.models.common import CoreModel
5+
6+
7+
class RouterType(str, Enum):
8+
SGLANG = "sglang"
9+
10+
11+
class SGLangRouterConfig(CoreModel):
12+
type: Literal["sglang"] = "sglang"
13+
policy: Literal["random", "round_robin", "cache_aware", "power_of_two"] = "cache_aware"
14+
15+
16+
AnyRouterConfig = SGLangRouterConfig
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{% for replica in replicas %}
2+
# Worker {{ loop.index }}
3+
upstream router_worker_{{ domain|replace('.', '_') }}_{{ ports[loop.index0] }}_upstream {
4+
server unix:{{ replica.socket }};
5+
}
6+
7+
server {
8+
listen 127.0.0.1:{{ ports[loop.index0] }};
9+
access_log off; # disable access logs for this internal endpoint
10+
11+
proxy_read_timeout 300s;
12+
proxy_send_timeout 300s;
13+
14+
location / {
15+
proxy_pass http://router_worker_{{ domain|replace('.', '_') }}_{{ ports[loop.index0] }}_upstream;
16+
proxy_http_version 1.1;
17+
proxy_set_header Host $host;
18+
proxy_set_header X-Real-IP $remote_addr;
19+
proxy_set_header Connection "";
20+
proxy_set_header Upgrade $http_upgrade;
21+
}
22+
}
23+
{% endfor %}

src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,13 @@ limit_req_zone {{ zone.key }} zone={{ zone.name }}:10m rate={{ zone.rpm }}r/m;
44

55
{% if replicas %}
66
upstream {{ domain }}.upstream {
7+
{% if router_port is not none %}
8+
server 127.0.0.1:{{ router_port }}; # SGLang router on the gateway
9+
{% else %}
710
{% for replica in replicas %}
811
server unix:{{ replica.socket }}; # replica {{ replica.id }}
912
{% endfor %}
13+
{% endif %}
1014
}
1115
{% else %}
1216

@@ -32,6 +36,13 @@ server {
3236
}
3337
{% endfor %}
3438

39+
{# For SGLang router: block all requests except whitelisted locations added dynamically above #}
40+
{% if router is not none and router.type == "sglang" %}
41+
location / {
42+
return 403;
43+
}
44+
{% endif %}
45+
3546
location @websocket {
3647
set $dstack_replica_hit 1;
3748
{% if replicas %}

0 commit comments

Comments
 (0)