diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 91d679829bb..8ece5903cea 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -1405,7 +1405,7 @@ def register_native_group_norm(): @update_features(exir_ops.edge.aten.native_layer_norm.default) def register_native_layer_norm(): return OpFeatures( - inputs_storage=utils.ANY_TEXTURE, + inputs_storage=utils.ANY_STORAGE, inputs_dtypes=utils.FP_T, supports_prepacking=True, supports_resize=True, diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_buffer.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_buffer.glsl new file mode 100644 index 00000000000..22d32d09b89 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_buffer.glsl @@ -0,0 +1,136 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +${define_active_storage_type("buffer")} + +#extension GL_EXT_control_flow_attributes : require + +layout(std430) buffer; + +#include "indexing.glslh" + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "w", "t_mean", DTYPE, "buffer")} +${layout_declare_tensor(B, "w", "t_rstd", DTYPE, "buffer")} + +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")} + +${layout_declare_ubo(B, "BufferMetadata", "outp")} +${layout_declare_ubo(B, "BufferMetadata", "inp")} +${layout_declare_ubo(B, "BufferMetadata", "mean_meta")} + +layout(push_constant) uniform PRECISION restrict Block { + float epsilon; +}; + +#define NUM_WORKERS 64 + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "in_layout", "CONTIG_LAYOUT_INT")} + +shared T shared_sum[NUM_WORKERS]; + +void reduce_shared(const uint worker_id) { + memoryBarrierShared(); + barrier(); + + [[unroll]] for (int stride = NUM_WORKERS / 2; stride > 0; stride >>= 1) { + if (worker_id < stride) { + shared_sum[worker_id] += shared_sum[worker_id + stride]; + } + memoryBarrierShared(); + barrier(); + } +} + +void main() { + // Each workgroup handles one output row (one mean/rstd element). + // gl_GlobalInvocationID.y = row index + // gl_LocalInvocationID.x = worker_id within the row + const uint row_idx = gl_GlobalInvocationID.y; + const uint worker_id = gl_LocalInvocationID.x; + + const uint row_width = width(inp); + + if (row_idx >= numel(mean_meta)) { + return; + } + + // Convert row_idx to a tensor index using the mean/rstd metadata. + // The mean/rstd tensor has shape [..., 1] (width dimension is 1). + // This gives us the outer dimension indices for this row. + TensorIndex row_tidx = linear_idx_to_tensor_idx(mean_meta, row_idx, in_layout); + + // The width stride in the input buffer tells us how to step through width + // elements. For contiguous layout, stride_at(inp, 0) == 1; for other + // layouts it may differ. + const uint width_stride = stride_at(inp, 0); + + // Compute the base buffer index for this row in the input tensor. + // Set width component to 0 and compute the buffer offset. + row_tidx.data[0][0] = 0; + const uint base_bufi = tensor_idx_to_linear_idx(inp, row_tidx); + + // Phase 1: Compute mean via cooperative reduction + T local_sum = T(0); + for (uint x = worker_id; x < row_width; x += NUM_WORKERS) { + const uint in_bufi = base_bufi + x * width_stride; + local_sum += t_in[in_bufi]; + } + + shared_sum[worker_id] = local_sum; + reduce_shared(worker_id); + + const T mean_val = shared_sum[0] / T(row_width); + + memoryBarrierShared(); + barrier(); + + // Phase 2: Compute variance via cooperative reduction + T local_var = T(0); + for (uint x = worker_id; x < row_width; x += NUM_WORKERS) { + const uint in_bufi = base_bufi + x * width_stride; + const T delta = t_in[in_bufi] - mean_val; + local_var += delta * delta; + } + + shared_sum[worker_id] = local_var; + reduce_shared(worker_id); + + const T var_val = shared_sum[0] / T(row_width); + const T rstd_val = pow(var_val + T(epsilon), T(-0.5)); + + // Phase 3: Normalize and write output + // Weight and bias are 1D tensors of size [width], indexed directly by x. + for (uint x = worker_id; x < row_width; x += NUM_WORKERS) { + const uint in_bufi = base_bufi + x * width_stride; + const T in_val = t_in[in_bufi]; + const T normalized = (in_val - mean_val) * rstd_val; + const T w = t_weight[x]; + const T b = t_bias[x]; + t_out[in_bufi] = normalized * w + b; + } + + // Write mean and rstd (only one thread per row) + if (worker_id == 0) { + const uint mean_bufi = tensor_idx_to_linear_idx(mean_meta, row_tidx); + t_mean[mean_bufi] = mean_val; + t_rstd[mean_bufi] = rstd_val; + } +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_buffer.yaml b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_buffer.yaml new file mode 100644 index 00000000000..1978f237ea5 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_buffer.yaml @@ -0,0 +1,16 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +native_layer_norm_buffer: + parameter_names_with_default_values: + DTYPE: float + STORAGE: buffer + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: native_layer_norm_buffer diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_texture.glsl similarity index 100% rename from backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.glsl rename to backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_texture.glsl diff --git a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_texture.yaml similarity index 85% rename from backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml rename to backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_texture.yaml index ac478599f8a..ee3c3b96b28 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/native_layer_norm_texture.yaml @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -native_layer_norm: +native_layer_norm_texture: parameter_names_with_default_values: DTYPE: float STORAGE: texture3d @@ -13,4 +13,4 @@ native_layer_norm: - VALUE: half - VALUE: float shader_variants: - - NAME: native_layer_norm + - NAME: native_layer_norm_texture3d diff --git a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp index 8e15b56b208..ec4bc7ff943 100644 --- a/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/NativeLayerNorm.cpp @@ -50,7 +50,36 @@ void resize_native_layer_norm_node( graph->virtual_resize(rstd, mean_size); } -void add_native_layer_norm_node( +// Global workgroup size for buffer path: one workgroup per row +utils::uvec3 layer_norm_buffer_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef mean_tensor = args.at(0).refs.at(1); + const uint32_t num_rows = + utils::safe_downcast(graph->numel_of(mean_tensor)); + return {1u, num_rows, 1u}; +} + +// Local workgroup size for buffer path: NUM_WORKERS threads per row +utils::uvec3 layer_norm_buffer_local_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& resize_args) { + (void)graph; + (void)shader; + (void)global_workgroup_size; + (void)args; + (void)resize_args; + return {64u, 1u, 1u}; +} + +void add_native_layer_norm_buffer_node( ComputeGraph& graph, const ValueRef in, const ValueRef normalized_shape, @@ -58,20 +87,51 @@ void add_native_layer_norm_node( const ValueRef bias_data, const ValueRef eps, const ValueRef out) { - const auto normalized_shape_dim = - graph.get_int_list(normalized_shape)->size(); - if (normalized_shape_dim > 1) { - VK_THROW("native_layer_norm only supports normalized_shape with dim == 1"); - } + ValueRef arg_weight = prepack_standard_like(graph, weight_data, in); + ValueRef arg_bias = prepack_standard_like(graph, bias_data, in); - if (graph.val_is_none(weight_data)) { - VK_THROW("native_layer_norm requires weight to be non-None"); - } + const auto out_val = graph.get_value_list(out); + const ValueRef out_tensor = out_val->at(0); + const ValueRef mean_tensor = out_val->at(1); + const ValueRef rstd_tensor = out_val->at(2); - if (graph.val_is_none(bias_data)) { - VK_THROW("native_layer_norm requires bias to be non-None"); - } + float epsilon = graph.extract_scalar(eps); + std::string kernel_name("native_layer_norm"); + kernel_name.reserve(kShaderNameReserve); + add_storage_type_suffix(kernel_name, graph.storage_type_of(out_tensor)); + add_dtype_suffix(kernel_name, graph.dtype_of(out_tensor)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + layer_norm_buffer_global_wg_size, + layer_norm_buffer_local_wg_size, + // Inputs and Outputs + {{{out_tensor, mean_tensor, rstd_tensor}, vkapi::kWrite}, + {{in, arg_weight, arg_bias}, vkapi::kRead}}, + // Shader params buffers + {graph.buffer_meta_ubo(out_tensor), + graph.buffer_meta_ubo(in), + graph.buffer_meta_ubo(mean_tensor)}, + // Push Constants + {PushConstantDataInfo(&epsilon, sizeof(epsilon))}, + // Specialization Constants + {graph.hashed_layout_of(in)}, + // Resize Args + {normalized_shape}, + // Resizing Logic + resize_native_layer_norm_node)); +} + +void add_native_layer_norm_texture_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef normalized_shape, + const ValueRef weight_data, + const ValueRef bias_data, + const ValueRef eps, + const ValueRef out) { ValueRef arg_weight = prepack_standard_like(graph, weight_data, in); ValueRef arg_bias = prepack_standard_like(graph, bias_data, in); @@ -84,25 +144,9 @@ void add_native_layer_norm_node( VK_CHECK_COND(check_same_packed_dim(graph, in, out_tensor)); - const std::vector in_sizes = graph.sizes_of(in); - - utils::uvec3 global_size = graph.logical_limits_of(out_tensor); - utils::uvec3 local_size; - - // Since the shader sets shared memory scale factor > 1, if dispatch is - // greater than maximum WG size. Setting WG size in X axis to max WG size, - // would allow best thread utilization. - if (global_size[0] > 64) { - local_size = {64, 1, 1}; - } else { - // If thread size in X axis is smaller or equal to maximum WG size, we can - // let the function decide the best WG size. - local_size = graph.create_local_wg_size(global_size); - } - std::string kernel_name("native_layer_norm"); kernel_name.reserve(kShaderNameReserve); - + add_storage_type_suffix(kernel_name, graph.storage_type_of(out_tensor)); add_dtype_suffix(kernel_name, graph.dtype_of(out_tensor)); graph.execute_nodes().emplace_back(new DynamicDispatchNode( @@ -132,6 +176,37 @@ void add_native_layer_norm_node( resize_native_layer_norm_node)); } +void add_native_layer_norm_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef normalized_shape, + const ValueRef weight_data, + const ValueRef bias_data, + const ValueRef eps, + const ValueRef out) { + const auto normalized_shape_dim = + graph.get_int_list(normalized_shape)->size(); + if (normalized_shape_dim > 1) { + VK_THROW("native_layer_norm only supports normalized_shape with dim == 1"); + } + + if (graph.val_is_none(weight_data)) { + VK_THROW("native_layer_norm requires weight to be non-None"); + } + + if (graph.val_is_none(bias_data)) { + VK_THROW("native_layer_norm requires bias to be non-None"); + } + + if (graph.is_buffer_storage(in)) { + add_native_layer_norm_buffer_node( + graph, in, normalized_shape, weight_data, bias_data, eps, out); + } else { + add_native_layer_norm_texture_node( + graph, in, normalized_shape, weight_data, bias_data, eps, out); + } +} + void native_layer_norm(ComputeGraph& graph, const std::vector& args) { return add_native_layer_norm_node( graph, args[0], args[1], args[2], args[3], args[4], args[5]); diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 4caaa772267..87a086db831 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -721,6 +721,10 @@ def get_native_layer_norm_inputs(): "utils::kHeightPacked", "utils::kChannelsPacked", ] + test_suite.storage_types = [ + "utils::kTexture3D", + "utils::kBuffer", + ] return test_suite