[PyT] Reduce test sizes in fused attn fp8 vs fp16 to avoid OOM #3020
[PyT] Reduce test sizes in fused attn fp8 vs fp16 to avoid OOM #3020vedaanta wants to merge 4 commits into
Conversation
Greptile SummaryThis PR shrinks nine
Confidence Score: 4/5Safe to merge for the memory-reduction goal; one test config quietly changes its feature combination instead of just shrinking its size. The memory reduction is well-motivated and measured. The only concern is fp8_13, which adds num_gqa_groups=4 while simultaneously keeping the same B and S — this changes which feature combination is exercised rather than just reducing resource usage, and the SWA-without-GQA at 32 heads scenario is no longer covered. Everything else is a straightforward size reduction. tests/pytorch/attention/test_attention.py — specifically the fp8_13 entry, which changed in feature coverage, not just size. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[fp8_vs_f16 test suite] --> B[fp8_9 to fp8_11 H=128 D=192]
A --> C[fp8_12 to fp8_13 H=32 D=128 S=8192]
A --> D[fp8_14 to fp8_17 H=64 D=64]
A --> E[fp8_18 to fp8_20 unchanged]
B --> B1[fp8_9 B=2 S=2048 no mask]
B --> B2[fp8_10 B=2 S=2048 causal]
B --> B3[fp8_11 B=2 S=2048 causal_bottom_right]
C --> C1[fp8_12 B=1 S=8192 GQA]
C --> C2[fp8_13 B=2 S=8192 GQA+SWA was SWA-only]
D --> D1[fp8_14 B=2 S=4096 GQA]
D --> D2[fp8_15 B=1 S=8192 SWA]
D --> D3[fp8_16 B=1 S=8192 GQA+learnable]
D --> D4[fp8_17 B=2 S=4096 SWA+learnable]
Reviews (4): Last reviewed commit: "tests/attention: black format fp8_13 Mod..." | Re-trigger Greptile |
The 9 fp8_9..fp8_17 configs in `model_configs_fp8_vs_f16` use shapes
(B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference
comparison. The reference path in `test_dpa_fp8_vs_f16` materializes the
full (B, H, S, S) attention matrix in bf16, and keeps a handful of them
live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64
the per-test peak is ~70 GiB, which exceeds the memory of common 80 GB
cards (H100) and pushes the suite into OOM territory on Blackwell (~91
GB measured with the cuDNN caching allocator residue).
Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured
on B200 (SM_100, cuDNN 9.23, TE main):
per-test peak `torch.cuda.max_memory_allocated`:
before: 70.0 GiB (fp8_14)
after : 36.1 GiB (fp8_14) -48%
per-test peak `nvidia-smi memory.used`:
before: 96.8 GiB
after : 51.3 GiB -47%
test outcome (B200, develop FE, this TE):
identical 618F / 2196P / 891S, wall time within ~3%
The shrunk configs still exercise every distinct shape/mask/SWA/GQA
combination that the originals did -- only B is smaller. The suite now
fits comfortably on 80 GB cards.
fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small
(~few GiB) and the larger batch is useful coverage for padding_causal.
Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>
Signed-off-by: Vedaanta Agarwalla <142048820+vedaanta@users.noreply.github.com>
Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>
1a59d59 to
c3f1e50
Compare
Line was 105 chars; black requires <=100 with the project's preview+ string_processing settings. Signed-off-by: Vedaanta Agarwalla <vagarwalla@nvidia.com>
|
/te-ci pytorch L0 |
The 9 fp8_9..fp8_17 configs in
model_configs_fp8_vs_f16use shapes (B=2, S=4096-8192, H=32-128, D=64-192) for the bf16-vs-fp8 reference comparison. The reference path intest_dpa_fp8_vs_f16materializes the full (B, H, S, S) attention matrix in bf16, and keeps a handful of them live (S, P, dP, dS, dropout-mask) simultaneously. At B=2, S=8192, H=64 the per-test peak is ~70 GiB, which pushes the suite into OOM territory on Blackwell (~91 GB measured with the cuDNN caching allocator residue).Halving B to 1 halves the bytes of every (B, H, S, S) tensor. Measured on B200 (SM_100, cuDNN 9.23, TE main):
per-test peak
torch.cuda.max_memory_allocated:before: 70.0 GiB (fp8_14)
after : 36.1 GiB (fp8_14) -48%
per-test peak
nvidia-smi memory.used:before: 96.8 GiB
after : 51.3 GiB -47%
test outcome (B200, develop FE, this TE):
identical 618F / 2196P / 891S, wall time within ~3%
The shrunk configs still exercise every distinct shape/mask/SWA/GQA combination that the originals did -- only B is smaller. The suite now fits comfortably on 80 GB cards.
fp8_19/20 (B=2, S=2048) are left at B=2 because their peak is small (~few GiB) and the larger batch is useful coverage for padding_causal.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: