Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def check_to_copy_node(node: torch.fx.Node) -> bool:
)
def register_softmax_cpp_ops():
return OpFeatures(
inputs_storage=utils.ANY_TEXTURE,
inputs_storage=utils.ANY_STORAGE,
inputs_dtypes=utils.FP_T,
supports_resize=True,
)
Expand Down
122 changes: 122 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/softmax_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

${define_required_extensions(STORAGE, DTYPE)}

#define PRECISION ${PRECISION}
#define T ${buffer_scalar_type(DTYPE)}

#define op1(X) ${OPERATOR1}

#define op2(X, Y) ${OPERATOR2}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "out_buf", DTYPE, STORAGE)}
${layout_declare_tensor(B, "r", "in_buf", DTYPE, STORAGE)}

${layout_declare_ubo(B, "ivec4", "in_sizes")}
${layout_declare_ubo(B, "ivec4", "in_strides")}
${layout_declare_ubo(B, "ivec4", "out_sizes")}
${layout_declare_ubo(B, "ivec4", "out_strides")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

layout(constant_id = 3) const int reduce_dim = 0;

#define NWORKERS 4
#define MAX_NTHREADS 16

shared T shared_max[NWORKERS];
shared T shared_sum[NWORKERS];

#include "indexing_utils.h"

/*
* Buffer-based softmax. Each workgroup processes one "row" along the reduction
* dimension. Within a workgroup, NWORKERS threads cooperate to compute the max
* and sum reductions, then each thread writes its portion of the final outputs.
*
* Thread mapping: the global WG size has 1 along reduce_dim, and all other
* dimensions correspond to output tensor sizes (WHCN order, with z encoding
* C*N). The local WG size has NWORKERS along reduce_dim. Each workgroup
* identifies a unique reduction "row" via the non-reduce dimensions of
* gl_GlobalInvocationID, and the NWORKERS threads within that workgroup
* cooperate on the reduction.
*/
void main() {
// Build the base 4D index for this workgroup's reduction row.
// gl_GlobalInvocationID has 0..NWORKERS-1 along reduce_dim; zero it out
// since the tid will iterate over the reduce_dim explicitly.
ivec3 gid = ivec3(gl_GlobalInvocationID);
gid[reduce_dim] = 0;

ivec4 base_idx = ivec4(gid.x, gid.y, gid.z % in_sizes.z, gid.z / in_sizes.z);

if (any(greaterThanEqual(base_idx, in_sizes))) {
return;
}

const uint tid = gl_LocalInvocationID[reduce_dim];
const int R = in_sizes[reduce_dim];

// Phase 1: Find maximum along reduce_dim
ivec4 in_idx = base_idx;

T local_max = T(-3.402823e+38);
for (int i = int(tid); i < R; i += NWORKERS) {
in_idx[reduce_dim] = i;
T v = in_buf[tidx_to_bufi(in_idx, in_strides)];
local_max = max(local_max, v);
}
shared_max[tid] = local_max;
barrier();

// Reduce partial maximums across workers
T max_val = shared_max[0];
for (int i = 1; i < NWORKERS; ++i) {
max_val = max(max_val, shared_max[i]);
}

// Phase 2: Compute sum of exp(x - max_val)
T local_sum = T(0);
for (int i = int(tid); i < R; i += NWORKERS) {
in_idx[reduce_dim] = i;
T v = in_buf[tidx_to_bufi(in_idx, in_strides)];
local_sum += exp(v - max_val);
}
shared_sum[tid] = local_sum;
barrier();

// Reduce partial sums across workers
T sum_val = shared_sum[0];
for (int i = 1; i < NWORKERS; ++i) {
sum_val += shared_sum[i];
}
// Clamp denominator to avoid 0/0 = NaN when all exp values underflow.
sum_val = max(sum_val, T(1e-37));

// Phase 3: Write outputs
for (int i = int(tid); i < R; i += NWORKERS) {
in_idx[reduce_dim] = i;
int in_buf_idx = tidx_to_bufi(in_idx, in_strides);
T v = in_buf[in_buf_idx];
T numerator = op1(v - max_val);
T result = op2(numerator, sum_val);

// Replace NaN/Inf with 0 using IEEE 754 bit-level manipulation
uint bits = floatBitsToUint(result);
if ((bits & 0x7F800000u) == 0x7F800000u) {
result = T(0);
}

out_buf[tidx_to_bufi(in_idx, out_strides)] = result;
}
}
21 changes: 21 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/softmax_buffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

softmax_buffer:
parameter_names_with_default_values:
OPERATOR1: exp(X)
OPERATOR2: X / Y
DTYPE: float
STORAGE: buffer
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: softmax_buffer
- NAME: log_softmax_buffer
OPERATOR1: X
OPERATOR2: X - log(Y)
141 changes: 129 additions & 12 deletions backends/vulkan/runtime/graph/ops/impl/Softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ namespace vkcompute {

using namespace utils;

utils::uvec3 pick_softmax_global_wg_size(
//
// Texture path
//

utils::uvec3 pick_softmax_texture_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)resize_args;

const ValueRef out = args.at(0).refs.at(0);
const int32_t reduce_dim_xyz =
Expand All @@ -35,7 +38,7 @@ utils::uvec3 pick_softmax_global_wg_size(
return global_size;
}

utils::uvec3 pick_softmax_local_wg_size(
utils::uvec3 pick_softmax_texture_local_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
Expand All @@ -51,7 +54,6 @@ utils::uvec3 pick_softmax_local_wg_size(
const int32_t reduce_dim_xyz =
graph->extract_scalar<int32_t>(resize_args.at(1));

// These values are hardcoded in add_softmax_node
const uint32_t nworkers_per_group = 4;
const uint32_t ngroups = 4;

Expand All @@ -74,16 +76,12 @@ void resize_softmax_node(
graph->virtual_resize(out, in_sizes);
}

void add_softmax_node(
void add_softmax_texture_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef dim_ref,
const ValueRef out,
bool log_softmax) {
VK_CHECK_COND(
!graph.is_buffer_storage(in) && !graph.is_buffer_storage(out),
"Vulkan softmax only supports texture storage");

const int64_t ndim = graph.dim_of(in);

int32_t reduce_dim_nchw = graph.extract_scalar<int32_t>(dim_ref);
Expand All @@ -101,7 +99,6 @@ void add_softmax_node(
"Softmax shader currently does not support concat dim == reduce dim");
}

vkapi::ShaderInfo shader_descriptor;
std::string kernel_name = "softmax";
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, graph.dtype_of(out));
Expand Down Expand Up @@ -134,8 +131,8 @@ void add_softmax_node(
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
pick_softmax_global_wg_size,
pick_softmax_local_wg_size,
pick_softmax_texture_global_wg_size,
pick_softmax_texture_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Shader params buffers
Expand All @@ -150,6 +147,126 @@ void add_softmax_node(
resize_softmax_node));
}

//
// Buffer path
//

utils::uvec3 pick_softmax_buffer_global_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)shader;
const ValueRef out = args.at(0).refs.at(0);
const ValueRef in = args.at(1).refs.at(0);
const int dim = resize_args.at(0);

const int64_t ndim = graph->dim_of(in);
int32_t reduce_dim = normalize(dim, ndim);
reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim);

utils::uvec3 global_size = {
graph->size_at<uint32_t>(-1, out),
graph->size_at<uint32_t>(-2, out),
graph->size_at<uint32_t>(-3, out) * graph->size_at<uint32_t>(-4, out)};
global_size[reduce_dim] = 1;
return global_size;
}

utils::uvec3 pick_softmax_buffer_local_wg_size(
ComputeGraph* graph,
const vkapi::ShaderInfo& shader,
const utils::uvec3& global_workgroup_size,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& resize_args) {
(void)shader;
(void)global_workgroup_size;
const ValueRef in = args.at(1).refs.at(0);
const int dim = resize_args.at(0);

const int64_t ndim = graph->dim_of(in);
int32_t reduce_dim = normalize(dim, ndim);
reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim);

const uint32_t nworkers_per_group = 4;
utils::uvec3 local_wg_size{1, 1, 1};
local_wg_size[reduce_dim] = nworkers_per_group;
return local_wg_size;
}

void add_softmax_buffer_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef dim_ref,
const ValueRef out,
bool log_softmax) {
const int64_t ndim = graph.dim_of(in);

int32_t reduce_dim_nchw = graph.extract_scalar<int32_t>(dim_ref);
reduce_dim_nchw = normalize(reduce_dim_nchw, ndim);
const int32_t reduce_dim = nchw_dim_to_whcn_dim(reduce_dim_nchw, ndim);

// Check that the concat dim is not the reduction dim, if the tensor has a
// batch dim greater than 1.
if (graph.dim_of(in) == 4 && graph.size_at<int>(0, in) > 1) {
VK_CHECK_COND(
graph.concat_dim_of(in) != reduce_dim,
"Softmax shader currently does not support concat dim == reduce dim");
VK_CHECK_COND(
graph.concat_dim_of(out) != reduce_dim,
"Softmax shader currently does not support concat dim == reduce dim");
}

std::string kernel_name = "softmax_buffer";
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, graph.dtype_of(out));
if (log_softmax) {
kernel_name = "log_" + kernel_name;
}

const int dim_val = graph.extract_scalar<int>(dim_ref);

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
pick_softmax_buffer_global_wg_size,
pick_softmax_buffer_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {in, vkapi::kRead}},
// Shader params buffers
{
graph.sizes_ubo(in),
graph.strides_ubo(in),
graph.sizes_ubo(out),
graph.strides_ubo(out),
},
// Push Constants
{},
// Specialization Constants
{reduce_dim},
// Resize Args
{dim_val},
// Resizing Logic
resize_softmax_node));
}

//
// Dispatch
//

void add_softmax_node(
ComputeGraph& graph,
const ValueRef in,
const ValueRef dim_ref,
const ValueRef out,
bool log_softmax) {
if (graph.is_buffer_storage(out)) {
add_softmax_buffer_node(graph, in, dim_ref, out, log_softmax);
} else {
add_softmax_texture_node(graph, in, dim_ref, out, log_softmax);
}
}

void softmax(ComputeGraph& graph, const std::vector<ValueRef>& args) {
// args[1] bool half_to_float is unused
return add_softmax_node(
Expand Down
2 changes: 2 additions & 0 deletions backends/vulkan/test/op_tests/cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -1590,6 +1590,7 @@ def get_softmax_inputs():
"utils::kWidthPacked",
"utils::kChannelsPacked",
]
test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]

# Large negative values regression test (edgeTAM attention scores that
# produced NaN due to missing max-shift in softmax numerics)
Expand All @@ -1602,6 +1603,7 @@ def get_softmax_inputs():
"utils::kWidthPacked",
"utils::kChannelsPacked",
]
large_neg_test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"]
large_neg_test_suite.data_range = (-1.8e10, -6.5e9)
large_neg_test_suite.test_name_suffix = "large_negative"
large_neg_test_suite.dtypes = ["at::kFloat"]
Expand Down
Loading