99import torch .distributed as dist
1010
1111import deepspeed
12- from transformers import AutoConfig , AutoTokenizer
12+ from huggingface_hub import try_to_load_from_cache
13+ from transformers import AutoConfig
1314
1415from ..utils import print_rank_n , run_rank_n
15- from .model import Model , get_downloaded_model_path , get_hf_model_class , load_tokenizer
16+ from .model import Model , get_hf_model_class
1617
1718
1819# basic DeepSpeed inference model class for benchmarking
@@ -24,26 +25,23 @@ def __init__(self, args: Namespace) -> None:
2425
2526 world_size = int (os .getenv ("WORLD_SIZE" , "1" ))
2627
27- downloaded_model_path = get_downloaded_model_path (args .model_name )
28-
29- self .tokenizer = load_tokenizer (downloaded_model_path )
30- self .pad = self .tokenizer .pad_token_id
31-
3228 # create dummy tensors for allocating space which will be filled with
3329 # the actual weights while calling deepspeed.init_inference in the
3430 # following code
3531 with deepspeed .OnDevice (dtype = torch .float16 , device = "meta" ):
3632 self .model = get_hf_model_class (args .model_class ).from_config (
37- AutoConfig .from_pretrained (downloaded_model_path ), torch_dtype = torch .bfloat16
33+ AutoConfig .from_pretrained (args . model_name ), torch_dtype = torch .bfloat16
3834 )
3935 self .model = self .model .eval ()
4036
37+ downloaded_model_path = get_model_path (args .model_name )
38+
4139 if args .dtype in [torch .float16 , torch .int8 ]:
4240 # We currently support the weights provided by microsoft (which are
4341 # pre-sharded)
44- if args .use_pre_sharded_checkpoints :
45- checkpoints_json = os .path .join (downloaded_model_path , "ds_inference_config.json" )
42+ checkpoints_json = os .path .join (downloaded_model_path , "ds_inference_config.json" )
4643
44+ if os .path .isfile (checkpoints_json ):
4745 self .model = deepspeed .init_inference (
4846 self .model ,
4947 mp_size = world_size ,
@@ -60,6 +58,7 @@ def __init__(self, args: Namespace) -> None:
6058 self .model = deepspeed .init_inference (
6159 self .model ,
6260 mp_size = world_size ,
61+ base_dir = downloaded_model_path ,
6362 dtype = args .dtype ,
6463 checkpoint = checkpoints_json ,
6564 replace_with_kernel_inject = True ,
@@ -74,6 +73,8 @@ def __init__(self, args: Namespace) -> None:
7473 print_rank_n ("Model loaded" )
7574 dist .barrier ()
7675
76+ self .post_init (args .model_name )
77+
7778
7879class TemporaryCheckpointsJSON :
7980 def __init__ (self , model_path : str ):
@@ -93,3 +94,16 @@ def __enter__(self):
9394
9495 def __exit__ (self , type , value , traceback ):
9596 return
97+
98+
99+ def get_model_path (model_name : str ):
100+ config_file = "config.json"
101+
102+ # will fall back to HUGGINGFACE_HUB_CACHE
103+ config_path = try_to_load_from_cache (model_name , config_file , cache_dir = os .getenv ("TRANSFORMERS_CACHE" ))
104+
105+ if config_path is not None :
106+ return os .path .dirname (config_path )
107+ # treat the model name as an explicit model path
108+ elif os .path .isfile (os .path .join (model_name , config_file )):
109+ return model_name
0 commit comments