diff --git a/backends/arm/_passes/dim_maps.py b/backends/arm/_passes/dim_maps.py new file mode 100644 index 00000000000..8f852908893 --- /dev/null +++ b/backends/arm/_passes/dim_maps.py @@ -0,0 +1,650 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +from collections import Counter, defaultdict, deque +from dataclasses import dataclass +from typing import cast, Iterable, Sequence + +import sympy # type: ignore[import-untyped] +import torch + +from executorch.exir.dialects._ops import ops as exir_ops +from torch.fx import Node + +_Dim = int | torch.SymInt +_FactorKey = tuple[str, int | str] + + +@dataclass(frozen=True) +class _Factor: + key: _FactorKey + axis: int + + +@dataclass +class _ViewGroups: + source_axis_to_groups: list[list[int]] + target_axis_to_groups: list[list[int]] + group_to_source_axes: dict[int, list[int]] + group_to_target_axes: dict[int, list[int]] + + +def _is_permutation(dims: Sequence[int], rank: int) -> bool: + return sorted(dims) == list(range(rank)) + + +def _normalize_dim(dim: int, rank: int) -> int: + normalized = dim if dim >= 0 else dim + rank + assert 0 <= normalized < rank, f"Invalid dim {dim} for rank {rank}" + return normalized + + +def _normalize_dims(dims: int | Sequence[int], rank: int) -> list[int]: + if isinstance(dims, int): + return [_normalize_dim(dims, rank)] + return [_normalize_dim(dim, rank) for dim in dims] + + +def _normalize_permutation(dims: Sequence[int], rank: int) -> list[int] | None: + if len(dims) != rank: + return None + try: + normalized = [_normalize_dim(dim, rank) for dim in dims] + except AssertionError: + return None + return normalized if _is_permutation(normalized, rank) else None + + +def _extend_permutation_with_singletons( + dims: Sequence[int], shape: Sequence[_Dim] +) -> list[int] | None: + """Extend a partial permutation with missing singleton axes.""" + try: + extended_dims = _normalize_dims(dims, len(shape)) + except AssertionError: + return None + if len(set(extended_dims)) != len(extended_dims): + return None + + missing_dims = [dim for dim in range(len(shape)) if dim not in set(extended_dims)] + if any(not _dim_equals(shape[dim], 1) for dim in missing_dims): + return None + + for dim in reversed(missing_dims): + insert_at = next( + ( + index + for index, existing_dim in enumerate(extended_dims) + if existing_dim > dim + ), + len(extended_dims), + ) + extended_dims.insert(insert_at, dim) + return extended_dims if _is_permutation(extended_dims, len(shape)) else None + + +def _dim_expr(dim: _Dim) -> sympy.Basic: + return sympy.Integer(dim) if isinstance(dim, int) else dim.node.expr + + +def _dim_equals(lhs: _Dim, rhs: _Dim) -> bool: + if isinstance(lhs, int) and isinstance(rhs, int): + return lhs == rhs + return sympy.simplify(_dim_expr(lhs) - _dim_expr(rhs)) == 0 + + +def _factor_int(dim: int) -> list[_FactorKey] | None: + if dim < 1: + return None + factors: list[_FactorKey] = [] + divisor = 2 + while divisor * divisor <= dim: + while dim % divisor == 0: + factors.append(("int", divisor)) + dim //= divisor + divisor += 1 if divisor == 2 else 2 + if dim > 1: + factors.append(("int", dim)) + return factors + + +def _factor_dim(dim: _Dim) -> list[_FactorKey] | None: + if _dim_equals(dim, 1): + return [] + if isinstance(dim, int): + return _factor_int(dim) + return [("sym", sympy.srepr(_dim_expr(dim)))] + + +def _factor_shape(shape: Sequence[_Dim]) -> list[_Factor] | None: + factors: list[_Factor] = [] + for axis, dim in enumerate(shape): + dim_factors = _factor_dim(dim) + if dim_factors is None: + return None + factors.extend(_Factor(factor, axis) for factor in dim_factors) + return factors + + +def _dedupe(items: Iterable[int]) -> list[int]: + deduped: list[int] = [] + seen: set[int] = set() + for item in items: + if item not in seen: + deduped.append(item) + seen.add(item) + return deduped + + +def numel(shape: Iterable[_Dim]) -> _Dim: + numel: _Dim = 1 + for dim in shape: + numel *= dim + return numel + + +def same_numel(first_shape: Iterable[_Dim], second_shape: Iterable[_Dim]) -> bool: + return _dim_equals(numel(first_shape), numel(second_shape)) + + +class _UnionFind: + def __init__(self, size: int) -> None: + self.parents = list(range(size)) + + def find(self, item: int) -> int: + parent = self.parents[item] + if parent != item: + self.parents[item] = self.find(parent) + return self.parents[item] + + def union(self, first: int, second: int) -> None: + first_root = self.find(first) + second_root = self.find(second) + if first_root != second_root: + self.parents[second_root] = first_root + + +class ViewMap: + """Maps dims before and after a view operator. + + The map models a view by expanding both shapes into ordered prime-factor + streams and finding the permutation between them. Singleton dims are not counted. + For example, the view from [4, 3, 10] to [2, 2, 1, 5, 3, 2] is represented as: + + Source: [4, 3, 10] + Source primes: [2, 2, 3, 2, 5] + Permutation: [0, 1, 4, 2, 3] + Target primes: [2, 2, 5, 3, 2] + Target: [2, 2, 1, 5, 3, 2] + + Dim mappings are derived by unioning factors into groups where + - factors from the same source axis belong to the same group; + e.g. [2,2] [3] [2,5] + - factors from the same target axis belong to the same group; + e.g. [2] [2] [5] [3] [2] + - factors whose source-to-target permutation order crosses belong to the same group. + e.g. [2] [2] [3,5,2] + + The final groups are formed by the union of all conditions, in this case + [2, 2] and [3, 5, 2]. A source dim maps to all target dims that share its + group with any of its factors, and vice versa. + + Additional conditions apply for the map being valid depending on if the mapped dim + is a reduction operator or a permutation operator, as described in the respective methods. + + SymInts are partialy supported by factorizing them as single primes as the true + value is not known, causing potentially fewer valid mappings. + + """ + + def __init__(self, view_node: Node) -> None: + """Build a view map from an FX view_copy node.""" + input_node = view_node.args[0] + assert isinstance(input_node, Node) and ( + view_node.target == exir_ops.edge.aten.view_copy.default + ) + input_val = input_node.meta["val"] + assert isinstance(input_val, torch.Tensor) + + self.source_shape = cast(list[_Dim], list(input_val.shape)) + self.target_shape = list(cast(Sequence[_Dim], view_node.args[1])) + self._groups = self._build_groups(self.source_shape, self.target_shape) + + @classmethod + def from_shapes( + cls, source_shape: Sequence[_Dim], target_shape: Sequence[_Dim] + ) -> ViewMap: + """Build a view map directly from source and target shapes.""" + view_map = cls.__new__(cls) + view_map.source_shape = list(source_shape) + view_map.target_shape = list(target_shape) + view_map._groups = cls._build_groups( + view_map.source_shape, view_map.target_shape + ) + return view_map + + @property + def is_valid_map(self) -> bool: + """Return whether the shapes can be represented by grouped factors.""" + return self._groups is not None + + @property + def source_rank(self) -> int: + """Return the source shape rank.""" + return len(self.source_shape) + + @property + def target_rank(self) -> int: + """Return the target shape rank.""" + return len(self.target_shape) + + def map_dim( + self, + source_dims: int | Sequence[int], + ) -> list[int] | None: + """Map source reduction dims (e.g. `x.sum(dim)`, `x.max(dim)`) to valid target + reduction dims: + + x.op(dims).view(S) == x.view(S').op(mapped_dims) + + Reduction dims are valid only when the selected dims and mapped dims both cover + complete groups. E.g. in the example view [4, 3, 10] -> [2, 2, 1, 5, 3, 2] the + valid maps are. + [0] <=> [0, 1] and [1, 2] <=> [3, 4, 5] + """ + try: + normalized_dims = _normalize_dims(source_dims, self.source_rank) + except AssertionError: + return None + + groups = self._valid_groups() + if not self._is_valid_reduction(normalized_dims, groups.source_axis_to_groups): + return None + + target_dims = self._map_dims( + normalized_dims, + groups.source_axis_to_groups, + groups.group_to_target_axes, + ) + if not target_dims or not self._is_valid_reduction( + target_dims, groups.target_axis_to_groups + ): + return None + return target_dims + + def map_dim_inverse( + self, + target_dims: int | Sequence[int], + ) -> list[int] | None: + """Map target reduction dims to valid source reduction dims, inverse map + of map_dim. + + x.view(S).op(dims) == x.op(mapped_dims).view(S') + + """ + try: + normalized_dims = _normalize_dims(target_dims, self.target_rank) + except AssertionError: + return None + + groups = self._valid_groups() + if not self._is_valid_reduction(normalized_dims, groups.target_axis_to_groups): + return None + + source_dims = self._map_dims( + normalized_dims, + groups.target_axis_to_groups, + groups.group_to_source_axes, + ) + if not source_dims or not self._is_valid_reduction( + source_dims, groups.source_axis_to_groups + ): + return None + return source_dims + + def map_permutation( + self, + source_permutation: Sequence[int], + ) -> list[int] | None: + """Map a source permutation to a valid target permutation. + + Permutation dims have an additional constraint on the order of dims: + Dims are valid only when dims mapped through one group appear as contiguous + increasing blocks dims in both source and target. + + In the example view [4, 3, 10] -> [2, 2, 1, 5, 3, 2], [1, 2, 0] is a valid + permutation, but [2, 0, 1] and [0, 2, 1] are not since [1, 2] belong to the same + group but are not a) contiguous, or b) in increasing order. + + """ + source_permutation = _normalize_permutation( + source_permutation, self.source_rank + ) + if source_permutation is None: + return None + + groups = self._valid_groups() + target_permutation = _extend_permutation_with_singletons( + self._map_dims( + source_permutation, + groups.source_axis_to_groups, + groups.group_to_target_axes, + ), + self.target_shape, + ) + if target_permutation is None: + return None + + return ( + target_permutation + if self._matching_permuted_group_blocks( + source_permutation, + target_permutation, + groups.source_axis_to_groups, + groups.target_axis_to_groups, + ) + else None + ) + + def map_permutation_inverse( + self, + target_permutation: Sequence[int], + ) -> list[int] | None: + """Map a target permutation to a valid source permutation. + + Inverse of map_permutation. + + """ + target_permutation = _normalize_permutation( + target_permutation, self.target_rank + ) + if target_permutation is None: + return None + + groups = self._valid_groups() + source_permutation = _extend_permutation_with_singletons( + self._map_dims( + target_permutation, + groups.target_axis_to_groups, + groups.group_to_source_axes, + ), + self.source_shape, + ) + if source_permutation is None: + return None + + return ( + source_permutation + if self._matching_permuted_group_blocks( + source_permutation, + target_permutation, + groups.source_axis_to_groups, + groups.target_axis_to_groups, + ) + else None + ) + + @staticmethod + def _map_dims( + source_dims: Iterable[int], + source_axis_to_groups: Sequence[Sequence[int]], + group_to_target_axes: dict[int, list[int]], + ) -> list[int]: + return _dedupe( + target_axis + for source_axis in source_dims + for group in source_axis_to_groups[source_axis] + for target_axis in group_to_target_axes[group] + ) + + @staticmethod + def _matching_permuted_group_blocks( + source_permutation: Sequence[int], + target_permutation: Sequence[int], + source_axis_to_groups: Sequence[Sequence[int]], + target_axis_to_groups: Sequence[Sequence[int]], + ) -> bool: + """Return whether source and target permutations consume groups + equally. + """ + closed_groups: set[int] = set() + source_index = 0 + target_index = 0 + + while True: + source_index, source_group = ViewMap._next_group( + source_permutation, source_axis_to_groups, source_index + ) + target_index, target_group = ViewMap._next_group( + target_permutation, target_axis_to_groups, target_index + ) + + if source_group is None or target_group is None: + return source_group is None and target_group is None + if source_group != target_group or source_group in closed_groups: + return False + + source_index, source_axes = ViewMap._consume_group( + source_permutation, + source_axis_to_groups, + source_index, + source_group, + ) + target_index, target_axes = ViewMap._consume_group( + target_permutation, + target_axis_to_groups, + target_index, + target_group, + ) + if source_axes != sorted(source_axes) or target_axes != sorted(target_axes): + return False + + closed_groups.add(source_group) + + @staticmethod + def _next_group( + permutation: Sequence[int], + axis_to_groups: Sequence[Sequence[int]], + index: int, + ) -> tuple[int, int | None]: + """Return the next grouped axis index and group, skipping singletons.""" + while index < len(permutation): + axis = permutation[index] + axis_groups = axis_to_groups[axis] + if not axis_groups: + index += 1 + continue + assert len(axis_groups) == 1 + return index, axis_groups[0] + return index, None + + @staticmethod + def _consume_group( + permutation: Sequence[int], + axis_to_groups: Sequence[Sequence[int]], + index: int, + group: int, + ) -> tuple[int, list[int]]: + """Consume one group block, ignoring singleton axes.""" + axes: list[int] = [] + while index < len(permutation): + axis = permutation[index] + axis_groups = axis_to_groups[axis] + if not axis_groups: + index += 1 + continue + assert len(axis_groups) == 1 + if axis_groups[0] != group: + break + axes.append(axis) + index += 1 + return index, axes + + @staticmethod + def _is_valid_reduction( + normalized_dims: Iterable[int], + axis_to_groups: Sequence[Sequence[int]], + ) -> bool: + """Return whether dims cover every selected group in one shape.""" + normalized_dims = set(normalized_dims) + if not normalized_dims: + return False + + group_to_axes: dict[int, set[int]] = defaultdict(set) + selected_groups: set[int] = set() + for axis, groups in enumerate(axis_to_groups): + for group in groups: + group_to_axes[group].add(axis) + if axis in normalized_dims: + selected_groups.add(group) + + if any(not axis_to_groups[axis] for axis in normalized_dims): + return False + + return all( + group_to_axes[group].issubset(normalized_dims) for group in selected_groups + ) + + @classmethod + def _build_groups( + cls, source_shape: Sequence[_Dim], target_shape: Sequence[_Dim] + ) -> _ViewGroups | None: + """Build source/target axis groups from ordered prime factors.""" + + # Compute orderd prime factorizations of input and output shapes + source_factors = _factor_shape(source_shape) + target_factors = _factor_shape(target_shape) + assert ( + source_factors is not None + and (target_factors is not None) + and Counter(factor.key for factor in source_factors) + == Counter(factor.key for factor in target_factors) + ), "Invalid view shapes" + + # Compute prime factor permutation between input and output shapes + factor_count = len(source_factors) + permutation = cls._find_permutation(source_factors, target_factors) + assert permutation is not None, "Invalid view shapes" + + # Find groups of factors that must be mapped together to preserve view equivalence + union_find = _UnionFind(factor_count) + cls._union_factors_sharing_axes( + union_find, (factor.axis for factor in source_factors) + ) + + cls._union_factors_sharing_axes( + union_find, + ( + target_factors[permutation[source_position]].axis + for source_position in range(factor_count) + ), + ) + + cls._union_crossing_factors(union_find, permutation) + + # Create group data structure + source_axis_groups: list[set[int]] = [set() for _ in source_shape] + target_axis_groups: list[set[int]] = [set() for _ in target_shape] + group_to_source_axes: dict[int, set[int]] = defaultdict(set) + group_to_target_axes: dict[int, set[int]] = defaultdict(set) + + for source_position, source_factor in enumerate(source_factors): + group = union_find.find(source_position) + target_factor = target_factors[permutation[source_position]] + + source_axis_groups[source_factor.axis].add(group) + target_axis_groups[target_factor.axis].add(group) + group_to_source_axes[group].add(source_factor.axis) + group_to_target_axes[group].add(target_factor.axis) + + return _ViewGroups( + source_axis_to_groups=[sorted(groups) for groups in source_axis_groups], + target_axis_to_groups=[sorted(groups) for groups in target_axis_groups], + group_to_source_axes={ + group: sorted(axes) for group, axes in group_to_source_axes.items() + }, + group_to_target_axes={ + group: sorted(axes) for group, axes in group_to_target_axes.items() + }, + ) + + @staticmethod + def _find_permutation( + X: Sequence[_Factor], Y: Sequence[_Factor] + ) -> list[int] | None: + """Computes the permutation from X -> Y, handling duplicates.""" + duplicates: dict[_FactorKey, deque[int]] = defaultdict(deque) + for i, y in enumerate(Y): + duplicates[y.key].append(i) + + permutation: list[int] = [] + for x in X: + positions = duplicates[x.key] + if not positions: + return None + permutation.append(positions.popleft()) + + return permutation + + @staticmethod + def _union_factors_sharing_axes( + union_find: _UnionFind, axes: Iterable[int] + ) -> None: + """Union factor positions that belong to the same axis.""" + first_position_by_axis: dict[int, int] = {} + for position, axis in enumerate(axes): + if axis in first_position_by_axis: + union_find.union(first_position_by_axis[axis], position) + else: + first_position_by_axis[axis] = position + + @staticmethod + def _union_crossing_factors( + union_find: _UnionFind, permutation: Sequence[int] + ) -> None: + """Union factor positions whose target ordering crosses.""" + for first in range(len(permutation)): + for second in range(first + 1, len(permutation)): + if permutation[first] > permutation[second]: + union_find.union(first, second) + + def _valid_groups(self) -> _ViewGroups: + """Return built groups for a valid map.""" + assert self._groups is not None + return self._groups + + +class PermuteMap: + """Maps dims to equivalent dims before and after a permute.""" + + def __init__(self, permute_node: Node) -> None: + permute_dims = permute_node.args[1] + assert isinstance(permute_dims, Sequence) and not isinstance( + permute_dims, (str, bytes) + ) + self.permute_dims = list(cast(Sequence[int], permute_dims)) + + def map_dims(self, dims: int | Sequence[int]) -> list[int]: + """Computes mapped dims s.t. + + x.op(dims).permute(P) == x.permute(P).op(mapped_dims) + + """ + normalized_dims = _normalize_dims(dims, len(self.permute_dims)) + inverse_permute = [0] * len(self.permute_dims) + for target_dim, source_dim in enumerate(self.permute_dims): + inverse_permute[source_dim] = target_dim + return [inverse_permute[dim] for dim in normalized_dims] + + def map_dims_inverse(self, dims: int | Sequence[int]) -> list[int]: + """Computes mapped dims s.t. + + x.permute(P).op(dims) == x.op(mapped_dims).permute(P) + + """ + normalized_dims = _normalize_dims(dims, len(self.permute_dims)) + return [self.permute_dims[dim] for dim in normalized_dims] diff --git a/backends/arm/test/passes/test_dim_maps.py b/backends/arm/test/passes/test_dim_maps.py new file mode 100644 index 00000000000..486fbec060b --- /dev/null +++ b/backends/arm/test/passes/test_dim_maps.py @@ -0,0 +1,614 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import combinations, permutations +from typing import cast, Sequence, TypeVar + +import sympy # type: ignore[import-untyped] +import torch + +from executorch.backends.arm._passes.dim_maps import PermuteMap, ViewMap +from torch.fx.experimental.symbolic_shapes import ShapeEnv + + +_RNG = torch.Generator().manual_seed(0) +_T = TypeVar("_T") +_Dim = int | torch.SymInt +_DimT = TypeVar("_DimT", bound=_Dim) + + +def _make_symint( + shape_env: ShapeEnv, symbol: str, hint: int, min: int = 1, max: int = 64 +) -> torch.SymInt: + symint = shape_env.create_symintnode(sympy.Symbol(symbol), hint=hint) + assert isinstance(symint, torch.SymInt) + shape_env.constrain_symbol_range( + symint.node.expr, compiler_min=min, compiler_max=max + ) + return symint + + +def _numel(shape: list[int]) -> int: + numel = 1 + for dim in shape: + numel *= dim + return numel + + +def _factorizations(numel: int, rank: int) -> list[list[int]]: + shapes: list[list[int]] = [] + + def recurse(remaining: int, remaining_rank: int, shape: list[int]) -> None: + if remaining_rank == 0: + if remaining == 1: + shapes.append(list(shape)) + return + + for dim in range(1, remaining + 1): + if remaining % dim == 0: + shape.append(dim) + recurse(remaining // dim, remaining_rank - 1, shape) + shape.pop() + + recurse(numel, rank, []) + return shapes + + +def _randint(low: int, high: int) -> int: + return int(torch.randint(low, high + 1, (), generator=_RNG).item()) + + +def _choice(choices: list[_T]) -> _T: + return choices[_randint(0, len(choices) - 1)] + + +def _shuffle(values: list[int]) -> None: + indices = torch.randperm(len(values), generator=_RNG).tolist() + values[:] = [values[index] for index in indices] + + +def _random_shape(rank: int, max_dim: int = 4) -> list[int]: + return [_randint(1, max_dim) for _ in range(rank)] + + +def _random_view_shape(numel: int, max_rank: int = 4) -> list[int]: + rank = _randint(1, max_rank) + return _choice(_factorizations(numel, rank)) + + +def _tensor(shape: list[int]) -> torch.Tensor: + return torch.arange(_numel(shape), dtype=torch.float32).reshape(shape) + + +def _inverse_permutation(permutation: list[int]) -> list[int]: + inverse = [0] * len(permutation) + for index, dim in enumerate(permutation): + inverse[dim] = index + return inverse + + +def _permute_map(permutation: list[int]) -> PermuteMap: + graph = torch.fx.Graph() + x = graph.placeholder("x") + permute = graph.call_function(torch.ops.aten.permute.default, args=(x, permutation)) + return PermuteMap(permute) + + +def _all_dim_subsets(rank: int) -> list[list[int]]: + return [ + list(dims) + for subset_size in range(1, rank + 1) + for dims in combinations(range(rank), subset_size) + ] + + +def _reduce_shape(shape: Sequence[_DimT], dims: list[int]) -> list[_DimT]: + reduced_shape = list(shape) + for dim in dims: + reduced_shape[dim] = cast(_DimT, 1) + return reduced_shape + + +def _reduce(tensor: torch.Tensor, dims: list[int]) -> torch.Tensor: + return tensor.sum(dim=tuple(dims), keepdim=True) + + +def _same(lhs: torch.Tensor, rhs: torch.Tensor) -> bool: + return lhs.shape == rhs.shape and torch.equal(lhs, rhs) + + +def _propose_permute_view_swap( + input_shape: Sequence[_DimT], + permutation: list[int], + output_shape: Sequence[_DimT], +) -> tuple[list[_DimT], list[int]] | None: + permuted_shape = [input_shape[dim] for dim in permutation] + view_map = ViewMap.from_shapes(permuted_shape, output_shape) + if not view_map.is_valid_map: + return None + + permuted_axis = _inverse_permutation(permutation) + target_axis_order = view_map.map_permutation(permuted_axis) + if target_axis_order is None: + return None + + return ( + [output_shape[target_axis] for target_axis in target_axis_order], + _inverse_permutation(target_axis_order), + ) + + +def _propose_view_permute_swap( + input_shape: Sequence[_DimT], + view_shape: Sequence[_DimT], + permutation: list[int], +) -> tuple[list[int], list[_DimT]] | None: + view_map = ViewMap.from_shapes(input_shape, view_shape) + if not view_map.is_valid_map: + return None + + mapped_dims = view_map.map_permutation_inverse(permutation) + if mapped_dims is None: + return None + + output_shape = [view_shape[dim] for dim in permutation] + return mapped_dims, output_shape + + +def _propose_reduction_view_swap( + input_shape: Sequence[_DimT], + source_dims: list[int], + view_shape: Sequence[_DimT], +) -> tuple[list[_DimT], list[int]] | None: + view_map = ViewMap.from_shapes(input_shape, view_shape) + if not view_map.is_valid_map: + return None + + target_dims = view_map.map_dim(source_dims) + if target_dims is None: + return None + return list(view_shape), target_dims + + +def _propose_view_reduction_swap( + input_shape: Sequence[_DimT], + view_shape: Sequence[_DimT], + target_dims: list[int], +) -> tuple[list[int], list[_DimT]] | None: + view_map = ViewMap.from_shapes(input_shape, view_shape) + if not view_map.is_valid_map: + return None + + source_dims = view_map.map_dim_inverse(target_dims) + if source_dims is None: + return None + return source_dims, _reduce_shape(view_shape, target_dims) + + +def _bruteforce_permute_view_swaps( + x: torch.Tensor, + permutation: list[int], + output_shape: list[int], +) -> list[tuple[list[int], list[int]]]: + original = x.permute(permutation).reshape(output_shape) + candidates: list[tuple[list[int], list[int]]] = [] + for candidate_permutation in permutations(range(len(output_shape))): + candidate_permutation_list = list(candidate_permutation) + for candidate_shape in _factorizations(_numel(output_shape), len(output_shape)): + candidate = x.reshape(candidate_shape).permute(candidate_permutation_list) + if _same(original, candidate): + candidates.append((candidate_shape, candidate_permutation_list)) + return candidates + + +def _bruteforce_view_permute_swaps( + x: torch.Tensor, + view_shape: list[int], + permutation: list[int], +) -> list[tuple[list[int], list[int]]]: + output_shape = [view_shape[dim] for dim in permutation] + original = x.reshape(view_shape).permute(permutation) + candidates: list[tuple[list[int], list[int]]] = [] + for candidate_permutation in permutations(range(len(x.shape))): + candidate_permutation_list = list(candidate_permutation) + candidate = x.permute(candidate_permutation_list).reshape(output_shape) + if _same(original, candidate): + candidates.append((candidate_permutation_list, output_shape)) + return candidates + + +def _bruteforce_reduction_view_swaps( + x: torch.Tensor, + source_dims: list[int], + view_shape: list[int], +) -> list[tuple[list[int], list[int]]]: + candidates: list[tuple[list[int], list[int]]] = [] + for target_dims in _all_dim_subsets(len(view_shape)): + output_shape = _reduce_shape(view_shape, target_dims) + reduced = _reduce(x, source_dims) + if reduced.numel() != _numel(output_shape): + continue + original = reduced.reshape(output_shape) + candidate = _reduce(x.reshape(view_shape), target_dims) + if _same(original, candidate): + candidates.append((view_shape, target_dims)) + return candidates + + +def _bruteforce_view_reduction_swaps( + x: torch.Tensor, + view_shape: list[int], + target_dims: list[int], +) -> list[tuple[list[int], list[int]]]: + original = _reduce(x.reshape(view_shape), target_dims) + candidates: list[tuple[list[int], list[int]]] = [] + for source_dims in _all_dim_subsets(len(x.shape)): + reduced = _reduce(x, source_dims) + if reduced.numel() != original.numel(): + continue + candidate = reduced.reshape(original.shape) + if _same(original, candidate): + candidates.append((source_dims, list(original.shape))) + return candidates + + +def test_dim_map_maps_split_and_merged_prime_factor_groups() -> None: + view_map = ViewMap.from_shapes([1, 2, 3, 4], [1, 6, 2, 2]) + + assert view_map.is_valid_map + assert view_map.map_dim(0) is None + assert view_map.map_dim(1) is None + assert view_map.map_dim(2) is None + assert view_map.map_dim(3) == [2, 3] + assert view_map.map_dim([1, 2]) == [1] + assert view_map.map_dim([3, 1]) is None + assert view_map.map_dim([3, 1, 2]) == [2, 3, 1] + + assert view_map.map_dim_inverse(0) is None + assert view_map.map_dim_inverse(1) == [1, 2] + assert view_map.map_dim_inverse(2) is None + assert view_map.map_dim_inverse([3, 1, 2]) == [3, 1, 2] + assert view_map.map_dim_inverse([2, 3, 1]) == [3, 1, 2] + + +def test_dim_map_groups_reordered_crossing_prime_factors() -> None: + view_map = ViewMap.from_shapes([2, 3], [3, 2]) + + assert view_map.is_valid_map + assert view_map.map_dim(0) is None + assert view_map.map_dim(1) is None + assert view_map.map_dim([0, 1]) == [0, 1] + assert view_map.map_dim_inverse(0) is None + assert view_map.map_dim_inverse(1) is None + assert view_map.map_dim_inverse([0, 1]) == [0, 1] + + +def test_dim_map_matches_view_map_docstring_example_reduction_dims() -> None: + view_map = ViewMap.from_shapes([4, 3, 10], [2, 2, 1, 5, 3, 2]) + + assert view_map.is_valid_map + assert view_map.map_dim(0) == [0, 1] + assert view_map.map_dim([1, 2]) == [3, 4, 5] + assert view_map.map_dim_inverse([0, 1]) == [0] + assert view_map.map_dim_inverse([3, 4, 5]) == [1, 2] + + assert view_map.map_dim(1) is None + assert view_map.map_dim(2) is None + assert view_map.map_dim_inverse(2) is None + + +def test_dim_map_matches_view_map_docstring_example_permutation_dims() -> None: + view_map = ViewMap.from_shapes([4, 3, 10], [2, 2, 1, 5, 3, 2]) + + assert view_map.map_permutation([1, 2, 0]) == [2, 3, 4, 5, 0, 1] + assert view_map.map_permutation_inverse([2, 3, 4, 5, 0, 1]) == [1, 2, 0] + + assert view_map.map_permutation([0, 2, 1]) is None + assert view_map.map_permutation([2, 0, 1]) is None + + +def test_dim_map_validates_reductions_by_whole_groups() -> None: + view_map = ViewMap.from_shapes([2, 3], [3, 2]) + + assert view_map.map_dim([0]) is None + assert view_map.map_dim_inverse([1]) is None + assert view_map.map_dim([0, 1]) == [0, 1] + assert view_map.map_dim_inverse([0, 1]) == [0, 1] + + +def test_dim_map_validates_permuted_group_blocks() -> None: + view_map = ViewMap.from_shapes([2, 3, 5], [3, 2, 5]) + + assert view_map.map_permutation([0, 1, 2]) == [0, 1, 2] + assert view_map.map_permutation([2, 0, 1]) == [2, 0, 1] + assert view_map.map_permutation_inverse([0, 2, 1]) is None + + merged_view_map = ViewMap.from_shapes([2, 3], [6]) + assert merged_view_map.map_permutation([0, 1]) == [0] + assert merged_view_map.map_permutation([1, 0]) is None + + +def test_extends_mapped_permutation_with_singletons() -> None: + view_map = ViewMap.from_shapes([2, 2], [2, 1, 2]) + assert view_map.map_permutation([0, 1]) == [0, 1, 2] + assert view_map.map_permutation([1, 0]) == [1, 2, 0] + + singleton_view_map = ViewMap.from_shapes([2], [1, 2]) + assert singleton_view_map.map_permutation([0]) == [0, 1] + assert singleton_view_map.map_permutation_inverse([1, 0]) == [0] + + assert view_map.map_permutation([0, 0]) is None + + +def test_dim_map_uses_strict_no_mapping_for_singletons() -> None: + view_map = ViewMap.from_shapes([1, 4], [4]) + + assert view_map.is_valid_map + assert view_map.map_dim(0) is None + assert view_map.map_dim(1) == [0] + assert view_map.map_dim_inverse(0) == [1] + + split_view_map = ViewMap.from_shapes([4], [2, 1, 2]) + assert split_view_map.map_dim(0) == [0, 2] + assert split_view_map.map_dim_inverse(1) is None + assert split_view_map.map_dim_inverse([0, 2]) == [0] + + +def test_dim_map_preserves_symbolic_dimensions_as_prime_factors() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=4) + + view_map = ViewMap.from_shapes([batch, 6], [batch, 2, 3]) + + assert view_map.is_valid_map + assert view_map.map_dim(0) == [0] + assert view_map.map_dim(1) == [1, 2] + assert view_map.map_dim_inverse(0) == [0] + + +def test_dim_map_permute_view_swap_preserves_symbolic_view_shape_dims() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=4) + input_shape: list[_Dim] = [batch, 6] + output_shape: list[_Dim] = [2, 3, batch] + + proposal = _propose_permute_view_swap(input_shape, [1, 0], output_shape) + + assert proposal is not None + view_shape, permutation = proposal + assert isinstance(view_shape[0], torch.SymInt) + assert view_shape[0] is batch + assert view_shape[1:] == [2, 3] + assert permutation == [1, 2, 0] + + +def test_dim_map_view_permute_swap_preserves_symbolic_output_shape_dims() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=4) + input_shape: list[_Dim] = [batch, 6] + view_shape: list[_Dim] = [batch, 2, 3] + + proposal = _propose_view_permute_swap(input_shape, view_shape, [1, 2, 0]) + + assert proposal is not None + permutation, output_shape = proposal + assert permutation == [1, 0] + assert output_shape[:2] == [2, 3] + assert isinstance(output_shape[2], torch.SymInt) + assert output_shape[2] is batch + + +def test_dim_map_reduction_view_swap_preserves_symbolic_view_shape_dims() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=4) + input_shape: list[_Dim] = [batch, 6] + view_shape: list[_Dim] = [batch, 2, 3] + + proposal = _propose_reduction_view_swap(input_shape, [1], view_shape) + + assert proposal is not None + view_shape, target_dims = proposal + assert isinstance(view_shape[0], torch.SymInt) + assert view_shape[0] is batch + assert view_shape[1:] == [2, 3] + assert target_dims == [1, 2] + + +def test_dim_map_view_reduction_swap_preserves_symbolic_output_shape_dims() -> None: + shape_env = ShapeEnv() + batch = _make_symint(shape_env, "batch", hint=4) + input_shape: list[_Dim] = [batch, 6] + view_shape: list[_Dim] = [batch, 2, 3] + + proposal = _propose_view_reduction_swap(input_shape, view_shape, [1, 2]) + + assert proposal is not None + source_dims, output_shape = proposal + assert source_dims == [1] + assert isinstance(output_shape[0], torch.SymInt) + assert output_shape[0] is batch + assert output_shape[1:] == [1, 1] + + +def test_permute_map_matches_docstring_reduction_identities() -> None: + input_shape = [2, 3, 5] + permutation = [2, 0, 1] + permute_map = _permute_map(permutation) + x = _tensor(input_shape) + + source_dims = [0, 2] + target_dims = permute_map.map_dims(source_dims) + assert target_dims == [1, 0] + assert _same( + _reduce(x, source_dims).permute(permutation), + _reduce(x.permute(permutation), target_dims), + ) + + target_dims = [0, 2] + source_dims = permute_map.map_dims_inverse(target_dims) + assert source_dims == [2, 1] + assert _same( + _reduce(x.permute(permutation), target_dims), + _reduce(x, source_dims).permute(permutation), + ) + + +def test_dim_map_randomized_permute_view_swaps_match_bruteforce() -> None: + accepted = 0 + rejected = 0 + + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + permutation = list(range(len(input_shape))) + _shuffle(permutation) + output_shape = _random_view_shape(_numel(input_shape), max_rank=4) + x = _tensor(input_shape) + + proposal = _propose_permute_view_swap(input_shape, permutation, output_shape) + brute_force_swaps = _bruteforce_permute_view_swaps(x, permutation, output_shape) + if proposal is None and brute_force_swaps: + proposal = brute_force_swaps[0] + + if proposal is None: + rejected += 1 + assert brute_force_swaps == [] + continue + + accepted += 1 + assert proposal in brute_force_swaps + view_shape, new_permutation = proposal + original = x.permute(permutation).reshape(output_shape) + candidate = x.reshape(view_shape).permute(new_permutation) + assert _same(original, candidate) + + +def test_dim_map_randomized_view_permute_swaps_match_bruteforce() -> None: + accepted = 0 + rejected = 0 + + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + view_shape = _random_view_shape(_numel(input_shape), max_rank=4) + permutation = list(range(len(view_shape))) + _shuffle(permutation) + x = _tensor(input_shape) + + proposal = _propose_view_permute_swap(input_shape, view_shape, permutation) + brute_force_swaps = _bruteforce_view_permute_swaps(x, view_shape, permutation) + if proposal is None and brute_force_swaps: + proposal = brute_force_swaps[0] + + if proposal is None: + rejected += 1 + assert brute_force_swaps == [] + continue + + accepted += 1 + assert proposal in brute_force_swaps + new_permutation, output_shape = proposal + original = x.reshape(view_shape).permute(permutation) + candidate = x.permute(new_permutation).reshape(output_shape) + assert _same(original, candidate) + + +def test_dim_map_randomized_reduction_view_swaps_match_bruteforce() -> None: + accepted = 0 + rejected = 0 + + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + source_dims = _choice(_all_dim_subsets(len(input_shape))) + view_shape = _random_view_shape(_numel(input_shape), max_rank=4) + x = _tensor(input_shape) + + proposal = _propose_reduction_view_swap(input_shape, source_dims, view_shape) + brute_force_swaps = _bruteforce_reduction_view_swaps(x, source_dims, view_shape) + if proposal is None and brute_force_swaps: + proposal = brute_force_swaps[0] + + if proposal is None: + rejected += 1 + assert brute_force_swaps == [] + continue + + accepted += 1 + assert proposal in brute_force_swaps + new_shape, target_dims = proposal + output_shape = _reduce_shape(new_shape, target_dims) + original = _reduce(x, source_dims).reshape(output_shape) + candidate = _reduce(x.reshape(new_shape), target_dims) + assert _same(original, candidate) + + +def test_dim_map_randomized_view_reduction_swaps_match_bruteforce() -> None: + accepted = 0 + rejected = 0 + + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + view_shape = _random_view_shape(_numel(input_shape), max_rank=4) + target_dims = _choice(_all_dim_subsets(len(view_shape))) + x = _tensor(input_shape) + + proposal = _propose_view_reduction_swap(input_shape, view_shape, target_dims) + brute_force_swaps = _bruteforce_view_reduction_swaps(x, view_shape, target_dims) + if proposal is None and brute_force_swaps: + proposal = brute_force_swaps[0] + + if proposal is None: + rejected += 1 + assert brute_force_swaps == [] + continue + + accepted += 1 + assert proposal in brute_force_swaps + source_dims, output_shape = proposal + original = _reduce(x.reshape(view_shape), target_dims) + candidate = _reduce(x, source_dims).reshape(output_shape) + assert _same(original, candidate) + + +def test_permute_map_randomized_reduction_permute_swaps_match_bruteforce() -> None: + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + source_dims = _choice(_all_dim_subsets(len(input_shape))) + permutation = list(range(len(input_shape))) + _shuffle(permutation) + permute_map = _permute_map(permutation) + target_dims = permute_map.map_dims(source_dims) + x = _tensor(input_shape) + + original = _reduce(x, source_dims).permute(permutation) + candidate = _reduce(x.permute(permutation), target_dims) + assert _same(original, candidate) + + brute_force_dims = [ + dims + for dims in _all_dim_subsets(len(input_shape)) + if _same(original, _reduce(x.permute(permutation), dims)) + ] + assert sorted(target_dims) in brute_force_dims + + +def test_permute_map_randomized_permute_reduction_swaps_match_bruteforce() -> None: + for _ in range(80): + input_shape = _random_shape(_randint(1, 4), max_dim=3) + permutation = list(range(len(input_shape))) + _shuffle(permutation) + target_dims = _choice(_all_dim_subsets(len(input_shape))) + permute_map = _permute_map(permutation) + source_dims = permute_map.map_dims_inverse(target_dims) + x = _tensor(input_shape) + + original = _reduce(x.permute(permutation), target_dims) + candidate = _reduce(x, source_dims).permute(permutation) + assert _same(original, candidate) + + brute_force_dims = [ + dims + for dims in _all_dim_subsets(len(input_shape)) + if _same(original, _reduce(x, dims).permute(permutation)) + ] + assert sorted(source_dims) in brute_force_dims