From 52ea0c1d02f96e33f102eef0709f003f1ff403ee Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 4 Jul 2025 11:34:26 +0000 Subject: [PATCH 01/23] [WIP] Add basic DeepSeekV3 Doesn't yet work, but the code for now is a copy-paste from https://github.com/pytorch/torchtitan/blob/deepseek-v3/torchtitan/models/deepseek_v3/model/moe.py so it will make it easier to track the changes --- examples/example_ds3.py | 844 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 844 insertions(+) create mode 100644 examples/example_ds3.py diff --git a/examples/example_ds3.py b/examples/example_ds3.py new file mode 100644 index 00000000..52914a81 --- /dev/null +++ b/examples/example_ds3.py @@ -0,0 +1,844 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch import nn + +from torch.distributed.tensor.placement_types import Replicate, Shard +from torch.testing._internal.distributed.fake_pg import FakeStore + +from torch.distributed.tensor import DTensor + + +from autoparallel.api import AutoParallel + +from dataclasses import dataclass +from typing import Literal + + + +# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +@dataclass +class DeepSeekV3ModelArgs: + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. + route_scale (float): Scaling factor for routing scores. + use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers. + load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + """ + + max_batch_size: int = 8 + max_seq_len: int = 4096 * 4 + dtype: Literal["bf16", "fp8"] = "bf16" + vocab_size: int = 102400 + dim: int = 2048 + inter_dim: int = 10944 + moe_inter_dim: int = 1408 + n_layers: int = 27 + n_dense_layers: int = 1 + n_heads: int = 16 + norm_eps: float = 1e-5 # eps used for RMSNorm + # MoE + n_routed_experts: int = 64 + n_shared_experts: int = 2 + n_activated_experts: int = 6 + n_expert_groups: int = 1 + n_limited_groups: int = 1 + score_func: Literal["softmax", "sigmoid"] = "softmax" + route_scale: float = 1.0 + use_grouped_mm: bool = True + load_balance_coeff: float = 1e-3 + # Multi-Head Latent Attention (MLA) + q_lora_rank: int = 0 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 128 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + use_flex_attn: bool = False + attn_mask_type: str = "causal" + # yarn + original_seq_len: int = 4096 + rope_theta: float = 10000.0 + rope_factor: float = 40 + beta_fast: int = 32 + beta_slow: int = 1 + mscale: float = 1.0 + + def update_from_config(self, job_config, tokenizer) -> None: + """ + Update the model_config config from the given job config. + """ + self.vocab_size = tokenizer.vocab_size + self.max_seq_len = job_config.training.seq_len + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + """ + Adopted from llama4 implementation. + """ + nparams_embedding = 0 + nparams_moe_router = 0 + nparams_shared_expert = 0 + nparams_experts = 0 + nparams_dense = 0 + + for name, p in model.named_parameters(): + if "embedding" in name: + nparams_embedding += p.numel() + nparams_dense += p.numel() + elif "moe.shared_expert" in name: + nparams_shared_expert += p.numel() + elif "moe.router" in name: + nparams_moe_router += p.numel() + elif "moe.experts" in name: + nparams_experts += p.numel() + else: + nparams_dense += p.numel() + + nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams = nparams_dense + nparams_sparse + nparams_sparse_active = ( + nparams_moe_router + + nparams_shared_expert + + nparams_experts * self.n_activated_experts // self.n_routed_experts + ) + + #logger.info( + # f"Total parameter count: dense {nparams_dense:,}, " + # f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}" + #) + + l, h, q, t = ( + self.n_layers, + self.n_heads, + self.dim // self.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = ( + 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + + 12 * l * h * q * t + ) + + return nparams, num_flops_per_token + + + + + + +import torch +import triton +import triton.language as tl + + +__all__ = ["generate_permute_indices"] + + +# parallelized kernel +@triton.jit +def _fill_indices_kernel( + tokens_per_expert_group_ptr, + start_index_values_ptr, + write_offsets_ptr, + output_ptr, + experts_per_rank: tl.constexpr, + num_ranks: tl.constexpr, + BLOCK_SIZE: tl.constexpr, # Number of threads per block +): + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + # map programs (blocks) to the experts and loop (grid stride) if needed + for expert_id in range(pid, experts_per_rank, num_programs): + # read this experts write offset + write_offset = tl.load(write_offsets_ptr + expert_id) + + for r in range(num_ranks): + # index into tokens_per_expert_group array + i = r * experts_per_rank + expert_id + + # load start index and number of tokens for this expert-rank pair + start_index = tl.load(start_index_values_ptr + i) + length = tl.load(tokens_per_expert_group_ptr + i) + + # each thread in block processes tokens in parallel + offsets = tl.arange(0, BLOCK_SIZE) + + # tokens are processed in chunks of BLOCK_SIZE + for chunk_start in range(0, length, BLOCK_SIZE): + chunk_offsets = chunk_start + offsets + + # mask valid indices + mask = chunk_offsets < length + + values = start_index + chunk_offsets + + # destination + dest_indices = write_offset + chunk_offsets + + # store + tl.store(output_ptr + dest_indices, values, mask=mask) + + # update write offset for next rank + write_offset += length + + +# ============== +# wrapper +# ============== + + +def fill_indices_wrapper( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + block_size: int = 128, + max_blocks: int = 1024, # cap on total number of blocks to launch +): + # preallocate output + permuted_indices = torch.full( + (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device + ) + + # write offsets is per local expert... + num_blocks = min(experts_per_rank, max_blocks) + # grid = one block per expert unless capped and then we loop... + grid = (num_blocks,) + + # launch kernel + _fill_indices_kernel[grid]( + tokens_per_expert_group, + start_index_values, + write_offsets, + permuted_indices, + experts_per_rank, + num_ranks, + BLOCK_SIZE=block_size, + ) + return permuted_indices + + +# reference +def fill_indices_cpu( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, +): + # We need to preallocate the output - we ignore device and force it on cpu + # device = tokens_per_expert_group.device + permuted_indices = torch.full( + (max_len,), + -1, + dtype=torch.int32, + ) # device=device) + # Fill the permuted indices + # For each local expert + for e in range(experts_per_rank): + write_start = write_offsets[e].item() + # For each remote rank + for r in range(num_ranks): + i = r * experts_per_rank + e + start_index = start_index_values[i].item() + length = tokens_per_expert_group[i].item() + # Fill in the indices + if length > 0: + end_idx = min(write_start + length, max_len) + permuted_indices[write_start:end_idx] = torch.arange( + start_index, + start_index + (end_idx - write_start), + dtype=torch.int32, + # device=device, + ) + write_start += length + return permuted_indices + + +def generate_permute_indices( + tokens_per_expert_group: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + alignment: int, + use_cpu: bool = False, +): + """ + Prepare permutation indices and the number of tokens for each expert. + + Args: + tokens_per_expert_group: number of tokens for each expert from all ranks. + experts_per_rank: number of experts per rank. + num_ranks: number of ranks. + max_len: maximum length of the output index vector. + alignment: alignment for each returned element in `m_sizes` and padding min for zero token experts. + use_cpu: whether to use CPU implementation. + + + Returns: + permuted_indices: Tensor of indices that map original token order to the expert-grouped order. + m_sizes: aligned number of tokens for each expert (padded to alignment boundary). + m_offsets: Cumulative sum of m_sizes. The exclusive ending position for each expert's tokens. + + Explanatory details: + `tokens_per_expert_group` is of shape (num_ranks * experts_per_rank,), for example: + From: | rank 0 | rank 1 | + To: | E0 | E1 | E2 | E3 | E0 | E1 | E2 | E3 | + | 4 | 2 | 1 | 3 | 1 | 2 | 3 | 4 | + """ + + # prefix sum to get start index of each expert (parallel scan kernel in future?) + start_index_values = ( + torch.cumsum(tokens_per_expert_group, 0) - tokens_per_expert_group + ) + + # total tokens for each expert (sum over ranks) + total_tokens_per_expert = tokens_per_expert_group.view(num_ranks, -1).sum(0) + + # pad out empty experts to alignment requirement + total_tokens_per_expert = torch.clamp_min(total_tokens_per_expert, alignment) + + # align the chunk sizes (cdiv) + m_sizes = ((total_tokens_per_expert + alignment - 1) // alignment * alignment).to( + torch.int32 + ) + + # additional prefix sum to get write offset of each expert in permuted_indices + # write offsets is per local expert, not global + m_offsets = torch.cumsum(m_sizes, 0) + write_offsets = m_offsets - m_sizes + + # Select the implementation to use + if use_cpu: + permuted_indices = fill_indices_cpu( + tokens_per_expert_group, + start_index_values, + write_offsets, + experts_per_rank, + num_ranks, + max_len, + ) + else: + permuted_indices = fill_indices_wrapper( + tokens_per_expert_group, + start_index_values, + write_offsets, + experts_per_rank, + num_ranks, + max_len, + ) + + return permuted_indices, m_sizes, m_offsets.to(torch.int32) + + +def expert_parallel(func): + """ + This is a wrapper applied to the GroupedExperts computation, serving + the following three purposes: + 1. Convert parameters from DTensors to plain Tensors, to work with + dynamic-shape inputs which cannot be easily expressed as DTensors. + 2. In Expert Parallel, apply the generate_permute_indices kernel to + permute the inputs to be ordered by local experts (see the _token_dispatch + function in ExpertParallel) and permute the outputs back. + 3. In order to use torch._grouped_mm, we need to make sure the number of + tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices + kernel also helps achieve this via padding, without incurring synchronization + between device and host. Note that this will create side effects when wrapping + the for-loop implementation of GroupedExperts, as it does not need padding. + + Among the above: + 1 and 2 are needed only when expert_parallel_degree > 1. + 3 is needed even for single-device computation. + 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. + """ + + def wrapper( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(w1, DTensor): + w1 = w1.to_local() + w2 = w2.to_local() + w3 = w3.to_local() + + if num_tokens_per_expert is not None: + experts_per_ep_rank = w1.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + ALIGN_SIZE_M = 16 + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, # offsets, + ) = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, + ALIGN_SIZE_M, + ) + + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] + + out = func(w1, w2, w3, x, num_tokens_per_expert) + + if num_tokens_per_expert is not None: + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] + + return out + + return wrapper + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + ): + super().__init__() + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float = 0.02): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class GroupedExperts(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + use_grouped_mm: bool, + ): + super().__init__() + self.num_experts = num_experts + self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) + self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) + self.use_grouped_mm = use_grouped_mm + + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.use_grouped_mm: + return GroupedExperts._run_experts_grouped_mm( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + else: + return GroupedExperts._run_experts_for_loop( + self.w1, self.w2, self.w3, x, num_tokens_per_expert + ) + + # TODO: keeping this for-loop implementation for comparison + # and readability, may remove later + @expert_parallel + @staticmethod + def _run_experts_for_loop( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = F.silu(torch.matmul(x_expert, w1[expert_idx])) + h = h * torch.matmul(x_expert, w3[expert_idx]) + h = torch.matmul(h, w2[expert_idx]) + # h shape (tokens_per_expert(varying), dim) + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) + # out shape (num_experts, tokens_per_expert, dim) + out = torch.bmm(h, w2) + + return out + + @expert_parallel + @staticmethod + def _run_experts_grouped_mm( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + assert ( + x.dtype == w1.dtype == w2.dtype == w3.dtype == torch.bfloat16 + ), "torch._grouped_mm only supports bf16 dtypes" + + h = F.silu(torch._grouped_mm(x, w1, offs=offsets)) + h = h * torch._grouped_mm(x, w3, offs=offsets) + out = torch._grouped_mm(h, w2, offs=offsets) + + return out + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std) + + +class TokenChoiceTopKRouter(nn.Module): + """This class implements token-choice routing. In token-choice top-K routing, each token is + routed to top K experts based on the router scores. + + Args: + gate (nn.Module): Gate module to calculate the scores, typically nn.Linear(dim, num_experts). + num_experts (int): Number of experts in each moe layer. + top_k (int): Number of experts each token will be routed to in token-choice routing. + use_sigmoid (bool): Whether to use sigmoid or softmax for router scores. Default is False. + """ + + def __init__( + self, + dim: int, + num_experts: int, + top_k: int, + use_sigmoid: bool = False, + route_sclaing_factor: float = 1.0, + ): + super().__init__() + + self.dim = dim + self.num_experts = num_experts + self.top_k = top_k + self.use_sigmoid = use_sigmoid + self.route_sclaing_factor = route_sclaing_factor + self.gate = nn.Linear(self.dim, self.num_experts, bias=False) + + def forward( + self, x: torch.Tensor, expert_bias: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + TODO: We haven't implement the group-based routing (node limit routing), + and currently EP is not supporting node limit routing yet. + + Args: + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + + Returns: + routed_input (torch.Tensor): + Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. + token_indices (torch.Tensor): + Token indices for routed_input with shape ``(bs*slen*top_k,)``. + num_tokens_per_expert (torch.Tensor): + Number of tokens assigned to each expert with shape ``(num_experts,)``. + """ + # scores shape (bs*slen, num_experts) + scores = self.gate(x) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if self.use_sigmoid: + scores = torch.sigmoid(scores.to(torch.float32)) + else: + scores = F.softmax(scores.to(torch.float32), dim=1) + + # top scores shape (bs*slen, top_k) + # NOTE: The expert_bias is only used for routing. The gating value + # top_scores is still derived from the original scores. + if expert_bias is not None: + _, selected_experts_indices = torch.topk( + scores + expert_bias, k=self.top_k, dim=1 + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk( + scores, k=self.top_k, dim=1 + ) + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + + # Reorder the token indices to match the order of the experts + # token_indices_experts_sorted shape (bs*slen*top_k,) + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + + # reorder the scores to match the order of the token indices + top_scores = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + + top_scores = ( + top_scores * self.route_sclaing_factor + ) # must multiply the scaling factor + return top_scores, token_indices_experts_sorted, num_tokens_per_expert + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + + +class MoE(nn.Module): + def __init__(self, model_args: DeepSeekV3ModelArgs): + + super().__init__() + dim = model_args.dim + + num_experts = model_args.n_routed_experts + hidden_dim = model_args.moe_inter_dim + top_k = model_args.n_activated_experts + route_scaling_factor = model_args.route_scale + + self.experts = GroupedExperts( + dim=dim, + hidden_dim=hidden_dim, + num_experts=num_experts, + use_grouped_mm=model_args.use_grouped_mm, + ) + self.router = TokenChoiceTopKRouter( + dim=dim, + num_experts=num_experts, + top_k=top_k, + use_sigmoid=model_args.score_func == "sigmoid", + route_sclaing_factor=route_scaling_factor, + ) + self.shared_expert = ( + # Reference: https://huggingface.co/deepseek-ai/DeepSeek-V3-Base/blob/main/modeling_deepseek.py#L517 + GroupedExperts( + dim=dim, + hidden_dim=hidden_dim * model_args.n_shared_experts, + num_experts=1, # Here needs to be 1 to make it equivalent to the MLP + use_grouped_mm=model_args.use_grouped_mm, + ) + if model_args.n_shared_experts > 0 + else None + ) + + # auxiliary-loss-free load balancing + self.load_balance_coeff = model_args.load_balance_coeff + if self.load_balance_coeff is not None: + assert self.load_balance_coeff > 0.0 + self.register_buffer( + "expert_bias", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + self.register_buffer( + "tokens_per_expert", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + else: + self.expert_bias = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + bs, slen, dim = x.shape + + # top_scores and selected_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_tokens_per_expert, + ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + + # tokens_per_expert will be used to update the expert bias for load balancing. + # Prevent extra local tokens accumulation on evaluation or activation recomputation. + if self.load_balance_coeff is not None and torch.is_grad_enabled(): + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + # shape (bs*slen*top_k, dim) + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + + # shape (bs*slen*top_k, dim) + routed_output = self.experts(routed_input, num_tokens_per_expert) + + routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( + x.dtype + ) + + # shared expert + if self.shared_expert is not None: + out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( + bs * slen, dim + ) + else: + out = torch.zeros_like(x.reshape(bs * slen, dim)) + + # Accumulate multiple expert results becase each token can be routed to multiple experts + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bs, slen, dim) + return out + + def init_weights( + self, + init_std: float, + buffer_device: torch.device, + ): + self.experts.init_weights(init_std) + self.router.init_weights(init_std) + if self.shared_expert is not None: + self.shared_expert.init_weights(init_std) + + if self.load_balance_coeff is not None: + with torch.device(buffer_device): + self.expert_bias = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + + +world_size = 256 + +fake_store = FakeStore() +torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size +) +# mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) +mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (world_size // 8, 8), + mesh_dim_names=( + "dp", + "tp", + ), +) + +bs = 1 +seqlen = 1024 +dim = 4096 + +def input_fn(): + return torch.randn(bs, seqlen, dim, dtype=torch.bfloat16, device="cuda") + +args = DeepSeekV3ModelArgs(dim=dim, n_layers=1) + +# parallelize the model +with torch.device("meta"): + model = MoE(args).bfloat16() + +autop = AutoParallel(model, input_fn, mesh) +autop.add_parameter_memory_constraint(low=None, high=None) + +x_sharding = (Shard(0), Replicate()) + +autop.add_input_constraints([x_sharding]) +autop.add_output_constraints([x_sharding]) + +sharding_placement = autop.optimize_placement() From 0d3ae2d15578eeafdf5d40f5fcfd1e7368753914 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 4 Jul 2025 11:37:26 +0000 Subject: [PATCH 02/23] Lint --- examples/example_ds3.py | 40 +++++++++++++--------------------------- 1 file changed, 13 insertions(+), 27 deletions(-) diff --git a/examples/example_ds3.py b/examples/example_ds3.py index 52914a81..f3def7e0 100644 --- a/examples/example_ds3.py +++ b/examples/example_ds3.py @@ -4,22 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Literal + import torch import torch.nn.functional as F +import triton +import triton.language as tl from torch import nn - +from torch.distributed.tensor import DTensor from torch.distributed.tensor.placement_types import Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore -from torch.distributed.tensor import DTensor - - from autoparallel.api import AutoParallel -from dataclasses import dataclass -from typing import Literal - - # Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py @dataclass @@ -134,10 +132,10 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in + nparams_experts * self.n_activated_experts // self.n_routed_experts ) - #logger.info( + # logger.info( # f"Total parameter count: dense {nparams_dense:,}, " # f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}" - #) + # ) l, h, q, t = ( self.n_layers, @@ -159,18 +157,6 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in return nparams, num_flops_per_token - - - - -import torch -import triton -import triton.language as tl - - -__all__ = ["generate_permute_indices"] - - # parallelized kernel @triton.jit def _fill_indices_kernel( @@ -743,11 +729,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # top_scores and selected_indices shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) - ( - top_scores, - token_indices, - num_tokens_per_expert, - ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + (top_scores, token_indices, num_tokens_per_expert,) = self.router( + x.reshape(bs * slen, dim), self.expert_bias + ) # tokens_per_expert will be used to update the expert bias for load balancing. # Prevent extra local tokens accumulation on evaluation or activation recomputation. @@ -824,9 +808,11 @@ def init_weights( seqlen = 1024 dim = 4096 + def input_fn(): return torch.randn(bs, seqlen, dim, dtype=torch.bfloat16, device="cuda") + args = DeepSeekV3ModelArgs(dim=dim, n_layers=1) # parallelize the model From 98d9dfdb76820f5aa8dc2d46ed984fe969551768 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 4 Jul 2025 13:24:52 +0000 Subject: [PATCH 03/23] Workarounds to make graph capture pass Needs to fix the grad_input renaming which is not working for some reason --- autoparallel/export_module.py | 11 ++++++++++- examples/example_ds3.py | 33 ++++++++++++++++++++++++++++----- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/autoparallel/export_module.py b/autoparallel/export_module.py index b3ff9e8c..849856b7 100644 --- a/autoparallel/export_module.py +++ b/autoparallel/export_module.py @@ -204,16 +204,21 @@ def rename_nodes(fx_g, nodes, new_name, idxs=None): # TODO: align number of grad names with inputs everywhere? all_output_nodes = fx_g.graph.find_nodes(op="output")[0].all_input_nodes output_nodes = all_output_nodes[: metadata.num_outputs] + print("output") rename_nodes(fx_g, output_nodes, "output") param_grad = all_output_nodes[ metadata.num_outputs : metadata.num_outputs + params_len ] + print("grad_param") rename_nodes(fx_g, param_grad, "grad_param") grad_inputs = all_output_nodes[metadata.num_outputs + params_len :] inputs_that_require_grad = [ i for i, n in enumerate(metadata.input_info[params_len:]) if n.requires_grad ] - rename_nodes(fx_g, grad_inputs, "grad_input", inputs_that_require_grad) + print("grad_input") + + # TODO: figure out and fix why this is not working + # rename_nodes(fx_g, grad_inputs, "grad_input", inputs_that_require_grad) tangent_nodes = fx_g.graph.find_nodes(op="placeholder")[ -len(metadata.traced_tangents) : @@ -221,15 +226,19 @@ def rename_nodes(fx_g, nodes, new_name, idxs=None): outputs_that_require_grad = [ i for i, n in enumerate(metadata.output_info) if n.requires_grad ] + print("tangents") rename_nodes(fx_g, tangent_nodes, "tangents", outputs_that_require_grad) input_nodes = fx_g.graph.find_nodes(op="placeholder")[ params_len + buffer_len : -len(metadata.traced_tangents) ] + print("input") rename_nodes(fx_g, input_nodes, "input") param_nodes = fx_g.graph.find_nodes(op="placeholder")[:params_len] + print("param") rename_nodes(fx_g, param_nodes, "param") buffer_nodes = fx_g.graph.find_nodes(op="placeholder")[ params_len : params_len + buffer_len ] + print("buffer") rename_nodes(fx_g, buffer_nodes, "buffer") diff --git a/examples/example_ds3.py b/examples/example_ds3.py index f3def7e0..82c3ea2f 100644 --- a/examples/example_ds3.py +++ b/examples/example_ds3.py @@ -211,6 +211,7 @@ def _fill_indices_kernel( # ============== +@torch.library.custom_op("autoparallel::fill_indices_wrapper", mutates_args=()) def fill_indices_wrapper( tokens_per_expert_group: torch.Tensor, start_index_values: torch.Tensor, @@ -220,7 +221,7 @@ def fill_indices_wrapper( max_len: int, block_size: int = 128, max_blocks: int = 1024, # cap on total number of blocks to launch -): +) -> torch.Tensor: # preallocate output permuted_indices = torch.full( (max_len,), -1, dtype=torch.int32, device=tokens_per_expert_group.device @@ -244,6 +245,24 @@ def fill_indices_wrapper( return permuted_indices +@fill_indices_wrapper.register_fake +def _( + tokens_per_expert_group: torch.Tensor, + start_index_values: torch.Tensor, + write_offsets: torch.Tensor, + experts_per_rank: int, + num_ranks: int, + max_len: int, + block_size: int = 128, + max_blocks: int = 1024, # cap on total number of blocks to launch +): + # preallocate output + permuted_indices = torch.empty( + (max_len,), dtype=torch.int32, device=tokens_per_expert_group.device + ) + return permuted_indices + + # reference def fill_indices_cpu( tokens_per_expert_group: torch.Tensor, @@ -729,9 +748,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # top_scores and selected_indices shape (bs*slen*top_k,) # num_tokens_per_expert shape (num_experts,) - (top_scores, token_indices, num_tokens_per_expert,) = self.router( - x.reshape(bs * slen, dim), self.expert_bias - ) + ( + top_scores, + token_indices, + num_tokens_per_expert, + ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. # Prevent extra local tokens accumulation on evaluation or activation recomputation. @@ -810,7 +831,9 @@ def init_weights( def input_fn(): - return torch.randn(bs, seqlen, dim, dtype=torch.bfloat16, device="cuda") + return torch.randn( + bs, seqlen, dim, dtype=torch.bfloat16, device="cuda", requires_grad=True + ) args = DeepSeekV3ModelArgs(dim=dim, n_layers=1) From 61a63c4542ea76672f5d8f9ebe16142cea898e37 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 4 Jul 2025 13:36:49 +0000 Subject: [PATCH 04/23] Add dummy propagation rules just to see what we need to implement They are not correct, it's just to get a list of what we need for DeepSeekV3 --- autoparallel/propagation_rules.py | 48 +++++++++++------- autoparallel/utils.py | 84 ++++++++++++++++++++++++++++++- 2 files changed, 113 insertions(+), 19 deletions(-) diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index fd18cac8..247a60d0 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -47,11 +47,15 @@ _op_rules = {} -def register_rule(op): +def register_rule(ops): global _op_rules def wrapper(impl): - _op_rules[op] = impl + if isinstance(ops, list): + for op in ops: + _op_rules[op] = impl + else: + _op_rules[ops] = impl return impl return wrapper @@ -335,14 +339,18 @@ def randperm_rule(mesh, specs): return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])]) -@register_rule(torch.ops.aten.full.default) +@register_rule([torch.ops.aten.full.default, torch.ops.aten.empty.memory_format]) def full_rule(mesh, specs): - raise NotImplementedError("Needs hardening, only tested on a few cases") + print( + f"Ops that need to be implemented {torch.ops.aten.full.default}, {torch.ops.aten.empty.memory_format}" + ) + # raise NotImplementedError("Needs hardening, only tested on a few cases") shape = specs[0] # TODO: get the dtype tensor_meta = _gen_tensor_meta(shape) # TODO: I'm hard-coding this here, I'll probably need to do something else about this - placement = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) + # placement = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) + placement = (Replicate(),) * mesh.ndim # placement = (Replicate(),) * mesh.ndim input_placement = (Replicate(),) * mesh.ndim spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta) @@ -538,21 +546,23 @@ def _unsafe_index_rule(mesh, op_schema): @register_opschema_rule(torch.ops.aten.index.Tensor) def index_rule(mesh, op_schema): - raise NotImplementedError("Needs hardening, only tested on a few cases") + print(f"Ops that need to be implemented {torch.ops.aten.index.Tensor}") + # raise NotImplementedError("Needs hardening, only tested on a few cases") strat = op_schema.args_schema specs = strat # TODO: clean this up res = [] idxs_placements = [(Replicate(), Replicate()), (Shard(0), Replicate())] - if strat[1].childs[0] is None: - idxs_placements = idxs_placements[:1] - else: - idxs_placements = idxs_placements[1:] + idxs_placements = [(Replicate(),) * mesh.ndim] + # if strat[1].childs[0] is None: + # idxs_placements = idxs_placements[:1] + # else: + # idxs_placements = idxs_placements[1:] # TODO: this is a nasty hack and won't work for most of the cases - for i, ss in enumerate(strat[0].strategies): + for i, ss in enumerate(strat[0].strategies[:1]): for plt in idxs_placements: ispec = ss.input_specs[0] ospec = DTensorSpec(mesh=mesh, placements=ispec.placements) - assert ss.output_spec == ispec + # assert ss.output_spec == ispec, f"{ss.output_spec}, {ispec}" idxs_strats = [ DTensorSpec(mesh, placements=plt) for x in strat[1].childs @@ -579,15 +589,17 @@ def index_rule(mesh, op_schema): @register_opschema_rule(torch.ops.aten.index_put.default) def index_put_rule(mesh, op_schema): - raise NotImplementedError("Needs hardening, only tested on a few cases") + print(f"Ops that need to be implemented {torch.ops.aten.index_put.default}") + # raise NotImplementedError("Needs hardening, only tested on a few cases") strat = op_schema.args_schema specs = strat # TODO: clean this up res = [] - idxs_placements = [(Replicate(), Replicate()), (Shard(0), Replicate())] - if strat[1].childs[0] is None: - idxs_placements = idxs_placements[:1] - else: - idxs_placements = idxs_placements[1:] + # idxs_placements = [(Replicate(), Replicate()), (Shard(0), Replicate())] + # if strat[1].childs[0] is None: + # idxs_placements = idxs_placements[:1] + # else: + # idxs_placements = idxs_placements[1:] + idxs_placements = [(Replicate(),) * mesh.ndim] # TODO: this is a nasty hack and won't work for most of the cases for i, ss in enumerate(strat[0].strategies): for plt in idxs_placements: diff --git a/autoparallel/utils.py b/autoparallel/utils.py index e9825ffb..5851d41a 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -83,6 +83,26 @@ def fill_missing_redistribute_cost(op, specs, out_strat): strat.redistribute_cost = redistribute_costs +def _generate_dummy_strategy(mesh, tensor_meta, num_input_args, num_input_strategies): + from torch.distributed.tensor._dtensor_spec import DTensorSpec + from torch.distributed.tensor._op_schema import OpSpec + from torch.distributed.tensor.placement_types import Replicate + + placements = (Replicate(),) * mesh.ndim + input_specs = [ + DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) + for _ in range(num_input_args) + ] + output_spec = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) + + out_strat = OpSpec(output_specs=output_spec, input_specs=input_specs) + out_strat.redistribute_cost = [ + [0.0] * num_input_strategies, + ] * num_input_args + out_strat = OpStrategy([out_strat]) + return out_strat + + def get_placement_options(mesh, op, specs, user_args, user_kwargs): # print(op) @@ -109,12 +129,74 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): if op in _op_partial_rules: out_strat = _op_partial_rules[op](mesh, op_schema) - else: + elif ( + op + in torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs + and op != torch.ops.aten.slice_scatter.default + ): out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ op ]( op_schema ) + else: + print(f"Ops that need to be implemented {op}") + from .propagation_rules import _create_all_options + + tensor_meta = strat[0].strategies[0].output_spec.tensor_meta + if op == torch.ops.aten.sort.stable: + out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ + torch.ops.aten.topk.default + ]( + op_schema + ) + elif op in { + torch.ops.autoparallel.fill_indices_wrapper.default, + torch.ops.aten.scatter_add.default, + torch.ops.prims.fma.default, + }: + """ + from torch.distributed.tensor._dtensor_spec import DTensorSpec + from torch.distributed.tensor._op_schema import OpSpec + from torch.distributed.tensor.placement_types import Replicate + + placements = (Replicate(),) * mesh.ndim + s0 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) + s1 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) + s2 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) + s3 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) + + out_strat = OpSpec(output_specs=s3, input_specs=[s0, s1, s2]) + num_strats = len(strat[0].strategies) + out_strat.redistribute_cost = [ + [0.0] * num_strats, + [0.0] * num_strats, + [0.0] * num_strats, + ] + out_strat = OpStrategy([out_strat]) + """ + num_strats = len(strat[0].strategies) + out_strat = _generate_dummy_strategy(mesh, tensor_meta, 3, num_strats) + elif op == torch.ops.aten.slice_scatter.default: + """ + from torch.distributed.tensor._dtensor_spec import DTensorSpec + from torch.distributed.tensor._op_schema import OpSpec + from torch.distributed.tensor.placement_types import Replicate + + placements = (Replicate(),) * mesh.ndim + s0 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) + s1 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) + s2 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) + + out_strat = OpSpec(output_specs=s2, input_specs=[s0, s1]) + num_strats = len(strat[0].strategies) + out_strat.redistribute_cost = [[0.0] * num_strats, [0.0] * num_strats] + out_strat = OpStrategy([out_strat]) + """ + num_strats = len(strat[0].strategies) + out_strat = _generate_dummy_strategy(mesh, tensor_meta, 2, num_strats) + else: + out_strat = _create_all_options(mesh, tensor_meta.shape, tensor_meta) propagate_tensor_meta(op, user_args, user_kwargs, out_strat) fill_missing_redistribute_cost(op, specs, out_strat) From 67eb2641e27b387efdd41ba65fdbf896de63a7cf Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 4 Jul 2025 13:37:11 +0000 Subject: [PATCH 05/23] Cleanup --- autoparallel/utils.py | 35 ----------------------------------- 1 file changed, 35 deletions(-) diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 5851d41a..777abdc3 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -155,44 +155,9 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): torch.ops.aten.scatter_add.default, torch.ops.prims.fma.default, }: - """ - from torch.distributed.tensor._dtensor_spec import DTensorSpec - from torch.distributed.tensor._op_schema import OpSpec - from torch.distributed.tensor.placement_types import Replicate - - placements = (Replicate(),) * mesh.ndim - s0 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) - s1 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) - s2 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) - s3 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) - - out_strat = OpSpec(output_specs=s3, input_specs=[s0, s1, s2]) - num_strats = len(strat[0].strategies) - out_strat.redistribute_cost = [ - [0.0] * num_strats, - [0.0] * num_strats, - [0.0] * num_strats, - ] - out_strat = OpStrategy([out_strat]) - """ num_strats = len(strat[0].strategies) out_strat = _generate_dummy_strategy(mesh, tensor_meta, 3, num_strats) elif op == torch.ops.aten.slice_scatter.default: - """ - from torch.distributed.tensor._dtensor_spec import DTensorSpec - from torch.distributed.tensor._op_schema import OpSpec - from torch.distributed.tensor.placement_types import Replicate - - placements = (Replicate(),) * mesh.ndim - s0 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) - s1 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) - s2 = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) - - out_strat = OpSpec(output_specs=s2, input_specs=[s0, s1]) - num_strats = len(strat[0].strategies) - out_strat.redistribute_cost = [[0.0] * num_strats, [0.0] * num_strats] - out_strat = OpStrategy([out_strat]) - """ num_strats = len(strat[0].strategies) out_strat = _generate_dummy_strategy(mesh, tensor_meta, 2, num_strats) else: From 86d53ff318f27790197585815b5988355d8f0290 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 4 Jul 2025 13:50:08 +0000 Subject: [PATCH 06/23] prims.fma comes from softmax_backward prims.fma is probably easier to implement, but I'm removing this decomp just in case --- autoparallel/api.py | 1 + autoparallel/utils.py | 14 ++++++++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index d036d73d..fc5b8d2f 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -184,6 +184,7 @@ def _get_decomp_table(): decomp_table.pop(torch.ops.aten.native_layer_norm.default) decomp_table.pop(torch.ops.aten.embedding_dense_backward.default) decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default) + decomp_table.pop(torch.ops.aten._softmax_backward_data.default) # decompose addmm to allow for TP on mm decomp_table.pop(torch.ops.aten.addmm.default) diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 777abdc3..0842139b 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -132,7 +132,11 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): elif ( op in torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs - and op != torch.ops.aten.slice_scatter.default + and op + not in { + torch.ops.aten.slice_scatter.default, + torch.ops.aten._softmax_backward_data.default, + } ): out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ op @@ -144,6 +148,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): from .propagation_rules import _create_all_options tensor_meta = strat[0].strategies[0].output_spec.tensor_meta + num_strats = len(strat[0].strategies) if op == torch.ops.aten.sort.stable: out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ torch.ops.aten.topk.default @@ -155,10 +160,11 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): torch.ops.aten.scatter_add.default, torch.ops.prims.fma.default, }: - num_strats = len(strat[0].strategies) out_strat = _generate_dummy_strategy(mesh, tensor_meta, 3, num_strats) - elif op == torch.ops.aten.slice_scatter.default: - num_strats = len(strat[0].strategies) + elif op in { + torch.ops.aten.slice_scatter.default, + torch.ops.aten._softmax_backward_data.default, + }: out_strat = _generate_dummy_strategy(mesh, tensor_meta, 2, num_strats) else: out_strat = _create_all_options(mesh, tensor_meta.shape, tensor_meta) From 7864f4d3d8dcd1f5bb49b2c9733226565e5ddafc Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 5 Jul 2025 08:07:28 +0000 Subject: [PATCH 07/23] Make _geenrate_dummy_strategy more generic Now should handle all ops properly, with correct shapes --- autoparallel/utils.py | 81 ++++++++++++++++++++++++------------------- 1 file changed, 45 insertions(+), 36 deletions(-) diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 0842139b..d4ec4a2e 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -13,21 +13,30 @@ from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs -def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): +def _get_meta_tensors_for_op(op, user_args, user_kwargs): out_t = op(*user_args, **user_kwargs) if isinstance(out_t, torch.Tensor): - new_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype) + out_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype) else: - new_tensor_meta = tree_map_only( + out_tensor_meta = tree_map_only( torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), out_t ) - tensor_metas = tree_flatten(user_args)[0] - tensor_metas = tree_map_only( - torch.Tensor, lambda x: TensorMeta(x.shape, x.stride(), x.dtype), tensor_metas + input_tensor_metas = tree_flatten(user_args)[0] + input_tensor_metas = tree_map_only( + torch.Tensor, + lambda x: TensorMeta(x.shape, x.stride(), x.dtype), + input_tensor_metas, + ) + input_tensor_metas = tuple( + x for x in input_tensor_metas if isinstance(x, TensorMeta) ) - tensor_metas = tuple(x for x in tensor_metas if isinstance(x, TensorMeta)) + return out_tensor_meta, input_tensor_metas + + +def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): + new_tensor_meta, tensor_metas = _get_meta_tensors_for_op(op, user_args, user_kwargs) for strat in out_strat.strategies: if isinstance(new_tensor_meta, TensorMeta): @@ -83,22 +92,44 @@ def fill_missing_redistribute_cost(op, specs, out_strat): strat.redistribute_cost = redistribute_costs -def _generate_dummy_strategy(mesh, tensor_meta, num_input_args, num_input_strategies): +def _generate_dummy_strategy(mesh, op, user_args, user_kwargs, input_strategies): from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import OpSpec from torch.distributed.tensor.placement_types import Replicate placements = (Replicate(),) * mesh.ndim + + out_tensor_meta, input_tensor_metas = _get_meta_tensors_for_op( + op, user_args, user_kwargs + ) + input_specs = [ - DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) - for _ in range(num_input_args) + DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tm) + for tm in input_tensor_metas ] - output_spec = DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tensor_meta) + if isinstance(out_tensor_meta, TensorMeta): + output_spec = DTensorSpec( + mesh=mesh, placements=placements, tensor_meta=out_tensor_meta + ) + else: + output_spec = tuple( + DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tm) + for tm in out_tensor_meta + ) out_strat = OpSpec(output_specs=output_spec, input_specs=input_specs) + num_input_args = len(input_tensor_metas) + input_strategies_flat = [ + x for x in tree_flatten(input_strategies)[0] if isinstance(x, OpStrategy) + ] + assert num_input_args == len( + input_strategies_flat + ), f"{op}, {num_input_args}, {len(input_strategies_flat)}" + # TODO: fix redistribute cost out_strat.redistribute_cost = [ - [0.0] * num_input_strategies, - ] * num_input_args + [0.0] * len(x.strategies) for x in input_strategies_flat + ] + assert len(out_strat.redistribute_cost) == num_input_args out_strat = OpStrategy([out_strat]) return out_strat @@ -145,29 +176,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): ) else: print(f"Ops that need to be implemented {op}") - from .propagation_rules import _create_all_options - - tensor_meta = strat[0].strategies[0].output_spec.tensor_meta - num_strats = len(strat[0].strategies) - if op == torch.ops.aten.sort.stable: - out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ - torch.ops.aten.topk.default - ]( - op_schema - ) - elif op in { - torch.ops.autoparallel.fill_indices_wrapper.default, - torch.ops.aten.scatter_add.default, - torch.ops.prims.fma.default, - }: - out_strat = _generate_dummy_strategy(mesh, tensor_meta, 3, num_strats) - elif op in { - torch.ops.aten.slice_scatter.default, - torch.ops.aten._softmax_backward_data.default, - }: - out_strat = _generate_dummy_strategy(mesh, tensor_meta, 2, num_strats) - else: - out_strat = _create_all_options(mesh, tensor_meta.shape, tensor_meta) + out_strat = _generate_dummy_strategy(mesh, op, user_args, user_kwargs, strat) propagate_tensor_meta(op, user_args, user_kwargs, out_strat) fill_missing_redistribute_cost(op, specs, out_strat) From 60ccf1acf7b131cf71ea38014ebf63a8b0e351fc Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 5 Jul 2025 08:13:11 +0000 Subject: [PATCH 08/23] Add proper redistribute_cost to dummy strategies --- autoparallel/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/autoparallel/utils.py b/autoparallel/utils.py index d4ec4a2e..b7cae983 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -125,10 +125,12 @@ def _generate_dummy_strategy(mesh, op, user_args, user_kwargs, input_strategies) assert num_input_args == len( input_strategies_flat ), f"{op}, {num_input_args}, {len(input_strategies_flat)}" - # TODO: fix redistribute cost - out_strat.redistribute_cost = [ - [0.0] * len(x.strategies) for x in input_strategies_flat + redistribute_cost = [ + generate_redistribute_costs(input_strategies_flat[i], input_specs[i]) + for i in range(num_input_args) ] + out_strat.redistribute_cost = redistribute_cost + assert len(out_strat.redistribute_cost) == num_input_args out_strat = OpStrategy([out_strat]) return out_strat From dbbc2058f8e7766fd638073e2c090c5a42d0efc6 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 5 Jul 2025 08:23:42 +0000 Subject: [PATCH 09/23] Hack around missing dtypes in compute estimation and handle grouped_mm cases with invalid strides The grouped_mm should be handled in the sharding propagation and those cases should just be removed I think --- autoparallel/compute_estimation.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py index ba5a5f90..b7bff9f3 100644 --- a/autoparallel/compute_estimation.py +++ b/autoparallel/compute_estimation.py @@ -147,12 +147,13 @@ def _get_device_tflops(dtype): f"Unsupported device: {device_name}. Supported devices: {[limit.name for limit in DEVICE_LIMITS]}" ) - if dtype not in device_limit.gemm_tflops: - raise ValueError( - f"Dtype {dtype} not supported on {device_limit.name}. Supported dtypes: {list(device_limit.gemm_tflops.keys())}" - ) + # TODO: add proper support for int64 etc + # if dtype not in device_limit.gemm_tflops: + # raise ValueError( + # f"Dtype {dtype} not supported on {device_limit.name}. Supported dtypes: {list(device_limit.gemm_tflops.keys())}" + # ) - return device_limit.gemm_tflops[dtype] + return device_limit.gemm_tflops.get(dtype, 1) def _get_sharded_shape_stride(spec): @@ -213,10 +214,16 @@ def estimate_strategy_runtime_cost(node, strategy): # TODO: maybe cache the flop_counter to avoid recreating it # all the time - with FlopCounterMode(display=False) as flop_counter: - node.target(*args, **kwargs) - - flops = flop_counter.get_total_flops() + try: + with FlopCounterMode(display=False) as flop_counter: + node.target(*args, **kwargs) + + flops = flop_counter.get_total_flops() + except RuntimeError as exc: + if node.target == torch.ops.aten._grouped_mm.default: + flops = float("inf") + else: + raise exc # TODO: fix this dtype = strategy.input_specs[0].tensor_meta.dtype From d92f8c6bc01a57820f16c4479fca9b14aa5d0680 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 5 Jul 2025 08:29:14 +0000 Subject: [PATCH 10/23] Add representative batch size Otherwise we can't shard on the batch dimension. With this change everything works up to executing the solver --- examples/example_ds3.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/example_ds3.py b/examples/example_ds3.py index 82c3ea2f..53c6589e 100644 --- a/examples/example_ds3.py +++ b/examples/example_ds3.py @@ -825,7 +825,7 @@ def init_weights( ), ) -bs = 1 +bs = 8 * mesh.shape[0] seqlen = 1024 dim = 4096 From e25ff7b976311bb411726365a1be0af95d3e3e84 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 17 Jul 2025 17:24:48 -0700 Subject: [PATCH 11/23] Fix grouped_mm stride issue --- autoparallel/compute_estimation.py | 14 ++++---------- autoparallel/propagation_rules.py | 6 ++++-- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py index b7bff9f3..9a04adcb 100644 --- a/autoparallel/compute_estimation.py +++ b/autoparallel/compute_estimation.py @@ -214,16 +214,10 @@ def estimate_strategy_runtime_cost(node, strategy): # TODO: maybe cache the flop_counter to avoid recreating it # all the time - try: - with FlopCounterMode(display=False) as flop_counter: - node.target(*args, **kwargs) - - flops = flop_counter.get_total_flops() - except RuntimeError as exc: - if node.target == torch.ops.aten._grouped_mm.default: - flops = float("inf") - else: - raise exc + with FlopCounterMode(display=False) as flop_counter: + node.target(*args, **kwargs) + + flops = flop_counter.get_total_flops() # TODO: fix this dtype = strategy.input_specs[0].tensor_meta.dtype diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index 247a60d0..1626731d 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -508,9 +508,11 @@ def native_layer_norm_backward_rule(mesh, op_schema): @register_opschema_rule(torch.ops.prims.convert_element_type.default) def convert_element_type_rule(mesh, op_schema): - from torch.distributed.tensor._ops._tensor_ops import default_strategy + from torch.distributed.tensor._ops._tensor_ops import ( + propagate_single_input_strategy, + ) - out_strat = default_strategy(op_schema) + out_strat = propagate_single_input_strategy(op_schema) return out_strat From 3b7e7fa9008c0dac104be9ef6ead123b23d4eff1 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 18 Jul 2025 07:15:44 -0700 Subject: [PATCH 12/23] get DS3 running forward, OOM at backward --- autoparallel/api.py | 6 ++---- examples/example_ds3.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index fc5b8d2f..bc337b95 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -281,10 +281,8 @@ def build_model_graph(self): # we basically want to remove noops in here prev = torch._inductor.config.pattern_matcher torch._inductor.config.pattern_matcher = False - try: - gm = joint_graph_passes(gm) - finally: - torch._inductor.config.pattern_matcher = prev + gm = joint_graph_passes(gm) + torch._inductor.config.pattern_matcher = prev remove_assert_ops(gm.graph) gm.graph.eliminate_dead_code() gm.recompile() diff --git a/examples/example_ds3.py b/examples/example_ds3.py index 53c6589e..60644233 100644 --- a/examples/example_ds3.py +++ b/examples/example_ds3.py @@ -851,3 +851,22 @@ def input_fn(): autop.add_output_constraints([x_sharding]) sharding_placement = autop.optimize_placement() +parallel_mod = autop.apply_placement(sharding_placement) + +# run weight init on our sharded DTensor params +parallel_mod.to_empty(device="cuda") +parallel_mod.init_weights(init_std=0.02, buffer_device="cuda") # maybe not correct value + +# # now let's run it +x = ( + torch.randn( + # 0, + # args.vocab_size, + (bs // mesh.shape[0], seqlen, dim), + device=torch.device("cuda"), + dtype=torch.bfloat16 + ), +) +out = parallel_mod(*x) +out.backward(torch.randn_like(out)) +print("All good!") From 3833a0683fce65c333b8688b71a4c93d97abea52 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 18 Jul 2025 15:41:04 -0700 Subject: [PATCH 13/23] WIP factory_strategy --- autoparallel/apply_sharding.py | 3 +- autoparallel/propagation_rules.py | 187 ++++++++++++++++++++---------- autoparallel/utils.py | 18 +++ examples/example_ds3.py | 2 +- 4 files changed, 144 insertions(+), 66 deletions(-) diff --git a/autoparallel/apply_sharding.py b/autoparallel/apply_sharding.py index 8e592e74..fe68e954 100644 --- a/autoparallel/apply_sharding.py +++ b/autoparallel/apply_sharding.py @@ -14,6 +14,7 @@ from torch.distributed.tensor.placement_types import Partial, Replicate, Shard # noqa from torch.fx.experimental.proxy_tensor import make_fx from torch.utils._pytree import tree_flatten, tree_map_only +from .propagation_rules import TENSOR_FACTORY_OPS def my_redistribute_local_tensor(arg, curr_spec, tgt_spec): @@ -129,7 +130,7 @@ def call_function(self, target, args, kwargs): new_args = self.redistribute_args(args) # apply sharding to constructor functions as well - if target == torch.ops.aten.full.default: + if target in TENSOR_FACTORY_OPS: val = list(new_args[0]) spec = self.sharding_placement[node].output_specs for mesh_size, placement in zip(spec.mesh.shape, spec.placements): diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index 1626731d..459959a6 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -33,7 +33,7 @@ propagate_shape_and_sharding, register_op_strategy_map, ) -from torch.distributed.tensor._ops.utils import generate_redistribute_costs +from torch.distributed.tensor._ops.utils import generate_redistribute_costs, is_tensor_shardable from torch.distributed.tensor.placement_types import Replicate, Shard # TODO: move this to PyTorch @@ -64,11 +64,15 @@ def wrapper(impl): _op_partial_rules = {} -def register_opschema_rule(op): +def register_opschema_rule(ops): global _op_partial_rules def wrapper(impl): - _op_partial_rules[op] = impl + if isinstance(ops, list): + for op in ops: + _op_partial_rules[op] = impl + else: + _op_partial_rules[ops] = impl return impl return wrapper @@ -101,7 +105,11 @@ def remove_invalid_configs(out_strat, mesh): output_specs = strategy.output_specs if isinstance(output_specs, DTensorSpec): output_specs = [output_specs] - specs = list(strategy.input_specs) + list(output_specs) + if strategy.input_specs is not None: + specs = list(strategy.input_specs) + list(output_specs) + else: + specs = list(output_specs) + for spec in specs: if spec is None: continue @@ -339,27 +347,78 @@ def randperm_rule(mesh, specs): return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])]) -@register_rule([torch.ops.aten.full.default, torch.ops.aten.empty.memory_format]) -def full_rule(mesh, specs): - print( - f"Ops that need to be implemented {torch.ops.aten.full.default}, {torch.ops.aten.empty.memory_format}" - ) - # raise NotImplementedError("Needs hardening, only tested on a few cases") - shape = specs[0] - # TODO: get the dtype - tensor_meta = _gen_tensor_meta(shape) - # TODO: I'm hard-coding this here, I'll probably need to do something else about this - # placement = (Shard(0),) + (Replicate(),) * (mesh.ndim - 1) - placement = (Replicate(),) * mesh.ndim - # placement = (Replicate(),) * mesh.ndim - input_placement = (Replicate(),) * mesh.ndim - spec = DTensorSpec(mesh, placement, tensor_meta=tensor_meta) - input_spec = DTensorSpec(mesh, input_placement, tensor_meta=tensor_meta) - # return OpStrategy([OpSpec(spec, input_specs=[spec], redistribute_cost=[[0.0]])]) - return OpStrategy( - [OpSpec(spec, input_specs=[input_spec], redistribute_cost=[[0.0]])] - ) +# We do a few special things for factory ops +# - use the factory rule below +# - fake that they have input schemas so the solver doesn't freak out +# - convert their sizes to 'local tensor sizes' (divide by mesh dim) during ApplySharding +TENSOR_FACTORY_OPS = [ + torch.ops.aten.zeros.default, + torch.ops.aten.ones.default, + torch.ops.aten.full.default, + torch.ops.aten.empty.memory_format, + torch.ops.aten.rand.default, + torch.ops.aten.randn.default, +] + +@register_opschema_rule(TENSOR_FACTORY_OPS) +def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: + """ + This is an auto-parallel specific util that won't be upstreamed becuase of a UX mismatch. + + In regular DTensor programs, a user has to either call `torch.full` to get a regular tensor, or + `torch.distributed.tensor.full` (with placements specified) to get a DTensor. + + There is no point registering a strategy in DTensor for factories like 'full' since there is no way they + could be used by DTensor's dispatching logic. (Note: DTensor does provide strategies for similar ops like + 'new_full' and 'full_like', the difference being there is an input tensor to trigger dispatch off of and to + use to direct the placement options.) + + This util applies to any factory function that takes 'size' as the first argument, + and supports Replication and Shard placements all at zero cost. + """ + shape = op_schema.args_schema[0] + x = torch.empty(shape, device="meta") + stride = x.stride() + dtype = torch.get_default_dtype() + if len(op_schema.args_schema) >= 3: + # Todo didn't really verify this + dtype = op_schema.args_schema[2] + assert isinstance(dtype, torch.dtype), dtype + + # TODO: ensure the solver knows that it is more expensive to Replicate factory functions than shard + # for now, put replicate last since this might encourage sharding :? + single_mesh_dim_strategies = [[Shard(i)] for i in range(len(shape))] + [[Replicate()]] + + """ + Expand the single_mesh_dim_strategies to full mesh dim strategies. + see docs for `expand_to_full_mesh_op_strategy` in _tensor_ops.py in pytorch + """ + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = list(itertools.product(*all_mesh_dim_strategies)) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [DTensorSpec(mesh, specs) for specs in zip(*strategy_comb)] + output_specs = spec_list[0] + output_specs.tensor_meta = TensorMeta(shape, stride, dtype) + + if not is_tensor_shardable(shape, output_specs): + continue + redistribute_cost = [ + # TODO: there shouldn't actually be a row here, since there is no input to the op and the rows correspond + # to the inputs. However, the optimization code is not set up to tolerate input-less ops, so hack around it + # (see "/data/users/whc/autoparallel/autoparallel/optimize_sharding.py", line 226, in walk_over_options) + [0.0] * len(strategy_combs) + ] + + strategy = OpSpec( + output_specs=output_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + return OpStrategy(all_strategies) # ====================================== # the following ops require meta_tensor fix @@ -589,46 +648,46 @@ def index_rule(mesh, op_schema): return out_strat -@register_opschema_rule(torch.ops.aten.index_put.default) -def index_put_rule(mesh, op_schema): - print(f"Ops that need to be implemented {torch.ops.aten.index_put.default}") - # raise NotImplementedError("Needs hardening, only tested on a few cases") - strat = op_schema.args_schema - specs = strat # TODO: clean this up - res = [] - # idxs_placements = [(Replicate(), Replicate()), (Shard(0), Replicate())] - # if strat[1].childs[0] is None: - # idxs_placements = idxs_placements[:1] - # else: - # idxs_placements = idxs_placements[1:] - idxs_placements = [(Replicate(),) * mesh.ndim] - # TODO: this is a nasty hack and won't work for most of the cases - for i, ss in enumerate(strat[0].strategies): - for plt in idxs_placements: - ispec = ss.input_specs[0] - ospec = DTensorSpec(mesh=mesh, placements=ispec.placements) - assert ss.output_spec == ispec, f"{ss.output_spec}, {ispec}" - idxs_strats = [ - DTensorSpec(mesh, placements=plt) - for x in strat[1].childs - if x is not None - ] - kspc = [x for x in strat[1].childs if x is not None] - t_strats = [DTensorSpec(mesh, placements=ispec.placements)] - s = OpSpec(output_specs=ospec, input_specs=[ispec] + idxs_strats + t_strats) - - redistribute_costs = ( - [generate_redistribute_costs(specs[0], ospec)] - + [ - generate_redistribute_costs(kk, idxs_strat) - for kk, idxs_strat in zip(kspc, idxs_strats) - ] - + [generate_redistribute_costs(specs[2], t_strats[0])] - ) - s.redistribute_cost = redistribute_costs - res.append(s) - out_strat = OpStrategy(res) - return out_strat +# @register_opschema_rule(torch.ops.aten.index_put.default) +# def index_put_rule(mesh, op_schema): +# print(f"Ops that need to be implemented {torch.ops.aten.index_put.default}") +# # raise NotImplementedError("Needs hardening, only tested on a few cases") +# strat = op_schema.args_schema +# specs = strat # TODO: clean this up +# res = [] +# # idxs_placements = [(Replicate(), Replicate()), (Shard(0), Replicate())] +# # if strat[1].childs[0] is None: +# # idxs_placements = idxs_placements[:1] +# # else: +# # idxs_placements = idxs_placements[1:] +# idxs_placements = [(Replicate(),) * mesh.ndim] +# # TODO: this is a nasty hack and won't work for most of the cases +# for i, ss in enumerate(strat[0].strategies): +# for plt in idxs_placements: +# ispec = ss.input_specs[0] +# ospec = DTensorSpec(mesh=mesh, placements=ispec.placements) +# assert ss.output_spec == ispec, f"{ss.output_spec}, {ispec}" +# idxs_strats = [ +# DTensorSpec(mesh, placements=plt) +# for x in strat[1].childs +# if x is not None +# ] +# kspc = [x for x in strat[1].childs if x is not None] +# t_strats = [DTensorSpec(mesh, placements=ispec.placements)] +# s = OpSpec(output_specs=ospec, input_specs=[ispec] + idxs_strats + t_strats) + +# redistribute_costs = ( +# [generate_redistribute_costs(specs[0], ospec)] +# + [ +# generate_redistribute_costs(kk, idxs_strat) +# for kk, idxs_strat in zip(kspc, idxs_strats) +# ] +# + [generate_redistribute_costs(specs[2], t_strats[0])] +# ) +# s.redistribute_cost = redistribute_costs +# res.append(s) +# out_strat = OpStrategy(res) +# return out_strat @register_opschema_rule(torch.ops.aten._scaled_dot_product_efficient_attention.default) diff --git a/autoparallel/utils.py b/autoparallel/utils.py index b7cae983..4deb7470 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -3,6 +3,7 @@ # This source code is licensed under the BSD license found in the # LICENSE file in the root directory of this source tree. +from re import I import torch from torch.distributed._tensor.placement_types import TensorMeta from torch.distributed.device_mesh import _get_device_handle @@ -53,6 +54,23 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): else: assert tm is None if strat.input_specs is None: + if op in [ + # TODO import the list of registered factories, but it would be circular the way it is now + torch.ops.aten.zeros.default, + torch.ops.aten.ones.default, + torch.ops.aten.full.default, + torch.ops.aten.empty.memory_format, + torch.ops.aten.rand.default, + torch.ops.aten.randn.default, + ]: + # there isn't an input spec bc the op has no input! + # continue + + # but index_put op insists on looking at 'input_specs' of its input, which seems absurd. + # so just copy it for now and fix later + strat.input_specs = (strat.output_specs,) + continue + supported_ops = { torch.ops.prims.convert_element_type.default, torch.ops.aten.clone.default, diff --git a/examples/example_ds3.py b/examples/example_ds3.py index 60644233..886e4aea 100644 --- a/examples/example_ds3.py +++ b/examples/example_ds3.py @@ -825,7 +825,7 @@ def init_weights( ), ) -bs = 8 * mesh.shape[0] +bs = 4 * mesh.shape[0] seqlen = 1024 dim = 4096 From 3740b4544a2d97e56c3f100d335d4c01bd7cf767 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 25 Jul 2025 09:38:16 +0000 Subject: [PATCH 14/23] Start rebasing on top of main --- autoparallel/api.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index bc337b95..fc5b8d2f 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -281,8 +281,10 @@ def build_model_graph(self): # we basically want to remove noops in here prev = torch._inductor.config.pattern_matcher torch._inductor.config.pattern_matcher = False - gm = joint_graph_passes(gm) - torch._inductor.config.pattern_matcher = prev + try: + gm = joint_graph_passes(gm) + finally: + torch._inductor.config.pattern_matcher = prev remove_assert_ops(gm.graph) gm.graph.eliminate_dead_code() gm.recompile() From 6bec5f52d6d9081a8f33fa450b7074c46e9c7b10 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 25 Jul 2025 11:28:04 +0000 Subject: [PATCH 15/23] Fixes so that it runs --- autoparallel/propagation_rules.py | 8 ++++++-- autoparallel/utils.py | 16 +++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index 046a0c1c..f253d619 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -381,7 +381,7 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: This util applies to any factory function that takes 'size' as the first argument, and supports Replication and Shard placements all at zero cost. """ - assert isinstance(op_schema.args_schema[0], torch.Size) + assert isinstance(op_schema.args_schema[0], (torch.Size, list)) shape = op_schema.args_schema[0] x = torch.empty(shape, device="meta") stride = x.stride() @@ -423,8 +423,11 @@ def factory_rule(mesh, op_schema: OpSchema) -> OpStrategy: * len(strategy_combs) ] + # TODO: should we add an input_spec here, so that we can ensure we always + # have input and output specs? For now I hacked it in utils.py strategy = OpSpec( output_specs=output_specs, + input_specs=[output_specs], redistribute_cost=redistribute_cost, ) all_strategies.append(strategy) @@ -616,7 +619,8 @@ def _unsafe_index_rule(mesh, op_schema): raise NotImplementedError() -@register_opschema_rule(torch.ops.aten.index.Tensor) +# Disable this rule as it's implementation is inferior than the baseline +# @register_opschema_rule(torch.ops.aten.index.Tensor) def index_rule(mesh, op_schema): print(f"Ops that need to be implemented {torch.ops.aten.index.Tensor}") # raise NotImplementedError("Needs hardening, only tested on a few cases") diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 1a751265..bb305256 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -58,9 +58,7 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): else: assert tm is None if strat.input_specs is None: - if op in TENSOR_FACTORY_OPS: - # there isn't an input spec bc the op has no input! - continue + # TODO: this should be cleaned up supported_ops = { torch.ops.prims.convert_element_type.default, @@ -73,9 +71,14 @@ def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): ) strat.input_specs = (strat.output_specs,) assert strat.redistribute_cost is None - assert len(tensor_metas) == len( - strat.input_specs - ), f"{op}, {len(tensor_metas)}, {len(strat.input_specs)}" + # TODO: this invariant wrt factory ops is something I believe + # I'll keep for the solver, so we need to have some consistency here + # i.e., even though factory ops don't have inputs, we do put an + # input spec for it which is equal to the output spec + if op not in TENSOR_FACTORY_OPS: + assert len(tensor_metas) == len( + strat.input_specs + ), f"{op}, {len(tensor_metas)}, {len(strat.input_specs)}" for tm, ispec in zip(tensor_metas, strat.input_specs): if ispec.tensor_meta is None: ispec.tensor_meta = tm @@ -176,7 +179,6 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): in torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs and op not in { - torch.ops.aten.slice_scatter.default, torch.ops.aten._softmax_backward_data.default, } ): From ce1c0a58c4ebae8af215e1386a4f27defde587b1 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 26 Jul 2025 06:55:15 +0000 Subject: [PATCH 16/23] [WIP] Plumb fake_mode to avoid materializing memory Needs cleanup --- autoparallel/api.py | 2 +- autoparallel/compute_estimation.py | 2 +- autoparallel/optimize_sharding.py | 13 ++++++++++--- autoparallel/utils.py | 25 ++++++++++++++++--------- 4 files changed, 28 insertions(+), 14 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index fc5b8d2f..907d6e19 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -247,7 +247,7 @@ def __init__(self, model, input_fn, mesh: DeviceMesh): self.mesh = mesh self.build_model_graph() - sharding_optimizer = ShardingOptimizer(self.gm, self.mesh) + sharding_optimizer = ShardingOptimizer(self.gm, self.mesh, self.fake_mode) # makes sharding of params and gradients the same sharding_optimizer.add_grad_param_constraints() self.sharding_optimizer = sharding_optimizer diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py index 9a04adcb..e13ef489 100644 --- a/autoparallel/compute_estimation.py +++ b/autoparallel/compute_estimation.py @@ -214,7 +214,7 @@ def estimate_strategy_runtime_cost(node, strategy): # TODO: maybe cache the flop_counter to avoid recreating it # all the time - with FlopCounterMode(display=False) as flop_counter: + with FlopCounterMode(display=False) as flop_counter, fake_mode: node.target(*args, **kwargs) flops = flop_counter.get_total_flops() diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 137be51b..be2a3640 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -114,10 +114,11 @@ def _get_next_name(name): class ShardingOptimizer: - def __init__(self, gm, mesh): + def __init__(self, gm, mesh, fake_mode): self.gm = gm self.graph = gm.graph self.mesh = mesh + self.fake_mode = fake_mode self.node_map = {node: i for i, node in enumerate(self.graph.nodes)} self.strats = self.build_sharding_metadata() # ds: Decision variables dictionary mapping (s_i, argi, ss, ii) -> ILP variable data @@ -147,7 +148,12 @@ def build_sharding_metadata(self): torch.fx.Node, lambda x: x.meta["val"], node.kwargs ) strat = get_placement_options( - self.mesh, node.target, user_strats, user_args, user_kwargs + self.mesh, + node.target, + user_strats, + user_args, + user_kwargs, + self.fake_mode, ) strats[node] = strat elif node.op == "output": @@ -194,7 +200,8 @@ def build_ds(self): "num_output_strat": len(s.strategies), } for ss, ssi in enumerate(s.strategies): - compute_cost = estimate_strategy_runtime_cost(node, ssi) + with self.fake_mode: + compute_cost = estimate_strategy_runtime_cost(node, ssi) for argi, xxi in enumerate(ssi.redistribute_cost): for ii, comm_cost in enumerate(xxi): va = pulp.LpVariable( diff --git a/autoparallel/utils.py b/autoparallel/utils.py index bb305256..83be36b9 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -18,8 +18,9 @@ ) -def _get_meta_tensors_for_op(op, user_args, user_kwargs): - out_t = op(*user_args, **user_kwargs) +def _get_meta_tensors_for_op(op, user_args, user_kwargs, fake_mode): + with fake_mode: + out_t = op(*user_args, **user_kwargs) if isinstance(out_t, torch.Tensor): out_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype) @@ -40,8 +41,10 @@ def _get_meta_tensors_for_op(op, user_args, user_kwargs): return out_tensor_meta, input_tensor_metas -def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): - new_tensor_meta, tensor_metas = _get_meta_tensors_for_op(op, user_args, user_kwargs) +def propagate_tensor_meta(op, user_args, user_kwargs, out_strat, fake_mode): + new_tensor_meta, tensor_metas = _get_meta_tensors_for_op( + op, user_args, user_kwargs, fake_mode + ) for strat in out_strat.strategies: if isinstance(new_tensor_meta, TensorMeta): @@ -104,7 +107,9 @@ def fill_missing_redistribute_cost(op, specs, out_strat): strat.redistribute_cost = redistribute_costs -def _generate_dummy_strategy(mesh, op, user_args, user_kwargs, input_strategies): +def _generate_dummy_strategy( + mesh, op, user_args, user_kwargs, input_strategies, fake_mode +): from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._op_schema import OpSpec from torch.distributed.tensor.placement_types import Replicate @@ -112,7 +117,7 @@ def _generate_dummy_strategy(mesh, op, user_args, user_kwargs, input_strategies) placements = (Replicate(),) * mesh.ndim out_tensor_meta, input_tensor_metas = _get_meta_tensors_for_op( - op, user_args, user_kwargs + op, user_args, user_kwargs, fake_mode ) input_specs = [ @@ -148,7 +153,7 @@ def _generate_dummy_strategy(mesh, op, user_args, user_kwargs, input_strategies) return out_strat -def get_placement_options(mesh, op, specs, user_args, user_kwargs): +def get_placement_options(mesh, op, specs, user_args, user_kwargs, fake_mode): # print(op) if op in _op_rules: @@ -189,9 +194,11 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): ) else: print(f"Ops that need to be implemented {op}") - out_strat = _generate_dummy_strategy(mesh, op, user_args, user_kwargs, strat) + out_strat = _generate_dummy_strategy( + mesh, op, user_args, user_kwargs, strat, fake_mode + ) - propagate_tensor_meta(op, user_args, user_kwargs, out_strat) + propagate_tensor_meta(op, user_args, user_kwargs, out_strat, fake_mode) fill_missing_redistribute_cost(op, specs, out_strat) out_strat = remove_invalid_configs(out_strat, mesh) From 5d79becde97f9c2ab47f2894bb48128c23bd5857 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 26 Jul 2025 06:56:21 +0000 Subject: [PATCH 17/23] Use more representative values for DS3 example --- examples/example_ds3.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/examples/example_ds3.py b/examples/example_ds3.py index 886e4aea..9f207191 100644 --- a/examples/example_ds3.py +++ b/examples/example_ds3.py @@ -809,7 +809,7 @@ def init_weights( ) -world_size = 256 +world_size = 2048 fake_store = FakeStore() torch.distributed.init_process_group( @@ -818,14 +818,14 @@ def init_weights( # mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) mesh = torch.distributed.device_mesh.init_device_mesh( "cuda", - (world_size // 8, 8), + (world_size // 32, 32), mesh_dim_names=( "dp", - "tp", + "ep", ), ) -bs = 4 * mesh.shape[0] +bs = 4 * mesh.shape[0] * mesh.shape[1] seqlen = 1024 dim = 4096 @@ -845,7 +845,8 @@ def input_fn(): autop = AutoParallel(model, input_fn, mesh) autop.add_parameter_memory_constraint(low=None, high=None) -x_sharding = (Shard(0), Replicate()) +# x_sharding = (Shard(0), Replicate()) +x_sharding = (Shard(0), Shard(0)) autop.add_input_constraints([x_sharding]) autop.add_output_constraints([x_sharding]) @@ -855,16 +856,18 @@ def input_fn(): # run weight init on our sharded DTensor params parallel_mod.to_empty(device="cuda") -parallel_mod.init_weights(init_std=0.02, buffer_device="cuda") # maybe not correct value +parallel_mod.init_weights( + init_std=0.02, buffer_device="cuda" +) # maybe not correct value # # now let's run it x = ( torch.randn( # 0, # args.vocab_size, - (bs // mesh.shape[0], seqlen, dim), + (bs // mesh.shape[0] // mesh.shape[1], seqlen, dim), device=torch.device("cuda"), - dtype=torch.bfloat16 + dtype=torch.bfloat16, ), ) out = parallel_mod(*x) From daea5a24a1a03c8d2b2bcf1803ddadad113de78c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 26 Jul 2025 20:21:24 +0000 Subject: [PATCH 18/23] Add approximate flop formula to grouped_mm There was no flop formula, which was making the solver think that computing this op is free --- autoparallel/compute_estimation.py | 33 +++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/autoparallel/compute_estimation.py b/autoparallel/compute_estimation.py index e13ef489..291c4eb2 100644 --- a/autoparallel/compute_estimation.py +++ b/autoparallel/compute_estimation.py @@ -8,7 +8,38 @@ import torch from torch.utils._pytree import tree_map_only -from torch.utils.flop_counter import FlopCounterMode +from torch.utils.flop_counter import FlopCounterMode, register_flop_formula + + +@register_flop_formula(torch.ops.aten._grouped_mm) +def gmm_flop( + a_shape, b_shape, offs_shape=None, bias_shape=None, out_shape=None, **kwargs +) -> int: + """Count flops for the gmm operation.""" + # Inputs should be a list of length 2. + # Inputs contains the shapes of two tensor + if len(a_shape) == 2: + assert offs_shape is not None + (b,) = offs_shape + m0, k = a_shape + # assumption: assume roughtly balanced, so falls-back to bmm + m = m0 // b + else: + assert offs_shape is None + b, m, k = a_shape + if len(b_shape) == 2: + assert offs_shape is not None + (b2,) = offs_shape + k2, n0 = b_shape + # assumption: assume roughtly balanced, so falls-back to bmm + n = n0 // b2 + else: + b2, k2, n = b_shape + assert b == b2 + assert k == k2 + # NB(chilli): Should be 2 * k - 1 technically for FLOPs. + flop = b * m * n * 2 * k + return flop @dataclass From 418ad55b1dd7ecb59906519959c825a3fbdf88f7 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sun, 27 Jul 2025 20:08:31 +0000 Subject: [PATCH 19/23] Glimpses of having DeepSeekV3 returning a reasonable solution This is still approximate as we can't evenly shard on the tokens, but doing this prior to see if we can introduce a DynamicShard primitive --- autoparallel/optimize_sharding.py | 7 +- autoparallel/propagation_rules.py | 121 +++++++++++++++ autoparallel/utils.py | 21 +++ examples/example_ds3.py | 240 +++++++++++++++++++++++++++--- 4 files changed, 365 insertions(+), 24 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 594efbcb..ab91f3ff 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -490,7 +490,12 @@ def print_costs_for_node(self, node, arg=0, **kwargs): from torch.distributed.tensor._op_schema import _pretty_print_spec tgt_strat = self.strats[node] - src_strat = self.strats[node.args[arg]] + # Use this instead of node.all_input_nodes because there could be + # duplicate nodes that get removed + all_input_nodes = [ + x for x in tree_flatten(node.args)[0] if isinstance(x, torch.fx.Node) + ] + src_strat = self.strats[all_input_nodes[arg]] src_placements = [""] + [ _pretty_print_spec(x.output_specs) for x in src_strat.strategies ] diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index 4fba8afb..db8fd24b 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -668,6 +668,127 @@ def index_rule(mesh, op_schema): return out_strat +@register_opschema_rule(torch.ops.aten.sort.stable) +def sort_rule(mesh, op_schema): + op = torch.ops.aten.topk.default + out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ + op + ]( + op_schema + ) + return out_strat + + +@register_opschema_rule(torch.ops.aten.gather.default) +def gather_strategy(mesh, op_schema): + from torch.distributed.tensor._op_schema import PlacementList + from torch.distributed.tensor._ops._embedding_ops import _MaskPartial + from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy + + input_strategy = op_schema.args_schema[0] + dim = op_schema.args_schema[1] + index_strategy = op_schema.args_schema[2] + + input_shape = input_strategy.shape + index_shape = index_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # input sharding, input sharded, index accepts mask partial, output follows index + # this only works when the input is sharded on the gather dimension, and + # index has size 1 on the gather dimension + if index_shape[dim] == 1: + index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) + input_sharding: PlacementList = [ + index_partial_placement, + Shard(dim), + index_partial_placement, + ] + single_mesh_dim_strategies.append(input_sharding) + + # index sharding, input replicated, index sharded, output follows index + # this only works when the sharding dimension is the gather dimension + index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] + single_mesh_dim_strategies.append(index_sharding) + + if len(input_shape) == len(index_shape): + for d in range(len(input_shape)): + if d != dim: + sharding = [Shard(d), Shard(d), Shard(d)] + single_mesh_dim_strategies.append(sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + +@register_opschema_rule(torch.ops.aten.scatter_add.default) +def scatter_add_strategy(mesh, op_schema): + from torch.distributed.tensor._op_schema import PlacementList + + # from torch.distributed.tensor._ops._embedding_ops import _MaskPartial + from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy + + input_strategy = op_schema.args_schema[0] + dim = op_schema.args_schema[1] + index_strategy = op_schema.args_schema[2] + # src_strategy = op_schema.args_schema[3] + + input_shape = input_strategy.shape + index_shape = index_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + """ + # input sharding, input sharded, index accepts mask partial, output follows index + # this only works when the input is sharded on the gather dimension, and + # index has size 1 on the gather dimension + if index_shape[dim] == 1: + index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) + input_sharding: PlacementList = [ + index_partial_placement, + Shard(dim), + index_partial_placement, + ] + single_mesh_dim_strategies.append(input_sharding) + """ + # index sharding, input replicated, index sharded, output follows index + # this only works when the sharding dimension is the gather dimension + index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)] + single_mesh_dim_strategies.append(index_sharding) + + if len(input_shape) == len(index_shape): + for d in range(len(input_shape)): + if d != dim: + sharding = [Shard(d), Shard(d), Shard(d), Shard(d)] + single_mesh_dim_strategies.append(sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + +@register_opschema_rule(torch.ops.aten.slice_scatter.default) +def slice_scatter_rule(mesh, op_schema): + op = torch.ops.aten.slice_scatter.default + out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ + op + ]( + op_schema + ) + return out_strat + + def sdpa_rule(op, mesh, op_schema): out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ op diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 733d8868..8d81e0dd 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -176,12 +176,32 @@ def _generate_dummy_strategy( return out_strat +def keep_unique_configs(op_strat): + added = set() + filtered_strats = [] + for strat in op_strat.strategies: + input_specs = strat.input_specs + output_specs = strat.output_specs + if isinstance(input_specs, list): + input_specs = tuple(input_specs) + if isinstance(output_specs, list): + output_specs = tuple(output_specs) + key = (input_specs, output_specs) + if key in added: + continue + + added.add(key) + filtered_strats.append(strat) + return OpStrategy(filtered_strats) + + def get_placement_options(mesh, op, specs, user_args, user_kwargs, fake_mode): # print(op) if op in _op_rules: out_strat = _op_rules[op](mesh, specs) out_strat = remove_invalid_configs(out_strat, mesh) + out_strat = keep_unique_configs(out_strat) return out_strat strat = [] @@ -224,6 +244,7 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs, fake_mode): propagate_tensor_meta(op, user_args, user_kwargs, out_strat, fake_mode) fill_missing_redistribute_cost(op, specs, out_strat) out_strat = remove_invalid_configs(out_strat, mesh) + out_strat = keep_unique_configs(out_strat) return out_strat diff --git a/examples/example_ds3.py b/examples/example_ds3.py index 9f207191..43f6e28f 100644 --- a/examples/example_ds3.py +++ b/examples/example_ds3.py @@ -19,6 +19,149 @@ from autoparallel.api import AutoParallel +@torch.library.custom_op("autoparallel::batched_grouped_mm", mutates_args=()) +def batched_grouped_mm( + mat1: torch.Tensor, mat2: torch.Tensor, offs: torch.Tensor +) -> torch.Tensor: + assert offs.ndim == 2 + assert mat1.ndim == 3 + assert mat2.ndim == 3, f"{mat2.shape}" + res = [] + for a, off in zip(mat1, offs): + res.append(a, mat2, off) + return torch.stack(res, 0) + + +def setup_context(ctx, inputs, output) -> torch.Tensor: + mat1, mat2, offs = inputs + ctx.save_for_backward(mat1, mat2, offs) + + +def backward(ctx, grad): + mat1, mat2, offs = ctx.saved_tensors + grad1 = batched_grouped_mm(grad, mat2.transpose(-2, -1), offs) + grad2 = batched_grouped_mm(mat1.transpose(-2, -1), grad, offs) + return grad1, grad2, None + + +torch.library.register_autograd( + "autoparallel::batched_grouped_mm", backward, setup_context=setup_context +) + + +@batched_grouped_mm.register_fake +def _(mat1: torch.Tensor, mat2: torch.Tensor, offs: torch.Tensor) -> torch.Tensor: + out = torch.empty( + mat1.shape[0], + mat1.shape[1], + mat2.shape[2], + dtype=mat1.dtype, + device=mat1.device, + ) + return out + + +@torch.library.custom_op("autoparallel::batched_histc", mutates_args=()) +def batched_histc( + x: torch.Tensor, bins: int = 100, min: int = 0, max: int = 0 +) -> torch.Tensor: + assert x.ndim == 2 + out = [] + for t in x: + out.append(torch.histc(t, bins, min, max)) + return torch.stack(out, 0) + + +@batched_histc.register_fake +def batched_histc( + x: torch.Tensor, bins: int = 100, min: int = 0, max: int = 0 +) -> torch.Tensor: + assert max - min == bins + out = torch.empty((x.shape[0], bins), dtype=torch.int64, device=x.device) + return out + + +from torch.utils.flop_counter import register_flop_formula + + +@register_flop_formula(torch.ops.autoparallel.batched_grouped_mm) +def gmm_flop( + a_shape, b_shape, offs_shape=None, bias_shape=None, out_shape=None, **kwargs +) -> int: + """Count flops for the gmm operation.""" + # Inputs should be a list of length 2. + # Inputs contains the shapes of two tensor + if len(a_shape) == 2: + assert offs_shape is not None + # b, = offs_shape + # m0, k = a_shape + # assumption: assume roughtly balanced, so falls-back to bmm + # m = m0 // b + else: + # assert offs_shape is None + ( + b0, + bb, + ) = offs_shape + b, m0, k = a_shape + m = m0 // bb + if len(b_shape) == 2: + assert offs_shape is not None + # b2, _ = offs_shape + # k2, n0 = b_shape + # assumption: assume roughtly balanced, so falls-back to bmm + # n = n0 // b2 + else: + b2, k2, n = b_shape + assert b0 == b + assert bb == b2 + assert k == k2 + # NB(chilli): Should be 2 * k - 1 technically for FLOPs. + flop = b * m * n * 2 * k + return flop + + +from torch.distributed.tensor.placement_types import Partial, Replicate, Shard + +from autoparallel.propagation_rules import register_opschema_rule + + +@register_opschema_rule(torch.ops.autoparallel.batched_grouped_mm.default) +def _(mesh, op_schema): + from torch.distributed.tensor._op_schema import PlacementList + from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy + + mat1_strategy = op_schema.args_schema[0] + mat2_strategy = op_schema.args_schema[1] + index_strategy = op_schema.args_schema[2] + + # input_shape = input_strategy.shape + # index_shape = index_strategy.shape + + assert len(mat1_strategy.shape) == 3 + assert len(mat2_strategy.shape) == 3 + assert len(index_strategy.shape) == 2 + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, mat1, mat2, offs] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + single_mesh_dim_strategies.append([Shard(0), Shard(0), Replicate(), Shard(0)]) + single_mesh_dim_strategies.append([Shard(2), Replicate(), Shard(2), Replicate()]) + single_mesh_dim_strategies.append([Partial(), Shard(2), Shard(1), Replicate()]) + + # FIXME: this is wrong, but approximation for more complex dynamic stuff + # we might want to introduce DynamicShard which splits the shards on + # dynamic sizes maybe? + single_mesh_dim_strategies.append([Shard(1), Shard(1), Shard(0), Shard(1)]) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + # Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py @dataclass class DeepSeekV3ModelArgs: @@ -554,7 +697,7 @@ def _run_experts_for_loop( return out - @expert_parallel + # @expert_parallel @staticmethod def _run_experts_grouped_mm( w1: torch.Tensor, @@ -564,9 +707,10 @@ def _run_experts_grouped_mm( num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: if num_tokens_per_expert is not None: - offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + offsets = torch.cumsum(num_tokens_per_expert, dim=-1, dtype=torch.int32) # grouped mm between a 2D tensor and a 3D tensor - assert x.dim() == 2 + # assert x.dim() == 2 + assert x.dim() == 3 else: offsets = None # fall back to regular bmm between 3D tensors @@ -576,9 +720,15 @@ def _run_experts_grouped_mm( x.dtype == w1.dtype == w2.dtype == w3.dtype == torch.bfloat16 ), "torch._grouped_mm only supports bf16 dtypes" - h = F.silu(torch._grouped_mm(x, w1, offs=offsets)) - h = h * torch._grouped_mm(x, w3, offs=offsets) - out = torch._grouped_mm(h, w2, offs=offsets) + # TODO: maybe introduce batched group_mm ? + if offsets is None: + h = F.silu(torch._grouped_mm(x, w1, offs=offsets)) + h = h * torch._grouped_mm(x, w3, offs=offsets) + out = torch._grouped_mm(h, w2, offs=offsets) + else: + h = F.silu(torch.ops.autoparallel.batched_grouped_mm(x, w1, offs=offsets)) + h = h * torch.ops.autoparallel.batched_grouped_mm(x, w3, offs=offsets) + out = torch.ops.autoparallel.batched_grouped_mm(h, w2, offs=offsets) return out @@ -624,7 +774,7 @@ def forward( and currently EP is not supporting node limit routing yet. Args: - x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. Returns: routed_input (torch.Tensor): @@ -634,31 +784,49 @@ def forward( num_tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert with shape ``(num_experts,)``. """ - # scores shape (bs*slen, num_experts) + # scores shape (bs, slen, num_experts) + + # x = x.reshape(bs * slen, dim) scores = self.gate(x) # By default, sigmoid or softmax is performed in float32 to avoid loss explosion if self.use_sigmoid: scores = torch.sigmoid(scores.to(torch.float32)) else: - scores = F.softmax(scores.to(torch.float32), dim=1) + scores = F.softmax(scores.to(torch.float32), dim=-1) # top scores shape (bs*slen, top_k) # NOTE: The expert_bias is only used for routing. The gating value # top_scores is still derived from the original scores. if expert_bias is not None: _, selected_experts_indices = torch.topk( - scores + expert_bias, k=self.top_k, dim=1 + scores + expert_bias, k=self.top_k, dim=-1 ) - top_scores = scores.gather(dim=1, index=selected_experts_indices) + top_scores = scores.gather(dim=-1, index=selected_experts_indices) else: top_scores, selected_experts_indices = torch.topk( - scores, k=self.top_k, dim=1 + scores, k=self.top_k, dim=-1 ) # group tokens together by expert indices from 0 to num_experts and pass that to experts forward - num_tokens_per_expert = torch.histc( - selected_experts_indices.view(-1), + + # TODO: reshape here to allow for group-based routing + local_batch_size = 4 + num_gpus_participating = ( + 32 # 64 # NOTE: I tweaked those values so that batch sharding can be done + ) + num_experts_per_groups = local_batch_size * num_gpus_participating + + # num_tokens_per_expert = torch.histc( + # selected_experts_indices.view(-1), + # bins=self.num_experts, + # min=0, + # max=self.num_experts, + # ) + num_tokens_per_expert = torch.ops.autoparallel.batched_histc( + selected_experts_indices.unflatten(0, (-1, num_experts_per_groups)).flatten( + 1 + ), bins=self.num_experts, min=0, max=self.num_experts, @@ -667,17 +835,30 @@ def forward( # Reorder the token indices to match the order of the experts # token_indices_experts_sorted shape (bs*slen*top_k,) token_indices_experts_sorted = torch.argsort( - selected_experts_indices.view(-1), stable=True + # selected_experts_indices.view(-1), stable=True + selected_experts_indices.unflatten(0, (-1, num_experts_per_groups)).flatten( + 1 + ), + dim=-1, + stable=True, ) # reorder the scores to match the order of the token indices - top_scores = top_scores.view(-1)[token_indices_experts_sorted] + # TODO: Shard() can have negative dims because of gather rules if we pass a -1 index, is that expected? + # we should probably normalize this this, like we do in topk for e.g. + # top_scores = top_scores.view(-1)[token_indices_experts_sorted] + # top_scores = top_scores.view_as(token_indices_experts_sorted).gather(-1, token_indices_experts_sorted) + top_scores = top_scores.view_as(token_indices_experts_sorted).gather( + 1, token_indices_experts_sorted + ) token_indices_experts_sorted = token_indices_experts_sorted // self.top_k top_scores = ( top_scores * self.route_sclaing_factor ) # must multiply the scaling factor return top_scores, token_indices_experts_sorted, num_tokens_per_expert + # return top_scores.flatten(0, 1), token_indices_experts_sorted.flatten(0, 1), num_tokens_per_expert + # return top_scores.flatten(0, 1), token_indices_experts_sorted, num_tokens_per_expert def init_weights(self, init_std: float): nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) @@ -752,7 +933,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: top_scores, token_indices, num_tokens_per_expert, - ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + # ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + ) = self.router(x, self.expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. # Prevent extra local tokens accumulation on evaluation or activation recomputation. @@ -760,32 +942,44 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) # shape (bs*slen*top_k, dim) - token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + # token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + token_indices = token_indices[..., None].expand(-1, -1, dim) # shape (bs*slen*top_k, dim) + # TODO: change here as well to support groups routed_input = torch.gather( - x.view(-1, dim), - dim=0, + # x.view(-1, dim), + x.view(token_indices.shape[0], -1, dim), + dim=1, # 0, index=token_indices, ) + # routed_input = routed_input.flatten(0, 1) + # token_indices = token_indices.flatten(0, 1) # shape (bs*slen*top_k, dim) routed_output = self.experts(routed_input, num_tokens_per_expert) + # routed_output = routed_output.flatten(0, 1) + # token_indices = token_indices.flatten(0, 1) + routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( x.dtype ) # shared expert if self.shared_expert is not None: + # out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( + # bs * slen, dim + # ) out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( - bs * slen, dim + token_indices.shape[0], -1, dim ) else: out = torch.zeros_like(x.reshape(bs * slen, dim)) # Accumulate multiple expert results becase each token can be routed to multiple experts - out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + # out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.scatter_add(dim=1, index=token_indices, src=routed_output) out = out.reshape(bs, slen, dim) return out @@ -836,7 +1030,7 @@ def input_fn(): ) -args = DeepSeekV3ModelArgs(dim=dim, n_layers=1) +args = DeepSeekV3ModelArgs(dim=dim, n_layers=1, load_balance_coeff=None) # parallelize the model with torch.device("meta"): From 6d5747aa3cb1347476d024f281b2f8a9f25fca03 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 30 Jul 2025 11:22:53 +0000 Subject: [PATCH 20/23] Use with_implicit_strategies instead of my generate_dummy_strategy --- autoparallel/optimize_sharding.py | 6 +--- autoparallel/utils.py | 59 ++----------------------------- 2 files changed, 4 insertions(+), 61 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index d6af093c..939a0063 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -177,11 +177,7 @@ def build_sharding_metadata(self): strats[node] = strat else: strat = get_placement_options( - self.mesh, - node.target, - user_strats, - user_args, - user_kwargs, + self.mesh, node.target, user_strats, user_args, user_kwargs ) strats[node] = strat elif node.op == "output": diff --git a/autoparallel/utils.py b/autoparallel/utils.py index d3f8bb77..23dc0b45 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -17,7 +17,7 @@ from torch.distributed.tensor.placement_types import Replicate from torch.utils._pytree import tree_flatten, tree_map_only -from .dtensor_util import get_op_strategy +from .dtensor_util import get_op_strategy, with_implicit_strategies from .propagation_rules import ( TENSOR_FACTORY_OPS, _op_partial_rules, @@ -128,50 +128,6 @@ def fill_missing_redistribute_cost(op, specs, out_strat): strat.redistribute_cost = redistribute_costs -def _generate_dummy_strategy(mesh, op, user_args, user_kwargs, input_strategies): - from torch.distributed.tensor._dtensor_spec import DTensorSpec - from torch.distributed.tensor._op_schema import OpSpec - from torch.distributed.tensor.placement_types import Replicate - - placements = (Replicate(),) * mesh.ndim - - out_tensor_meta, input_tensor_metas = _get_meta_tensors_for_op( - op, user_args, user_kwargs - ) - - input_specs = [ - DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tm) - for tm in input_tensor_metas - ] - if isinstance(out_tensor_meta, TensorMeta): - output_spec = DTensorSpec( - mesh=mesh, placements=placements, tensor_meta=out_tensor_meta - ) - else: - output_spec = tuple( - DTensorSpec(mesh=mesh, placements=placements, tensor_meta=tm) - for tm in out_tensor_meta - ) - - out_strat = OpSpec(output_specs=output_spec, input_specs=input_specs) - num_input_args = len(input_tensor_metas) - input_strategies_flat = [ - x for x in tree_flatten(input_strategies)[0] if isinstance(x, OpStrategy) - ] - assert num_input_args == len( - input_strategies_flat - ), f"{op}, {num_input_args}, {len(input_strategies_flat)}" - redistribute_cost = [ - generate_redistribute_costs(input_strategies_flat[i], input_specs[i]) - for i in range(num_input_args) - ] - out_strat.redistribute_cost = redistribute_cost - - assert len(out_strat.redistribute_cost) == num_input_args - out_strat = OpStrategy([out_strat]) - return out_strat - - def keep_unique_configs(op_strat: OpStrategy) -> OpStrategy: added = set() filtered_strats = [] @@ -218,18 +174,9 @@ def get_placement_options(mesh, op, specs, user_args, user_kwargs): if op in _op_partial_rules: out_strat = _op_partial_rules[op](mesh, op_schema) - elif ( - op - in torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs - and op - not in { - torch.ops.aten._softmax_backward_data.default, - } - ): - out_strat = get_op_strategy(op, op_schema) else: - print(f"Ops that need to be implemented {op}") - out_strat = _generate_dummy_strategy(mesh, op, user_args, user_kwargs, strat) + with with_implicit_strategies(): + out_strat = get_op_strategy(op, op_schema) propagate_tensor_meta(op, user_args, user_kwargs, out_strat) fill_missing_redistribute_cost(op, specs, out_strat) From e0ae8a2dcf8b331a073fcf03aea434f9b3af128e Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 30 Jul 2025 21:05:19 +0000 Subject: [PATCH 21/23] [WIP] Convert view->mm->view into matmul --- autoparallel/api.py | 52 +++++++++++++- autoparallel/propagation_rules.py | 109 ++++++++++++++++++++++++++++++ examples/example_ds3.py | 40 ++++++++++- 3 files changed, 198 insertions(+), 3 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 29a8f816..054473b7 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -59,6 +59,8 @@ def delete_user_cb(n): # skip ops which return tuple if not isinstance(node.meta["val"], torch.Tensor): continue + if node.target != torch.ops.aten.view.default: + continue with graph.inserting_after(node): alias_node = graph.call_function(torch.ops.aten.alias.default, args=(node,)) alias_node.meta.update(node.meta) @@ -67,7 +69,6 @@ def delete_user_cb(n): return n != alias_node node.replace_all_uses_with(alias_node, delete_user_cb=delete_user_cb) - """ for node in graph.find_nodes(op="output")[0].all_input_nodes: @@ -84,6 +85,53 @@ def delete_user_cb(n): return gm +def _replace_view_mm_view_with_matmul(gm): + mm_nodes = gm.graph.find_nodes(op="call_function", target=torch.ops.aten.mm.default) + for node in mm_nodes: + first_input, second_input = node.all_input_nodes + if first_input.target == torch.ops.aten.view.default: + view_input = first_input.all_input_nodes[0] + users = list(node.users) + if ( + len(users) == 1 + and users[0].target == torch.ops.aten.view.default + and view_input.meta["val"].shape[:-1] == users[0].meta["val"].shape[:-1] + ): + print(f"Found matmul node {node}") + with gm.graph.inserting_before(node): + new_node = gm.graph.call_function( + torch.ops.aten.matmul.default, args=(view_input, second_input) + ) + new_node.meta.update(users[0].meta) + users[0].replace_all_uses_with(new_node) + + elif second_input.target == torch.ops.aten.view.default: + if first_input.target != torch.ops.aten.permute.default: + continue + if first_input.all_input_nodes[0].target != torch.ops.aten.view.default: + continue + orig_first = first_input.all_input_nodes[0].all_input_nodes[0] + orig_second = second_input.all_input_nodes[0] + users = list(node.users) + if ( + len(users) == 1 + and users[0].target == torch.ops.aten.permute.default + and orig_first.meta["val"].shape[:-1] + == orig_second.meta["val"].shape[:-1] + ): + print(f"Found matmul node {node}") + with gm.graph.inserting_before(node): + # TODO: check einsum equation + new_node = gm.graph.call_function( + torch.ops.aten.einsum.default, + args=("bmn,bmk->nk", [orig_first, orig_second]), + ) + new_node.meta.update(users[0].meta) + users[0].replace_all_uses_with(new_node) + gm.graph.eliminate_dead_code() + gm.recompile() + + def try_convert_fake_to_real(tensors): out = {} for k, t in tensors.items(): @@ -103,6 +151,7 @@ def _get_decomp_table(): decomp_table.pop(torch.ops.aten.embedding_dense_backward.default) decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default) decomp_table.pop(torch.ops.aten._softmax_backward_data.default) + decomp_table.pop(torch.ops.aten._softmax.default) # decompose addmm to allow for TP on mm decomp_table.pop(torch.ops.aten.addmm.default) @@ -255,6 +304,7 @@ def build_model_graph(self): # think I trust the default FX DCE logic gm.graph.eliminate_dead_code() gm.recompile() + _replace_view_mm_view_with_matmul(gm) # disable pattern_matcher as it gets on our way # we basically want to remove noops in here prev = torch._inductor.config.pattern_matcher diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index e845a879..a4e77a64 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -864,3 +864,112 @@ def expand_rule(mesh, op_schema_): for remov in to_remove: ss.redistribute_cost[0].insert(remov, math.inf) return out_strat + + +@register_opschema_rule(torch.ops.aten.matmul.default) +def matmul_rule(mesh, op_schema): + # inspired from _mm_like_strategy but removing shards on inexisting dimensions + from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies + from torch.distributed.tensor._ops._matrix_ops import is_tensor_shardable + + self_strategy, mat2_strategy = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(mat2_strategy, OpStrategy) + + self_shape = len(self_strategy.shape) + mat2_shape = len(mat2_strategy.shape) + # assert self_shape in + assert len(self_strategy.shape) == 3 + assert len(mat2_strategy.shape) == 2 + + mm_equation = "bmk,kn->bmn" + # mm_equation = "bmk,bmn->kn" + # mm_equation = "bmn,bmk->nk" + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # from IPython import embed; embed(); sys.sdf + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + assert strtg.input_specs is not None + self_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + should_skip = False + for plc in self_spec.placements: + if plc.is_shard() and plc.dim >= len(self_strategy.shape): + should_skip = True + break + + for plc in mat2_spec.placements: + if plc.is_shard() and plc.dim >= len(mat2_strategy.shape): + should_skip = True + break + + if should_skip: + continue + + if is_tensor_shardable(self_strategy.shape, self_spec) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec + ): + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + return mm_strategy + + +@register_opschema_rule(torch.ops.aten.einsum.default) +def einsum_rule(mesh, op_schema): + # inspired from _mm_like_strategy but removing shards on inexisting dimensions + from torch.distributed.tensor._op_schema import TupleStrategy + from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies + from torch.distributed.tensor._ops._matrix_ops import is_tensor_shardable + + mm_equation, mat_strategy = op_schema.args_schema + assert isinstance(mm_equation, str) + assert isinstance(mat_strategy, TupleStrategy) + + assert len(mat_strategy.children) == 2, f"Only two args to einsum supported for now" + + self_strategy, mat2_strategy = mat_strategy.children + + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + assert strtg.input_specs is not None + self_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + should_skip = False + for plc in self_spec.placements: + if plc.is_shard() and plc.dim >= len(self_strategy.shape): + should_skip = True + break + + for plc in mat2_spec.placements: + if plc.is_shard() and plc.dim >= len(mat2_strategy.shape): + should_skip = True + break + + if should_skip: + continue + + if is_tensor_shardable(self_strategy.shape, self_spec) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec + ): + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + return mm_strategy diff --git a/examples/example_ds3.py b/examples/example_ds3.py index a98ccf1b..2c3654f4 100644 --- a/examples/example_ds3.py +++ b/examples/example_ds3.py @@ -28,7 +28,7 @@ def batched_grouped_mm( assert mat2.ndim == 3, f"{mat2.shape}" res = [] for a, off in zip(mat1, offs): - res.append(a, mat2, off) + res.append(torch._grouped_mm(a, mat2, off)) return torch.stack(res, 0) @@ -121,6 +121,31 @@ def gmm_flop( return flop +@register_flop_formula(torch.ops.aten.matmul) +def matmul_flop(a_shape, b_shape, out_shape=None, **kwargs) -> int: + assert len(a_shape) == 3 + assert len(b_shape) == 2 + b, m, k = a_shape + k1, n = b_shape + assert k == k1 + flop = b * m * n * 2 * k + return flop + + +@register_flop_formula(torch.ops.aten.einsum, get_raw=True) +def einsum_flop(eq, tensors, out=None, **kwargs) -> int: + assert len(tensors) == 2 + a_shape, b_shape = [x.shape for x in tensors] + assert len(a_shape) == 3 + assert len(b_shape) == 3 + b, m, k = a_shape + b1, k1, n = b_shape + assert b == b1 + assert k == k1 + flop = b * m * n * 2 * k + return flop + + from torch.distributed.tensor.placement_types import Partial, Replicate, Shard from autoparallel.propagation_rules import register_opschema_rule @@ -1012,7 +1037,7 @@ def init_weights( # mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) mesh = torch.distributed.device_mesh.init_device_mesh( "cuda", - (world_size // 32, 32), + (world_size // 64, 64), mesh_dim_names=( "dp", "ep", @@ -1042,10 +1067,21 @@ def input_fn(): # x_sharding = (Shard(0), Replicate()) x_sharding = (Shard(0), Shard(0)) + mm_nodes = autop.gm.graph.find_nodes( + op="call_function", target=torch.ops.aten.matmul.default + ) + autop.sharding_optimizer.add_node_constraint(mm_nodes[0], x_sharding) + autop.add_input_constraints([x_sharding]) autop.add_output_constraints([x_sharding]) sharding_placement = autop.optimize_placement() + from IPython import embed + + embed() + import sys + + sys.exit() parallel_mod = autop.apply_placement(sharding_placement) # run weight init on our sharded DTensor params From 67542ad1756cdaee9bc59a697bead7f3d220a425 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 9 Aug 2025 05:52:22 +0000 Subject: [PATCH 22/23] Remove sharding rules that have been since moved to PyTorch Gather and scatter_add were merged yesterday in https://github.com/pytorch/pytorch/pull/160140 --- autoparallel/propagation_rules.py | 121 ------------------------------ 1 file changed, 121 deletions(-) diff --git a/autoparallel/propagation_rules.py b/autoparallel/propagation_rules.py index f2d5c877..6eee28e1 100644 --- a/autoparallel/propagation_rules.py +++ b/autoparallel/propagation_rules.py @@ -667,127 +667,6 @@ def index_rule(mesh, op_schema): return out_strat -@register_opschema_rule(torch.ops.aten.sort.stable) -def sort_rule(mesh, op_schema): - op = torch.ops.aten.topk.default - out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ - op - ]( - op_schema - ) - return out_strat - - -@register_opschema_rule(torch.ops.aten.gather.default) -def gather_strategy(mesh, op_schema): - from torch.distributed.tensor._op_schema import PlacementList - from torch.distributed.tensor._ops._embedding_ops import _MaskPartial - from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy - - input_strategy = op_schema.args_schema[0] - dim = op_schema.args_schema[1] - index_strategy = op_schema.args_schema[2] - - input_shape = input_strategy.shape - index_shape = index_strategy.shape - - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, input, index] - # first we always have replicate all for inputs and output - all_replicate: PlacementList = [Replicate()] * 3 - single_mesh_dim_strategies.append(all_replicate) - - # input sharding, input sharded, index accepts mask partial, output follows index - # this only works when the input is sharded on the gather dimension, and - # index has size 1 on the gather dimension - if index_shape[dim] == 1: - index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) - input_sharding: PlacementList = [ - index_partial_placement, - Shard(dim), - index_partial_placement, - ] - single_mesh_dim_strategies.append(input_sharding) - - # index sharding, input replicated, index sharded, output follows index - # this only works when the sharding dimension is the gather dimension - index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] - single_mesh_dim_strategies.append(index_sharding) - - if len(input_shape) == len(index_shape): - for d in range(len(input_shape)): - if d != dim: - sharding = [Shard(d), Shard(d), Shard(d)] - single_mesh_dim_strategies.append(sharding) - - return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=1 - ) - - -@register_opschema_rule(torch.ops.aten.scatter_add.default) -def scatter_add_strategy(mesh, op_schema): - from torch.distributed.tensor._op_schema import PlacementList - - # from torch.distributed.tensor._ops._embedding_ops import _MaskPartial - from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy - - input_strategy = op_schema.args_schema[0] - dim = op_schema.args_schema[1] - index_strategy = op_schema.args_schema[2] - # src_strategy = op_schema.args_schema[3] - - input_shape = input_strategy.shape - index_shape = index_strategy.shape - - single_mesh_dim_strategies = [] - - # placement list stores placements of [output, input, index] - # first we always have replicate all for inputs and output - all_replicate: PlacementList = [Replicate()] * 4 - single_mesh_dim_strategies.append(all_replicate) - - """ - # input sharding, input sharded, index accepts mask partial, output follows index - # this only works when the input is sharded on the gather dimension, and - # index has size 1 on the gather dimension - if index_shape[dim] == 1: - index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) - input_sharding: PlacementList = [ - index_partial_placement, - Shard(dim), - index_partial_placement, - ] - single_mesh_dim_strategies.append(input_sharding) - """ - # index sharding, input replicated, index sharded, output follows index - # this only works when the sharding dimension is the gather dimension - index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim), Shard(dim)] - single_mesh_dim_strategies.append(index_sharding) - - if len(input_shape) == len(index_shape): - for d in range(len(input_shape)): - if d != dim: - sharding = [Shard(d), Shard(d), Shard(d), Shard(d)] - single_mesh_dim_strategies.append(sharding) - - return expand_to_full_mesh_op_strategy( - mesh, op_schema, single_mesh_dim_strategies, input_index=1 - ) - - -@register_opschema_rule(torch.ops.aten.slice_scatter.default) -def slice_scatter_rule(mesh, op_schema): - op = torch.ops.aten.slice_scatter.default - out_strat = torch.distributed.tensor.DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs[ - op - ]( - op_schema - ) - return out_strat - - def sdpa_rule(op, mesh, op_schema): out_strat = get_op_strategy(op, op_schema) # remove wrong context-parallel strategy From 124034e13ca0c3cd1a1bc5d61905515c1a948199 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 4 Sep 2025 08:03:38 +0000 Subject: [PATCH 23/23] Fixes after rebase --- autoparallel/api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index 32e5bbd0..6ea45ec7 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -39,7 +39,7 @@ from .optimize_sharding import ShardingOptimizer from .utils import _get_device_from_mesh -_APPLY_VIEW_MM_VIEW_PATTERN = False +_APPLY_VIEW_MM_VIEW_PATTERN = True def try_convert_fake_to_real(tensors): @@ -279,7 +279,7 @@ def build_model_graph(self): _replace_view_mm_view_with_einsum(gm) # now add aliases nodes to the graph to # give more room for optimizations - _add_alias(gm, version="v1") + _add_alias(gm, version="v2") trace_structured( "artifact", metadata_fn=lambda: {