Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,9 @@ case object LinkConfig {
case BroadcastPartition() =>
BroadcastPartitioning(
dataTransferBatchSize,
fromWorkerIds.zip(toWorkerIds).map {
case (fromWorkerId, toWorkerId) =>
ChannelIdentity(fromWorkerId, toWorkerId, isControl = false)
}
fromWorkerIds.flatMap(fromId =>
toWorkerIds.map(toId => ChannelIdentity(fromId, toId, isControl = false))
)
)

case UnknownPartition() =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,7 @@ class LinkConfigSpec extends AnyFlatSpec with Matchers {

// ----- BroadcastPartition -----

"BroadcastPartition" should "produce a BroadcastPartitioning whose channels follow zip pairing today (current behavior)" in {
// Pin: BroadcastPartition currently uses `fromWorkerIds.zip(toWorkerIds)`
// — the SAME 1:1 pairing as OneToOnePartition. ChannelConfig in the same
// package emits a full cross product for the BroadcastPartition arm,
// which matches broadcast semantics ("each sender targets every
// receiver"). The two helpers diverge today; pinning this so a fix that
// realigns the contract surfaces here. Filed as a Bug.
"BroadcastPartition" should "produce a BroadcastPartitioning with the full sender x receiver cross product" in {
val out = LinkConfig.toPartitioning(
List(w1, w2, w3),
List(u1, u2, u3),
Expand All @@ -160,18 +154,35 @@ class LinkConfigSpec extends AnyFlatSpec with Matchers {
out shouldBe a[BroadcastPartitioning]
val bp = out.asInstanceOf[BroadcastPartitioning]
bp.batchSize shouldBe batch
endpoints(bp.channels) shouldBe Seq(("w1", "u1"), ("w2", "u2"), ("w3", "u3"))
endpoints(bp.channels) shouldBe Seq(
("w1", "u1"),
("w1", "u2"),
("w1", "u3"),
("w2", "u1"),
("w2", "u2"),
("w2", "u3"),
("w3", "u1"),
("w3", "u2"),
("w3", "u3")
)
}

it should "silently truncate broadcast pairings when sides differ in length (current behavior)" in {
it should "emit the full cross product even when sender and receiver counts differ" in {
val out = LinkConfig.toPartitioning(
List(w1, w2, w3),
List(u1, u2),
BroadcastPartition(),
batch
)
val bp = out.asInstanceOf[BroadcastPartitioning]
endpoints(bp.channels) shouldBe Seq(("w1", "u1"), ("w2", "u2"))
endpoints(bp.channels) shouldBe Seq(
("w1", "u1"),
("w1", "u2"),
("w2", "u1"),
("w2", "u2"),
("w3", "u1"),
("w3", "u2")
)
}

// ----- UnknownPartition -----
Expand Down
Loading