Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 81 additions & 49 deletions backends/vulkan/_passes/tag_memory_meta_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import logging
import operator

from collections import deque
from typing import Any

import executorch.backends.vulkan.utils as utils
Expand Down Expand Up @@ -332,81 +334,111 @@ def trace_node_users_to_constrain_repset( # noqa: C901
search_depth: list[int] | None = None,
) -> utils.TensorRepSet:
"""
For an ambiguous repset, try to constrain the repset by tracing the required
repsets of the users of `origin_node`. The idea is to try to find a representation
that can be used the longest without needing user nodes to insert a transition
for its arguments.
BFS over downstream users to constrain an ambiguous repset. Explores all
immediate users at each level before going deeper, so that nearby constrained
ops (e.g. linear requiring width_packed) are discovered before the search
budget is spent on a single deep branch.
"""
# Optionally limit the total number of nodes explored to improve export
# time. search_depth is a mutable list so that all branches of a fan-out
# share a single counter, preventing exponential blowup.
if self.max_trace_search_depth is not None:
if search_depth is None:
search_depth = [self.max_trace_search_depth]
search_depth[0] -= 1
if search_depth[0] <= 0:

queue: deque[torch.fx.Node] = deque()
queue.append(origin_node)

while queue:
if repset.is_constrained():
return repset

users_to_trace = origin_node.users
if self.max_trace_search_depth is not None:
search_depth[0] -= 1
if search_depth[0] <= 0:
return repset

node = queue.popleft()

users_to_trace = node.users

sync_outs_repr = True
if self.is_valid_op_node(node):
sync_outs_repr = self.get_node_cached_repsets(node).sync_outs_repr

sync_outs_repr = True
if self.is_valid_op_node(origin_node):
sync_outs_repr = self.get_node_cached_repsets(origin_node).sync_outs_repr
if utils.num_tensors_in_node(node) > 1 and not sync_outs_repr:
users_to_trace = []
for usage_node in node.users:
if (
usage_node.target == operator.getitem
and usage_node.args[1] == 1
):
users_to_trace.append(usage_node)

if utils.num_tensors_in_node(origin_node) > 1 and not sync_outs_repr:
users_to_trace = []
for usage_node in origin_node.users:
if usage_node.target == operator.getitem and usage_node.args[1] == 1:
users_to_trace.append(usage_node)
for usage_node in users_to_trace:
if repset.is_constrained():
return repset

for usage_node in users_to_trace:
arg_i_in_user = None
for i in range(len(usage_node.args)):
if origin_node == usage_node.args[i]:
arg_i_in_user = i
break
arg_i_in_user = None
for i in range(len(usage_node.args)):
if node == usage_node.args[i]:
arg_i_in_user = i
break

if arg_i_in_user is not None:
repset = self.constrain_repset_with_user(
usage_node, arg_i_in_user, repset, search_depth
if arg_i_in_user is None:
continue

if not self.is_valid_op_node(usage_node):
continue

cur_node_repsets = self.get_node_cached_repsets(usage_node)
req_arg_repset = cur_node_repsets.get_arg_repset(arg_i_in_user)

if not req_arg_repset.any_in_common(repset):
continue

repset = repset.make_intersect(req_arg_repset)

repset_propagates_to_output = (
cur_node_repsets.sync_primary_io_repr
and (
cur_node_repsets.sync_args_repr
or arg_i_in_user == cur_node_repsets.primary_arg_idx
)
)

if repset.is_constrained():
return repset
if repset_propagates_to_output:
queue.append(usage_node)

return repset

def constrain_op_arg_repset(self, arg_i: int, op_repsets: utils.OpRepSets) -> None:
"""
Attempts to constrain the repset of the argument at index `arg_i` of the op
associated with `op_repsets`. Does this with two stages:

1. First, account for any existing representation that has already been determined
for the argument. If no existing representation has been determined, then use
the output repset of the operator that produces the argument.
2. Then, try to trace through the users of the argument to find a representation
that can be used for as long as possible without needing a transition.
associated with `op_repsets`. Prefers downstream consumers' layout requirements
over the upstream source's existing layout, falling back to the source only when
downstream tracing does not fully constrain the repset.
"""
# If forcing fp16, then try to use texture storage whenever possible. This is
# a temporary stopgap measure until all buffer implementations properly account
# for potential overflow of fp16 representation range when doing math in fp16.
if self.force_fp16:
op_repsets.try_constrain_with_arg_repset(arg_i, utils.ANY_TEXTURE)

arg_source_repset = self.get_arg_tensor_source_repset(op_repsets.op_node, arg_i)
op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset)

arg_repset = op_repsets.get_arg_repset(arg_i)
if arg_repset.is_constrained():
return

# First, trace downstream users to discover what layout they prefer.
arg_node = op_repsets.op_node.args[arg_i]

if isinstance(arg_node, list):
arg_node = arg_node[0]

arg_repset = self.trace_node_users_to_constrain_repset(arg_node, arg_repset)
op_repsets.try_constrain_with_arg_repset(arg_i, arg_repset)
arg_repset = op_repsets.get_arg_repset(arg_i)
if not arg_repset.is_constrained():
downstream_repset = self.trace_node_users_to_constrain_repset(
arg_node, arg_repset
)
op_repsets.try_constrain_with_arg_repset(arg_i, downstream_repset)

# Fall back to the upstream source's existing layout only if downstream
# tracing did not fully constrain the repset.
arg_repset = op_repsets.get_arg_repset(arg_i)
if not arg_repset.is_constrained():
arg_source_repset = self.get_arg_tensor_source_repset(
op_repsets.op_node, arg_i
)
op_repsets.try_constrain_with_arg_repset(arg_i, arg_source_repset)

def constrain_op_out_repset(self, op_repsets: utils.OpRepSets) -> None:
"""
Expand Down
Loading