Skip to content

Commit 3aa311d

Browse files
committed
added a model to micro models; fixed merge conflicts
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
2 parents 647ac3c + aeb7c9d commit 3aa311d

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

tests/models/test_decoders.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
GPTQ_ENABLED = False
3636

3737
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
38-
MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home")
38+
MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home/models/tiny-models")
3939

4040
# Add models to test here
4141
LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
@@ -44,18 +44,21 @@
4444
LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
4545

4646
micro_model_mapping = {
47-
LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-8b-layers-3-step-24000"),
48-
GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-24000")
47+
LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-3.1-8b-layers-3-step-24000"),
48+
GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-100000"),
49+
LLAMA_3p1_70B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-3.1-70b-layers-3-step-24000")
4950
}
5051

5152
SHARE_GPT_DATASET_PATH = os.environ.get(
5253
"SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json")
5354
)
5455
USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1"
5556
USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1"
57+
5658
FORCE_VALIDATION_LEVEL_1 = (
5759
os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1"
5860
)
61+
skip_assertions = os.environ.get("FMS_TEST_SHAPES_SKIP_ASSERTIONS", {})
5962
validation_info_dir = os.environ.get(
6063
"FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info"
6164
)
@@ -114,6 +117,16 @@
114117
if isinstance(common_max_new_tokens, str):
115118
common_max_new_tokens = [int(mnt) for mnt in common_max_new_tokens.split(",")]
116119

120+
# pass metrics to skip as a comma separated list (ce,mean_diff)
121+
if isinstance(skip_assertions, str):
122+
_skip_assertions = []
123+
for metric in skip_assertions.split(","):
124+
metric = metric.lower()
125+
if metric not in {"ce", "mean_diff"}:
126+
pytest.fail("FMS_TEST_SHAPES_SKIP_ASSERTIONS can only accept metrics ce and mean_diff")
127+
_skip_assertions.append(metric)
128+
skip_assertions = set(_skip_assertions)
129+
117130
common_shapes = list(
118131
itertools.product(
119132
common_model_paths,
@@ -538,12 +551,14 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
538551
ce_failure_rate = len(ce_fail_responses_list) / total_tokens
539552
dprint(f"mean diff failure rate: {diff_failure_rate}")
540553
dprint(f"cross entropy loss failure rate: {ce_failure_rate}")
541-
assert diff_failure_rate < failure_rate_threshold, (
542-
f"failure rate for mean diff was too high: {diff_failure_rate}"
543-
)
544-
assert ce_failure_rate < failure_rate_threshold, (
545-
f"failure rate for cross entropy loss was too high: {ce_failure_rate}"
546-
)
554+
if "mean_diff" not in skip_assertions:
555+
assert diff_failure_rate < failure_rate_threshold, (
556+
f"failure rate for mean diff was too high: {diff_failure_rate}"
557+
)
558+
if "ce" not in skip_assertions:
559+
assert ce_failure_rate < failure_rate_threshold, (
560+
f"failure rate for cross entropy loss was too high: {ce_failure_rate}"
561+
)
547562

548563
print("passed validation level 1")
549564
else:

tests/models/test_encoders.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,7 @@ def reset_compiler():
8686
else:
8787
os.environ['HF_HOME'] = ORIGINAL_HF_HOME
8888

89-
encoder_paths = ["deepset/roberta-base-squad2"]
90-
common_encoder_shapes = list(itertools.product(encoder_paths, common_batch_sizes, common_seq_lengths))
91-
92-
@pytest.mark.parametrize("model_path,batch_size,seq_length", common_encoder_shapes)
89+
@pytest.mark.parametrize("model_path,batch_size,seq_length", common_shapes)
9390
def test_common_shapes(model_path, batch_size, seq_length):
9491
os.environ["COMPILATION_MODE"] = "offline"
9592

0 commit comments

Comments
 (0)