diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 6f5cb10f1b2..d3b44e0a9ca 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -6,6 +6,8 @@ import logging import operator + +from collections import deque from typing import Any import executorch.backends.vulkan.utils as utils @@ -332,81 +334,111 @@ def trace_node_users_to_constrain_repset( # noqa: C901 search_depth: list[int] | None = None, ) -> utils.TensorRepSet: """ - For an ambiguous repset, try to constrain the repset by tracing the required - repsets of the users of `origin_node`. The idea is to try to find a representation - that can be used the longest without needing user nodes to insert a transition - for its arguments. + BFS over downstream users to constrain an ambiguous repset. Explores all + immediate users at each level before going deeper, so that nearby constrained + ops (e.g. linear requiring width_packed) are discovered before the search + budget is spent on a single deep branch. """ - # Optionally limit the total number of nodes explored to improve export - # time. search_depth is a mutable list so that all branches of a fan-out - # share a single counter, preventing exponential blowup. if self.max_trace_search_depth is not None: if search_depth is None: search_depth = [self.max_trace_search_depth] - search_depth[0] -= 1 - if search_depth[0] <= 0: + + queue: deque[torch.fx.Node] = deque() + queue.append(origin_node) + + while queue: + if repset.is_constrained(): return repset - users_to_trace = origin_node.users + if self.max_trace_search_depth is not None: + search_depth[0] -= 1 + if search_depth[0] <= 0: + return repset + + node = queue.popleft() + + users_to_trace = node.users + + sync_outs_repr = True + if self.is_valid_op_node(node): + sync_outs_repr = self.get_node_cached_repsets(node).sync_outs_repr - sync_outs_repr = True - if self.is_valid_op_node(origin_node): - sync_outs_repr = self.get_node_cached_repsets(origin_node).sync_outs_repr + if utils.num_tensors_in_node(node) > 1 and not sync_outs_repr: + users_to_trace = [] + for usage_node in node.users: + if ( + usage_node.target == operator.getitem + and usage_node.args[1] == 1 + ): + users_to_trace.append(usage_node) - if utils.num_tensors_in_node(origin_node) > 1 and not sync_outs_repr: - users_to_trace = [] - for usage_node in origin_node.users: - if usage_node.target == operator.getitem and usage_node.args[1] == 1: - users_to_trace.append(usage_node) + for usage_node in users_to_trace: + if repset.is_constrained(): + return repset - for usage_node in users_to_trace: - arg_i_in_user = None - for i in range(len(usage_node.args)): - if origin_node == usage_node.args[i]: - arg_i_in_user = i - break + arg_i_in_user = None + for i in range(len(usage_node.args)): + if node == usage_node.args[i]: + arg_i_in_user = i + break - if arg_i_in_user is not None: - repset = self.constrain_repset_with_user( - usage_node, arg_i_in_user, repset, search_depth + if arg_i_in_user is None: + continue + + if not self.is_valid_op_node(usage_node): + continue + + cur_node_repsets = self.get_node_cached_repsets(usage_node) + req_arg_repset = cur_node_repsets.get_arg_repset(arg_i_in_user) + + if not req_arg_repset.any_in_common(repset): + continue + + repset = repset.make_intersect(req_arg_repset) + + repset_propagates_to_output = ( + cur_node_repsets.sync_primary_io_repr + and ( + cur_node_repsets.sync_args_repr + or arg_i_in_user == cur_node_repsets.primary_arg_idx + ) ) - if repset.is_constrained(): - return repset + if repset_propagates_to_output: + queue.append(usage_node) return repset def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> None: """ Attempts to constrain the repset of the argument at index `arg_i` of the op - associated with `op_repsets`. Does this with two stages: - - 1. First, account for any existing representation that has already been determined - for the argument. If no existing representation has been determined, then use - the output repset of the operator that produces the argument. - 2. Then, try to trace through the users of the argument to find a representation - that can be used for as long as possible without needing a transition. + associated with `op_repsets`. Prefers downstream consumers' layout requirements + over the upstream source's existing layout, falling back to the source only when + downstream tracing does not fully constrain the repset. """ - # If forcing fp16, then try to use texture storage whenever possible. This is - # a temporary stopgap measure until all buffer implementations properly account - # for potential overflow of fp16 representation range when doing math in fp16. if self.force_fp16: op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE) - arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i) - op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) - - arg_repset = op_repsets.get_arg_repset(arg_i) - if arg_repset.is_constrained(): - return - + # First, trace downstream users to discover what layout they prefer. arg_node = op_repsets.op_node.args[arg_i] - if isinstance(arg_node, list): arg_node = arg_node[0] - arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset) - op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset) + arg_repset = op_repsets.get_arg_repset(arg_i) + if not arg_repset.is_constrained(): + downstream_repset = self.trace_node_users_to_constrain_repset( + arg_node, arg_repset + ) + op_repsets.try_constrain_with_arg_repset(arg_i, downstream_repset) + + # Fall back to the upstream source's existing layout only if downstream + # tracing did not fully constrain the repset. + arg_repset = op_repsets.get_arg_repset(arg_i) + if not arg_repset.is_constrained(): + arg_source_repset = self.get_arg_tensor_source_repset( + op_repsets.op_node, arg_i + ) + op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None: """