Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions backends/xnnpack/runtime/graph/graph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#include <executorch/backends/xnnpack/runtime/graph/graph.h>

#include <executorch/backends/xnnpack/runtime/core/variant_util.h>
#include <executorch/runtime/core/error.h>
#include <executorch/runtime/platform/log.h>

#include <cassert>

namespace executorch::backends::xnnpack::graph {

namespace {

void scan_spec(const TensorSpec& spec, uint32_t& max_id) {
for (auto& dim : spec.sizes) {
for (auto& term : dim.coeffs) {
if (term.sym >= max_id) {
max_id = term.sym + 1;
}
}
}
}

void scan_output_spec(const OutputSpec& os, uint32_t& max_id) {
std::visit(
overloaded{
[&](const TensorSpec& s) { scan_spec(s, max_id); },
[&](const std::vector<TensorSpec>& v) {
for (auto& s : v)
scan_spec(s, max_id);
},
},
os);
}

} // namespace

uint32_t Graph::symint_count() const {
uint32_t count = 0;
for (auto& spec : input_specs) {
scan_spec(spec, count);
}
for (auto& node : nodes) {
std::visit(
overloaded{
[](const InputNode&) {},
[](const ConstantNode&) {},
[&](const CallOperatorNode& n) {
scan_output_spec(n.output_specs, count);
},
[&](const CallSubgraphNode& n) {
scan_output_spec(n.output_specs, count);
},
},
node.value);
}
return count;
}

void Graph::update_users() {
for (auto& node : nodes) {
node.users.clear();
}

for (NodeHandle i = 0; i < nodes.size(); ++i) {
std::visit(
overloaded{
[](const InputNode&) {},
[](const ConstantNode&) {},
[&](const CallOperatorNode& n) {
for (auto arg : n.args) {
if (!arg.is_null()) {
nodes[arg.node].users.push_back(i);
}
}
},
[&](const CallSubgraphNode& n) {
for (auto arg : n.args) {
if (!arg.is_null()) {
nodes[arg.node].users.push_back(i);
}
}
},
},
nodes[i].value);
}
}

runtime::Error Graph::compact_nodes() {
std::vector<uint32_t> remap(nodes.size(), UINT32_MAX);
uint32_t new_idx = 0;
for (NodeHandle i = 0; i < nodes.size(); i++) {
if ((nodes[i].flags & NodeFlags::Dead) != NodeFlags::None) {
continue;
}
remap[i] = new_idx++;
}

// Validate that no live node or output references a dead/invalid node before
// mutating any handles, so a failure leaves the graph untouched.
bool valid = true;
auto check_vh = [&](const ValueHandle& vh) {
if (vh.is_null()) {
return;
}
if (vh.node >= remap.size() || remap[vh.node] == UINT32_MAX) {
valid = false;
}
};
for (NodeHandle i = 0; i < nodes.size(); i++) {
if ((nodes[i].flags & NodeFlags::Dead) != NodeFlags::None) {
continue;
}
for (const auto& a : nodes[i].get_args()) {
check_vh(a);
}
}
for (const auto& out : outputs) {
check_vh(out);
}
ET_CHECK_OR_RETURN_ERROR(
valid,
Internal,
"compact_nodes: a live node or output references a dead node");

auto rewrite_vh = [&](ValueHandle& vh) {
if (!vh.is_null()) {
vh.node = remap[vh.node];
}
};

for (NodeHandle i = 0; i < nodes.size(); i++) {
if ((nodes[i].flags & NodeFlags::Dead) != NodeFlags::None) {
continue;
}
std::visit(
overloaded{
[](InputNode&) {},
[](ConstantNode&) {},
[&](CallOperatorNode& n) {
for (auto& a : n.args)
rewrite_vh(a);
},
[&](CallSubgraphNode& n) {
for (auto& a : n.args)
rewrite_vh(a);
},
},
nodes[i].value);
}

for (auto& out : outputs) {
rewrite_vh(out);
}

std::vector<Node> compacted;
compacted.reserve(new_idx);
for (NodeHandle i = 0; i < nodes.size(); i++) {
if ((nodes[i].flags & NodeFlags::Dead) != NodeFlags::None) {
continue;
}
compacted.push_back(std::move(nodes[i]));
}
nodes = std::move(compacted);

update_users();
return runtime::Error::Ok;
}

} // namespace executorch::backends::xnnpack::graph
71 changes: 71 additions & 0 deletions backends/xnnpack/runtime/graph/graph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#pragma once

#include <executorch/backends/xnnpack/runtime/core/variant_util.h>
#include <executorch/backends/xnnpack/runtime/graph/handles.h>
#include <executorch/backends/xnnpack/runtime/graph/node.h>
#include <executorch/backends/xnnpack/runtime/graph/tensor_spec.h>
#include <executorch/runtime/core/error.h>
#include <vector>

namespace executorch::backends::xnnpack::graph {

/*
* Describes a computational graph.
*/
struct Graph {
std::vector<TensorSpec> input_specs;
std::vector<Node> nodes;
std::vector<ValueHandle> outputs;

/* Clean up nodes marked as dead. */
[[nodiscard]] runtime::Error compact_nodes();

/* Returns the number of symints referenced in the graph. */
uint32_t symint_count() const;

/* Regenerate user metadata on nodes. */
void update_users();

/* Retrieve the output specifier for a given node handle. */
inline OutputSpec get_output_spec_for_node(NodeHandle node) const {
return std::visit(
overloaded{
[&](const InputNode& n) -> OutputSpec {
return input_specs.at(n.input);
},
[](const ConstantNode& n) -> OutputSpec {
TensorSpec spec;
spec.dtype = n.tensor->dtype;
spec.sizes.reserve(n.tensor->sizes.size());
for (auto s : n.tensor->sizes) {
spec.sizes.push_back(
DimSizeSpec::constant(static_cast<int64_t>(s)));
}
spec.quant_params = n.quant_params;
return spec;
},
[](const CallOperatorNode& n) -> OutputSpec {
return n.output_specs;
},
[](const CallSubgraphNode& n) -> OutputSpec {
return n.output_specs;
},
},
nodes[node].value);
}

/* Retrieve the tensor spec for a given value handle. */
inline TensorSpec get_tensor_spec(ValueHandle vh) const {
auto spec = get_output_spec_for_node(vh.node);
return std::visit(
overloaded{
[](const TensorSpec& s) -> TensorSpec { return s; },
[&](const std::vector<TensorSpec>& v) -> TensorSpec {
return v.at(vh.output);
},
},
spec);
}
};

} // namespace executorch::backends::xnnpack::graph
102 changes: 102 additions & 0 deletions backends/xnnpack/runtime/graph/graph_builder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include <executorch/backends/xnnpack/runtime/graph/graph_builder.h>

#include <utility>

namespace executorch::backends::xnnpack::graph {

Graph GraphBuilder::build() {
Graph g;
g.input_specs = std::move(input_specs_);
g.nodes = std::move(nodes_);
g.outputs = std::move(outputs_);
return g;
}

ValueHandle GraphBuilder::createInput(TensorSpec spec) {
input_specs_.push_back(std::move(spec));

InputHandle input = next_input_;
next_input_++;

ValueHandle handle{static_cast<uint32_t>(nodes_.size())};
Node node;
node.value = InputNode{input};
nodes_.push_back(std::move(node));
return handle;
}

ValueHandle GraphBuilder::createConstant(
std::shared_ptr<const core::Tensor> tensor,
std::optional<core::QuantParams> quant_params) {
ValueHandle handle{static_cast<uint32_t>(nodes_.size())};
ConstantNode cn;
cn.tensor = std::move(tensor);
cn.quant_params = std::move(quant_params);
Node node;
node.value = std::move(cn);
nodes_.push_back(std::move(node));
return handle;
}

ValueHandle GraphBuilder::createOperator(
Operator op,
TensorSpec output_spec,
ValueHandles args) {
ValueHandle handle{static_cast<uint32_t>(nodes_.size())};
CallOperatorNode con;
con.args = std::move(args);
con.op = op;
con.output_specs = std::move(output_spec);
Node node;
node.value = std::move(con);
nodes_.push_back(std::move(node));
return handle;
}

ValueHandle GraphBuilder::createOperator(
Operator op,
TensorSpec output_spec,
ValueHandles args,
std::vector<ConstantArg> constant_args,
float output_min,
float output_max) {
ValueHandle handle{static_cast<uint32_t>(nodes_.size())};
CallOperatorNode con;
con.args = std::move(args);
con.op = op;
con.output_specs = std::move(output_spec);
con.constant_args = std::move(constant_args);
con.output_min = output_min;
con.output_max = output_max;
Node node;
node.value = std::move(con);
nodes_.push_back(std::move(node));
return handle;
}

ValueHandle GraphBuilder::createOperatorM(
Operator op,
std::vector<TensorSpec> output_specs,
ValueHandles args) {
ValueHandle handle{static_cast<uint32_t>(nodes_.size())};
CallOperatorNode con;
con.args = std::move(args);
con.op = op;
con.output_specs = std::move(output_specs);
Node node;
node.value = std::move(con);
nodes_.push_back(std::move(node));
return handle;
}

OutputHandle GraphBuilder::createOutput(ValueHandle handle) {
OutputHandle output = static_cast<OutputHandle>(outputs_.size());
outputs_.push_back(handle);
return output;
}

SymIntHandle GraphBuilder::createSymInt() {
return next_sym_int_++;
}

} // namespace executorch::backends::xnnpack::graph
Loading
Loading