|
35 | 35 | GPTQ_ENABLED = False |
36 | 36 |
|
37 | 37 | ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None) |
38 | | -MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home") |
| 38 | +MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home/models/tiny-models") |
39 | 39 |
|
40 | 40 | # Add models to test here |
41 | 41 | LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct" |
|
44 | 44 | LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct" |
45 | 45 |
|
46 | 46 | micro_model_mapping = { |
47 | | - LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-8b-layers-3-step-24000"), |
48 | | - GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-24000") |
| 47 | + LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-3.1-8b-layers-3-step-24000"), |
| 48 | + GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-100000"), |
| 49 | + LLAMA_3p1_70B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-3.1-70b-layers-3-step-24000") |
49 | 50 | } |
50 | 51 |
|
51 | 52 | SHARE_GPT_DATASET_PATH = os.environ.get( |
52 | 53 | "SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json") |
53 | 54 | ) |
54 | 55 | USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1" |
55 | 56 | USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1" |
| 57 | + |
56 | 58 | FORCE_VALIDATION_LEVEL_1 = ( |
57 | 59 | os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1" |
58 | 60 | ) |
| 61 | +skip_assertions = os.environ.get("FMS_TEST_SHAPES_SKIP_ASSERTIONS", {}) |
59 | 62 | validation_info_dir = os.environ.get( |
60 | 63 | "FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info" |
61 | 64 | ) |
|
114 | 117 | if isinstance(common_max_new_tokens, str): |
115 | 118 | common_max_new_tokens = [int(mnt) for mnt in common_max_new_tokens.split(",")] |
116 | 119 |
|
| 120 | +# pass metrics to skip as a comma separated list (ce,mean_diff) |
| 121 | +if isinstance(skip_assertions, str): |
| 122 | + _skip_assertions = [] |
| 123 | + for metric in skip_assertions.split(","): |
| 124 | + metric = metric.lower() |
| 125 | + if metric not in {"ce", "mean_diff"}: |
| 126 | + pytest.fail("FMS_TEST_SHAPES_SKIP_ASSERTIONS can only accept metrics ce and mean_diff") |
| 127 | + _skip_assertions.append(metric) |
| 128 | + skip_assertions = set(_skip_assertions) |
| 129 | + |
117 | 130 | common_shapes = list( |
118 | 131 | itertools.product( |
119 | 132 | common_model_paths, |
@@ -538,12 +551,14 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): |
538 | 551 | ce_failure_rate = len(ce_fail_responses_list) / total_tokens |
539 | 552 | dprint(f"mean diff failure rate: {diff_failure_rate}") |
540 | 553 | dprint(f"cross entropy loss failure rate: {ce_failure_rate}") |
541 | | - assert diff_failure_rate < failure_rate_threshold, ( |
542 | | - f"failure rate for mean diff was too high: {diff_failure_rate}" |
543 | | - ) |
544 | | - assert ce_failure_rate < failure_rate_threshold, ( |
545 | | - f"failure rate for cross entropy loss was too high: {ce_failure_rate}" |
546 | | - ) |
| 554 | + if "mean_diff" not in skip_assertions: |
| 555 | + assert diff_failure_rate < failure_rate_threshold, ( |
| 556 | + f"failure rate for mean diff was too high: {diff_failure_rate}" |
| 557 | + ) |
| 558 | + if "ce" not in skip_assertions: |
| 559 | + assert ce_failure_rate < failure_rate_threshold, ( |
| 560 | + f"failure rate for cross entropy loss was too high: {ce_failure_rate}" |
| 561 | + ) |
547 | 562 |
|
548 | 563 | print("passed validation level 1") |
549 | 564 | else: |
|
0 commit comments