|
5 | 5 |
|
6 | 6 | import pytest |
7 | 7 | import torch |
| 8 | +import torch.fx.traceback as fx_traceback |
8 | 9 | from torch import nn |
9 | | -from torch.distributed.tensor.placement_types import Shard |
| 10 | +from torch.distributed.tensor.placement_types import Replicate, Shard |
10 | 11 | from torch.testing._internal.distributed.fake_pg import FakeStore |
11 | 12 |
|
12 | 13 | from autoparallel.api import AutoParallel |
@@ -114,3 +115,193 @@ def input_fn(): |
114 | 115 | assert torch.equal( |
115 | 116 | parallel_mod.get_buffer("buf").full_tensor(), torch.arange(dim, device="cuda") |
116 | 117 | ) |
| 118 | + |
| 119 | + |
| 120 | +def test_fx_graph_annotate(device_mesh_1d): |
| 121 | + dim = 128 |
| 122 | + |
| 123 | + class Model(nn.Module): |
| 124 | + def __init__(self, dim): |
| 125 | + super().__init__() |
| 126 | + self.a = nn.Linear(dim, dim, bias=False) |
| 127 | + self.b = nn.Linear(dim, dim, bias=False) |
| 128 | + self.c = nn.Linear(dim, dim, bias=False) |
| 129 | + self.d = nn.Linear(dim, dim, bias=False) |
| 130 | + |
| 131 | + def forward(self, x): |
| 132 | + with fx_traceback.annotate({"outer": 0}): |
| 133 | + with fx_traceback.annotate({"inner": 0}): |
| 134 | + a = self.a(x) |
| 135 | + with fx_traceback.annotate({"inner": 1}): |
| 136 | + b = self.b(a) |
| 137 | + with fx_traceback.annotate({"inner": 2}): |
| 138 | + c = self.c(b) |
| 139 | + with fx_traceback.annotate({"inner": 3}): |
| 140 | + d = self.d(c) |
| 141 | + return d |
| 142 | + |
| 143 | + def input_fn(): |
| 144 | + b = 512 |
| 145 | + inputs = (torch.rand(b, dim, device="cuda"),) |
| 146 | + return inputs |
| 147 | + |
| 148 | + with torch.device("meta"): |
| 149 | + model = Model(dim) |
| 150 | + |
| 151 | + with fx_traceback.preserve_node_meta(), AutoParallel( |
| 152 | + model, |
| 153 | + input_fn, |
| 154 | + device_mesh_1d, |
| 155 | + ) as autop: |
| 156 | + x_sharding = (Shard(0),) |
| 157 | + autop.add_input_constraints([x_sharding]) |
| 158 | + sharding_placement = autop.optimize_placement() |
| 159 | + |
| 160 | + # AutoParallel produces a module with meta-DTensor parameters that need to be initialized |
| 161 | + _ = autop.apply_placement(sharding_placement) |
| 162 | + |
| 163 | + graph = autop.parallel_gm.graph |
| 164 | + |
| 165 | + # 4 linear -> 4 mm ops |
| 166 | + fw_seen_annotations = set() |
| 167 | + bw_seen_annotations = set() |
| 168 | + for mm in [n for n in graph.nodes if "mm" in n.name]: |
| 169 | + assert mm.meta["custom"]["outer"] == 0 |
| 170 | + assert "inner" in mm.meta["custom"] |
| 171 | + if mm.meta.get("partitioner_tag", "") == "is_backward": |
| 172 | + bw_seen_annotations.add(mm.meta["custom"]["inner"]) |
| 173 | + else: |
| 174 | + fw_seen_annotations.add(mm.meta["custom"]["inner"]) |
| 175 | + assert fw_seen_annotations == bw_seen_annotations == {0, 1, 2, 3} |
| 176 | + |
| 177 | + for ph in graph.find_nodes(op="placeholder"): |
| 178 | + assert ( |
| 179 | + "custom" not in ph.meta |
| 180 | + ), "Placeholders didn't have have custom metadata before" |
| 181 | + for out in graph.find_nodes(op="output"): |
| 182 | + assert ( |
| 183 | + "custom" not in out.meta |
| 184 | + ), "Output didn't have have custom metadata before" |
| 185 | + |
| 186 | + # NOTE: The tests below are just to prevent semantics from changing silently. |
| 187 | + # Currently, custom metadata is not set for: |
| 188 | + # - graph inputs |
| 189 | + # - graph outputs |
| 190 | + # - collectives/waits added by AP |
| 191 | + for node in graph.nodes: |
| 192 | + if node.meta.get("custom", None) is None: |
| 193 | + assert ( |
| 194 | + node.op == "placeholder" |
| 195 | + or node.op == "output" |
| 196 | + or node.target.namespace == "_c10d_functional" |
| 197 | + ) |
| 198 | + |
| 199 | + |
| 200 | +def test_fx_graph_annotate_overlap_pass(device_mesh_1d): |
| 201 | + class DummyOp(torch.autograd.Function): |
| 202 | + @staticmethod |
| 203 | + def forward(ctx, x, scalar): |
| 204 | + ctx.save_for_backward(x) |
| 205 | + return x + scalar |
| 206 | + |
| 207 | + @staticmethod |
| 208 | + def backward(ctx, grad_out): |
| 209 | + return grad_out, None |
| 210 | + |
| 211 | + def mock_fw_compute(x): |
| 212 | + with fx_traceback.annotate({"compute": 0}): |
| 213 | + return DummyOp.apply(x, 10) |
| 214 | + |
| 215 | + def mock_bw_comm(x): |
| 216 | + with fx_traceback.annotate({"comm": 0}): |
| 217 | + return DummyOp.apply(x, 20) |
| 218 | + |
| 219 | + def mock_bw_compute(x): |
| 220 | + return DummyOp.apply(x, 30) |
| 221 | + |
| 222 | + class Model(nn.Module): |
| 223 | + def forward(self, fw_in, bw_in): |
| 224 | + fw_out = mock_fw_compute(fw_in) |
| 225 | + # bw_in blocks bw_out |
| 226 | + bw_in = mock_bw_comm(bw_in) |
| 227 | + bw_out = mock_bw_compute(bw_in) |
| 228 | + return fw_out, bw_out |
| 229 | + |
| 230 | + def input_fn(): |
| 231 | + inputs = (torch.rand(2, 128, device="cuda", requires_grad=True),) |
| 232 | + grad_ins = (torch.rand(2, 128, device="cuda"),) |
| 233 | + return ( |
| 234 | + *inputs, |
| 235 | + *grad_ins, |
| 236 | + ) |
| 237 | + |
| 238 | + with torch.device("meta"): |
| 239 | + model = Model() |
| 240 | + |
| 241 | + with fx_traceback.preserve_node_meta(), AutoParallel( |
| 242 | + model, |
| 243 | + input_fn, |
| 244 | + device_mesh_1d, |
| 245 | + ) as autop: |
| 246 | + autop.add_input_constraints( |
| 247 | + [ |
| 248 | + (Replicate(),), |
| 249 | + (Replicate(),), |
| 250 | + ] |
| 251 | + ) |
| 252 | + autop.add_output_constraints( |
| 253 | + [ |
| 254 | + (Replicate(),), |
| 255 | + (Replicate(),), |
| 256 | + ] |
| 257 | + ) |
| 258 | + sharding_placement = autop.optimize_placement() |
| 259 | + |
| 260 | + # AutoParallel produces a module with meta-DTensor parameters that need to be initialized |
| 261 | + _ = autop.apply_placement(sharding_placement) |
| 262 | + |
| 263 | + graph = autop.parallel_gm.graph |
| 264 | + |
| 265 | + # At this point, the graph looks like: |
| 266 | + # graph(): |
| 267 | + # %primals_1 : [num_users=1] = placeholder[target=primals_1] |
| 268 | + # %primals_2 : [num_users=1] = placeholder[target=primals_2] |
| 269 | + # %tangents_1 : [num_users=1] = placeholder[target=tangents_1] |
| 270 | + # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {}) |
| 271 | + # %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_2, 20), kwargs = {}) |
| 272 | + # %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {}) |
| 273 | + # return ((add, add_2), (tangents_1, None)) |
| 274 | + |
| 275 | + compute_nodes = { |
| 276 | + n for n in graph.nodes if n.meta.get("custom", {}).get("compute", None) == 0 |
| 277 | + } |
| 278 | + comm_nodes = [ |
| 279 | + n for n in graph.nodes if n.meta.get("custom", {}).get("comm", None) == 0 |
| 280 | + ] |
| 281 | + assert len(compute_nodes) == 1 |
| 282 | + assert len(comm_nodes) == 1 |
| 283 | + |
| 284 | + # move comm nodes before compute nodes |
| 285 | + first_compute_node = None |
| 286 | + for n in graph.nodes: |
| 287 | + if n in compute_nodes: |
| 288 | + first_compute_node = n |
| 289 | + break |
| 290 | + |
| 291 | + assert first_compute_node is not None |
| 292 | + for node in reversed(comm_nodes): |
| 293 | + first_compute_node.prepend(node) |
| 294 | + |
| 295 | + # After pass, add_1 (comm) should be before add (compute) |
| 296 | + node_names = [n.name for n in graph.nodes] |
| 297 | + assert node_names.index("add_1") == node_names.index("add") - 1 |
| 298 | + |
| 299 | + # The graph looks like: |
| 300 | + # graph(): |
| 301 | + # %primals_1 : [num_users=1] = placeholder[target=primals_1] |
| 302 | + # %primals_2 : [num_users=1] = placeholder[target=primals_2] |
| 303 | + # %tangents_1 : [num_users=1] = placeholder[target=tangents_1] |
| 304 | + # %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_2, 20), kwargs = {}) |
| 305 | + # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {}) |
| 306 | + # %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {}) |
| 307 | + # return ((add, add_2), (tangents_1, None)) |
0 commit comments