Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Nov 11, 2025

Stacked PRs:


Currently the forward matches per microbatch (no batch invariance)

Intended usage:

> torchrun --standalone --nproc-per-node 4 examples/example_ds3_local_map.py --rng-seed 42; torchrun --standalone --nproc-per-node 8 examples/example_ds3_pp.py --rng-seed 42

(a) [14:59:59] ~/core/a/autoparallel (mybranch) > diff out/0/diff.log out/1/diff.log 
(a) [15:00:07] ~/core/a/autoparallel (mybranch) > diff out/0/weights.log out/1/pp_weights.log 
--- out/0/weights.log   2025-11-19 14:23:31.313739075 -0800
+++ out/1/pp_weights.log        2025-11-19 14:24:33.369228991 -0800
@@ -60,9 +60,12 @@
 name='freqs_cis' hash=DTensor(real=54976837666734080, imag=9351734845035773952))
 name='layers.0.moe.expert_bias' hash=DTensor(0)
 name='layers.0.moe.tokens_per_expert' hash=DTensor(0)
+name='freqs_cis' hash=DTensor(real=54976837666734080, imag=9351734845035773952))
 name='layers.1.moe.expert_bias' hash=DTensor(0)
 name='layers.1.moe.tokens_per_expert' hash=DTensor(0)
+name='freqs_cis' hash=DTensor(real=54976837666734080, imag=9351734845035773952))
 name='layers.2.moe.expert_bias' hash=DTensor(0)
 name='layers.2.moe.tokens_per_expert' hash=DTensor(0)
+name='freqs_cis' hash=DTensor(real=54976837666734080, imag=9351734845035773952))
 name='layers.3.moe.expert_bias' hash=DTensor(0)
 name='layers.3.moe.tokens_per_expert' hash=DTensor(0)

Currently, fw ins are the same, but the forward is being ran with different rng state between the two setups so there's some numerical differences

xmfan added a commit that referenced this pull request Nov 11, 2025
stack-info: PR: #246, branch: xmfan/stack/20
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 11, 2025
@xmfan xmfan changed the title Log forward intermediates hashes w/pp vs w/o pp Log forward intermediates/output hashes w/o pp Nov 11, 2025
@xmfan xmfan changed the base branch from xmfan/stack/19 to main November 12, 2025 00:04
xmfan added a commit that referenced this pull request Nov 12, 2025
stack-info: PR: #246, branch: xmfan/stack/20
@xmfan xmfan changed the title Log forward intermediates/output hashes w/o pp Log forward intermediates hashes w/pp vs w/o pp Nov 12, 2025
@xmfan xmfan changed the base branch from main to xmfan/stack/19 November 12, 2025 00:05
@xmfan xmfan changed the base branch from xmfan/stack/19 to main November 12, 2025 05:02
xmfan added a commit that referenced this pull request Nov 12, 2025
stack-info: PR: #246, branch: xmfan/stack/20
@xmfan xmfan changed the base branch from main to xmfan/stack/19 November 12, 2025 05:02
@xmfan xmfan changed the base branch from xmfan/stack/19 to main November 12, 2025 05:09
xmfan added a commit that referenced this pull request Nov 12, 2025
stack-info: PR: #246, branch: xmfan/stack/20
@xmfan xmfan changed the base branch from main to xmfan/stack/19 November 12, 2025 05:09
@xmfan xmfan changed the base branch from xmfan/stack/19 to main November 12, 2025 06:50
xmfan added a commit that referenced this pull request Nov 12, 2025
stack-info: PR: #246, branch: xmfan/stack/20
@xmfan xmfan changed the base branch from main to xmfan/stack/19 November 12, 2025 06:50
@xmfan xmfan marked this pull request as ready for review November 12, 2025 07:18
@xmfan xmfan marked this pull request as draft November 13, 2025 20:09
@xmfan xmfan changed the base branch from xmfan/stack/19 to main November 13, 2025 22:55
xmfan added a commit that referenced this pull request Nov 13, 2025
stack-info: PR: #246, branch: xmfan/stack/20
@xmfan xmfan changed the title Log forward intermediates hashes w/pp vs w/o pp Compare microbatch forward outputs and gradients Nov 13, 2025
@xmfan xmfan marked this pull request as ready for review November 13, 2025 22:57
xmfan added a commit that referenced this pull request Nov 14, 2025
stack-info: PR: #246, branch: xmfan/stack/20
@wconstab
Copy link
Contributor

granted the rng affects the grads, why does the diff show 'none' rather than a different hash?

if rng_seed is not None:
numerics_logger = NumericsLogger(logs_dir)
with AutoParallel(
model, input_fn, mesh, dynamic=True, numerics_logger=None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be numerics_logger = numerics_logger?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's too noisy, it logs the intermediates for each op in the graph. i haven't thought of how to address it yet

return

rank = torch.distributed.get_rank()
if rank == 4:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you somehow not hardcode this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take a look at the new logic

action: _Action,
ctx: _PipelineContext,
numerics_logs: Optional[list[str]] = None,
forward_hook: Callable | None = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Optional[Callable]

if self.rank == 0:
print(f"Weight hashes written to {path}")

def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, ranks):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is num_world_stages?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

number of stages

rank = torch.distributed.get_rank()
if rank == 4:
numerics_logger.log_diff(
output, rank=4, prefix=f"mb{action.microbatch_index} fwd out"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, very confusing. Also do we care about pp_rank or global rank? Finally v style schedules will have last stage on rank 0?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i just want to log from the last pp stage, and want to log it once only

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

take a look at the new logic

@sanketpurandare
Copy link
Contributor

But for the backward, all grads are None
Currently, fw ins are the same, but the forward is being ran with different rng state between the two setups so there's some numerical differences

If we land #250 first it fixes the grad issue.

@sanketpurandare
Copy link
Contributor

granted the rng affects the grads, why does the diff show 'none' rather than a different hash?

There was a bug in gradient accumulation that is fixed by #250

xmfan added a commit that referenced this pull request Nov 19, 2025
stack-info: PR: #246, branch: xmfan/stack/20
@sanketpurandare
Copy link
Contributor

@xmfan Would it be possible to add numerics logging logic to the GraphPipelineStage class. In this way when we create the stage we can pass in the numerics logging args to the stage itself. Then from the stage object you can grab the variables or callables you want and create log for any of the action methods. This way you don't need to change the signature of stage_forward etc.

stack-info: PR: #246, branch: xmfan/stack/20
@xmfan
Copy link
Member Author

xmfan commented Nov 20, 2025

verified failures are due to recent nightlies

@xmfan xmfan merged commit 10d8208 into main Nov 20, 2025
4 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants