Skip to content

Commit 36535ef

Browse files
authored
Fix configs for 'aten' bucketing (#236)
1 parent 90cd287 commit 36535ef

File tree

1 file changed

+6
-13
lines changed

1 file changed

+6
-13
lines changed

autoparallel/auto_bucketing.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,21 +119,14 @@ def aten_autobucketing_reordering_pass(
119119
def configure_inductor_for_autobucketing(mode: str = "aten"):
120120
# allow configuring inductor comms optimizations from torchtitan commandline
121121
if mode == "aten":
122-
from autoparallel.auto_bucketing import (
123-
aten_autobucketing_config,
124-
aten_autobucketing_reordering_pass,
125-
)
126-
127-
# this is from the stacked pr in https://github.com/pytorch/pytorch/pull/163960
128-
torch._inductor.config.reorder_for_peak_memory = False
129-
torch._inductor.config.reorder_for_compute_comm_overlap = False
130-
aten_autobucketing_reordering_pass = partial(
131-
aten_autobucketing_reordering_pass,
132-
configs=aten_autobucketing_config, # type: ignore
122+
torch._inductor.config.aten_distributed_optimizations.enable_overlap_scheduling = (
123+
True
133124
)
134-
torch._inductor.config.post_grad_custom_post_pass = (
135-
aten_autobucketing_reordering_pass # type: ignore
125+
torch._inductor.config.aten_distributed_optimizations.collective_bucketing = (
126+
True
136127
)
128+
torch._inductor.config.aten_distributed_optimizations.insert_overlap_deps = True
129+
torch._inductor.config.aten_distributed_optimizations.max_compute_pre_fetch = 10
137130
elif mode == "inductor":
138131
from autoparallel.auto_bucketing import (
139132
simple_fsdp_autobucketing_reordering_pass,

0 commit comments

Comments
 (0)