Skip to content

Commit 9a3a532

Browse files
sanketpurandareSanket Jayant Purandare
andauthored
linting fix (#197)
Co-authored-by: Sanket Jayant Purandare <sanketpurandare@fb.com>
1 parent 4ccb360 commit 9a3a532

File tree

7 files changed

+96
-337
lines changed

7 files changed

+96
-337
lines changed

examples/native_ds3/example_deepseek.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from dataclasses import dataclass
77
from typing import Literal
88

9-
import moe_ops
10-
9+
import moe_ops # noqa: F401
1110
import torch
1211
import torch.nn.functional as F
1312
from torch import nn
@@ -346,11 +345,12 @@ def init_weights(
346345

347346
if __name__ == "__main__":
348347

349-
from autoparallel.api import AutoParallel
350348
from torch.distributed.fsdp import MixedPrecisionPolicy
351349
from torch.distributed.tensor.placement_types import Replicate, Shard
352350
from torch.testing._internal.distributed.fake_pg import FakeStore
353351

352+
from autoparallel.api import AutoParallel
353+
354354
# Model configuration
355355
world_size = 256
356356
fake_store = FakeStore()
@@ -409,10 +409,7 @@ def input_fn():
409409
assert any(n.meta.get("fwd_nn_module_stack") for n in autop.gm.graph.nodes)
410410
autop.add_parameter_memory_constraint(low=None, high=None)
411411

412-
x_sharding = (
413-
Shard(0),
414-
Shard(1),
415-
) + (
412+
x_sharding = (Shard(0), Shard(1),) + (
416413
Replicate(),
417414
) * (mesh.ndim - 2)
418415

0 commit comments

Comments
 (0)