-
Notifications
You must be signed in to change notification settings - Fork 10
Compare microbatch forward outputs and gradients #246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #246, branch: xmfan/stack/20
0813cd5 to
580144b
Compare
72c4ffc to
79bf049
Compare
stack-info: PR: #246, branch: xmfan/stack/20
79bf049 to
4b0b462
Compare
stack-info: PR: #246, branch: xmfan/stack/20
4b0b462 to
b9d82ef
Compare
stack-info: PR: #246, branch: xmfan/stack/20
b9d82ef to
adbd32c
Compare
stack-info: PR: #246, branch: xmfan/stack/20
adbd32c to
f984301
Compare
6e8451c to
59670d0
Compare
stack-info: PR: #246, branch: xmfan/stack/20
f984301 to
e5c0227
Compare
stack-info: PR: #246, branch: xmfan/stack/20
e5c0227 to
7c45448
Compare
|
granted the rng affects the grads, why does the diff show 'none' rather than a different hash? |
| if rng_seed is not None: | ||
| numerics_logger = NumericsLogger(logs_dir) | ||
| with AutoParallel( | ||
| model, input_fn, mesh, dynamic=True, numerics_logger=None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be numerics_logger = numerics_logger?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's too noisy, it logs the intermediates for each op in the graph. i haven't thought of how to address it yet
examples/example_ds3_pp.py
Outdated
| return | ||
|
|
||
| rank = torch.distributed.get_rank() | ||
| if rank == 4: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you somehow not hardcode this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
take a look at the new logic
autoparallel/graph_pp_runner.py
Outdated
| action: _Action, | ||
| ctx: _PipelineContext, | ||
| numerics_logs: Optional[list[str]] = None, | ||
| forward_hook: Callable | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Optional[Callable]
autoparallel/utils.py
Outdated
| if self.rank == 0: | ||
| print(f"Weight hashes written to {path}") | ||
|
|
||
| def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, ranks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is num_world_stages?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
number of stages
examples/example_ds3_pp.py
Outdated
| rank = torch.distributed.get_rank() | ||
| if rank == 4: | ||
| numerics_logger.log_diff( | ||
| output, rank=4, prefix=f"mb{action.microbatch_index} fwd out" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, very confusing. Also do we care about pp_rank or global rank? Finally v style schedules will have last stage on rank 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i just want to log from the last pp stage, and want to log it once only
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
take a look at the new logic
If we land #250 first it fixes the grad issue. |
There was a bug in gradient accumulation that is fixed by #250 |
stack-info: PR: #246, branch: xmfan/stack/20
7c45448 to
6e72707
Compare
|
@xmfan Would it be possible to add numerics logging logic to the |
6e72707 to
b8546e1
Compare
b8546e1 to
2ef8efe
Compare
stack-info: PR: #246, branch: xmfan/stack/20
2ef8efe to
fc61cd1
Compare
|
verified failures are due to recent nightlies |
Stacked PRs:
Currently the forward matches per microbatch (no batch invariance)
Intended usage:
Currently, fw ins are the same, but the forward is being ran with different rng state between the two setups so there's some numerical differences