From 3de977f3afb055dddb80801c08c9bd20acd8730f Mon Sep 17 00:00:00 2001 From: Gregory Comer Date: Fri, 12 Jun 2026 14:30:16 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/xnnpack/CMakeLists.txt | 2 + backends/xnnpack/runtime/plan/partition.cpp | 491 ++++++++++++++++++ backends/xnnpack/runtime/plan/partition.h | 30 ++ backends/xnnpack/runtime/plan/xnn_support.cpp | 118 +++++ backends/xnnpack/runtime/plan/xnn_support.h | 18 + backends/xnnpack/test/CMakeLists.txt | 1 + .../xnnpack/test/runtime/test_partition.cpp | 390 ++++++++++++++ 7 files changed, 1050 insertions(+) create mode 100644 backends/xnnpack/runtime/plan/partition.cpp create mode 100644 backends/xnnpack/runtime/plan/partition.h create mode 100644 backends/xnnpack/runtime/plan/xnn_support.cpp create mode 100644 backends/xnnpack/runtime/plan/xnn_support.h create mode 100644 backends/xnnpack/test/runtime/test_partition.cpp diff --git a/backends/xnnpack/CMakeLists.txt b/backends/xnnpack/CMakeLists.txt index 32bafa4ec59..02fd8373275 100644 --- a/backends/xnnpack/CMakeLists.txt +++ b/backends/xnnpack/CMakeLists.txt @@ -111,6 +111,8 @@ list( backends/xnnpack/runtime/operators/operator.cpp backends/xnnpack/runtime/executor/arena.cpp backends/xnnpack/runtime/executor/shape_env.cpp + backends/xnnpack/runtime/plan/xnn_support.cpp + backends/xnnpack/runtime/plan/partition.cpp ) list(TRANSFORM _xnnpack_backend__srcs PREPEND "${EXECUTORCH_ROOT}/") diff --git a/backends/xnnpack/runtime/plan/partition.cpp b/backends/xnnpack/runtime/plan/partition.cpp new file mode 100644 index 00000000000..13d6fc26b4d --- /dev/null +++ b/backends/xnnpack/runtime/plan/partition.cpp @@ -0,0 +1,491 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +/* + * Partitioning groups the nodes XNNPACK can execute into maximal subgraphs, + * which are later fused into single delegated nodes. The partitions must form a + * valid topological cut: two subgraphs may not depend on each other (even + * transitively through non-delegated nodes), or fusing them would create a + * cycle. + * + * assign_partitions builds the partitions with a topological sweep, pulling + * ready delegated nodes (by input-readiness) into the current partition. When a + * delegated node feeds a non-delegated node (it "escapes" the subgraph), the + * delegated nodes reachable past that boundary are blocked from the current + * partition and deferred to a later one -- this is what keeps the subgraphs + * acyclic. A new partition is started once no more nodes can join. Each node's + * tag holds its partition id (0 = none); fuse_partitions then collapses each + * partition into one subgraph node. + */ + +namespace executorch::backends::xnnpack::plan { + +using namespace graph; + +namespace { + +std::optional take_from_queue( + Graph& graph, + std::deque& queue, + std::vector& block_in_partition, + std::vector& in_edges, + std::vector& deferred, + uint32_t current_partition) { + while (!queue.empty()) { + auto node_handle = queue.front(); + queue.pop_front(); + + auto& node = graph.nodes[node_handle]; + + assert((node.flags & NodeFlags::UseXnnpack) != NodeFlags::None); + + if (node.tag != 0) { + continue; + } + if (in_edges[node_handle] != 0) { + continue; + } + if (block_in_partition[node_handle] == current_partition) { + deferred.push_back(node_handle); + continue; + } + + return node_handle; + } + + return {}; +} + +void update_block_frontier( + Graph& graph, + Node& node, + std::vector& block_in_partition, + std::vector& seen_in_partition, + uint32_t current_partition) { + std::deque queue; + for (auto user_handle : node.users) { + auto& user = graph.nodes[user_handle]; + if ((user.flags & NodeFlags::UseXnnpack) == NodeFlags::None) { + queue.push_back(user_handle); + } + } + + while (!queue.empty()) { + auto handle = queue.front(); + queue.pop_front(); + + if (seen_in_partition[handle] == current_partition) { + continue; + } + seen_in_partition[handle] = current_partition; + + for (auto user_handle : graph.nodes[handle].users) { + auto& user = graph.nodes[user_handle]; + if ((user.flags & NodeFlags::UseXnnpack) != NodeFlags::None) { + block_in_partition[user_handle] = current_partition; + } else { + queue.push_back(user_handle); + } + } + } +} + +} // namespace + +runtime::Result assign_partitions(Graph& graph) { + uint32_t current_partition_id = 0; + + std::deque delegated_escape_queue; + std::deque delegated_noescape_queue; + std::deque non_delegated_queue; + + uint32_t remaining_delegated_node_count = 0; + + std::vector block_in_partition(graph.nodes.size(), 0); + std::vector seen_in_partition(graph.nodes.size(), 0); + std::vector in_edges(graph.nodes.size(), 0); + + for (NodeHandle n = 0u; n < graph.nodes.size(); n++) { + auto& node = graph.nodes[n]; + auto& args = node.get_args(); + + auto has_nondelegated_user = + std::any_of(node.users.begin(), node.users.end(), [&](NodeHandle u) { + return (graph.nodes[u].flags & NodeFlags::UseXnnpack) == + NodeFlags::None; + }); + if (has_nondelegated_user) { + node.flags |= NodeFlags::PassInternal1; + } else { + node.flags &= ~NodeFlags::PassInternal1; + } + + in_edges[n] = + std::count_if(args.begin(), args.end(), [&](const ValueHandle& a) { + if (a.is_null()) + return false; + return !std::holds_alternative( + graph.nodes[a.node].value) && + !std::holds_alternative(graph.nodes[a.node].value); + }); + + if ((node.flags & NodeFlags::UseXnnpack) != NodeFlags::None) { + remaining_delegated_node_count++; + } + + if (in_edges[n] == 0) { + if ((node.flags & NodeFlags::UseXnnpack) != NodeFlags::None) { + if ((node.flags & NodeFlags::PassInternal1) != NodeFlags::None) { + delegated_escape_queue.push_back(n); + } else { + delegated_noescape_queue.push_back(n); + } + } else { + non_delegated_queue.push_back(n); + } + } + } + + while (remaining_delegated_node_count > 0) { + std::vector deferred; + + ET_CHECK_OR_RETURN_ERROR( + current_partition_id != std::numeric_limits::max(), + Internal, + "assign_partitions exceeded the maximum partition count"); + current_partition_id++; + + while (true) { + while (!non_delegated_queue.empty()) { + auto ndh = non_delegated_queue.front(); + non_delegated_queue.pop_front(); + + bool is_input = + std::holds_alternative(graph.nodes[ndh].value) || + std::holds_alternative(graph.nodes[ndh].value); + + for (auto user : graph.nodes[ndh].users) { + if (is_input) { + continue; + } + assert(in_edges[user] > 0); + in_edges[user]--; + if (in_edges[user] == 0) { + if ((graph.nodes[user].flags & NodeFlags::UseXnnpack) != + NodeFlags::None) { + if ((graph.nodes[user].flags & NodeFlags::PassInternal1) != + NodeFlags::None) { + delegated_escape_queue.push_back(user); + } else { + delegated_noescape_queue.push_back(user); + } + } else { + non_delegated_queue.push_back(user); + } + } + } + } + + std::optional nh = take_from_queue( + graph, + delegated_noescape_queue, + block_in_partition, + in_edges, + deferred, + current_partition_id); + + if (!nh) { + nh = take_from_queue( + graph, + delegated_escape_queue, + block_in_partition, + in_edges, + deferred, + current_partition_id); + } + + if (!nh) { + break; + } + + auto& node = graph.nodes[*nh]; + node.tag = current_partition_id; + remaining_delegated_node_count--; + + if ((node.flags & NodeFlags::PassInternal1) != NodeFlags::None) { + update_block_frontier( + graph, + node, + block_in_partition, + seen_in_partition, + current_partition_id); + } + + for (auto user : node.users) { + assert(in_edges[user] > 0); + + in_edges[user]--; + if (in_edges[user] == 0) { + if ((graph.nodes[user].flags & NodeFlags::UseXnnpack) != + NodeFlags::None) { + if ((graph.nodes[user].flags & NodeFlags::PassInternal1) != + NodeFlags::None) { + delegated_escape_queue.push_back(user); + } else { + delegated_noescape_queue.push_back(user); + } + } else { + non_delegated_queue.push_back(user); + } + } + } + } + + for (auto handle : deferred) { + if ((graph.nodes[handle].flags & NodeFlags::PassInternal1) != + NodeFlags::None) { + delegated_escape_queue.push_back(handle); + } else { + delegated_noescape_queue.push_back(handle); + } + } + } + + return current_partition_id; +} + +namespace { + +void tag_xnn_nodes(Graph& graph) { + for (auto& node : graph.nodes) { + auto* op_node = std::get_if(&node.value); + if (op_node && check_xnn_node_support(*op_node, graph) && + !prefer_in_tree_kernel(*op_node, graph)) { + node.flags |= NodeFlags::UseXnnpack; + } + } +} + +} // namespace + +runtime::Error fuse_partitions(Graph& graph, uint32_t partition_count) { + const auto sentinel = std::numeric_limits::max(); + + for (uint32_t p = 1; p <= partition_count; p++) { + std::vector members; + for (NodeHandle n = 0; n < graph.nodes.size(); n++) { + if (graph.nodes[n].tag == p) { + members.push_back(n); + } + } + if (members.empty()) { + continue; + } + + std::vector member_mask(graph.nodes.size(), 0); + for (auto m : members) { + member_mask[m] = 1; + } + auto is_member = [&](NodeHandle h) { return member_mask[h] != 0; }; + + // The fused-node remapping below is keyed per node and assumes each member + // produces a single output, so multi-output nodes are not yet supported. + for (auto m : members) { + auto* op = std::get_if(&graph.nodes[m].value); + if (op && + std::holds_alternative>(op->output_specs)) { + ET_LOG( + Error, + "Multi-output nodes are not supported in XNNPACK partitions"); + return runtime::Error::NotSupported; + } + } + + std::vector ext_inputs; + for (auto m : members) { + for (auto arg : graph.nodes[m].get_args()) { + if (arg.is_null()) + continue; + if (std::holds_alternative(graph.nodes[arg.node].value)) + continue; + if (!is_member(arg.node) && + std::find(ext_inputs.begin(), ext_inputs.end(), arg) == + ext_inputs.end()) { + ext_inputs.push_back(arg); + } + } + } + + std::vector output_members; + for (auto m : members) { + bool external = std::any_of( + graph.nodes[m].users.begin(), + graph.nodes[m].users.end(), + [&](NodeHandle u) { return !is_member(u); }); + if (!external) { + external = std::any_of( + graph.outputs.begin(), + graph.outputs.end(), + [&](const ValueHandle& vh) { return vh.node == m; }); + } + if (external) { + output_members.push_back(m); + } + } + if (output_members.empty()) { + continue; + } + + NodeHandle anchor = members.back(); + + auto subgraph = std::make_unique(); + + std::vector handle_map(graph.nodes.size(), sentinel); + + for (size_t i = 0; i < ext_inputs.size(); i++) { + auto ext_node = ext_inputs[i].node; + if (handle_map[ext_node] == sentinel) { + handle_map[ext_node] = static_cast(subgraph->nodes.size()); + + auto spec = graph.get_tensor_spec(ext_inputs[i]); + subgraph->input_specs.push_back(spec); + + Node node; + node.value = InputNode{static_cast(i)}; + subgraph->nodes.push_back(std::move(node)); + } + } + + for (auto m : members) { + for (auto arg : graph.nodes[m].get_args()) { + if (arg.is_null()) + continue; + if (handle_map[arg.node] != sentinel) + continue; + auto* cn = std::get_if(&graph.nodes[arg.node].value); + if (!cn) + continue; + handle_map[arg.node] = static_cast(subgraph->nodes.size()); + ConstantNode cloned; + cloned.tensor = cn->tensor; + cloned.quant_params = cn->quant_params; + Node node; + node.value = std::move(cloned); + subgraph->nodes.push_back(std::move(node)); + } + } + + { + auto next_pos = static_cast(subgraph->nodes.size()); + for (auto m : members) { + handle_map[m] = next_pos++; + } + } + + for (auto m : members) { + auto* op = std::get_if(&graph.nodes[m].value); + assert(op); + + CallOperatorNode remapped; + remapped.op = op->op; + remapped.output_specs = op->output_specs; + remapped.constant_args = op->constant_args; + remapped.output_min = op->output_min; + remapped.output_max = op->output_max; + for (auto arg : op->args) { + if (arg.is_null()) { + remapped.args.push_back(ValueHandle::null()); + continue; + } + assert(handle_map[arg.node] != sentinel); + remapped.args.push_back(ValueHandle{ + handle_map[arg.node], + arg.output, + }); + } + + Node node; + node.value = std::move(remapped); + subgraph->nodes.push_back(std::move(node)); + } + + std::vector output_index(graph.nodes.size(), sentinel); + std::vector output_specs; + for (auto m : output_members) { + output_index[m] = static_cast(subgraph->outputs.size()); + subgraph->outputs.push_back(ValueHandle{ + handle_map[m], + }); + output_specs.push_back( + std::get(graph.get_output_spec_for_node(m))); + } + + CallSubgraphNode fused; + fused.args = std::move(ext_inputs); + if (output_members.size() == 1) { + fused.output_specs = output_specs[0]; + } else { + fused.output_specs = std::move(output_specs); + } + fused.subgraph = std::move(subgraph); + + graph.nodes[anchor].value = std::move(fused); + graph.nodes[anchor].tag = 0; + graph.nodes[anchor].flags = NodeFlags::None; + + for (NodeHandle n = 0; n < graph.nodes.size(); n++) { + if (is_member(n)) { + continue; + } + auto* op = std::get_if(&graph.nodes[n].value); + if (!op) { + continue; + } + for (auto& arg : op->args) { + if (arg.is_null()) + continue; + if (output_index[arg.node] != sentinel) { + arg = ValueHandle{anchor, output_index[arg.node]}; + } + } + } + for (auto& out : graph.outputs) { + if (output_index[out.node] != sentinel) { + out = ValueHandle{anchor, output_index[out.node]}; + } + } + + for (auto m : members) { + if (m == anchor) { + continue; + } + graph.nodes[m].value = InputNode{0}; + graph.nodes[m].tag = 0; + graph.nodes[m].flags = NodeFlags::Dead; + } + } + + graph.update_users(); + return runtime::Error::Ok; +} + +runtime::Error partition_xnn_subgraphs(Graph& graph) { + graph.update_users(); + tag_xnn_nodes(graph); + ET_UNWRAP(partition_count, assign_partitions(graph)); + ET_CHECK_OK_OR_RETURN_ERROR(fuse_partitions(graph, partition_count)); + ET_CHECK_OK_OR_RETURN_ERROR(graph.compact_nodes()); + return runtime::Error::Ok; +} + +} // namespace executorch::backends::xnnpack::plan diff --git a/backends/xnnpack/runtime/plan/partition.h b/backends/xnnpack/runtime/plan/partition.h new file mode 100644 index 00000000000..0114e37f4ce --- /dev/null +++ b/backends/xnnpack/runtime/plan/partition.h @@ -0,0 +1,30 @@ +#pragma once + +#include +#include +#include + +namespace executorch::backends::xnnpack::plan { + +/* + * Partitions the graph into XNNPACK-delegated subgraphs: tags the nodes + * XNNPACK can run, groups them into partitions, and fuses each partition into + * a single subgraph node. + */ +runtime::Error partition_xnn_subgraphs(graph::Graph& graph); + +/* + * Groups the tagged XNNPACK nodes into partitions, recording each node's + * partition in its tag, and returns the number of partitions. This is + * primarily an internal step in partition_xnn_subgraphs. + */ +runtime::Result assign_partitions(graph::Graph& graph); + +/* + * Replaces the nodes of each of the `partition_count` partitions with a single + * fused subgraph node. This is primarily an internal step in the partitioning + * process. + */ +runtime::Error fuse_partitions(graph::Graph& graph, uint32_t partition_count); + +} // namespace executorch::backends::xnnpack::plan diff --git a/backends/xnnpack/runtime/plan/xnn_support.cpp b/backends/xnnpack/runtime/plan/xnn_support.cpp new file mode 100644 index 00000000000..a5a1b648876 --- /dev/null +++ b/backends/xnnpack/runtime/plan/xnn_support.cpp @@ -0,0 +1,118 @@ +#include + +#include +#include +#include + +namespace executorch::backends::xnnpack::plan { + +namespace { + +using namespace graph; + +bool check_xnn_dtype_support(core::DType dtype) { + switch (dtype) { + case core::DType::Float32: + case core::DType::Float16: + case core::DType::QUInt8: + case core::DType::QInt8: + case core::DType::QInt32: + return true; + default: + return false; + } +} + +bool check_xnn_op_support(Operator op) { + switch (op) { + case Operator::Add: + case Operator::Subtract: + case Operator::Multiply: + case Operator::Divide: + case Operator::Maximum: + case Operator::Minimum: + case Operator::CopySign: + case Operator::SquaredDifference: + case Operator::PReLU: + case Operator::Modulus: + case Operator::Atan2: + case Operator::Pow: + case Operator::Abs: + case Operator::Negate: + case Operator::Clamp: + case Operator::Ceiling: + case Operator::Floor: + case Operator::Round: + case Operator::Square: + case Operator::SquareRoot: + case Operator::ReciprocalSquareRoot: + case Operator::Exp: + case Operator::Log: + case Operator::Sigmoid: + case Operator::Tanh: + case Operator::ELU: + case Operator::GELU: + case Operator::HardSwish: + case Operator::LeakyReLU: + case Operator::Sine: + case Operator::Cosine: + case Operator::Sign: + case Operator::ReLU: + case Operator::Linear: + case Operator::BatchMatrixMultiply: + case Operator::Conv2d: + case Operator::ConvTranspose2d: + case Operator::DepthwiseConv2d: + case Operator::AvgPool2d: + case Operator::AdaptiveAvgPool2d: + case Operator::MaxPool2d: + case Operator::Softmax: + case Operator::Mean: + case Operator::Sum: + case Operator::Reshape: + case Operator::View: + case Operator::Transpose: + case Operator::Permute: + case Operator::Slice: + case Operator::Cat: + case Operator::Unsqueeze: + case Operator::Expand: + case Operator::Clone: + case Operator::Pad: + case Operator::StaticResizeBilinear2D: + case Operator::Quantize: + case Operator::Dequantize: + return true; + default: + return false; + } +} + +} // namespace + +bool check_xnn_node_support(const CallOperatorNode& node, const Graph& graph) { + if (!check_xnn_op_support(node.op)) { + return false; + } + + for (auto& arg : node.args) { + if (arg.is_null()) + continue; + const auto& tensor_spec = graph.get_tensor_spec(arg); + + if (!check_xnn_dtype_support(tensor_spec.dtype)) { + return false; + } + } + + return true; +} + +bool prefer_in_tree_kernel( + const CallOperatorNode& /*node*/, + const Graph& /*graph*/) { + // TODO Add logic here once we have in-tree kernels... + return false; +} + +} // namespace executorch::backends::xnnpack::plan diff --git a/backends/xnnpack/runtime/plan/xnn_support.h b/backends/xnnpack/runtime/plan/xnn_support.h new file mode 100644 index 00000000000..ba3a67ac34d --- /dev/null +++ b/backends/xnnpack/runtime/plan/xnn_support.h @@ -0,0 +1,18 @@ +#pragma once + +#include + +namespace executorch::backends::xnnpack::plan { + +// Returns true if XNNPACK can run the given operator node. +bool check_xnn_node_support( + const graph::CallOperatorNode& node, + const graph::Graph& graph); + +// Returns true if we have a preferred in-tree kernel for this node, meaning it +// should not be delegated to XNNPACK even when XNNPACK supports it. +bool prefer_in_tree_kernel( + const graph::CallOperatorNode& node, + const graph::Graph& graph); + +} // namespace executorch::backends::xnnpack::plan diff --git a/backends/xnnpack/test/CMakeLists.txt b/backends/xnnpack/test/CMakeLists.txt index a9b432d2de4..0e2be1aacc4 100644 --- a/backends/xnnpack/test/CMakeLists.txt +++ b/backends/xnnpack/test/CMakeLists.txt @@ -45,6 +45,7 @@ target_include_directories( set(_graph_runtime_test_srcs runtime/test_quant_params.cpp runtime/test_graph_builder.cpp runtime/test_shape_env.cpp runtime/test_arena.cpp + runtime/test_partition.cpp ) et_cxx_test( diff --git a/backends/xnnpack/test/runtime/test_partition.cpp b/backends/xnnpack/test/runtime/test_partition.cpp new file mode 100644 index 00000000000..3afaa1aa121 --- /dev/null +++ b/backends/xnnpack/test/runtime/test_partition.cpp @@ -0,0 +1,390 @@ +#include + +#include +#include + +using namespace executorch::backends::xnnpack::core; +using namespace executorch::backends::xnnpack::graph; +using namespace executorch::backends::xnnpack::plan; +using executorch::runtime::Error; + +TEST(TestPartition, single_node) { + auto builder = GraphBuilder(); + + auto spec = TensorSpec{ + .dtype = DType::Float32, + .sizes = {DimSizeSpec::constant(1), DimSizeSpec::constant(10)}}; + + auto input_a = builder.createInput(spec); + auto input_b = builder.createInput(spec); + auto add = builder.createOperator(Operator::Add, spec, input_a, input_b); + builder.createOutput(add); + + auto graph = builder.build(); + + // Pre-tag the add node for XNNPACK. + graph.nodes[add.node].flags |= NodeFlags::UseXnnpack; + + graph.update_users(); + auto partition_count_result = assign_partitions(graph); + ASSERT_TRUE(partition_count_result.ok()); + auto partition_count = *partition_count_result; + + // There should be exactly one partition. + EXPECT_EQ(partition_count, 1); + + // The add node should be assigned to partition 1. + EXPECT_EQ(graph.nodes[add.node].tag, 1); + + // Input nodes should not be assigned to any partition. + EXPECT_EQ(graph.nodes[input_a.node].tag, 0); + EXPECT_EQ(graph.nodes[input_b.node].tag, 0); +} + +// Helper to set UseXnnpack on a node. +static void set_xnnpack(Graph& graph, ValueHandle handle) { + graph.nodes[handle.node].flags |= NodeFlags::UseXnnpack; +} + +TEST(TestPartition, sequential_all_delegated) { + // Build a chain: input_a, input_b -> add1 -> add2 -> add3 -> add4 -> output + auto builder = GraphBuilder(); + + auto spec = TensorSpec{ + .dtype = DType::Float32, + .sizes = {DimSizeSpec::constant(1), DimSizeSpec::constant(10)}}; + + auto input_a = builder.createInput(spec); + auto input_b = builder.createInput(spec); + auto add1 = builder.createOperator(Operator::Add, spec, input_a, input_b); + auto add2 = builder.createOperator(Operator::Add, spec, add1, input_b); + auto add3 = builder.createOperator(Operator::Add, spec, add2, input_b); + auto add4 = builder.createOperator(Operator::Add, spec, add3, input_b); + builder.createOutput(add4); + + auto graph = builder.build(); + + set_xnnpack(graph, add1); + set_xnnpack(graph, add2); + set_xnnpack(graph, add3); + set_xnnpack(graph, add4); + + graph.update_users(); + auto partition_count_result = assign_partitions(graph); + ASSERT_TRUE(partition_count_result.ok()); + auto partition_count = *partition_count_result; + + // All delegated nodes should be in a single partition. + EXPECT_EQ(partition_count, 1); + + EXPECT_EQ(graph.nodes[add1.node].tag, 1); + EXPECT_EQ(graph.nodes[add2.node].tag, 1); + EXPECT_EQ(graph.nodes[add3.node].tag, 1); + EXPECT_EQ(graph.nodes[add4.node].tag, 1); + + // Input nodes should not be assigned to any partition. + EXPECT_EQ(graph.nodes[input_a.node].tag, 0); + EXPECT_EQ(graph.nodes[input_b.node].tag, 0); +} + +TEST(TestPartition, sequential_alternating) { + // Chain: input_a, input_b -> add1 (D) -> add2 -> add3 (D) -> add4 -> output + // Delegated nodes are separated by undelegated nodes, so they can't be in + // the same partition without creating a cycle. + auto builder = GraphBuilder(); + + auto spec = TensorSpec{ + .dtype = DType::Float32, + .sizes = {DimSizeSpec::constant(1), DimSizeSpec::constant(10)}}; + + auto input_a = builder.createInput(spec); + auto input_b = builder.createInput(spec); + auto add1 = builder.createOperator(Operator::Add, spec, input_a, input_b); + auto add2 = builder.createOperator(Operator::Add, spec, add1, input_b); + auto add3 = builder.createOperator(Operator::Add, spec, add2, input_b); + auto add4 = builder.createOperator(Operator::Add, spec, add3, input_b); + builder.createOutput(add4); + + auto graph = builder.build(); + + set_xnnpack(graph, add1); + set_xnnpack(graph, add3); + + graph.update_users(); + auto partition_count_result = assign_partitions(graph); + ASSERT_TRUE(partition_count_result.ok()); + auto partition_count = *partition_count_result; + + EXPECT_EQ(partition_count, 2); + + // Each delegated node should be in a separate partition. + EXPECT_EQ(graph.nodes[add1.node].tag, 1); + EXPECT_EQ(graph.nodes[add3.node].tag, 2); + + // Undelegated nodes should not be assigned. + EXPECT_EQ(graph.nodes[add2.node].tag, 0); + EXPECT_EQ(graph.nodes[add4.node].tag, 0); + EXPECT_EQ(graph.nodes[input_a.node].tag, 0); + EXPECT_EQ(graph.nodes[input_b.node].tag, 0); +} + +TEST(TestPartition, diamond_skip_connection) { + auto builder = GraphBuilder(); + + auto spec = TensorSpec{ + .dtype = DType::Float32, + .sizes = {DimSizeSpec::constant(1), DimSizeSpec::constant(10)}}; + + auto input_a = builder.createInput(spec); + auto input_b = builder.createInput(spec); + auto add1 = builder.createOperator(Operator::Add, spec, input_a, input_b); + auto add2 = builder.createOperator(Operator::Add, spec, add1, input_b); + auto add3 = builder.createOperator(Operator::Add, spec, add1, input_b); + auto add4 = builder.createOperator(Operator::Add, spec, add2, input_b); + auto add5 = builder.createOperator(Operator::Add, spec, add3, input_b); + auto add6 = builder.createOperator(Operator::Add, spec, add4, add5); + builder.createOutput(add6); + + auto graph = builder.build(); + + set_xnnpack(graph, add1); + // add2 is NOT delegated. + set_xnnpack(graph, add3); + set_xnnpack(graph, add4); + set_xnnpack(graph, add5); + set_xnnpack(graph, add6); + + graph.update_users(); + auto partition_count_result = assign_partitions(graph); + ASSERT_TRUE(partition_count_result.ok()); + auto partition_count = *partition_count_result; + + EXPECT_EQ(partition_count, 2); + + // Partition 1: add1, add3, add5 (connected through delegated nodes only). + EXPECT_EQ(graph.nodes[add1.node].tag, 1); + EXPECT_EQ(graph.nodes[add3.node].tag, 1); + EXPECT_EQ(graph.nodes[add5.node].tag, 1); + + // Partition 2: add4, add6 (add4 blocked from partition 1; add6 depends on + // add4). + EXPECT_EQ(graph.nodes[add4.node].tag, 2); + EXPECT_EQ(graph.nodes[add6.node].tag, 2); + + // Non-delegated / inputs unassigned. + EXPECT_EQ(graph.nodes[add2.node].tag, 0); + EXPECT_EQ(graph.nodes[input_a.node].tag, 0); + EXPECT_EQ(graph.nodes[input_b.node].tag, 0); +} + +TEST(TestPartition, converging_delegated) { + auto builder = GraphBuilder(); + + auto spec = TensorSpec{ + .dtype = DType::Float32, + .sizes = {DimSizeSpec::constant(1), DimSizeSpec::constant(10)}}; + + auto input_a = builder.createInput(spec); + auto input_b = builder.createInput(spec); + auto add1 = builder.createOperator(Operator::Add, spec, input_a, input_a); + auto add2 = builder.createOperator(Operator::Add, spec, input_b, input_b); + auto add3 = builder.createOperator(Operator::Add, spec, add1, input_a); + auto add_sink = builder.createOperator(Operator::Add, spec, add2, input_b); + auto add5 = builder.createOperator(Operator::Add, spec, add3, input_a); + auto add6 = builder.createOperator(Operator::Add, spec, add5, add2); + builder.createOutput(add6); + builder.createOutput(add_sink); + + auto graph = builder.build(); + + set_xnnpack(graph, add1); + set_xnnpack(graph, add2); + set_xnnpack(graph, add3); + // add_sink is NOT delegated — makes add2 escape. + set_xnnpack(graph, add5); + set_xnnpack(graph, add6); + + graph.update_users(); + auto partition_count_result = assign_partitions(graph); + ASSERT_TRUE(partition_count_result.ok()); + auto partition_count = *partition_count_result; + + // All delegated nodes should be in one partition. + EXPECT_EQ(partition_count, 1); + EXPECT_EQ(graph.nodes[add1.node].tag, 1); + EXPECT_EQ(graph.nodes[add2.node].tag, 1); + EXPECT_EQ(graph.nodes[add3.node].tag, 1); + EXPECT_EQ(graph.nodes[add5.node].tag, 1); + EXPECT_EQ(graph.nodes[add6.node].tag, 1); + EXPECT_EQ(graph.nodes[add_sink.node].tag, 0); +} + +// --- fuse_partitions tests --- + +TEST(TestFusePartitions, single_node) { + auto builder = GraphBuilder(); + + auto spec = TensorSpec{ + .dtype = DType::Float32, + .sizes = {DimSizeSpec::constant(1), DimSizeSpec::constant(10)}}; + + auto input_a = builder.createInput(spec); + auto input_b = builder.createInput(spec); + auto add = builder.createOperator(Operator::Add, spec, input_a, input_b); + builder.createOutput(add); + + auto graph = builder.build(); + + graph.nodes[add.node].flags |= NodeFlags::UseXnnpack; + + graph.update_users(); + auto partition_count_result = assign_partitions(graph); + ASSERT_TRUE(partition_count_result.ok()); + auto partition_count = *partition_count_result; + + EXPECT_EQ(partition_count, 1); + + ASSERT_EQ(fuse_partitions(graph, partition_count), Error::Ok); + + // The add node should have been replaced with a CallSubgraphNode. + EXPECT_TRUE( + std::holds_alternative(graph.nodes[add.node].value)); + + // The subgraph should contain the original operator. + auto& subgraph_node = std::get(graph.nodes[add.node].value); + EXPECT_NE(subgraph_node.subgraph, nullptr); + + // The fused node's args should be the original inputs. + EXPECT_EQ(subgraph_node.args.size(), 2); + + // Graph outputs should still reference the fused node. + EXPECT_EQ(graph.outputs.size(), 1); + EXPECT_EQ(graph.outputs[0].node, add.node); +} + +TEST(TestFusePartitions, parallel_multi_output) { + auto builder = GraphBuilder(); + + auto spec = TensorSpec{ + .dtype = DType::Float32, + .sizes = {DimSizeSpec::constant(1), DimSizeSpec::constant(10)}}; + + auto input_a = builder.createInput(spec); + auto input_b = builder.createInput(spec); + auto add1 = builder.createOperator(Operator::Add, spec, input_a, input_b); + auto add2 = builder.createOperator(Operator::Add, spec, input_a, input_b); + builder.createOutput(add1); + builder.createOutput(add2); + + auto graph = builder.build(); + + set_xnnpack(graph, add1); + set_xnnpack(graph, add2); + + graph.update_users(); + auto partition_count_result = assign_partitions(graph); + ASSERT_TRUE(partition_count_result.ok()); + auto partition_count = *partition_count_result; + + EXPECT_EQ(partition_count, 1); + EXPECT_EQ(graph.nodes[add1.node].tag, 1); + EXPECT_EQ(graph.nodes[add2.node].tag, 1); + + ASSERT_EQ(fuse_partitions(graph, partition_count), Error::Ok); + + // The anchor (add2) should be a CallSubgraphNode. + EXPECT_TRUE( + std::holds_alternative(graph.nodes[add2.node].value)); + auto& fused = std::get(graph.nodes[add2.node].value); + EXPECT_NE(fused.subgraph, nullptr); + + // The subgraph should have two outputs. + EXPECT_EQ(fused.subgraph->outputs.size(), 2); + + // The fused node should have a multi-output spec. + EXPECT_TRUE( + std::holds_alternative>(fused.output_specs)); + EXPECT_EQ(std::get>(fused.output_specs).size(), 2); + + // graph.outputs[0] (was add1) should now reference the anchor with output 0. + EXPECT_EQ(graph.outputs[0].node, add2.node); + EXPECT_EQ(graph.outputs[0].output, 0); + + // graph.outputs[1] (was add2/anchor) keeps its original reference. + EXPECT_EQ(graph.outputs[1].node, add2.node); + + // --- Compaction --- + // Before compaction, add1 is a Dead tombstone. + EXPECT_EQ((graph.nodes[add1.node].flags & NodeFlags::Dead), NodeFlags::Dead); + + size_t pre_compact_size = graph.nodes.size(); + ASSERT_EQ(graph.compact_nodes(), executorch::runtime::Error::Ok); + + // Dead nodes should be removed. + EXPECT_LT(graph.nodes.size(), pre_compact_size); + + // Graph outputs should still be valid and reference the fused node. + EXPECT_EQ(graph.outputs.size(), 2); + + auto& fused_after = + std::get(graph.nodes[graph.outputs[0].node].value); + EXPECT_NE(fused_after.subgraph, nullptr); + EXPECT_EQ(graph.outputs[0].node, graph.outputs[1].node); +} + +TEST(TestCompactNodes, sequential_chain) { + auto builder = GraphBuilder(); + + auto spec = TensorSpec{ + .dtype = DType::Float32, + .sizes = {DimSizeSpec::constant(1), DimSizeSpec::constant(10)}}; + + auto input_a = builder.createInput(spec); + auto input_b = builder.createInput(spec); + auto add1 = builder.createOperator(Operator::Add, spec, input_a, input_b); + auto add2 = builder.createOperator(Operator::Add, spec, add1, input_b); + auto add3 = builder.createOperator(Operator::Add, spec, add2, input_b); + builder.createOutput(add3); + + auto graph = builder.build(); + size_t original_size = graph.nodes.size(); + + set_xnnpack(graph, add1); + set_xnnpack(graph, add2); + set_xnnpack(graph, add3); + + graph.update_users(); + auto partition_count_result = assign_partitions(graph); + ASSERT_TRUE(partition_count_result.ok()); + auto partition_count = *partition_count_result; + EXPECT_EQ(partition_count, 1); + + ASSERT_EQ(fuse_partitions(graph, partition_count), Error::Ok); + + // add1, add2 are tombstoned; add3 (anchor) is the CallSubgraphNode. + EXPECT_EQ((graph.nodes[add1.node].flags & NodeFlags::Dead), NodeFlags::Dead); + EXPECT_EQ((graph.nodes[add2.node].flags & NodeFlags::Dead), NodeFlags::Dead); + + ASSERT_EQ(graph.compact_nodes(), executorch::runtime::Error::Ok); + + // Two dead nodes removed: inputs (2) + anchor (1) = 3 live nodes. + EXPECT_EQ(graph.nodes.size(), original_size - 2); + + // No Dead nodes remain. + for (size_t i = 0; i < graph.nodes.size(); i++) { + EXPECT_EQ((graph.nodes[i].flags & NodeFlags::Dead), NodeFlags::None) + << "Node " << i << " is still dead after compaction"; + } + + // Graph output should point to a valid CallSubgraphNode. + EXPECT_EQ(graph.outputs.size(), 1); + auto& out_node = graph.nodes[graph.outputs[0].node]; + EXPECT_TRUE(std::holds_alternative(out_node.value)); + + // The fused node's args should reference valid input nodes. + auto& fused = std::get(out_node.value); + for (auto& arg : fused.args) { + EXPECT_LT(arg.node, graph.nodes.size()); + EXPECT_TRUE(std::holds_alternative(graph.nodes[arg.node].value)); + } +}