From 538ce0a0eb1bf9d521d5e419fd1a2f6860294448 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Tue, 2 Jun 2026 10:18:11 +0200 Subject: [PATCH] Arm backend: Add dim mapping helpers Dim args such as sum(dim=1) needs to be transformed when swapping place with operators which change shape, i.e. permutes and views. ViewMap and PermuteMap handles and validates these transforms for reduction dims and permute dims. Signed-off-by: Adrian Lundell Change-Id: I634f4494df37294d4ac3397f8457bedfd20f5830 --- backends/arm/_passes/dim_maps.py | 650 ++++++++++++++++++++++ backends/arm/test/passes/test_dim_maps.py | 614 ++++++++++++++++++++ 2 files changed, 1264 insertions(+) create mode 100644 backends/arm/_passes/dim_maps.py create mode 100644 backends/arm/test/passes/test_dim_maps.py diff --git a/backends/arm/_passes/dim_maps.py b/backends/arm/_passes/dim_maps.py new file mode 100644 index 00000000000..8f852908893 --- /dev/null +++ b/backends/arm/_passes/dim_maps.py @@ -0,0 +1,650 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections import Counter, defaultdict, deque +from dataclasses import dataclass +from typing import cast, Iterable, Sequence + +import sympy # type: ignore[import-untyped] +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx import Node + +_Dim = int | torch.SymInt +_FactorKey = tuple[str, int | str] + + +@dataclass(frozen=True) +class _Factor: + key: _FactorKey + axis: int + + +@dataclass +class _ViewGroups: + source_axis_to_groups: list[list[int]] + target_axis_to_groups: list[list[int]] + group_to_source_axes: dict[int, list[int]] + group_to_target_axes: dict[int, list[int]] + + +def _is_permutation(dims: Sequence[int], rank: int) -> bool: + return sorted(dims) == list(range(rank)) + + +def _normalize_dim(dim: int, rank: int) -> int: + normalized = dim if dim >= 0 else dim + rank + assert 0 <= normalized < rank, f"Invalid dim {dim} for rank {rank}" + return normalized + + +def _normalize_dims(dims: int | Sequence[int], rank: int) -> list[int]: + if isinstance(dims, int): + return [_normalize_dim(dims, rank)] + return [_normalize_dim(dim, rank) for dim in dims] + + +def _normalize_permutation(dims: Sequence[int], rank: int) -> list[int] | None: + if len(dims) != rank: + return None + try: + normalized = [_normalize_dim(dim, rank) for dim in dims] + except AssertionError: + return None + return normalized if _is_permutation(normalized, rank) else None + + +def _extend_permutation_with_singletons( + dims: Sequence[int], shape: Sequence[_Dim] +) -> list[int] | None: + """Extend a partial permutation with missing singleton axes.""" + try: + extended_dims = _normalize_dims(dims, len(shape)) + except AssertionError: + return None + if len(set(extended_dims)) != len(extended_dims): + return None + + missing_dims = [dim for dim in range(len(shape)) if dim not in set(extended_dims)] + if any(not _dim_equals(shape[dim], 1) for dim in missing_dims): + return None + + for dim in reversed(missing_dims): + insert_at = next( + ( + index + for index, existing_dim in enumerate(extended_dims) + if existing_dim > dim + ), + len(extended_dims), + ) + extended_dims.insert(insert_at, dim) + return extended_dims if _is_permutation(extended_dims, len(shape)) else None + + +def _dim_expr(dim: _Dim) -> sympy.Basic: + return sympy.Integer(dim) if isinstance(dim, int) else dim.node.expr + + +def _dim_equals(lhs: _Dim, rhs: _Dim) -> bool: + if isinstance(lhs, int) and isinstance(rhs, int): + return lhs == rhs + return sympy.simplify(_dim_expr(lhs) - _dim_expr(rhs)) == 0 + + +def _factor_int(dim: int) -> list[_FactorKey] | None: + if dim < 1: + return None + factors: list[_FactorKey] = [] + divisor = 2 + while divisor * divisor <= dim: + while dim % divisor == 0: + factors.append(("int", divisor)) + dim //= divisor + divisor += 1 if divisor == 2 else 2 + if dim > 1: + factors.append(("int", dim)) + return factors + + +def _factor_dim(dim: _Dim) -> list[_FactorKey] | None: + if _dim_equals(dim, 1): + return [] + if isinstance(dim, int): + return _factor_int(dim) + return [("sym", sympy.srepr(_dim_expr(dim)))] + + +def _factor_shape(shape: Sequence[_Dim]) -> list[_Factor] | None: + factors: list[_Factor] = [] + for axis, dim in enumerate(shape): + dim_factors = _factor_dim(dim) + if dim_factors is None: + return None + factors.extend(_Factor(factor, axis) for factor in dim_factors) + return factors + + +def _dedupe(items: Iterable[int]) -> list[int]: + deduped: list[int] = [] + seen: set[int] = set() + for item in items: + if item not in seen: + deduped.append(item) + seen.add(item) + return deduped + + +def numel(shape: Iterable[_Dim]) -> _Dim: + numel: _Dim = 1 + for dim in shape: + numel *= dim + return numel + + +def same_numel(first_shape: Iterable[_Dim], second_shape: Iterable[_Dim]) -> bool: + return _dim_equals(numel(first_shape), numel(second_shape)) + + +class _UnionFind: + def __init__(self, size: int) -> None: + self.parents = list(range(size)) + + def find(self, item: int) -> int: + parent = self.parents[item] + if parent != item: + self.parents[item] = self.find(parent) + return self.parents[item] + + def union(self, first: int, second: int) -> None: + first_root = self.find(first) + second_root = self.find(second) + if first_root != second_root: + self.parents[second_root] = first_root + + +class ViewMap: + """Maps dims before and after a view operator. + + The map models a view by expanding both shapes into ordered prime-factor + streams and finding the permutation between them. Singleton dims are not counted. + For example, the view from [4, 3, 10] to [2, 2, 1, 5, 3, 2] is represented as: + + Source: [4, 3, 10] + Source primes: [2, 2, 3, 2, 5] + Permutation: [0, 1, 4, 2, 3] + Target primes: [2, 2, 5, 3, 2] + Target: [2, 2, 1, 5, 3, 2] + + Dim mappings are derived by unioning factors into groups where + - factors from the same source axis belong to the same group; + e.g. [2,2] [3] [2,5] + - factors from the same target axis belong to the same group; + e.g. [2] [2] [5] [3] [2] + - factors whose source-to-target permutation order crosses belong to the same group. + e.g. [2] [2] [3,5,2] + + The final groups are formed by the union of all conditions, in this case + [2, 2] and [3, 5, 2]. A source dim maps to all target dims that share its + group with any of its factors, and vice versa. + + Additional conditions apply for the map being valid depending on if the mapped dim + is a reduction operator or a permutation operator, as described in the respective methods. + + SymInts are partialy supported by factorizing them as single primes as the true + value is not known, causing potentially fewer valid mappings. + + """ + + def __init__(self, view_node: Node) -> None: + """Build a view map from an FX view_copy node.""" + input_node = view_node.args[0] + assert isinstance(input_node, Node) and ( + view_node.target == exir_ops.edge.aten.view_copy.default + ) + input_val = input_node.meta["val"] + assert isinstance(input_val, torch.Tensor) + + self.source_shape = cast(list[_Dim], list(input_val.shape)) + self.target_shape = list(cast(Sequence[_Dim], view_node.args[1])) + self._groups = self._build_groups(self.source_shape, self.target_shape) + + @classmethod + def from_shapes( + cls, source_shape: Sequence[_Dim], target_shape: Sequence[_Dim] + ) -> ViewMap: + """Build a view map directly from source and target shapes.""" + view_map = cls.__new__(cls) + view_map.source_shape = list(source_shape) + view_map.target_shape = list(target_shape) + view_map._groups = cls._build_groups( + view_map.source_shape, view_map.target_shape + ) + return view_map + + @property + def is_valid_map(self) -> bool: + """Return whether the shapes can be represented by grouped factors.""" + return self._groups is not None + + @property + def source_rank(self) -> int: + """Return the source shape rank.""" + return len(self.source_shape) + + @property + def target_rank(self) -> int: + """Return the target shape rank.""" + return len(self.target_shape) + + def map_dim( + self, + source_dims: int | Sequence[int], + ) -> list[int] | None: + """Map source reduction dims (e.g. `x.sum(dim)`, `x.max(dim)`) to valid target + reduction dims: + + x.op(dims).view(S) == x.view(S').op(mapped_dims) + + Reduction dims are valid only when the selected dims and mapped dims both cover + complete groups. E.g. in the example view [4, 3, 10] -> [2, 2, 1, 5, 3, 2] the + valid maps are. + [0] <=> [0, 1] and [1, 2] <=> [3, 4, 5] + """ + try: + normalized_dims = _normalize_dims(source_dims, self.source_rank) + except AssertionError: + return None + + groups = self._valid_groups() + if not self._is_valid_reduction(normalized_dims, groups.source_axis_to_groups): + return None + + target_dims = self._map_dims( + normalized_dims, + groups.source_axis_to_groups, + groups.group_to_target_axes, + ) + if not target_dims or not self._is_valid_reduction( + target_dims, groups.target_axis_to_groups + ): + return None + return target_dims + + def map_dim_inverse( + self, + target_dims: int | Sequence[int], + ) -> list[int] | None: + """Map target reduction dims to valid source reduction dims, inverse map + of map_dim. + + x.view(S).op(dims) == x.op(mapped_dims).view(S') + + """ + try: + normalized_dims = _normalize_dims(target_dims, self.target_rank) + except AssertionError: + return None + + groups = self._valid_groups() + if not self._is_valid_reduction(normalized_dims, groups.target_axis_to_groups): + return None + + source_dims = self._map_dims( + normalized_dims, + groups.target_axis_to_groups, + groups.group_to_source_axes, + ) + if not source_dims or not self._is_valid_reduction( + source_dims, groups.source_axis_to_groups + ): + return None + return source_dims + + def map_permutation( + self, + source_permutation: Sequence[int], + ) -> list[int] | None: + """Map a source permutation to a valid target permutation. + + Permutation dims have an additional constraint on the order of dims: + Dims are valid only when dims mapped through one group appear as contiguous + increasing blocks dims in both source and target. + + In the example view [4, 3, 10] -> [2, 2, 1, 5, 3, 2], [1, 2, 0] is a valid + permutation, but [2, 0, 1] and [0, 2, 1] are not since [1, 2] belong to the same + group but are not a) contiguous, or b) in increasing order. + + """ + source_permutation = _normalize_permutation( + source_permutation, self.source_rank + ) + if source_permutation is None: + return None + + groups = self._valid_groups() + target_permutation = _extend_permutation_with_singletons( + self._map_dims( + source_permutation, + groups.source_axis_to_groups, + groups.group_to_target_axes, + ), + self.target_shape, + ) + if target_permutation is None: + return None + + return ( + target_permutation + if self._matching_permuted_group_blocks( + source_permutation, + target_permutation, + groups.source_axis_to_groups, + groups.target_axis_to_groups, + ) + else None + ) + + def map_permutation_inverse( + self, + target_permutation: Sequence[int], + ) -> list[int] | None: + """Map a target permutation to a valid source permutation. + + Inverse of map_permutation. + + """ + target_permutation = _normalize_permutation( + target_permutation, self.target_rank + ) + if target_permutation is None: + return None + + groups = self._valid_groups() + source_permutation = _extend_permutation_with_singletons( + self._map_dims( + target_permutation, + groups.target_axis_to_groups, + groups.group_to_source_axes, + ), + self.source_shape, + ) + if source_permutation is None: + return None + + return ( + source_permutation + if self._matching_permuted_group_blocks( + source_permutation, + target_permutation, + groups.source_axis_to_groups, + groups.target_axis_to_groups, + ) + else None + ) + + @staticmethod + def _map_dims( + source_dims: Iterable[int], + source_axis_to_groups: Sequence[Sequence[int]], + group_to_target_axes: dict[int, list[int]], + ) -> list[int]: + return _dedupe( + target_axis + for source_axis in source_dims + for group in source_axis_to_groups[source_axis] + for target_axis in group_to_target_axes[group] + ) + + @staticmethod + def _matching_permuted_group_blocks( + source_permutation: Sequence[int], + target_permutation: Sequence[int], + source_axis_to_groups: Sequence[Sequence[int]], + target_axis_to_groups: Sequence[Sequence[int]], + ) -> bool: + """Return whether source and target permutations consume groups + equally. + """ + closed_groups: set[int] = set() + source_index = 0 + target_index = 0 + + while True: + source_index, source_group = ViewMap._next_group( + source_permutation, source_axis_to_groups, source_index + ) + target_index, target_group = ViewMap._next_group( + target_permutation, target_axis_to_groups, target_index + ) + + if source_group is None or target_group is None: + return source_group is None and target_group is None + if source_group != target_group or source_group in closed_groups: + return False + + source_index, source_axes = ViewMap._consume_group( + source_permutation, + source_axis_to_groups, + source_index, + source_group, + ) + target_index, target_axes = ViewMap._consume_group( + target_permutation, + target_axis_to_groups, + target_index, + target_group, + ) + if source_axes != sorted(source_axes) or target_axes != sorted(target_axes): + return False + + closed_groups.add(source_group) + + @staticmethod + def _next_group( + permutation: Sequence[int], + axis_to_groups: Sequence[Sequence[int]], + index: int, + ) -> tuple[int, int | None]: + """Return the next grouped axis index and group, skipping singletons.""" + while index < len(permutation): + axis = permutation[index] + axis_groups = axis_to_groups[axis] + if not axis_groups: + index += 1 + continue + assert len(axis_groups) == 1 + return index, axis_groups[0] + return index, None + + @staticmethod + def _consume_group( + permutation: Sequence[int], + axis_to_groups: Sequence[Sequence[int]], + index: int, + group: int, + ) -> tuple[int, list[int]]: + """Consume one group block, ignoring singleton axes.""" + axes: list[int] = [] + while index < len(permutation): + axis = permutation[index] + axis_groups = axis_to_groups[axis] + if not axis_groups: + index += 1 + continue + assert len(axis_groups) == 1 + if axis_groups[0] != group: + break + axes.append(axis) + index += 1 + return index, axes + + @staticmethod + def _is_valid_reduction( + normalized_dims: Iterable[int], + axis_to_groups: Sequence[Sequence[int]], + ) -> bool: + """Return whether dims cover every selected group in one shape.""" + normalized_dims = set(normalized_dims) + if not normalized_dims: + return False + + group_to_axes: dict[int, set[int]] = defaultdict(set) + selected_groups: set[int] = set() + for axis, groups in enumerate(axis_to_groups): + for group in groups: + group_to_axes[group].add(axis) + if axis in normalized_dims: + selected_groups.add(group) + + if any(not axis_to_groups[axis] for axis in normalized_dims): + return False + + return all( + group_to_axes[group].issubset(normalized_dims) for group in selected_groups + ) + + @classmethod + def _build_groups( + cls, source_shape: Sequence[_Dim], target_shape: Sequence[_Dim] + ) -> _ViewGroups | None: + """Build source/target axis groups from ordered prime factors.""" + + # Compute orderd prime factorizations of input and output shapes + source_factors = _factor_shape(source_shape) + target_factors = _factor_shape(target_shape) + assert ( + source_factors is not None + and (target_factors is not None) + and Counter(factor.key for factor in source_factors) + == Counter(factor.key for factor in target_factors) + ), "Invalid view shapes" + + # Compute prime factor permutation between input and output shapes + factor_count = len(source_factors) + permutation = cls._find_permutation(source_factors, target_factors) + assert permutation is not None, "Invalid view shapes" + + # Find groups of factors that must be mapped together to preserve view equivalence + union_find = _UnionFind(factor_count) + cls._union_factors_sharing_axes( + union_find, (factor.axis for factor in source_factors) + ) + + cls._union_factors_sharing_axes( + union_find, + ( + target_factors[permutation[source_position]].axis + for source_position in range(factor_count) + ), + ) + + cls._union_crossing_factors(union_find, permutation) + + # Create group data structure + source_axis_groups: list[set[int]] = [set() for _ in source_shape] + target_axis_groups: list[set[int]] = [set() for _ in target_shape] + group_to_source_axes: dict[int, set[int]] = defaultdict(set) + group_to_target_axes: dict[int, set[int]] = defaultdict(set) + + for source_position, source_factor in enumerate(source_factors): + group = union_find.find(source_position) + target_factor = target_factors[permutation[source_position]] + + source_axis_groups[source_factor.axis].add(group) + target_axis_groups[target_factor.axis].add(group) + group_to_source_axes[group].add(source_factor.axis) + group_to_target_axes[group].add(target_factor.axis) + + return _ViewGroups( + source_axis_to_groups=[sorted(groups) for groups in source_axis_groups], + target_axis_to_groups=[sorted(groups) for groups in target_axis_groups], + group_to_source_axes={ + group: sorted(axes) for group, axes in group_to_source_axes.items() + }, + group_to_target_axes={ + group: sorted(axes) for group, axes in group_to_target_axes.items() + }, + ) + + @staticmethod + def _find_permutation( + X: Sequence[_Factor], Y: Sequence[_Factor] + ) -> list[int] | None: + """Computes the permutation from X -> Y, handling duplicates.""" + duplicates: dict[_FactorKey, deque[int]] = defaultdict(deque) + for i, y in enumerate(Y): + duplicates[y.key].append(i) + + permutation: list[int] = [] + for x in X: + positions = duplicates[x.key] + if not positions: + return None + permutation.append(positions.popleft()) + + return permutation + + @staticmethod + def _union_factors_sharing_axes( + union_find: _UnionFind, axes: Iterable[int] + ) -> None: + """Union factor positions that belong to the same axis.""" + first_position_by_axis: dict[int, int] = {} + for position, axis in enumerate(axes): + if axis in first_position_by_axis: + union_find.union(first_position_by_axis[axis], position) + else: + first_position_by_axis[axis] = position + + @staticmethod + def _union_crossing_factors( + union_find: _UnionFind, permutation: Sequence[int] + ) -> None: + """Union factor positions whose target ordering crosses.""" + for first in range(len(permutation)): + for second in range(first + 1, len(permutation)): + if permutation[first] > permutation[second]: + union_find.union(first, second) + + def _valid_groups(self) -> _ViewGroups: + """Return built groups for a valid map.""" + assert self._groups is not None + return self._groups + + +class PermuteMap: + """Maps dims to equivalent dims before and after a permute.""" + + def __init__(self, permute_node: Node) -> None: + permute_dims = permute_node.args[1] + assert isinstance(permute_dims, Sequence) and not isinstance( + permute_dims, (str, bytes) + ) + self.permute_dims = list(cast(Sequence[int], permute_dims)) + + def map_dims(self, dims: int | Sequence[int]) -> list[int]: + """Computes mapped dims s.t. + + x.op(dims).permute(P) == x.permute(P).op(mapped_dims) + + """ + normalized_dims = _normalize_dims(dims, len(self.permute_dims)) + inverse_permute = [0] * len(self.permute_dims) + for target_dim, source_dim in enumerate(self.permute_dims): + inverse_permute[source_dim] = target_dim + return [inverse_permute[dim] for dim in normalized_dims] + + def map_dims_inverse(self, dims: int | Sequence[int]) -> list[int]: + """Computes mapped dims s.t. + + x.permute(P).op(dims) == x.op(mapped_dims).permute(P) + + """ + normalized_dims = _normalize_dims(dims, len(self.permute_dims)) + return [self.permute_dims[dim] for dim in normalized_dims] diff --git a/backends/arm/test/passes/test_dim_maps.py b/backends/arm/test/passes/test_dim_maps.py new file mode 100644 index 00000000000..486fbec060b --- /dev/null +++ b/backends/arm/test/passes/test_dim_maps.py @@ -0,0 +1,614 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import combinations, permutations +from typing import cast, Sequence, TypeVar + +import sympy # type: ignore[import-untyped] +import torch + +from executorch.backends.arm._passes.dim_maps import PermuteMap, ViewMap +from torch.fx.experimental.symbolic_shapes import ShapeEnv + + +_RNG = torch.Generator().manual_seed(0) +_T = TypeVar("_T") +_Dim = int | torch.SymInt +_DimT = TypeVar("_DimT", bound=_Dim) + + +def _make_symint( + shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64 +) -> torch.SymInt: + symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint) + assert isinstance(symint, torch.SymInt) + shape_env.constrain_symbol_range( + symint.node.expr, compiler_min=min, compiler_max=max + ) + return symint + + +def _numel(shape: list[int]) -> int: + numel = 1 + for dim in shape: + numel *= dim + return numel + + +def _factorizations(numel: int, rank: int) -> list[list[int]]: + shapes: list[list[int]] = [] + + def recurse(remaining: int, remaining_rank: int, shape: list[int]) -> None: + if remaining_rank == 0: + if remaining == 1: + shapes.append(list(shape)) + return + + for dim in range(1, remaining + 1): + if remaining % dim == 0: + shape.append(dim) + recurse(remaining // dim, remaining_rank - 1, shape) + shape.pop() + + recurse(numel, rank, []) + return shapes + + +def _randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=_RNG).item()) + + +def _choice(choices: list[_T]) -> _T: + return choices[_randint(0, len(choices) - 1)] + + +def _shuffle(values: list[int]) -> None: + indices = torch.randperm(len(values), generator=_RNG).tolist() + values[:] = [values[index] for index in indices] + + +def _random_shape(rank: int, max_dim: int = 4) -> list[int]: + return [_randint(1, max_dim) for _ in range(rank)] + + +def _random_view_shape(numel: int, max_rank: int = 4) -> list[int]: + rank = _randint(1, max_rank) + return _choice(_factorizations(numel, rank)) + + +def _tensor(shape: list[int]) -> torch.Tensor: + return torch.arange(_numel(shape), dtype=torch.float32).reshape(shape) + + +def _inverse_permutation(permutation: list[int]) -> list[int]: + inverse = [0] * len(permutation) + for index, dim in enumerate(permutation): + inverse[dim] = index + return inverse + + +def _permute_map(permutation: list[int]) -> PermuteMap: + graph = torch.fx.Graph() + x = graph.placeholder("x") + permute = graph.call_function(torch.ops.aten.permute.default, args=(x, permutation)) + return PermuteMap(permute) + + +def _all_dim_subsets(rank: int) -> list[list[int]]: + return [ + list(dims) + for subset_size in range(1, rank + 1) + for dims in combinations(range(rank), subset_size) + ] + + +def _reduce_shape(shape: Sequence[_DimT], dims: list[int]) -> list[_DimT]: + reduced_shape = list(shape) + for dim in dims: + reduced_shape[dim] = cast(_DimT, 1) + return reduced_shape + + +def _reduce(tensor: torch.Tensor, dims: list[int]) -> torch.Tensor: + return tensor.sum(dim=tuple(dims), keepdim=True) + + +def _same(lhs: torch.Tensor, rhs: torch.Tensor) -> bool: + return lhs.shape == rhs.shape and torch.equal(lhs, rhs) + + +def _propose_permute_view_swap( + input_shape: Sequence[_DimT], + permutation: list[int], + output_shape: Sequence[_DimT], +) -> tuple[list[_DimT], list[int]] | None: + permuted_shape = [input_shape[dim] for dim in permutation] + view_map = ViewMap.from_shapes(permuted_shape, output_shape) + if not view_map.is_valid_map: + return None + + permuted_axis = _inverse_permutation(permutation) + target_axis_order = view_map.map_permutation(permuted_axis) + if target_axis_order is None: + return None + + return ( + [output_shape[target_axis] for target_axis in target_axis_order], + _inverse_permutation(target_axis_order), + ) + + +def _propose_view_permute_swap( + input_shape: Sequence[_DimT], + view_shape: Sequence[_DimT], + permutation: list[int], +) -> tuple[list[int], list[_DimT]] | None: + view_map = ViewMap.from_shapes(input_shape, view_shape) + if not view_map.is_valid_map: + return None + + mapped_dims = view_map.map_permutation_inverse(permutation) + if mapped_dims is None: + return None + + output_shape = [view_shape[dim] for dim in permutation] + return mapped_dims, output_shape + + +def _propose_reduction_view_swap( + input_shape: Sequence[_DimT], + source_dims: list[int], + view_shape: Sequence[_DimT], +) -> tuple[list[_DimT], list[int]] | None: + view_map = ViewMap.from_shapes(input_shape, view_shape) + if not view_map.is_valid_map: + return None + + target_dims = view_map.map_dim(source_dims) + if target_dims is None: + return None + return list(view_shape), target_dims + + +def _propose_view_reduction_swap( + input_shape: Sequence[_DimT], + view_shape: Sequence[_DimT], + target_dims: list[int], +) -> tuple[list[int], list[_DimT]] | None: + view_map = ViewMap.from_shapes(input_shape, view_shape) + if not view_map.is_valid_map: + return None + + source_dims = view_map.map_dim_inverse(target_dims) + if source_dims is None: + return None + return source_dims, _reduce_shape(view_shape, target_dims) + + +def _bruteforce_permute_view_swaps( + x: torch.Tensor, + permutation: list[int], + output_shape: list[int], +) -> list[tuple[list[int], list[int]]]: + original = x.permute(permutation).reshape(output_shape) + candidates: list[tuple[list[int], list[int]]] = [] + for candidate_permutation in permutations(range(len(output_shape))): + candidate_permutation_list = list(candidate_permutation) + for candidate_shape in _factorizations(_numel(output_shape), len(output_shape)): + candidate = x.reshape(candidate_shape).permute(candidate_permutation_list) + if _same(original, candidate): + candidates.append((candidate_shape, candidate_permutation_list)) + return candidates + + +def _bruteforce_view_permute_swaps( + x: torch.Tensor, + view_shape: list[int], + permutation: list[int], +) -> list[tuple[list[int], list[int]]]: + output_shape = [view_shape[dim] for dim in permutation] + original = x.reshape(view_shape).permute(permutation) + candidates: list[tuple[list[int], list[int]]] = [] + for candidate_permutation in permutations(range(len(x.shape))): + candidate_permutation_list = list(candidate_permutation) + candidate = x.permute(candidate_permutation_list).reshape(output_shape) + if _same(original, candidate): + candidates.append((candidate_permutation_list, output_shape)) + return candidates + + +def _bruteforce_reduction_view_swaps( + x: torch.Tensor, + source_dims: list[int], + view_shape: list[int], +) -> list[tuple[list[int], list[int]]]: + candidates: list[tuple[list[int], list[int]]] = [] + for target_dims in _all_dim_subsets(len(view_shape)): + output_shape = _reduce_shape(view_shape, target_dims) + reduced = _reduce(x, source_dims) + if reduced.numel() != _numel(output_shape): + continue + original = reduced.reshape(output_shape) + candidate = _reduce(x.reshape(view_shape), target_dims) + if _same(original, candidate): + candidates.append((view_shape, target_dims)) + return candidates + + +def _bruteforce_view_reduction_swaps( + x: torch.Tensor, + view_shape: list[int], + target_dims: list[int], +) -> list[tuple[list[int], list[int]]]: + original = _reduce(x.reshape(view_shape), target_dims) + candidates: list[tuple[list[int], list[int]]] = [] + for source_dims in _all_dim_subsets(len(x.shape)): + reduced = _reduce(x, source_dims) + if reduced.numel() != original.numel(): + continue + candidate = reduced.reshape(original.shape) + if _same(original, candidate): + candidates.append((source_dims, list(original.shape))) + return candidates + + +def test_dim_map_maps_split_and_merged_prime_factor_groups() -> None: + view_map = ViewMap.from_shapes([1, 2, 3, 4], [1, 6, 2, 2]) + + assert view_map.is_valid_map + assert view_map.map_dim(0) is None + assert view_map.map_dim(1) is None + assert view_map.map_dim(2) is None + assert view_map.map_dim(3) == [2, 3] + assert view_map.map_dim([1, 2]) == [1] + assert view_map.map_dim([3, 1]) is None + assert view_map.map_dim([3, 1, 2]) == [2, 3, 1] + + assert view_map.map_dim_inverse(0) is None + assert view_map.map_dim_inverse(1) == [1, 2] + assert view_map.map_dim_inverse(2) is None + assert view_map.map_dim_inverse([3, 1, 2]) == [3, 1, 2] + assert view_map.map_dim_inverse([2, 3, 1]) == [3, 1, 2] + + +def test_dim_map_groups_reordered_crossing_prime_factors() -> None: + view_map = ViewMap.from_shapes([2, 3], [3, 2]) + + assert view_map.is_valid_map + assert view_map.map_dim(0) is None + assert view_map.map_dim(1) is None + assert view_map.map_dim([0, 1]) == [0, 1] + assert view_map.map_dim_inverse(0) is None + assert view_map.map_dim_inverse(1) is None + assert view_map.map_dim_inverse([0, 1]) == [0, 1] + + +def test_dim_map_matches_view_map_docstring_example_reduction_dims() -> None: + view_map = ViewMap.from_shapes([4, 3, 10], [2, 2, 1, 5, 3, 2]) + + assert view_map.is_valid_map + assert view_map.map_dim(0) == [0, 1] + assert view_map.map_dim([1, 2]) == [3, 4, 5] + assert view_map.map_dim_inverse([0, 1]) == [0] + assert view_map.map_dim_inverse([3, 4, 5]) == [1, 2] + + assert view_map.map_dim(1) is None + assert view_map.map_dim(2) is None + assert view_map.map_dim_inverse(2) is None + + +def test_dim_map_matches_view_map_docstring_example_permutation_dims() -> None: + view_map = ViewMap.from_shapes([4, 3, 10], [2, 2, 1, 5, 3, 2]) + + assert view_map.map_permutation([1, 2, 0]) == [2, 3, 4, 5, 0, 1] + assert view_map.map_permutation_inverse([2, 3, 4, 5, 0, 1]) == [1, 2, 0] + + assert view_map.map_permutation([0, 2, 1]) is None + assert view_map.map_permutation([2, 0, 1]) is None + + +def test_dim_map_validates_reductions_by_whole_groups() -> None: + view_map = ViewMap.from_shapes([2, 3], [3, 2]) + + assert view_map.map_dim([0]) is None + assert view_map.map_dim_inverse([1]) is None + assert view_map.map_dim([0, 1]) == [0, 1] + assert view_map.map_dim_inverse([0, 1]) == [0, 1] + + +def test_dim_map_validates_permuted_group_blocks() -> None: + view_map = ViewMap.from_shapes([2, 3, 5], [3, 2, 5]) + + assert view_map.map_permutation([0, 1, 2]) == [0, 1, 2] + assert view_map.map_permutation([2, 0, 1]) == [2, 0, 1] + assert view_map.map_permutation_inverse([0, 2, 1]) is None + + merged_view_map = ViewMap.from_shapes([2, 3], [6]) + assert merged_view_map.map_permutation([0, 1]) == [0] + assert merged_view_map.map_permutation([1, 0]) is None + + +def test_extends_mapped_permutation_with_singletons() -> None: + view_map = ViewMap.from_shapes([2, 2], [2, 1, 2]) + assert view_map.map_permutation([0, 1]) == [0, 1, 2] + assert view_map.map_permutation([1, 0]) == [1, 2, 0] + + singleton_view_map = ViewMap.from_shapes([2], [1, 2]) + assert singleton_view_map.map_permutation([0]) == [0, 1] + assert singleton_view_map.map_permutation_inverse([1, 0]) == [0] + + assert view_map.map_permutation([0, 0]) is None + + +def test_dim_map_uses_strict_no_mapping_for_singletons() -> None: + view_map = ViewMap.from_shapes([1, 4], [4]) + + assert view_map.is_valid_map + assert view_map.map_dim(0) is None + assert view_map.map_dim(1) == [0] + assert view_map.map_dim_inverse(0) == [1] + + split_view_map = ViewMap.from_shapes([4], [2, 1, 2]) + assert split_view_map.map_dim(0) == [0, 2] + assert split_view_map.map_dim_inverse(1) is None + assert split_view_map.map_dim_inverse([0, 2]) == [0] + + +def test_dim_map_preserves_symbolic_dimensions_as_prime_factors() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=4) + + view_map = ViewMap.from_shapes([batch, 6], [batch, 2, 3]) + + assert view_map.is_valid_map + assert view_map.map_dim(0) == [0] + assert view_map.map_dim(1) == [1, 2] + assert view_map.map_dim_inverse(0) == [0] + + +def test_dim_map_permute_view_swap_preserves_symbolic_view_shape_dims() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=4) + input_shape: list[_Dim] = [batch, 6] + output_shape: list[_Dim] = [2, 3, batch] + + proposal = _propose_permute_view_swap(input_shape, [1, 0], output_shape) + + assert proposal is not None + view_shape, permutation = proposal + assert isinstance(view_shape[0], torch.SymInt) + assert view_shape[0] is batch + assert view_shape[1:] == [2, 3] + assert permutation == [1, 2, 0] + + +def test_dim_map_view_permute_swap_preserves_symbolic_output_shape_dims() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=4) + input_shape: list[_Dim] = [batch, 6] + view_shape: list[_Dim] = [batch, 2, 3] + + proposal = _propose_view_permute_swap(input_shape, view_shape, [1, 2, 0]) + + assert proposal is not None + permutation, output_shape = proposal + assert permutation == [1, 0] + assert output_shape[:2] == [2, 3] + assert isinstance(output_shape[2], torch.SymInt) + assert output_shape[2] is batch + + +def test_dim_map_reduction_view_swap_preserves_symbolic_view_shape_dims() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=4) + input_shape: list[_Dim] = [batch, 6] + view_shape: list[_Dim] = [batch, 2, 3] + + proposal = _propose_reduction_view_swap(input_shape, [1], view_shape) + + assert proposal is not None + view_shape, target_dims = proposal + assert isinstance(view_shape[0], torch.SymInt) + assert view_shape[0] is batch + assert view_shape[1:] == [2, 3] + assert target_dims == [1, 2] + + +def test_dim_map_view_reduction_swap_preserves_symbolic_output_shape_dims() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=4) + input_shape: list[_Dim] = [batch, 6] + view_shape: list[_Dim] = [batch, 2, 3] + + proposal = _propose_view_reduction_swap(input_shape, view_shape, [1, 2]) + + assert proposal is not None + source_dims, output_shape = proposal + assert source_dims == [1] + assert isinstance(output_shape[0], torch.SymInt) + assert output_shape[0] is batch + assert output_shape[1:] == [1, 1] + + +def test_permute_map_matches_docstring_reduction_identities() -> None: + input_shape = [2, 3, 5] + permutation = [2, 0, 1] + permute_map = _permute_map(permutation) + x = _tensor(input_shape) + + source_dims = [0, 2] + target_dims = permute_map.map_dims(source_dims) + assert target_dims == [1, 0] + assert _same( + _reduce(x, source_dims).permute(permutation), + _reduce(x.permute(permutation), target_dims), + ) + + target_dims = [0, 2] + source_dims = permute_map.map_dims_inverse(target_dims) + assert source_dims == [2, 1] + assert _same( + _reduce(x.permute(permutation), target_dims), + _reduce(x, source_dims).permute(permutation), + ) + + +def test_dim_map_randomized_permute_view_swaps_match_bruteforce() -> None: + accepted = 0 + rejected = 0 + + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + permutation = list(range(len(input_shape))) + _shuffle(permutation) + output_shape = _random_view_shape(_numel(input_shape), max_rank=4) + x = _tensor(input_shape) + + proposal = _propose_permute_view_swap(input_shape, permutation, output_shape) + brute_force_swaps = _bruteforce_permute_view_swaps(x, permutation, output_shape) + if proposal is None and brute_force_swaps: + proposal = brute_force_swaps[0] + + if proposal is None: + rejected += 1 + assert brute_force_swaps == [] + continue + + accepted += 1 + assert proposal in brute_force_swaps + view_shape, new_permutation = proposal + original = x.permute(permutation).reshape(output_shape) + candidate = x.reshape(view_shape).permute(new_permutation) + assert _same(original, candidate) + + +def test_dim_map_randomized_view_permute_swaps_match_bruteforce() -> None: + accepted = 0 + rejected = 0 + + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + view_shape = _random_view_shape(_numel(input_shape), max_rank=4) + permutation = list(range(len(view_shape))) + _shuffle(permutation) + x = _tensor(input_shape) + + proposal = _propose_view_permute_swap(input_shape, view_shape, permutation) + brute_force_swaps = _bruteforce_view_permute_swaps(x, view_shape, permutation) + if proposal is None and brute_force_swaps: + proposal = brute_force_swaps[0] + + if proposal is None: + rejected += 1 + assert brute_force_swaps == [] + continue + + accepted += 1 + assert proposal in brute_force_swaps + new_permutation, output_shape = proposal + original = x.reshape(view_shape).permute(permutation) + candidate = x.permute(new_permutation).reshape(output_shape) + assert _same(original, candidate) + + +def test_dim_map_randomized_reduction_view_swaps_match_bruteforce() -> None: + accepted = 0 + rejected = 0 + + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + source_dims = _choice(_all_dim_subsets(len(input_shape))) + view_shape = _random_view_shape(_numel(input_shape), max_rank=4) + x = _tensor(input_shape) + + proposal = _propose_reduction_view_swap(input_shape, source_dims, view_shape) + brute_force_swaps = _bruteforce_reduction_view_swaps(x, source_dims, view_shape) + if proposal is None and brute_force_swaps: + proposal = brute_force_swaps[0] + + if proposal is None: + rejected += 1 + assert brute_force_swaps == [] + continue + + accepted += 1 + assert proposal in brute_force_swaps + new_shape, target_dims = proposal + output_shape = _reduce_shape(new_shape, target_dims) + original = _reduce(x, source_dims).reshape(output_shape) + candidate = _reduce(x.reshape(new_shape), target_dims) + assert _same(original, candidate) + + +def test_dim_map_randomized_view_reduction_swaps_match_bruteforce() -> None: + accepted = 0 + rejected = 0 + + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + view_shape = _random_view_shape(_numel(input_shape), max_rank=4) + target_dims = _choice(_all_dim_subsets(len(view_shape))) + x = _tensor(input_shape) + + proposal = _propose_view_reduction_swap(input_shape, view_shape, target_dims) + brute_force_swaps = _bruteforce_view_reduction_swaps(x, view_shape, target_dims) + if proposal is None and brute_force_swaps: + proposal = brute_force_swaps[0] + + if proposal is None: + rejected += 1 + assert brute_force_swaps == [] + continue + + accepted += 1 + assert proposal in brute_force_swaps + source_dims, output_shape = proposal + original = _reduce(x.reshape(view_shape), target_dims) + candidate = _reduce(x, source_dims).reshape(output_shape) + assert _same(original, candidate) + + +def test_permute_map_randomized_reduction_permute_swaps_match_bruteforce() -> None: + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + source_dims = _choice(_all_dim_subsets(len(input_shape))) + permutation = list(range(len(input_shape))) + _shuffle(permutation) + permute_map = _permute_map(permutation) + target_dims = permute_map.map_dims(source_dims) + x = _tensor(input_shape) + + original = _reduce(x, source_dims).permute(permutation) + candidate = _reduce(x.permute(permutation), target_dims) + assert _same(original, candidate) + + brute_force_dims = [ + dims + for dims in _all_dim_subsets(len(input_shape)) + if _same(original, _reduce(x.permute(permutation), dims)) + ] + assert sorted(target_dims) in brute_force_dims + + +def test_permute_map_randomized_permute_reduction_swaps_match_bruteforce() -> None: + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + permutation = list(range(len(input_shape))) + _shuffle(permutation) + target_dims = _choice(_all_dim_subsets(len(input_shape))) + permute_map = _permute_map(permutation) + source_dims = permute_map.map_dims_inverse(target_dims) + x = _tensor(input_shape) + + original = _reduce(x.permute(permutation), target_dims) + candidate = _reduce(x, source_dims).permute(permutation) + assert _same(original, candidate) + + brute_force_dims = [ + dims + for dims in _all_dim_subsets(len(input_shape)) + if _same(original, _reduce(x, dims).permute(permutation)) + ] + assert sorted(source_dims) in brute_force_dims