Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 62 additions & 19 deletions backends/vulkan/runtime/VulkanBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
ComputeGraph* compute_graph = static_cast<ComputeGraph*>(handle);

const size_t num_inputs = compute_graph->inputs().size();
const size_t num_outputs = compute_graph->outputs().size();
bool should_propagate_resize = false;
#ifdef ET_EVENT_TRACER_ENABLED
runtime::EventTracer* event_tracer = context.event_tracer();
Expand All @@ -690,22 +691,51 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
for (size_t i = 0; i < num_inputs; i++) {
const ValueRef iref = compute_graph->inputs()[i].value;
if (compute_graph->val_is_tensor(iref)) {
VK_CHECK_COND(args[i]->isTensor());
bool was_resized =
maybe_resize_input(compute_graph, i, args[i]->toTensor());
should_propagate_resize = should_propagate_resize || was_resized;
compute_graph->maybe_cast_and_copy_into_staging(
compute_graph->inputs()[i].staging,
args[i]->toTensor().const_data_ptr(),
args[i]->toTensor().numel(),
equivalent_scalar_type(args[i]->toTensor().scalar_type()));
if (args[i]->isTensor()) {
bool was_resized =
maybe_resize_input(compute_graph, i, args[i]->toTensor());
should_propagate_resize = should_propagate_resize || was_resized;
compute_graph->maybe_cast_and_copy_into_staging(
compute_graph->inputs()[i].staging,
args[i]->toTensor().const_data_ptr(),
args[i]->toTensor().numel(),
equivalent_scalar_type(args[i]->toTensor().scalar_type()));
} else if (args[i]->isInt() || args[i]->isBool()) {
int64_t val =
args[i]->isInt() ? args[i]->toInt() : (args[i]->toBool() ? 1 : 0);
vkapi::ScalarType tensor_dtype = compute_graph->dtype_of(iref);
if (tensor_dtype == vkapi::kFloat) {
float fval = static_cast<float>(val);
compute_graph->maybe_cast_and_copy_into_staging(
compute_graph->inputs()[i].staging, &fval, 1, vkapi::kFloat);
} else if (tensor_dtype == vkapi::kInt) {
int32_t ival = static_cast<int32_t>(val);
compute_graph->maybe_cast_and_copy_into_staging(
compute_graph->inputs()[i].staging, &ival, 1, vkapi::kInt);
} else {
compute_graph->maybe_cast_and_copy_into_staging(
compute_graph->inputs()[i].staging, &val, 1, vkapi::kLong);
}
} else {
VK_THROW(
"Tensor input[",
i,
"] has unsupported EValue tag ",
static_cast<int>(args[i]->tag));
}
} else if (compute_graph->val_is_symint(iref)) {
VK_CHECK_COND(
args[i]->isTensor(),
"Cannot handle symint arg to graph that is not derived from a "
"scalar tensor at the moment.");
bool was_updated = maybe_update_scalar_tensor(
compute_graph, iref, args[i]->toTensor());
bool was_updated = false;
if (args[i]->isTensor()) {
was_updated = maybe_update_scalar_tensor(
compute_graph, iref, args[i]->toTensor());
} else if (args[i]->isInt()) {
const int32_t new_val = static_cast<int32_t>(args[i]->toInt());
const int32_t cur_val = compute_graph->read_symint(iref);
if (new_val != cur_val) {
compute_graph->set_symint(iref, new_val);
was_updated = true;
}
}
// Since symint inputs may impact tensor's sizes, trigger a resize if
// any symbolic integer shapes are updated.
should_propagate_resize = should_propagate_resize || was_updated;
Expand Down Expand Up @@ -770,14 +800,13 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
"ETVK_COPY_OUTPUTS",
/* delegate_debug_id = */ -1);
#endif // ET_EVENT_TRACER_ENABLED
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
const size_t o = i + num_inputs;
const size_t output_offset = args.size() - num_outputs;
for (size_t i = 0; i < num_outputs; i++) {
const size_t o = output_offset + i;
const ValueRef oref = compute_graph->outputs()[i].value;
if (compute_graph->val_is_tensor(oref)) {
VK_CHECK_COND(args[o]->isTensor());
maybe_resize_output(compute_graph, i, args[o]->toTensor());
// args holds inputs directly followed by outputs, so the i'th output
// for compute_graph corresponds to the o'th arg
compute_graph->maybe_cast_and_copy_from_staging(
compute_graph->outputs()[i].staging,
args[o]->toTensor().mutable_data_ptr(),
Expand All @@ -789,6 +818,20 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
// returned as an output, no action is required.
else if (compute_graph->val_is_tref(oref)) {
continue;
} else if (compute_graph->val_is_symint(oref)) {
const int32_t symint_val = compute_graph->read_symint(oref);
if (args[o]->isTensor()) {
executorch::aten::Tensor& out_tensor = args[o]->toTensor();
executorch::aten::ScalarType dtype = out_tensor.scalar_type();
if (dtype == executorch::aten::ScalarType::Int) {
*out_tensor.mutable_data_ptr<int32_t>() = symint_val;
} else if (dtype == executorch::aten::ScalarType::Long) {
*out_tensor.mutable_data_ptr<int64_t>() =
static_cast<int64_t>(symint_val);
}
} else if (args[o]->isInt()) {
*args[o] = EValue(static_cast<int64_t>(symint_val));
}
} else {
VK_THROW(
"Could not handle output with type ",
Expand Down
20 changes: 12 additions & 8 deletions backends/vulkan/runtime/graph/ComputeGraph.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -452,14 +452,15 @@
const utils::AxisMapLayout axis_map_layout) {
ValueRef idx(static_cast<int>(values_.size()));
check_no_active_value_ptrs();
values_.emplace_back(api::vTensor(
context(),
sizes,
dtype,
storage_type,
memory_layout,
false,
axis_map_layout));
values_.emplace_back(
api::vTensor(
context(),
sizes,
dtype,
storage_type,
memory_layout,
false,
axis_map_layout));

if (shared_object_idx >= 0) {
get_shared_object(shared_object_idx).add_user(this, idx);
Expand Down Expand Up @@ -725,6 +726,9 @@
}

int32_t ComputeGraph::read_symint(const ValueRef idx) {
if (values_.at(idx).isInt()) {
return static_cast<int32_t>(values_.at(idx).toInt());
}
return get_symint(idx)->get();
}

Expand Down
3 changes: 3 additions & 0 deletions backends/vulkan/runtime/graph/ComputeGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,9 @@ class ComputeGraph final {
if (value.isBool()) {
return static_cast<T>(value.toBool());
}
if (value.isSymInt()) {
return utils::safe_downcast<T>(read_symint(idx));
}
VK_THROW("Cannot extract scalar from Value with type ", value.type());
}

Expand Down
Loading