From 2a1cd8a5c7989ccb41af3dc83bcdf892f6f6cfa1 Mon Sep 17 00:00:00 2001 From: Vedaanta Agarwalla Date: Wed, 20 May 2026 22:51:29 -0700 Subject: [PATCH 1/3] tests/attention: shrink fp8_vs_f16 configs from B=2 to B=1 The 9 fp8_9..fp8_17 configs in `model_configs_fp8_vs_f16` use shapes (B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference comparison. The reference path in `test_dpa_fp8_vs_f16` materializes the full (B, H, S, S) attention matrix in bf16, and keeps a handful of them live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64 the per-test peak is ~70 GiB, which exceeds the memory of common 80 GB cards (H100) and pushes the suite into OOM territory on Blackwell (~91 GB measured with the cuDNN caching allocator residue). Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured on B200 (SM_100, cuDNN 9.23, TE main): per-test peak `torch.cuda.max_memory_allocated`: before: 70.0 GiB (fp8_14) after : 36.1 GiB (fp8_14) -48% per-test peak `nvidia-smi memory.used`: before: 96.8 GiB after : 51.3 GiB -47% test outcome (B200, develop FE, this TE): identical 618F / 2196P / 891S, wall time within ~3% The shrunk configs still exercise every distinct shape/mask/SWA/GQA combination that the originals did -- only B is smaller. The suite now fits comfortably on 80 GB cards. fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small (~few GiB) and the larger batch is useful coverage for padding_causal. Signed-off-by: Vedaanta Agarwalla --- tests/pytorch/attention/test_attention.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 32ea1694ee..84c5d6c37f 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1911,7 +1911,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig( - 2, + 1, 4096, 128, 192, @@ -1926,22 +1926,22 @@ def get_model(dtype, config): attn_mask_type="causal", ), "fp8_11": ModelConfig( - 2, + 1, 4096, 128, 192, head_dim_v=128, attn_mask_type="causal_bottom_right", ), - "fp8_12": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), - "fp8_14": ModelConfig(2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), - "fp8_15": ModelConfig(2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), + "fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_13": ModelConfig(1, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), + "fp8_14": ModelConfig(1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), "fp8_16": ModelConfig( - 2, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + 1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" ), "fp8_17": ModelConfig( - 2, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + 1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" ), "fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), "fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), From a48e7c54ff2c0ebb468ef61a47c9cb8bda312214 Mon Sep 17 00:00:00 2001 From: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com> Date: Thu, 21 May 2026 15:27:12 -0700 Subject: [PATCH 2/3] address changes recommended by Kshitij Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com> --- tests/pytorch/attention/test_attention.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 84c5d6c37f..157de12e56 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1911,37 +1911,37 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig( - 1, - 4096, + 2, + 2048, 128, 192, head_dim_v=128, ), "fp8_10": ModelConfig( - 1, - 4096, + 2, + 2048, 128, 192, head_dim_v=128, attn_mask_type="causal", ), "fp8_11": ModelConfig( - 1, - 4096, + 2, + 2048, 128, 192, head_dim_v=128, attn_mask_type="causal_bottom_right", ), "fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_13": ModelConfig(1, 8192, 32, 128, attn_mask_type="causal", window_size=(128, 0)), - "fp8_14": ModelConfig(1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "fp8_13": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), + "fp8_14": ModelConfig(2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), "fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), "fp8_16": ModelConfig( 1, 8192, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" ), "fp8_17": ModelConfig( - 1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" + 2, 4096, 64, 64, attn_mask_type="causal", window_size=(128, 0), softmax_type="learnable" ), "fp8_18": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), "fp8_19": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), From cd297630e4a5713005d167a0aca7c64b7775a291 Mon Sep 17 00:00:00 2001 From: Vedaanta Agarwalla Date: Thu, 21 May 2026 15:44:52 -0700 Subject: [PATCH 3/3] tests/attention: black format fp8_13 ModelConfig Line was 105 chars; black requires <=100 with the project's preview+ string_processing settings. Signed-off-by: Vedaanta Agarwalla --- tests/pytorch/attention/test_attention.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 157de12e56..769170d3b7 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1934,7 +1934,9 @@ def get_model(dtype, config): attn_mask_type="causal_bottom_right", ), "fp8_12": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), - "fp8_13": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0)), + "fp8_13": ModelConfig( + 2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal", window_size=(128, 0) + ), "fp8_14": ModelConfig(2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), "fp8_15": ModelConfig(1, 8192, 64, 64, attn_mask_type="causal", window_size=(128, 0)), "fp8_16": ModelConfig(