Skip to content

Commit 403795c

Browse files
committed
Update (base update)
[ghstack-poisoned]
1 parent b7014bd commit 403795c

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

autoparallel/autobucketing_util/bucket_func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import torch
1111
from torch._inductor import ir, scheduler
12-
from torch._inductor.dependencies import StarDep, WeakDep
1312
from torch._inductor.comms import get_op_idx
13+
from torch._inductor.dependencies import StarDep, WeakDep
1414
from torch._inductor.utils import is_collective, is_wait
1515
from torch._inductor.virtualized import V
1616
from torch.utils._ordered_set import OrderedSet

autoparallel/autobucketing_util/bucket_plan.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,11 +187,13 @@ def get_simplefsdp_auto_plan(
187187
)
188188

189189
# (3) Communication size criteria
190-
break_comm_size_criteria = (
191-
comm_cache.ag_max_inp_size <= comm_size_inp
192-
or comm_cache.rs_max_out_size
193-
<= heuristic_info["this_step_rs_comm_out_size"]
194-
)
190+
break_comm_size_criteria = comm_cache.ag_max_inp_size < comm_size_inp
191+
if comm_cache.rs_max_out_size > 0:
192+
break_comm_size_criteria = (
193+
break_comm_size_criteria
194+
or comm_cache.rs_max_out_size
195+
< heuristic_info["this_step_rs_comm_out_size"]
196+
)
195197

196198
if (
197199
break_overlap_criteria

0 commit comments

Comments
 (0)