Skip to content
This repository was archived by the owner on Oct 9, 2024. It is now read-only.

Commit 6b9b96f

Browse files
authored
fix model path for int8 (#54)
1 parent 0619d9a commit 6b9b96f

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

inference_server/models/ds_inference.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from functools import partial
77

88
import torch
9-
import torch.distributed as dist
109

1110
import deepspeed
1211
from huggingface_hub import try_to_load_from_cache
1312
from transformers import AutoConfig
1413

15-
from ..utils import get_world_size, print_rank_n, run_rank_n
14+
from ..utils import get_world_size, run_rank_n
1615
from .model import Model, get_hf_model_class
1716

1817

@@ -90,13 +89,17 @@ def __exit__(self, type, value, traceback):
9089

9190

9291
def get_model_path(model_name: str):
93-
config_file = "config.json"
94-
95-
# will fall back to HUGGINGFACE_HUB_CACHE
96-
config_path = try_to_load_from_cache(model_name, config_file, cache_dir=os.getenv("TRANSFORMERS_CACHE"))
97-
98-
if config_path is not None:
99-
return os.path.dirname(config_path)
100-
# treat the model name as an explicit model path
101-
elif os.path.isfile(os.path.join(model_name, config_file)):
92+
try:
93+
config_file = "config.json"
94+
95+
# will fall back to HUGGINGFACE_HUB_CACHE
96+
config_path = try_to_load_from_cache(model_name, config_file, cache_dir=os.getenv("TRANSFORMERS_CACHE"))
97+
98+
if config_path is None:
99+
# treat the model name as an explicit model path
100+
return model_name
101+
elif os.path.isfile(os.path.join(model_name, config_file)):
102+
return os.path.dirname(config_path)
103+
except:
104+
# treat the model name as an explicit model path
102105
return model_name

inference_server/models/ds_zero.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import os
21
from argparse import Namespace
32

43
import torch
5-
import torch.distributed as dist
64

75
import deepspeed
86
from transformers import AutoConfig

inference_server/models/hf_accelerate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import torch
44

5-
from ..utils import get_world_size, print_rank_n
5+
from ..utils import get_world_size
66
from .model import Model, get_hf_model_class
77

88

0 commit comments

Comments
 (0)