Skip to content

[PyTorch] Integrate cuDNN GQA + DSA backend into DotProductAttention #3028

@nvMelissa

Description

@nvMelissa

Is your feature request related to a problem? Please describe.

Transformer Engine currently does not expose a path that combines Grouped Query Attention (GQA) with DeepSeek-style sparse attention (DSA), where each query token attends only to a TopK subset of key/value tokens. Several training workloads need this combination — a GQA attention shape (many query heads sharing fewer K/V heads) with a sparsity pattern that drops attention to all but a small index list per query. Without a TE-native backend, teams either fall back to community Triton kernels, which can't reach production-scale performance, or implement sparse attention outside of TE — losing autograd integration, kernel fusion, and parity with TE's existing attention features.

Describe the solution you'd like

Add a cuDNN-backed sparse-attention path inside DotProductAttention for the PyTorch frontend that:

  • Recognizes a sparse-attention mode and dispatches to the new cuDNN GQA + DSA kernel
  • Accepts a per-query sparse_indices tensor of shape [B, S_q, topk] selecting which K/V positions each query attends to
  • Supports the standard GQA shape (num_attention_heads ≠ num_gqa_groups)
  • Supports BF16 attention at minimum (FP8 indexer extension as a follow-on if needed)
  • Integrates cleanly with TE's autograd and existing context-parallelism path
  • Ships with numerical-equivalence tests against a reference dense-attention baseline restricted to the same TopK indices

cc: @cyanguwa

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions