|
35 | 35 | GPTQ_ENABLED = False |
36 | 36 |
|
37 | 37 | ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None) |
38 | | -MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MODELS_HOME", "/home/senuser/models") |
| 38 | +MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home") |
39 | 39 |
|
40 | 40 | # Add models to test here |
41 | 41 | LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct" |
|
44 | 44 | LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct" |
45 | 45 |
|
46 | 46 | micro_model_mapping = { |
47 | | - LLAMA_3p1_8B_INSTRUCT: os.path.join(MODELS_HOME, "llama-8b-layers-3-step-24000"), |
| 47 | + LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-8b-layers-3-step-24000"), |
| 48 | + GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-24000") |
48 | 49 | } |
49 | 50 |
|
50 | 51 | SHARE_GPT_DATASET_PATH = os.environ.get( |
|
56 | 57 | os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1" |
57 | 58 | ) |
58 | 59 | validation_info_dir = os.environ.get( |
59 | | - "FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/home/senuser/models/validation_info" |
| 60 | + "FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info" |
60 | 61 | ) |
61 | 62 | common_model_paths = os.environ.get( |
62 | 63 | "FMS_TEST_SHAPES_COMMON_MODEL_PATHS", |
|
125 | 126 | # thresholds are chosen based on 1024 tokens per sequence |
126 | 127 | # 1% error threshold rate between cpu fp32 and cuda fp16 |
127 | 128 | # if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above |
128 | | -# threshold key is model_id |
| 129 | +# threshold key is (model_id, is_tiny_model) |
129 | 130 | fail_thresholds = { |
130 | | - LLAMA_3p1_8B_INSTRUCT: ( |
| 131 | + (LLAMA_3p1_8B_INSTRUCT, False): ( |
131 | 132 | 2.6994638133048965, |
132 | 133 | 0.00047589250549208347, |
133 | 134 | ), |
134 | | - GRANITE_3p2_8B_INSTRUCT: ( |
| 135 | + (GRANITE_3p2_8B_INSTRUCT, False): ( |
135 | 136 | 2.3919514417648315, |
136 | 137 | 0.0005767398688476533, |
137 | 138 | ), |
138 | | - GRANITE_20B_CODE_INSTRUCT_8K: ( |
| 139 | + (GRANITE_20B_CODE_INSTRUCT_8K, False): ( |
139 | 140 | 2.640706129074097, |
140 | 141 | 0.00034344267623964697, |
141 | 142 | ), |
142 | | - LLAMA_3p1_70B_INSTRUCT: ( |
| 143 | + (LLAMA_3p1_70B_INSTRUCT, False): ( |
143 | 144 | 2.841279556751251, |
144 | 145 | 0.0044301633024588115, |
145 | 146 | ), |
@@ -303,7 +304,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens): |
303 | 304 | os.environ["COMPILATION_MODE"] = "offline_decoder" |
304 | 305 |
|
305 | 306 | if "HF_HOME" not in os.environ: |
306 | | - os.environ["HF_HOME"] = "/home/senuser/models/hf_cache" |
| 307 | + os.environ["HF_HOME"] = "/tmp/models/hf_cache" |
307 | 308 |
|
308 | 309 | dprint( |
309 | 310 | f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}" |
@@ -420,6 +421,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens): |
420 | 421 |
|
421 | 422 | # if level 0 fails validation, validate level 1 |
422 | 423 | if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0: |
| 424 | + |
423 | 425 | if failed_validation_level_0: |
424 | 426 | dprint("failed validation level 0, testing validation level 1") |
425 | 427 | else: |
@@ -508,9 +510,15 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): |
508 | 510 | ce_threshold, diff_threshold = default_metrics_threshold |
509 | 511 | # if we have real weights, try and get the proper validation metrics threshold |
510 | 512 | else: |
511 | | - ce_threshold, diff_threshold = fail_thresholds.get( |
512 | | - model_path, default_metrics_threshold |
513 | | - ) |
| 513 | + # if we have a micro model with real weights, but no real thresholds, default to the full model thresholds |
| 514 | + if USE_MICRO_MODELS: |
| 515 | + ce_threshold, diff_threshold = fail_thresholds.get( |
| 516 | + (model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold) |
| 517 | + ) |
| 518 | + else: |
| 519 | + ce_threshold, diff_threshold = fail_thresholds.get( |
| 520 | + (model_path, False), default_metrics_threshold |
| 521 | + ) |
514 | 522 |
|
515 | 523 | # get all failed responses for each metric |
516 | 524 | ce_fail_responses = filter_failed_level_1_cases( |
|
0 commit comments