Skip to content

Conversation

@kolehma8
Copy link
Collaborator

@kolehma8 kolehma8 commented Dec 17, 2025

Summary

Changing RMS layer norm to accept DTensors.

Details

RMS layer norm parameters are NOT sharded under typical tensor parallelism implementations but the inputs (and gradients) may become sharded (DTensors). This PR resolves the issue by gathering the input tensors and performing the full compute in each device.

Relates Issues:

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@kolehma8 kolehma8 self-assigned this Dec 17, 2025
@kolehma8 kolehma8 force-pushed the kolehma8/rms_norm_dtensor branch from ff318e5 to 23fc67f Compare December 17, 2025 00:18
@shimizust
Copy link
Collaborator

Can you run the example e2e training script just using FSDP2 (since TP won't work until all kernels are tp-compatible) to ensure the perf/loss is correct?

# needs to be gathered to a local tensor to compute
# RMSE layer norm on each TP worker.
# TODO: support CP.
X = X.full_tensor()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, pytorch native TP keeps activations as DTensors and lets subsequent ops decide what to do?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is my understanding as well.

return PEFT_AVAILABLE


def infer_backend():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest renaming to infer_comm_backend

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed as suggested.

@kolehma8 kolehma8 force-pushed the kolehma8/rms_norm_dtensor branch 5 times, most recently from 80279eb to 75a54f8 Compare December 19, 2025 20:53
@kolehma8
Copy link
Collaborator Author

Can you run the example e2e training script just using FSDP2 (since TP won't work until all kernels are tp-compatible) to ensure the perf/loss is correct?

I did not see any benchmarks using FSDP in the scripts folder. I think we can say to a very high degree of confidence that this PR will not impact FSDP since if it did, the op would have crashed earlier (e.g. if any inputs are DTensors). Vice versa, this PR only impacts input DTensors and behavior for regular tensors remains unchanged.

@kolehma8 kolehma8 requested a review from shimizust December 19, 2025 21:31
@kolehma8 kolehma8 force-pushed the kolehma8/rms_norm_dtensor branch from 75a54f8 to 9af46d9 Compare January 5, 2026 18:34
@kolehma8 kolehma8 force-pushed the kolehma8/rms_norm_dtensor branch from 9af46d9 to b445975 Compare January 5, 2026 18:42
@kolehma8 kolehma8 merged commit 5101e3c into main Jan 5, 2026
5 of 7 checks passed
@kolehma8 kolehma8 deleted the kolehma8/rms_norm_dtensor branch January 5, 2026 20:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants