Skip to content

Commit 8f51d3b

Browse files
committed
Add CUMULATIVE_TEST_TOKENS_PER_SEQUENCE env var
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
1 parent 178bc89 commit 8f51d3b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/models/test_decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@
6969
USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1"
7070
USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1"
7171
TIMING = os.environ.get("TIMING", "")
72-
72+
CUMULATIVE_TEST_TOKENS_PER_SEQUENCE = os.environ.get("FMS_TEST_SHAPES_CUMULATIVE_TEST_TOKENS_PER_SEQUENCE", "1024")
7373
ATTN_TYPE = os.environ.get("FMS_TEST_SHAPES_ATTN_TYPE", "sdpa")
7474
attention_map = {
7575
"sdpa": "sdpa_causal",
@@ -608,7 +608,7 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
608608
)
609609
return (cross_entropy, diff)
610610

611-
iters = 1024 // max_new_tokens
611+
iters = CUMULATIVE_TEST_TOKENS_PER_SEQUENCE // max_new_tokens
612612
ce_fail_responses_list = []
613613
diff_fail_responses_list = []
614614
total_tokens = 0

0 commit comments

Comments
 (0)