-
Notifications
You must be signed in to change notification settings - Fork 457
Changing RMS layer norm to accept DTensors. #982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ff318e5 to
23fc67f
Compare
|
Can you run the example e2e training script just using FSDP2 (since TP won't work until all kernels are tp-compatible) to ensure the perf/loss is correct? |
| # needs to be gathered to a local tensor to compute | ||
| # RMSE layer norm on each TP worker. | ||
| # TODO: support CP. | ||
| X = X.full_tensor() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my understanding, pytorch native TP keeps activations as DTensors and lets subsequent ops decide what to do?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is my understanding as well.
src/liger_kernel/utils.py
Outdated
| return PEFT_AVAILABLE | ||
|
|
||
|
|
||
| def infer_backend(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest renaming to infer_comm_backend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed as suggested.
80279eb to
75a54f8
Compare
I did not see any benchmarks using FSDP in the scripts folder. I think we can say to a very high degree of confidence that this PR will not impact FSDP since if it did, the op would have crashed earlier (e.g. if any inputs are DTensors). Vice versa, this PR only impacts input DTensors and behavior for regular tensors remains unchanged. |
75a54f8 to
9af46d9
Compare
9af46d9 to
b445975
Compare
Summary
Changing RMS layer norm to accept DTensors.
Details
RMS layer norm parameters are NOT sharded under typical tensor parallelism implementations but the inputs (and gradients) may become sharded (DTensors). This PR resolves the issue by gathering the input tensors and performing the full compute in each device.
Relates Issues:
Testing Done
make testto ensure correctnessmake checkstyleto ensure code stylemake test-convergenceto ensure convergence