diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index f348cbbce9e..91d679829bb 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -327,7 +327,7 @@ def check_to_copy_node(node: torch.fx.Node) -> bool: ) def register_softmax_cpp_ops(): return OpFeatures( - inputs_storage=utils.ANY_TEXTURE, + inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_T, supports_resize=True, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/softmax_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/softmax_buffer.glsl new file mode 100644 index 00000000000..4e5e034a33c --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/softmax_buffer.glsl @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions(STORAGE, DTYPE)} + +#define PRECISION ${PRECISION} +#define T ${buffer_scalar_type(DTYPE)} + +#define op1(X) ${OPERATOR1} + +#define op2(X, Y) ${OPERATOR2} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "out_buf", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "in_buf", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "ivec4", "in_sizes")} +${layout_declare_ubo(B, "ivec4", "in_strides")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec4", "out_strides")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +layout(constant_id = 3) const int reduce_dim = 0; + +#define NWORKERS 4 +#define MAX_NTHREADS 16 + +shared T shared_max[NWORKERS]; +shared T shared_sum[NWORKERS]; + +#include "indexing_utils.h" + +/* + * Buffer-based softmax. Each workgroup processes one "row" along the reduction + * dimension. Within a workgroup, NWORKERS threads cooperate to compute the max + * and sum reductions, then each thread writes its portion of the final outputs. + * + * Thread mapping: the global WG size has 1 along reduce_dim, and all other + * dimensions correspond to output tensor sizes (WHCN order, with z encoding + * C*N). The local WG size has NWORKERS along reduce_dim. Each workgroup + * identifies a unique reduction "row" via the non-reduce dimensions of + * gl_GlobalInvocationID, and the NWORKERS threads within that workgroup + * cooperate on the reduction. + */ +void main() { + // Build the base 4D index for this workgroup's reduction row. + // gl_GlobalInvocationID has 0..NWORKERS-1 along reduce_dim; zero it out + // since the tid will iterate over the reduce_dim explicitly. + ivec3 gid = ivec3(gl_GlobalInvocationID); + gid[reduce_dim] = 0; + + ivec4 base_idx = ivec4(gid.x, gid.y, gid.z % in_sizes.z, gid.z / in_sizes.z); + + if (any(greaterThanEqual(base_idx, in_sizes))) { + return; + } + + const uint tid = gl_LocalInvocationID[reduce_dim]; + const int R = in_sizes[reduce_dim]; + + // Phase 1: Find maximum along reduce_dim + ivec4 in_idx = base_idx; + + T local_max = T(-3.402823e+38); + for (int i = int(tid); i < R; i += NWORKERS) { + in_idx[reduce_dim] = i; + T v = in_buf[tidx_to_bufi(in_idx, in_strides)]; + local_max = max(local_max, v); + } + shared_max[tid] = local_max; + barrier(); + + // Reduce partial maximums across workers + T max_val = shared_max[0]; + for (int i = 1; i < NWORKERS; ++i) { + max_val = max(max_val, shared_max[i]); + } + + // Phase 2: Compute sum of exp(x - max_val) + T local_sum = T(0); + for (int i = int(tid); i < R; i += NWORKERS) { + in_idx[reduce_dim] = i; + T v = in_buf[tidx_to_bufi(in_idx, in_strides)]; + local_sum += exp(v - max_val); + } + shared_sum[tid] = local_sum; + barrier(); + + // Reduce partial sums across workers + T sum_val = shared_sum[0]; + for (int i = 1; i < NWORKERS; ++i) { + sum_val += shared_sum[i]; + } + // Clamp denominator to avoid 0/0 = NaN when all exp values underflow. + sum_val = max(sum_val, T(1e-37)); + + // Phase 3: Write outputs + for (int i = int(tid); i < R; i += NWORKERS) { + in_idx[reduce_dim] = i; + int in_buf_idx = tidx_to_bufi(in_idx, in_strides); + T v = in_buf[in_buf_idx]; + T numerator = op1(v - max_val); + T result = op2(numerator, sum_val); + + // Replace NaN/Inf with 0 using IEEE 754 bit-level manipulation + uint bits = floatBitsToUint(result); + if ((bits & 0x7F800000u) == 0x7F800000u) { + result = T(0); + } + + out_buf[tidx_to_bufi(in_idx, out_strides)] = result; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/softmax_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/softmax_buffer.yaml new file mode 100644 index 00000000000..419e1a01ea7 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/softmax_buffer.yaml @@ -0,0 +1,21 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +softmax_buffer: + parameter_names_with_default_values: + OPERATOR1: exp(X) + OPERATOR2: X / Y + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: softmax_buffer + - NAME: log_softmax_buffer + OPERATOR1: X + OPERATOR2: X - log(Y) diff --git a/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp b/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp index 2d683719ba2..102a0c13384 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Softmax.cpp @@ -18,13 +18,16 @@ namespace vkcompute { using namespace utils; -utils::uvec3 pick_softmax_global_wg_size( +// +// Texture path +// + +utils::uvec3 pick_softmax_texture_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const std::vector& args, const std::vector& resize_args) { (void)shader; - (void)resize_args; const ValueRef out = args.at(0).refs.at(0); const int32_t reduce_dim_xyz = @@ -35,7 +38,7 @@ utils::uvec3 pick_softmax_global_wg_size( return global_size; } -utils::uvec3 pick_softmax_local_wg_size( +utils::uvec3 pick_softmax_texture_local_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, const utils::uvec3& global_workgroup_size, @@ -51,7 +54,6 @@ utils::uvec3 pick_softmax_local_wg_size( const int32_t reduce_dim_xyz = graph->extract_scalar(resize_args.at(1)); - // These values are hardcoded in add_softmax_node const uint32_t nworkers_per_group = 4; const uint32_t ngroups = 4; @@ -74,16 +76,12 @@ void resize_softmax_node( graph->virtual_resize(out, in_sizes); } -void add_softmax_node( +void add_softmax_texture_node( ComputeGraph& graph, const ValueRef in, const ValueRef dim_ref, const ValueRef out, bool log_softmax) { - VK_CHECK_COND( - !graph.is_buffer_storage(in) && !graph.is_buffer_storage(out), - "Vulkan softmax only supports texture storage"); - const int64_t ndim = graph.dim_of(in); int32_t reduce_dim_nchw = graph.extract_scalar(dim_ref); @@ -101,7 +99,6 @@ void add_softmax_node( "Softmax shader currently does not support concat dim == reduce dim"); } - vkapi::ShaderInfo shader_descriptor; std::string kernel_name = "softmax"; kernel_name.reserve(kShaderNameReserve); add_dtype_suffix(kernel_name, graph.dtype_of(out)); @@ -134,8 +131,8 @@ void add_softmax_node( graph.execute_nodes().emplace_back(new DynamicDispatchNode( graph, VK_KERNEL_FROM_STR(kernel_name), - pick_softmax_global_wg_size, - pick_softmax_local_wg_size, + pick_softmax_texture_global_wg_size, + pick_softmax_texture_local_wg_size, // Inputs and Outputs {{out, vkapi::kWrite}, {in, vkapi::kRead}}, // Shader params buffers @@ -150,6 +147,126 @@ void add_softmax_node( resize_softmax_node)); } +// +// Buffer path +// + +utils::uvec3 pick_softmax_buffer_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + const ValueRef out = args.at(0).refs.at(0); + const ValueRef in = args.at(1).refs.at(0); + const int dim = resize_args.at(0); + + const int64_t ndim = graph->dim_of(in); + int32_t reduce_dim = normalize(dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + utils::uvec3 global_size = { + graph->size_at(-1, out), + graph->size_at(-2, out), + graph->size_at(-3, out) * graph->size_at(-4, out)}; + global_size[reduce_dim] = 1; + return global_size; +} + +utils::uvec3 pick_softmax_buffer_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)global_workgroup_size; + const ValueRef in = args.at(1).refs.at(0); + const int dim = resize_args.at(0); + + const int64_t ndim = graph->dim_of(in); + int32_t reduce_dim = normalize(dim, ndim); + reduce_dim = nchw_dim_to_whcn_dim(reduce_dim, ndim); + + const uint32_t nworkers_per_group = 4; + utils::uvec3 local_wg_size{1, 1, 1}; + local_wg_size[reduce_dim] = nworkers_per_group; + return local_wg_size; +} + +void add_softmax_buffer_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef dim_ref, + const ValueRef out, + bool log_softmax) { + const int64_t ndim = graph.dim_of(in); + + int32_t reduce_dim_nchw = graph.extract_scalar(dim_ref); + reduce_dim_nchw = normalize(reduce_dim_nchw, ndim); + const int32_t reduce_dim = nchw_dim_to_whcn_dim(reduce_dim_nchw, ndim); + + // Check that the concat dim is not the reduction dim, if the tensor has a + // batch dim greater than 1. + if (graph.dim_of(in) == 4 && graph.size_at(0, in) > 1) { + VK_CHECK_COND( + graph.concat_dim_of(in) != reduce_dim, + "Softmax shader currently does not support concat dim == reduce dim"); + VK_CHECK_COND( + graph.concat_dim_of(out) != reduce_dim, + "Softmax shader currently does not support concat dim == reduce dim"); + } + + std::string kernel_name = "softmax_buffer"; + kernel_name.reserve(kShaderNameReserve); + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + if (log_softmax) { + kernel_name = "log_" + kernel_name; + } + + const int dim_val = graph.extract_scalar(dim_ref); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_softmax_buffer_global_wg_size, + pick_softmax_buffer_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {in, vkapi::kRead}}, + // Shader params buffers + { + graph.sizes_ubo(in), + graph.strides_ubo(in), + graph.sizes_ubo(out), + graph.strides_ubo(out), + }, + // Push Constants + {}, + // Specialization Constants + {reduce_dim}, + // Resize Args + {dim_val}, + // Resizing Logic + resize_softmax_node)); +} + +// +// Dispatch +// + +void add_softmax_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef dim_ref, + const ValueRef out, + bool log_softmax) { + if (graph.is_buffer_storage(out)) { + add_softmax_buffer_node(graph, in, dim_ref, out, log_softmax); + } else { + add_softmax_texture_node(graph, in, dim_ref, out, log_softmax); + } +} + void softmax(ComputeGraph& graph, const std::vector& args) { // args[1] bool half_to_float is unused return add_softmax_node( diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index e06b7f3ce6b..4caaa772267 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -1590,6 +1590,7 @@ def get_softmax_inputs(): "utils::kWidthPacked", "utils::kChannelsPacked", ] + test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] # Large negative values regression test (edgeTAM attention scores that # produced NaN due to missing max-shift in softmax numerics) @@ -1602,6 +1603,7 @@ def get_softmax_inputs(): "utils::kWidthPacked", "utils::kChannelsPacked", ] + large_neg_test_suite.storage_types = ["utils::kTexture3D", "utils::kBuffer"] large_neg_test_suite.data_range = (-1.8e10, -6.5e9) large_neg_test_suite.test_name_suffix = "large_negative" large_neg_test_suite.dtypes = ["at::kFloat"]