From 03983dce061c09ff7dc35b75c7752f044db3dd54 Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 24 Apr 2026 09:46:26 -0700 Subject: [PATCH 1/3] [ET-VK] Prefer downstream layout in TagMemoryMetaPass to reduce transitions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two changes to the layout assignment pass that together reduce layout transitions by ~89% for transformer-style models (73 → 9 for EdgeTAM ViT-S encoder): 1. BFS instead of DFS for downstream user tracing. The old DFS could exhaust the search budget (64 nodes) on one deep branch before discovering a constraining op on a sibling branch. BFS explores all immediate users at each level first, finding nearby layout-constrained ops (e.g. linear requiring width_packed) more reliably. 2. Prefer downstream consumers' layout over upstream source's layout. Previously, if the upstream source already had a representation (e.g. channels_packed from conv2d), that was applied first and locked in the layout via sync_primary_io_repr before downstream tracing could run. Now, downstream users are traced first to discover what layout they prefer, and the upstream source is only used as a fallback when downstream doesn't constrain. For ViT-style transformers, conv2d (patch embedding) forces channels_packed, which previously propagated through all residual connections via flexible ops (layer_norm, add, mul). With downstream-preferred layout, linear ops' width_packed requirement is discovered first, so the entire transformer stack stays width_packed. Transitions only occur at the conv2d↔transformer boundaries. Differential Revision: [D102360203](https://our.internmc.facebook.com/intern/diff/D102360203/) [ghstack-poisoned] --- .../vulkan/_passes/tag_memory_meta_pass.py | 130 +++++++++++------- 1 file changed, 81 insertions(+), 49 deletions(-) diff --git a/backends/vulkan/_passes/tag_memory_meta_pass.py b/backends/vulkan/_passes/tag_memory_meta_pass.py index 6f5cb10f1b2..d3b44e0a9ca 100644 --- a/backends/vulkan/_passes/tag_memory_meta_pass.py +++ b/backends/vulkan/_passes/tag_memory_meta_pass.py @@ -6,6 +6,8 @@ import logging import operator + +from collections import deque from typing import Any import executorch.backends.vulkan.utils as utils @@ -332,81 +334,111 @@ def trace_node_users_to_constrain_repset( # noqa: C901 search_depth: list[int] | None = None, ) -> utils.TensorRepSet: """ - For an ambiguous repset, try to constrain the repset by tracing the required - repsets of the users of `origin_node`. The idea is to try to find a representation - that can be used the longest without needing user nodes to insert a transition - for its arguments. + BFS over downstream users to constrain an ambiguous repset. Explores all + immediate users at each level before going deeper, so that nearby constrained + ops (e.g. linear requiring width_packed) are discovered before the search + budget is spent on a single deep branch. """ - # Optionally limit the total number of nodes explored to improve export - # time. search_depth is a mutable list so that all branches of a fan-out - # share a single counter, preventing exponential blowup. if self.max_trace_search_depth is not None: if search_depth is None: search_depth = [self.max_trace_search_depth] - search_depth[0] -= 1 - if search_depth[0] <= 0: + + queue: deque[torch.fx.Node] = deque() + queue.append(origin_node) + + while queue: + if repset.is_constrained(): return repset - users_to_trace = origin_node.users + if self.max_trace_search_depth is not None: + search_depth[0] -= 1 + if search_depth[0] <= 0: + return repset + + node = queue.popleft() + + users_to_trace = node.users + + sync_outs_repr = True + if self.is_valid_op_node(node): + sync_outs_repr = self.get_node_cached_repsets(node).sync_outs_repr - sync_outs_repr = True - if self.is_valid_op_node(origin_node): - sync_outs_repr = self.get_node_cached_repsets(origin_node).sync_outs_repr + if utils.num_tensors_in_node(node) > 1 and not sync_outs_repr: + users_to_trace = [] + for usage_node in node.users: + if ( + usage_node.target == operator.getitem + and usage_node.args[1] == 1 + ): + users_to_trace.append(usage_node) - if utils.num_tensors_in_node(origin_node) > 1 and not sync_outs_repr: - users_to_trace = [] - for usage_node in origin_node.users: - if usage_node.target == operator.getitem and usage_node.args[1] == 1: - users_to_trace.append(usage_node) + for usage_node in users_to_trace: + if repset.is_constrained(): + return repset - for usage_node in users_to_trace: - arg_i_in_user = None - for i in range(len(usage_node.args)): - if origin_node == usage_node.args[i]: - arg_i_in_user = i - break + arg_i_in_user = None + for i in range(len(usage_node.args)): + if node == usage_node.args[i]: + arg_i_in_user = i + break - if arg_i_in_user is not None: - repset = self.constrain_repset_with_user( - usage_node, arg_i_in_user, repset, search_depth + if arg_i_in_user is None: + continue + + if not self.is_valid_op_node(usage_node): + continue + + cur_node_repsets = self.get_node_cached_repsets(usage_node) + req_arg_repset = cur_node_repsets.get_arg_repset(arg_i_in_user) + + if not req_arg_repset.any_in_common(repset): + continue + + repset = repset.make_intersect(req_arg_repset) + + repset_propagates_to_output = ( + cur_node_repsets.sync_primary_io_repr + and ( + cur_node_repsets.sync_args_repr + or arg_i_in_user == cur_node_repsets.primary_arg_idx + ) ) - if repset.is_constrained(): - return repset + if repset_propagates_to_output: + queue.append(usage_node) return repset def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> None: """ Attempts to constrain the repset of the argument at index `arg_i` of the op - associated with `op_repsets`. Does this with two stages: - - 1. First, account for any existing representation that has already been determined - for the argument. If no existing representation has been determined, then use - the output repset of the operator that produces the argument. - 2. Then, try to trace through the users of the argument to find a representation - that can be used for as long as possible without needing a transition. + associated with `op_repsets`. Prefers downstream consumers' layout requirements + over the upstream source's existing layout, falling back to the source only when + downstream tracing does not fully constrain the repset. """ - # If forcing fp16, then try to use texture storage whenever possible. This is - # a temporary stopgap measure until all buffer implementations properly account - # for potential overflow of fp16 representation range when doing math in fp16. if self.force_fp16: op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE) - arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i) - op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) - - arg_repset = op_repsets.get_arg_repset(arg_i) - if arg_repset.is_constrained(): - return - + # First, trace downstream users to discover what layout they prefer. arg_node = op_repsets.op_node.args[arg_i] - if isinstance(arg_node, list): arg_node = arg_node[0] - arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset) - op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset) + arg_repset = op_repsets.get_arg_repset(arg_i) + if not arg_repset.is_constrained(): + downstream_repset = self.trace_node_users_to_constrain_repset( + arg_node, arg_repset + ) + op_repsets.try_constrain_with_arg_repset(arg_i, downstream_repset) + + # Fall back to the upstream source's existing layout only if downstream + # tracing did not fully constrain the repset. + arg_repset = op_repsets.get_arg_repset(arg_i) + if not arg_repset.is_constrained(): + arg_source_repset = self.get_arg_tensor_source_repset( + op_repsets.op_node, arg_i + ) + op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset) def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None: """ From 6e5042d0697c44a51d8138dca3543d2b126281e7 Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 24 Apr 2026 12:29:41 -0700 Subject: [PATCH 2/3] Update on "[ET-VK] Prefer downstream layout in TagMemoryMetaPass to reduce transitions" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two changes to the layout assignment pass that together reduce layout transitions by ~89% for transformer-style models (73 → 9 for EdgeTAM ViT-S encoder): 1. BFS instead of DFS for downstream user tracing. The old DFS could exhaust the search budget (64 nodes) on one deep branch before discovering a constraining op on a sibling branch. BFS explores all immediate users at each level first, finding nearby layout-constrained ops (e.g. linear requiring width_packed) more reliably. 2. Prefer downstream consumers' layout over upstream source's layout. Previously, if the upstream source already had a representation (e.g. channels_packed from conv2d), that was applied first and locked in the layout via sync_primary_io_repr before downstream tracing could run. Now, downstream users are traced first to discover what layout they prefer, and the upstream source is only used as a fallback when downstream doesn't constrain. For ViT-style transformers, conv2d (patch embedding) forces channels_packed, which previously propagated through all residual connections via flexible ops (layer_norm, add, mul). With downstream-preferred layout, linear ops' width_packed requirement is discovered first, so the entire transformer stack stays width_packed. Transitions only occur at the conv2d↔transformer boundaries. Differential Revision: [D102360203](https://our.internmc.facebook.com/intern/diff/D102360203/) [ghstack-poisoned] From 7f44d91404f1e8ab91b9849ab99271ff96a0b28d Mon Sep 17 00:00:00 2001 From: ssjia Date: Fri, 24 Apr 2026 22:41:35 -0700 Subject: [PATCH 3/3] Update on "[ET-VK] Prefer downstream layout in TagMemoryMetaPass to reduce transitions" MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two changes to the layout assignment pass that together reduce layout transitions by ~89% for transformer-style models (73 → 9 for EdgeTAM ViT-S encoder): 1. BFS instead of DFS for downstream user tracing. The old DFS could exhaust the search budget (64 nodes) on one deep branch before discovering a constraining op on a sibling branch. BFS explores all immediate users at each level first, finding nearby layout-constrained ops (e.g. linear requiring width_packed) more reliably. 2. Prefer downstream consumers' layout over upstream source's layout. Previously, if the upstream source already had a representation (e.g. channels_packed from conv2d), that was applied first and locked in the layout via sync_primary_io_repr before downstream tracing could run. Now, downstream users are traced first to discover what layout they prefer, and the upstream source is only used as a fallback when downstream doesn't constrain. For ViT-style transformers, conv2d (patch embedding) forces channels_packed, which previously propagated through all residual connections via flexible ops (layer_norm, add, mul). With downstream-preferred layout, linear ops' width_packed requirement is discovered first, so the entire transformer stack stays width_packed. Transitions only occur at the conv2d↔transformer boundaries. Differential Revision: [D102360203](https://our.internmc.facebook.com/intern/diff/D102360203/) [ghstack-poisoned]