3434except ImportError :
3535 GPTQ_ENABLED = False
3636
37- ORIGINAL_HF_HOME = os .environ .get ("HF_HOME " , None )
37+ MICRO_MODELS_HOME = os .environ .get ("FMS_TEST_SHAPES_MICRO_MODELS_HOME " , "/mnt/home/models/tiny-models" )
3838
3939# Add models to test here
4040LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
4141GRANITE_3p2_8B_INSTRUCT = "ibm-granite/granite-3.2-8b-instruct"
42+ GRANITE_3p3_8B_INSTRUCT = "ibm-granite/granite-3.3-8b-instruct"
4243GRANITE_20B_CODE_INSTRUCT_8K = "ibm-granite/granite-20b-code-instruct-8k"
4344LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
45+ MISTRAL_0p3_7B_INSTRUCT = "mistralai/Mistral-7B-Instruct-v0.3"
46+
47+ micro_model_mapping = {
48+ LLAMA_3p1_8B_INSTRUCT : os .path .join (MICRO_MODELS_HOME , "llama-3.1-8b-layers-3-step-24000" ),
49+ GRANITE_3p2_8B_INSTRUCT : os .path .join (MICRO_MODELS_HOME , "granite-3.2-8b-layers-3-step-100000" ),
50+ # FIXME: Because this uses the same config as 3.2, re-using here, but should update
51+ GRANITE_3p3_8B_INSTRUCT : os .path .join (MICRO_MODELS_HOME , "granite-3.2-8b-layers-3-step-100000" ),
52+ LLAMA_3p1_70B_INSTRUCT : os .path .join (MICRO_MODELS_HOME , "llama-3.1-70b-layers-3-step-24000" )
53+ }
4454
4555SHARE_GPT_DATASET_PATH = os .environ .get (
4656 "SHARE_GPT_DATASET_PATH" , os .path .expanduser ("~/share_gpt.json" )
4757)
4858USE_MICRO_MODELS = os .environ .get ("FMS_TEST_SHAPES_USE_MICRO_MODELS" , "1" ) == "1"
4959USE_DISTRIBUTED = os .environ .get ("FMS_TEST_SHAPES_DISTRIBUTED" , "0" ) == "1"
50- FORCE_VALIDATION_LEVEL_1 = os .environ .get ("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1" , "0" ) == "1"
60+
61+ FORCE_VALIDATION_LEVEL_1 = (
62+ os .environ .get ("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1" , "0" ) == "1"
63+ )
5164skip_assertions = os .environ .get ("FMS_TEST_SHAPES_SKIP_ASSERTIONS" , {})
5265validation_info_dir = os .environ .get (
5366 "FMS_TEST_SHAPES_VALIDATION_INFO_DIR" , "/tmp/models/validation_info"
5467)
5568common_model_paths = os .environ .get (
5669 "FMS_TEST_SHAPES_COMMON_MODEL_PATHS" ,
57- [LLAMA_3p1_8B_INSTRUCT , GRANITE_3p2_8B_INSTRUCT , GRANITE_20B_CODE_INSTRUCT_8K , LLAMA_3p1_70B_INSTRUCT ],
70+ [
71+ LLAMA_3p1_8B_INSTRUCT ,
72+ GRANITE_3p2_8B_INSTRUCT ,
73+ GRANITE_3p3_8B_INSTRUCT ,
74+ GRANITE_20B_CODE_INSTRUCT_8K ,
75+ LLAMA_3p1_70B_INSTRUCT ,
76+ MISTRAL_0p3_7B_INSTRUCT
77+ ],
5878)
5979# for validation level 1, the default is a failure rate of 1%
6080# set this environment variable if you would like to relax that threshold
6181failure_rate_threshold = os .environ .get ("FMS_TEST_SHAPES_FAILURE_THRESHOLD" , 0.01 )
6282default_metrics_threshold = os .environ .get (
63- "FMS_TEST_SHAPES_METRICS_THRESHOLD" , (3.0 , .001 )
83+ "FMS_TEST_SHAPES_METRICS_THRESHOLD" , (3.0 , 0 .001 )
6484)
6585save_validation_info_outputs = (
6686 os .environ .get ("FMS_TEST_SHAPES_SAVE_VALIDATION_INFO_OUTPUTS" , "0" ) == "1"
86106
87107# pass custom default metrics threshold as a comma separated str of floats <cross-entropy threshold>,<mean diff threshold>
88108if isinstance (default_metrics_threshold , str ):
89- default_metrics_threshold = tuple ([float (m ) for m in default_metrics_threshold .split ("," )])
109+ default_metrics_threshold = tuple (
110+ [float (m ) for m in default_metrics_threshold .split ("," )]
111+ )
90112
91113# pass custom common batch sizes as a comma separated str of ints
92114if isinstance (common_batch_sizes , str ):
124146# if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above
125147# threshold key is (model_id, is_tiny_model)
126148fail_thresholds = {
127- (LLAMA_3p1_8B_INSTRUCT , True ): (
128- 3.7392955756187423 ,
129- .001 , # FIXME: compute
130- ),
131- (GRANITE_3p2_8B_INSTRUCT , True ): (
132- 2.996668996810913 ,
133- .001 , # FIXME: compute
134- ),
135- (GRANITE_20B_CODE_INSTRUCT_8K , True ): (
136- 3.7392955756187423 , # FIXME: compute -- setting to micro llama 3.1 8b instruct
137- .001 , # FIXME: compute
138- ),
139- (LLAMA_3p1_70B_INSTRUCT , True ): (
140- 3.8235735702514626 ,
141- .001 , # FIXME: compute
142- ),
143149 (LLAMA_3p1_8B_INSTRUCT , False ): (
144- 2.6994638133048965 ,
145- 0.00047589250549208347 ,
150+ 2.7080255031585696 ,
151+ 0.0004068055667448795 ,
146152 ),
147153 (GRANITE_3p2_8B_INSTRUCT , False ): (
148154 2.3919514417648315 ,
149155 0.0005767398688476533 ,
150156 ),
157+ (GRANITE_3p2_8B_INSTRUCT , True ): (
158+ 2.7449850964546205 ,
159+ 0.00018840670207282534 ,
160+ ),
161+ (GRANITE_3p3_8B_INSTRUCT , False ): (
162+ 2.4444521379470827 ,
163+ 0.0004970188625156878 ,
164+ ),
151165 (GRANITE_20B_CODE_INSTRUCT_8K , False ): (
152- 2.640706129074097 ,
153- 0.00034344267623964697 ,
166+ 2.646075320243838 ,
167+ 0.0003458251833217223 ,
154168 ),
169+ # TODO: run llama 70B with 1,2,4,8 batches
155170 (LLAMA_3p1_70B_INSTRUCT , False ): (
156171 2.841279556751251 ,
157172 0.0044301633024588115 ,
158173 ),
174+ (MISTRAL_0p3_7B_INSTRUCT , False ): (
175+ 2.846206340789795 ,
176+ 0.0008768103783950205 ,
177+ ),
159178}
160179# custom weight adaptation to be used in future. For instance if we would like to add some other adaptation, we can register it with this custom adapter
161180# and provide it when converting from an aiu fms model's weights to a cpu fms model's weights. Currently this is only done for gptq, but may be done for other
@@ -170,10 +189,6 @@ def reset_compiler():
170189 torch .compiler .reset ()
171190 torch ._dynamo .reset ()
172191 os .environ .pop ("COMPILATION_MODE" , None )
173- if ORIGINAL_HF_HOME is None :
174- os .environ .pop ("HF_HOME" , None )
175- else :
176- os .environ ["HF_HOME" ] = ORIGINAL_HF_HOME
177192
178193
179194# TODO: Currently, gptq does not have the same level of support as non-gptq models for get_model. This method provides the extra requirements for gptq for get_model,
@@ -315,9 +330,6 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
315330 torch .manual_seed (42 )
316331 os .environ ["COMPILATION_MODE" ] = "offline_decoder"
317332
318- if "HF_HOME" not in os .environ :
319- os .environ ["HF_HOME" ] = "/tmp/models/hf_cache"
320-
321333 dprint (
322334 f"testing model={ model_path } , batch_size={ batch_size } , seq_length={ seq_length } , max_new_tokens={ max_new_tokens } , micro_model={ USE_MICRO_MODELS } "
323335 )
@@ -326,13 +338,18 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
326338 gptq_kwargs_aiu , gptq_kwargs_cpu = __maybe_get_gptq_kwargs (model_path )
327339 is_gptq = len (gptq_kwargs_aiu ) != 0
328340
329- if USE_MICRO_MODELS :
341+ micro_model_path = micro_model_mapping .get (model_path , None )
342+ if USE_MICRO_MODELS and micro_model_path is None :
343+ dprint ("using randomly initialized model" )
330344 micro_model_kwargs = {"architecture" : "hf_configured" , "nlayers" : 3 }
331345 else :
346+ dprint ("using trained model" )
332347 micro_model_kwargs = {"architecture" : "hf_pretrained" }
333348
334349 if not USE_MICRO_MODELS and os .path .exists (model_path ):
335350 model_path_kwargs = {"model_path" : model_path }
351+ elif USE_MICRO_MODELS and micro_model_path is not None :
352+ model_path_kwargs = {"model_path" : micro_model_path }
336353 else :
337354 model_path_kwargs = {"variant" : model_path }
338355
@@ -439,10 +456,12 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
439456 cross_entropy = torch .nn .CrossEntropyLoss ()(
440457 r , t .softmax (dim = 1 ).to (dtype = torch .float32 )
441458 )
442- diff = torch .mean (torch .abs (
443- r .softmax (dim = 1 ).to (dtype = torch .float32 )
444- - t .softmax (dim = 1 ).to (dtype = torch .float32 )
445- ))
459+ diff = torch .mean (
460+ torch .abs (
461+ r .softmax (dim = 1 ).to (dtype = torch .float32 )
462+ - t .softmax (dim = 1 ).to (dtype = torch .float32 )
463+ )
464+ )
446465 return (cross_entropy , diff )
447466
448467 iters = 1024 // max_new_tokens
@@ -510,9 +529,20 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
510529 # only consider those metrics captured prior to the eos
511530 level_1_metrics = __filter_before_eos (level_1_metrics , eos_indexes )
512531
513- ce_threshold , diff_threshold = fail_thresholds .get (
514- (model_path , USE_MICRO_MODELS ), default_metrics_threshold
515- )
532+ # if we do not have real model weights, use a default_metrics_threshold
533+ if USE_MICRO_MODELS and micro_model_path is None :
534+ ce_threshold , diff_threshold = default_metrics_threshold
535+ # if we have real weights, try and get the proper validation metrics threshold
536+ else :
537+ # if we have a micro model with real weights, but no real thresholds, default to the full model thresholds
538+ if USE_MICRO_MODELS :
539+ ce_threshold , diff_threshold = fail_thresholds .get (
540+ (model_path , True ), fail_thresholds .get ((model_path , False ), default_metrics_threshold )
541+ )
542+ else :
543+ ce_threshold , diff_threshold = fail_thresholds .get (
544+ (model_path , False ), default_metrics_threshold
545+ )
516546
517547 # get all failed responses for each metric
518548 ce_fail_responses = filter_failed_level_1_cases (
0 commit comments