Skip to content

Commit 647ac3c

Browse files
committed
added granite micro model; reverted key as model_id and keeping as model_id and is_tiny_model -- in case key not found, we default to fullsize model
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
1 parent 8c4446d commit 647ac3c

File tree

1 file changed

+20
-12
lines changed

1 file changed

+20
-12
lines changed

tests/models/test_decoders.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
GPTQ_ENABLED = False
3636

3737
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
38-
MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MODELS_HOME", "/home/senuser/models")
38+
MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home")
3939

4040
# Add models to test here
4141
LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
@@ -44,7 +44,8 @@
4444
LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
4545

4646
micro_model_mapping = {
47-
LLAMA_3p1_8B_INSTRUCT: os.path.join(MODELS_HOME, "llama-8b-layers-3-step-24000"),
47+
LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-8b-layers-3-step-24000"),
48+
GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-24000")
4849
}
4950

5051
SHARE_GPT_DATASET_PATH = os.environ.get(
@@ -56,7 +57,7 @@
5657
os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1"
5758
)
5859
validation_info_dir = os.environ.get(
59-
"FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/home/senuser/models/validation_info"
60+
"FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info"
6061
)
6162
common_model_paths = os.environ.get(
6263
"FMS_TEST_SHAPES_COMMON_MODEL_PATHS",
@@ -125,21 +126,21 @@
125126
# thresholds are chosen based on 1024 tokens per sequence
126127
# 1% error threshold rate between cpu fp32 and cuda fp16
127128
# if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above
128-
# threshold key is model_id
129+
# threshold key is (model_id, is_tiny_model)
129130
fail_thresholds = {
130-
LLAMA_3p1_8B_INSTRUCT: (
131+
(LLAMA_3p1_8B_INSTRUCT, False): (
131132
2.6994638133048965,
132133
0.00047589250549208347,
133134
),
134-
GRANITE_3p2_8B_INSTRUCT: (
135+
(GRANITE_3p2_8B_INSTRUCT, False): (
135136
2.3919514417648315,
136137
0.0005767398688476533,
137138
),
138-
GRANITE_20B_CODE_INSTRUCT_8K: (
139+
(GRANITE_20B_CODE_INSTRUCT_8K, False): (
139140
2.640706129074097,
140141
0.00034344267623964697,
141142
),
142-
LLAMA_3p1_70B_INSTRUCT: (
143+
(LLAMA_3p1_70B_INSTRUCT, False): (
143144
2.841279556751251,
144145
0.0044301633024588115,
145146
),
@@ -303,7 +304,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
303304
os.environ["COMPILATION_MODE"] = "offline_decoder"
304305

305306
if "HF_HOME" not in os.environ:
306-
os.environ["HF_HOME"] = "/home/senuser/models/hf_cache"
307+
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
307308

308309
dprint(
309310
f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
@@ -420,6 +421,7 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
420421

421422
# if level 0 fails validation, validate level 1
422423
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
424+
423425
if failed_validation_level_0:
424426
dprint("failed validation level 0, testing validation level 1")
425427
else:
@@ -508,9 +510,15 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
508510
ce_threshold, diff_threshold = default_metrics_threshold
509511
# if we have real weights, try and get the proper validation metrics threshold
510512
else:
511-
ce_threshold, diff_threshold = fail_thresholds.get(
512-
model_path, default_metrics_threshold
513-
)
513+
# if we have a micro model with real weights, but no real thresholds, default to the full model thresholds
514+
if USE_MICRO_MODELS:
515+
ce_threshold, diff_threshold = fail_thresholds.get(
516+
(model_path, True), fail_thresholds.get((model_path, False), default_metrics_threshold)
517+
)
518+
else:
519+
ce_threshold, diff_threshold = fail_thresholds.get(
520+
(model_path, False), default_metrics_threshold
521+
)
514522

515523
# get all failed responses for each metric
516524
ce_fail_responses = filter_failed_level_1_cases(

0 commit comments

Comments
 (0)