Skip to content

Commit ea7efd1

Browse files
authored
Merge pull request #125 from Abhishek-TAMU/add_cumulative_env
Add Env var to control cumulative test tokens generated per sequence
2 parents 178bc89 + 823b561 commit ea7efd1

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/models/test_decoders.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,9 @@
6969
USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1"
7070
USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1"
7171
TIMING = os.environ.get("TIMING", "")
72-
72+
CUMULATIVE_TEST_TOKENS_PER_SEQUENCE = int(
73+
os.environ.get("FMS_TEST_SHAPES_CUMULATIVE_TEST_TOKENS_PER_SEQUENCE", "1024")
74+
)
7375
ATTN_TYPE = os.environ.get("FMS_TEST_SHAPES_ATTN_TYPE", "sdpa")
7476
attention_map = {
7577
"sdpa": "sdpa_causal",
@@ -608,7 +610,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
608610
)
609611
return (cross_entropy, diff)
610612

611-
iters = 1024 // max_new_tokens
613+
iters = int(CUMULATIVE_TEST_TOKENS_PER_SEQUENCE) // max_new_tokens
612614
ce_fail_responses_list = []
613615
diff_fail_responses_list = []
614616
total_tokens = 0

0 commit comments

Comments
 (0)