Skip to content

Conversation

@shantanugupta2004
Copy link

@shantanugupta2004 shantanugupta2004 commented Dec 23, 2025

What does this PR do?

This PR replaces the matrix multiplication operator (@) with broadcasting element-wise multiplication (*) in the RotaryEmbedding implementation for several major models (Llama, Mistral, Mixtral, Qwen2, Gemma, Gemma2).
When compiling a model with torch.compile in bfloat16, the RoPE frequency calculation (which is intentionally kept in float32 for precision) triggers a UserWarning regarding TensorFloat32 (TF32) if it's not enabled.
Since the shapes involved in this specific operation [batch, dim/2, 1] and [batch, 1, seq_len] result in an outer product, using @ is mathematically equivalent to * with broadcasting. However, using * avoids the "matrix multiplication" code path in the compiler, effectively silencing the false-positive warning and potentially offering a minor performance optimization by avoiding a full GEMM call for a simple outer product.

Fixes #43012

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue link?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed.

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: afmoe, apertus, arcee, aria, bamba, bitnet, chameleon, csm, cwm, dbrx, deepseek_v3, dia, diffllama, doge, dots1, emu3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Compiling a bfloat16 model triggers float32 precision PyTorch warning

1 participant