Skip to content

Commit c5edfbc

Browse files
committed
Merge branch 'main' into featrue/add_vllm_support
2 parents efc5a0d + 83dd000 commit c5edfbc

File tree

11 files changed

+60
-13
lines changed

11 files changed

+60
-13
lines changed

src/backend/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ async def serve_index():
140140
],
141141
announce_maddrs=args.announce_maddrs,
142142
http_port=args.port,
143+
use_hfcache=args.use_hfcache,
143144
)
144145

145146
request_handler.set_scheduler_manage(scheduler_manage)

src/backend/server/scheduler_manage.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def __init__(
3232
host_maddrs: List[str] = [],
3333
announce_maddrs: List[str] = [],
3434
http_port: int = 3001,
35+
use_hfcache: bool = False,
3536
):
3637
"""Initialize the manager with networking bootstrap parameters."""
3738
self.initial_peers = initial_peers
@@ -40,6 +41,7 @@ def __init__(
4041
self.host_maddrs = host_maddrs
4142
self.announce_maddrs = announce_maddrs
4243
self.http_port = http_port
44+
self.use_hfcache = use_hfcache
4345
self.model_name = None
4446
self.init_nodes_num = None
4547
self.scheduler = None
@@ -134,7 +136,7 @@ def _start_scheduler(self, model_name, init_nodes_num):
134136
self.model_name = model_name
135137
self.init_nodes_num = init_nodes_num
136138

137-
model_info = get_model_info(model_name)
139+
model_info = get_model_info(model_name, self.use_hfcache)
138140
self.scheduler = Scheduler(model_info, [], min_nodes_bootstrapping=init_nodes_num)
139141

140142
# Run the scheduler's event/dispatch loops in background so the process

src/backend/server/server_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ def parse_args() -> argparse.Namespace:
3535
parser.add_argument(
3636
"--is-local-network", type=bool, default=True, help="Whether to use local network"
3737
)
38+
parser.add_argument(
39+
"--use-hfcache",
40+
action="store_true",
41+
default=False,
42+
help="Use local Hugging Face cache only (no network download)",
43+
)
3844

3945
parser.add_argument(
4046
"--gpu-backend",

src/backend/server/static_config.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
NODE_JOIN_COMMAND_PUBLIC_NETWORK = """parallax join -s {scheduler_addr} """
6767

6868

69-
def get_model_info(model_name):
69+
def get_model_info(model_name, use_hfcache: bool = False):
7070
def _load_config_only(name: str) -> dict:
7171
local_path = Path(name)
7272
if local_path.exists():
@@ -77,7 +77,9 @@ def _load_config_only(name: str) -> dict:
7777
# Hugging Face only – download just config.json
7878
from huggingface_hub import hf_hub_download # type: ignore
7979

80-
config_file = hf_hub_download(repo_id=name, filename="config.json")
80+
config_file = hf_hub_download(
81+
repo_id=name, filename="config.json", local_files_only=use_hfcache
82+
)
8183
with open(config_file, "r") as f:
8284
return json.load(f)
8385

src/parallax/server/executor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
dtype: str = "float16",
7474
# Backend selection
7575
gpu_backend: str = "sglang",
76+
use_hfcache: bool = False,
7677
# Scheduler Configs
7778
max_batch_size: Optional[int] = 8,
7879
max_sequence_length: Optional[int] = None,
@@ -108,6 +109,7 @@ def __init__(
108109
):
109110
# Backend
110111
self.device = get_current_device()
112+
self.use_hfcache = use_hfcache
111113
logger.debug(f"Executor initializing on device: {self.device}")
112114
self.backend_type = gpu_backend
113115

@@ -150,6 +152,7 @@ def __init__(
150152
"tp_rank": tp_rank,
151153
"tp_size": tp_size,
152154
"nccl_port": nccl_port,
155+
"using_hfcache": use_hfcache,
153156
}
154157

155158
self.model_runner, self.config, self.tokenizer = initialize_cuda_model_runner(
@@ -176,7 +179,10 @@ def __init__(
176179
f"Initializing MLX sharded model loader for repo={model_repo}, layers=[{start_layer}, {end_layer})"
177180
)
178181
self.shard_loader = MLXModelLoader(
179-
model_repo, start_layer=start_layer, end_layer=end_layer
182+
model_repo,
183+
start_layer=start_layer,
184+
end_layer=end_layer,
185+
use_hfcache=self.use_hfcache,
180186
)
181187
t0 = time.time()
182188
self.model_shard, self.config, self.tokenizer = self.shard_loader.load()
@@ -1629,5 +1635,6 @@ def create_executor_config(args: argparse.Namespace, gradient_server=None):
16291635
"tp_size": args.tp_size,
16301636
"nccl_port": args.nccl_port,
16311637
"gradient_server": gradient_server,
1638+
"use_hfcache": args.use_hfcache,
16321639
}
16331640
return config

src/parallax/server/http_server.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
executor_input_ipc_name,
9999
executor_output_ipc_name,
100100
model_path_str,
101+
use_hfcache: bool = False,
101102
):
102103
self.asyncio_tasks = set()
103104
# Init inter-process communication
@@ -114,9 +115,10 @@ def __init__(
114115
if Path(model_path_str).exists():
115116
model_path = Path(model_path_str)
116117
else:
117-
model_path = download_metadata_only(model_path_str)
118+
model_path = download_metadata_only(model_path_str, local_files_only=use_hfcache)
118119
config = load_config(model_path)
119120
self.model_path_str = model_path_str
121+
self.use_hfcache = use_hfcache
120122
self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))
121123
self.detokenizer_class, self.tokenmap = load_detokenizer(model_path, self.tokenizer)
122124

@@ -338,13 +340,18 @@ def create_error_response(
338340

339341

340342
async def init_app_states(
341-
state: State, executor_input_ipc: str, executor_output_ipc: str, model_path: str
343+
state: State,
344+
executor_input_ipc: str,
345+
executor_output_ipc: str,
346+
model_path: str,
347+
use_hfcache: bool = False,
342348
):
343349
"""Init FastAPI app states, including http handler, etc."""
344350
state.http_handler = HTTPHandler(
345351
executor_input_ipc,
346352
executor_output_ipc,
347353
model_path,
354+
use_hfcache,
348355
)
349356

350357

@@ -433,6 +440,7 @@ def __init__(self, args):
433440
self.executor_input_ipc_name = args.executor_input_ipc
434441
self.executor_output_ipc_name = args.executor_output_ipc
435442
self.model_path = args.model_path
443+
self.use_hfcache = args.use_hfcache
436444

437445
async def run_uvicorn(self):
438446
"""
@@ -467,6 +475,7 @@ def run(self):
467475
self.executor_input_ipc_name,
468476
self.executor_output_ipc_name,
469477
self.model_path,
478+
self.use_hfcache,
470479
)
471480
)
472481
asyncio.run(self.run_tasks())

src/parallax/server/server_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,13 @@ def parse_args() -> argparse.Namespace:
200200
help="GPU backend to use",
201201
)
202202

203+
parser.add_argument(
204+
"--use-hfcache",
205+
action="store_true",
206+
default=False,
207+
help="Use local Hugging Face cache only (no network download)",
208+
)
209+
203210
args = parser.parse_args()
204211

205212
# Validate arguments

src/parallax/server/shard_loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def __init__(
3636
*,
3737
start_layer: Optional[int] = None,
3838
end_layer: Optional[int] = None,
39+
use_hfcache: bool = False,
3940
):
4041
"""
4142
Initializes the model loader.
@@ -47,10 +48,12 @@ def __init__(
4748
Defaults to the beginning of the model.
4849
end_layer (Optional[int]): The ending layer index for the shard (exclusive).
4950
Defaults to the end of the model.
51+
use_hfcache (bool): If True, use local Hugging Face cache only (no network download).
5052
"""
5153
self.model_path_str = model_path_or_hf_repo
5254
self.start_layer = start_layer
5355
self.end_layer = end_layer
56+
self.use_hfcache = use_hfcache
5457
self.register_block_class()
5558

5659
def register_block_class(self):
@@ -113,6 +116,7 @@ def load(
113116
self.model_path_str,
114117
start_layer=self.start_layer,
115118
end_layer=self.end_layer,
119+
local_files_only=self.use_hfcache,
116120
)
117121
else:
118122
model_path = get_model_path(self.model_path_str)[0]

src/parallax/sglang/model_runner.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -241,26 +241,25 @@ def initialize_sgl_model_runner(
241241
"""
242242
apply_parallax_sglang_monkey_patch()
243243

244+
# Extract TP-related parameters from kwargs or use defaults
245+
tp_rank = kwargs.get("tp_rank", 0)
246+
tp_size = kwargs.get("tp_size", 1)
247+
use_hfcache = kwargs.get("use_hfcache", False)
248+
nccl_port = kwargs.get("nccl_port", None)
244249
# Use selective download for GPU models to save bandwidth and disk space
245250
from parallax.utils.selective_download import get_model_path_with_selective_download
246251

247252
logger.info(
248253
f"Downloading model with selective weight files for layers [{start_layer}, {end_layer})"
249254
)
250255
model_path = get_model_path_with_selective_download(
251-
model_repo,
252-
start_layer=start_layer,
253-
end_layer=end_layer,
256+
model_repo, start_layer=start_layer, end_layer=end_layer, use_hfcache=use_hfcache
254257
)
255258

256259
config = load_config(model_path)
257260
tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))
258261
dtype = config.get("torch_dtype", "bfloat16")
259262

260-
# Extract TP-related parameters from kwargs or use defaults
261-
tp_rank = kwargs.get("tp_rank", 0)
262-
tp_size = kwargs.get("tp_size", 1)
263-
nccl_port = kwargs.get("nccl_port", None)
264263
if nccl_port is None:
265264
nccl_port = random.randint(4000, 5000)
266265

src/parallax/utils/selective_download.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def download_metadata_only(
2424
repo_id: str,
2525
cache_dir: Optional[str] = None,
2626
force_download: bool = False,
27+
local_files_only: bool = False,
2728
) -> Path:
2829
# If a local path is provided, return it directly without contacting HF Hub
2930
local_path = Path(repo_id)
@@ -35,6 +36,7 @@ def download_metadata_only(
3536
cache_dir=cache_dir,
3637
ignore_patterns=EXCLUDE_WEIGHT_PATTERNS,
3738
force_download=force_download,
39+
local_files_only=local_files_only,
3840
)
3941
return Path(path)
4042

@@ -45,6 +47,7 @@ def selective_model_download(
4547
end_layer: Optional[int] = None,
4648
cache_dir: Optional[str] = None,
4749
force_download: bool = False,
50+
local_files_only: bool = False,
4851
) -> Path:
4952
# Handle local model directory
5053
local_path = Path(repo_id)
@@ -58,6 +61,7 @@ def selective_model_download(
5861
repo_id=repo_id,
5962
cache_dir=cache_dir,
6063
force_download=force_download,
64+
local_files_only=local_files_only,
6165
)
6266
logger.debug(f"Downloaded model metadata to {model_path}")
6367
is_remote = True
@@ -78,6 +82,7 @@ def selective_model_download(
7882
repo_id=repo_id,
7983
cache_dir=cache_dir,
8084
force_download=force_download,
85+
local_files_only=local_files_only,
8186
)
8287
else:
8388
# Step 3: Download only the needed weight files
@@ -90,6 +95,7 @@ def selective_model_download(
9095
filename=weight_file,
9196
cache_dir=cache_dir,
9297
force_download=force_download,
98+
local_files_only=local_files_only,
9399
)
94100

95101
logger.debug(f"Downloaded weight files for layers [{start_layer}, {end_layer})")
@@ -104,6 +110,7 @@ def selective_model_download(
104110
repo_id=repo_id,
105111
cache_dir=cache_dir,
106112
force_download=force_download,
113+
local_files_only=local_files_only,
107114
)
108115
else:
109116
logger.debug("No layer range specified and using local path; nothing to download")
@@ -115,9 +122,11 @@ def get_model_path_with_selective_download(
115122
model_path_or_repo: str,
116123
start_layer: Optional[int] = None,
117124
end_layer: Optional[int] = None,
125+
local_files_only: bool = False,
118126
) -> Path:
119127
return selective_model_download(
120128
repo_id=model_path_or_repo,
121129
start_layer=start_layer,
122130
end_layer=end_layer,
131+
local_files_only=local_files_only,
123132
)

0 commit comments

Comments
 (0)