diff --git a/backends/vulkan/op_registry.py b/backends/vulkan/op_registry.py index 2849a1550a3..861a4c067ed 100644 --- a/backends/vulkan/op_registry.py +++ b/backends/vulkan/op_registry.py @@ -800,9 +800,36 @@ def check_conv_node(node: torch.fx.Node) -> bool: return True + def pick_conv_storage( + node: torch.fx.Node, + ) -> Tuple[List[utils.TensorRepSet], utils.TensorRepSet]: + x_shape = node.args[0].meta["val"].size() # type: ignore[union-attr] + no_storage_tail = [utils.NO_STORAGE] * (len(node.args) - 1) + if len(x_shape) == 3: + weight = node.args[1] # type: ignore[union-attr] + weight_shape = weight.meta["val"].size() # type: ignore[union-attr] + groups = node.args[8] # type: ignore[union-attr] + groups_val = groups if isinstance(groups, int) else int(groups) + is_depthwise = weight_shape[0] == groups_val and weight_shape[1] == 1 + if weight_shape[2] == 1 or is_depthwise: + # Pointwise and depthwise 1D conv both have texture implementations + # using width-packed TEXTURE_3D. + return ( + [utils.WIDTH_PACKED_TEXTURE] + no_storage_tail, + utils.WIDTH_PACKED_TEXTURE, + ) + # General (non-pointwise, non-depthwise) 1D convolution: buffer path + return [utils.CONTIGUOUS_BUFFER] + no_storage_tail, utils.CONTIGUOUS_BUFFER + else: + # 2D convolution: channels-packed texture path + return ( + [utils.CHANNELS_PACKED_TEXTURE] + no_storage_tail, + utils.CHANNELS_PACKED_TEXTURE, + ) + return OpFeatures( inputs_storage=[ - utils.CHANNELS_PACKED_TEXTURE, # input + utils.CHANNELS_PACKED_TEXTURE, # input (overridden by pick_conv_storage) utils.NO_STORAGE, # weight (prepacked) utils.NO_STORAGE, # bias (prepacked) utils.NO_STORAGE, # stride (non tensor) @@ -818,6 +845,7 @@ def check_conv_node(node: torch.fx.Node) -> bool: supports_resize=True, supports_prepacking=True, are_node_inputs_supported_fn=check_conv_node, + pick_io_storage_fn=pick_conv_storage, ) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl index 4e3b91e6c49..dbd1f8f3359 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d.glsl @@ -8,114 +8,71 @@ #version 450 core +${define_required_extensions("buffer", DTYPE)} + #define PRECISION ${PRECISION} -#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} #define op(X, A, B) ${OPERATOR} layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "kernel_in", DTYPE, STORAGE)} -${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")} -${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec4", "out_strides")} ${layout_declare_ubo(B, "ivec4", "in_sizes")} - -${layout_declare_ubo(B,"int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation", "int", "in_group_size", "int", "out_group_size")} - +${layout_declare_ubo(B, "ivec4", "in_strides")} +${layout_declare_ubo(B, "ivec4", "weight_strides")} +${layout_declare_ubo(B, "int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation", "int", "in_group_size", "int", "out_group_size")} ${layout_declare_ubo(B, "float", "out_min", "float", "out_max")} #include "indexing_utils.h" layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); - -${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); - -${layout_declare_spec_const(C, "int", "kernel_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 kernel_axis_map = unhash_axis_map(kernel_layout); - -${layout_declare_spec_const(C, "int", "bias_layout", "DEFAULT_LAYOUT")} -const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout); - -// Let us define -// -// input = (N, in_C, in_L), -// output = (N, out_C, out_L), -// groups = G, -// kernel = K, -// -// which results in shapes -// -// weight = (out_C, in_C / G, K), -// bias = (out_C,). -// -// This implementation performs N x out_C x out_L shader invocations, where each invocation -// calculates the rolling kernel of the length dimension for each batch, i.e., -// computes out_L results. +/* + * Computes a 1D convolution over width-packed buffer tensors. Each shader + * invocation computes one output element at position (n, out_c, out_l). + * + * Tensor sizes/strides are in WHCN order: + * out_sizes.x = L_out, out_sizes.z = C_out, out_sizes.w = N + * in_sizes.x = L_in, in_sizes.z = C_in + */ void main() { - const ivec3 lpos = ivec3(gl_GlobalInvocationID); + const int out_l = int(gl_GlobalInvocationID.x); + const int out_c = int(gl_GlobalInvocationID.y); + const int n = int(gl_GlobalInvocationID.z); - if (any(greaterThanEqual(lpos, out_limits))) { + // WHCN sizes for [N, C, L]: (L, C, N, 1) -> sizes.y=C, sizes.z=N + if (out_l >= out_sizes.x || out_c >= out_sizes.y || n >= out_sizes.z) { return; } - // "out_c" is the output's channel index where we write our result. - // Across shader invocations, this is the only value that varies. - const int out_c = lpos.y; - - // "in_c" tracks the input's channel start index. - // We iterate over the input group that corresponds to the output group. const int c_start = (out_c / out_group_size) * in_group_size; - const int c_end = c_start + in_group_size; - - // "out_l" tracks the output's length index where we write our result. - const int out_l = lpos.x; - - // "N" is the batch index - const int N = lpos.z; - - // "in_l" tracks the input's length start index for our input-kernel overlay - // region. - const int in_l = out_l * stride - padding; - VEC4_T sum = VEC4_T(0); - - const int out_c_packed_index = out_c >> 2; - const int out_c_packed_lane = out_c & 0x3; - - for (int in_c = c_start; in_c < c_end; ++in_c) { - // "k" tracks the kernel's index for our input-kernel computation. - // It reads out-of-bound zeros, but trying to avoid them complicates - // for-loop conditions, which results in worse performance. - - // The weight tensor is channel-packed. It may not be trival choice for - // performance reason since need to have more data fetch. The reason is - // for some sequence model, we found that the weight tensor - // (out_channel, in_channel / group, kernel) often has a large - // out_channel >> kernel, leading to non-optimal use of memory as the - // weight tensor gets very deep. As a mitigation, we use channel-packing - // for the weight tensor, yielding a 75% reduction in weight-tensor - // memory. - - // It is possible to further reduce the memory footprint by swapping the - // dimensions, using x extent for out_channel, and y for kernel. - for (int k = 0; k < kernel_size; k++) { - const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c_packed_index); - const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map); - VEC4_T weight = VEC4_T(weight_texel[out_c_packed_lane]); - const ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, N), in_axis_map); - sum = fma(weight, load_texel(t_in, in_pos), sum); + T sum = T(0); + for (int ic = 0; ic < in_group_size; ic++) { + const int in_c = c_start + ic; + for (int k = 0; k < kernel_size; k++) { + const int in_l = out_l * stride - padding + k * dilation; + if (in_l >= 0 && in_l < in_sizes.x) { + // WHCN tidx for (n, in_c, in_l) in [N, C, L] tensor: (in_l, in_c, n, 0) + const int in_idx = tidx_to_bufi(ivec4(in_l, in_c, n, 0), in_strides); + // WHCN tidx for weight (k, ic, out_c) in [C_out, C_in/g, K]: (k, ic, out_c, 0) + const int w_idx = tidx_to_bufi(ivec4(k, ic, out_c, 0), weight_strides); + sum += t_in[in_idx] * t_weight[w_idx]; + } } } - const VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c_packed_index, 0, 0), bias_axis_map); - const ivec3 out_lpos = ivec3(out_l, out_c, N); - write_texel_lpos(t_out, out_lpos, op(sum + bias[out_c_packed_lane], out_min, out_max), out_axis_map); + sum += T(t_bias[out_c]); + + // WHCN tidx for (n, out_c, out_l): (out_l, out_c, n, 0) + const int out_idx = tidx_to_bufi(ivec4(out_l, out_c, n, 0), out_strides); + t_out[out_idx] = op(sum, T(out_min), T(out_max)); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv1d.yaml index 2266649d2b9..bf5a65068a3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/conv1d.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d.yaml @@ -8,7 +8,6 @@ conv1d: parameter_names_with_default_values: OPERATOR: X DTYPE: float - STORAGE: texture3d generate_variant_forall: DTYPE: - VALUE: half diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.glsl new file mode 100644 index 00000000000..a6340531568 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.glsl @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +#define op(X, A, B) ${OPERATOR} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec4", "out_strides")} +${layout_declare_ubo(B, "ivec4", "in_sizes")} +${layout_declare_ubo(B, "ivec4", "in_strides")} +${layout_declare_ubo(B, "ivec4", "weight_strides")} +${layout_declare_ubo(B, "int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation", "int", "in_group_size", "int", "out_group_size")} +${layout_declare_ubo(B, "float", "out_min", "float", "out_max")} + +#include "indexing_utils.h" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Computes a depthwise 1D convolution over width-packed buffer tensors. Each + * shader invocation computes one output element at position (n, c, out_l). + * + * For depthwise conv: groups == C_in == C_out, so each output channel uses + * exactly one input channel. Weight shape is [C, 1, K]. + */ +void main() { + const int out_l = int(gl_GlobalInvocationID.x); + const int c = int(gl_GlobalInvocationID.y); + const int n = int(gl_GlobalInvocationID.z); + + // WHCN sizes for [N, C, L]: (L, C, N, 1) -> sizes.y=C, sizes.z=N + if (out_l >= out_sizes.x || c >= out_sizes.y || n >= out_sizes.z) { + return; + } + + T sum = T(0); + for (int k = 0; k < kernel_size; k++) { + const int in_l = out_l * stride - padding + k * dilation; + if (in_l >= 0 && in_l < in_sizes.x) { + // WHCN tidx for (n, c, in_l) in [N, C, L] tensor: (in_l, c, n, 0) + const int in_idx = tidx_to_bufi(ivec4(in_l, c, n, 0), in_strides); + // WHCN tidx for weight (k, 0, c) in [C, 1, K]: (k, 0, c, 0) + const int w_idx = tidx_to_bufi(ivec4(k, 0, c, 0), weight_strides); + sum += t_in[in_idx] * t_weight[w_idx]; + } + } + + sum += T(t_bias[c]); + + // WHCN tidx for (n, c, out_l): (out_l, c, n, 0) + const int out_idx = tidx_to_bufi(ivec4(out_l, c, n, 0), out_strides); + t_out[out_idx] = op(sum, T(out_min), T(out_max)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.yaml new file mode 100644 index 00000000000..79ca0ec0627 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv1d_dw: + parameter_names_with_default_values: + OPERATOR: X + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: conv1d_dw + - NAME: conv1d_dw_clamp + OPERATOR: clamp(X, A, B) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw_texture.glsl new file mode 100644 index 00000000000..f639bc1bf26 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw_texture.glsl @@ -0,0 +1,93 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("texture3d", DTYPE)} +${define_required_extensions("buffer", DTYPE)} + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +#define op(X, A, B) ${OPERATOR} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation")} +${layout_declare_ubo(B, "int", "in_length")} +${layout_declare_ubo(B, "float", "out_min", "float", "out_max")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Depthwise 1D convolution for width-packed TEXTURE_3D tensors. + * + * Each invocation computes one output texel which contains up to 4 adjacent + * width (length) positions for a single channel c at batch n. + * + * Tensor layout: [N, C, L] stored as width-packed texture3d where + * texture x = L / 4 (packed width texel index), y = C, z = N. + * + * For depthwise conv: groups == C_in == C_out, so each output channel c reads + * only from input channel c. Weight shape is [C, 1, K] stored as a contiguous + * buffer: t_weight[c * kernel_size + k]. + */ +void main() { + const int out_l = int(gl_GlobalInvocationID.x); + const int out_c = int(gl_GlobalInvocationID.y); + const int n = int(gl_GlobalInvocationID.z); + + if (out_l >= out_limits.x || out_c >= out_limits.y || n >= out_limits.z) { + return; + } + + // out_l is a texel index, each texel holds 4 width positions. + // The 4 logical output positions are: base_l, base_l+1, base_l+2, base_l+3. + const int base_out_l = out_l * 4; + const int w_base = out_c * kernel_size; + + VEC4_T sum = VEC4_T(0); + + for (int k = 0; k < kernel_size; k++) { + const T w = t_weight[w_base + k]; + + // For each of the 4 packed width lanes, compute the input position and + // accumulate. All 4 lanes share the same weight scalar w. + const ivec4 in_l = ivec4( + base_out_l * stride - padding + k * dilation, + (base_out_l + 1) * stride - padding + k * dilation, + (base_out_l + 2) * stride - padding + k * dilation, + (base_out_l + 3) * stride - padding + k * dilation); + + // Each lane reads from a potentially different input texel. + const ivec4 in_texel_idx = in_l >> 2; // divide by 4 + const ivec4 in_lane = in_l & 3; // mod 4 + + for (int lane = 0; lane < 4; lane++) { + if (in_l[lane] >= 0 && in_l[lane] < in_length) { + const VEC4_T in_texel = + texelFetch(t_in, ivec3(in_texel_idx[lane], out_c, n), 0); + sum[lane] += w * in_texel[in_lane[lane]]; + } + } + } + + sum += VEC4_T(T(t_bias[out_c])); + imageStore( + t_out, + ivec3(out_l, out_c, n), + op(sum, VEC4_T(out_min), VEC4_T(out_max))); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw_texture.yaml new file mode 100644 index 00000000000..36287d5a35f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_dw_texture.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv1d_dw_texture: + parameter_names_with_default_values: + OPERATOR: X + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: conv1d_dw_texture + - NAME: conv1d_dw_texture_clamp + OPERATOR: clamp(X, A, B) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.glsl new file mode 100644 index 00000000000..3c5c8991b16 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.glsl @@ -0,0 +1,76 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("buffer", DTYPE)} + +#define PRECISION ${PRECISION} + +#define T ${buffer_scalar_type(DTYPE)} + +#define op(X, A, B) ${OPERATOR} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec4", "out_strides")} +${layout_declare_ubo(B, "ivec4", "in_sizes")} +${layout_declare_ubo(B, "ivec4", "in_strides")} +${layout_declare_ubo(B, "ivec4", "weight_strides")} +${layout_declare_ubo(B, "int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation", "int", "in_group_size", "int", "out_group_size")} +${layout_declare_ubo(B, "float", "out_min", "float", "out_max")} + +#include "indexing_utils.h" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Computes a pointwise (kernel_size=1) 1D convolution over width-packed buffer + * tensors. Each shader invocation computes one output element at (n, out_c, out_l). + * + * Since kernel_size=1 there is no spatial loop; only the channel reduction loop. + */ +void main() { + const int out_l = int(gl_GlobalInvocationID.x); + const int out_c = int(gl_GlobalInvocationID.y); + const int n = int(gl_GlobalInvocationID.z); + + // WHCN sizes for [N, C, L]: (L, C, N, 1) -> sizes.y=C, sizes.z=N + if (out_l >= out_sizes.x || out_c >= out_sizes.y || n >= out_sizes.z) { + return; + } + + const int c_start = (out_c / out_group_size) * in_group_size; + + // Pointwise: kernel_size=1, k=0 always + const int in_l = out_l * stride - padding; + + T sum = T(0); + if (in_l >= 0 && in_l < in_sizes.x) { + for (int ic = 0; ic < in_group_size; ic++) { + const int in_c = c_start + ic; + // WHCN tidx for (n, in_c, in_l) in [N, C, L] tensor: (in_l, in_c, n, 0) + const int in_idx = tidx_to_bufi(ivec4(in_l, in_c, n, 0), in_strides); + // WHCN tidx for weight (0, ic, out_c) in [C_out, C_in/g, 1]: (0, ic, out_c, 0) + const int w_idx = tidx_to_bufi(ivec4(0, ic, out_c, 0), weight_strides); + sum += t_in[in_idx] * t_weight[w_idx]; + } + } + + sum += T(t_bias[out_c]); + + // WHCN tidx for (n, out_c, out_l): (out_l, out_c, n, 0) + const int out_idx = tidx_to_bufi(ivec4(out_l, out_c, n, 0), out_strides); + t_out[out_idx] = op(sum, T(out_min), T(out_max)); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.yaml new file mode 100644 index 00000000000..2533b17b2b1 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv1d_pw: + parameter_names_with_default_values: + OPERATOR: X + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: conv1d_pw + - NAME: conv1d_pw_clamp + OPERATOR: clamp(X, A, B) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw_texture.glsl new file mode 100644 index 00000000000..95c542946c0 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw_texture.glsl @@ -0,0 +1,68 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions("texture3d", DTYPE)} +${define_required_extensions("buffer", DTYPE)} + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} +#define T ${buffer_scalar_type(DTYPE)} + +#define op(X, A, B) ${OPERATOR} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_in", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")} +${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")} + +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "int", "in_group_size", "int", "out_group_size")} +${layout_declare_ubo(B, "float", "out_min", "float", "out_max")} + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +/* + * Pointwise (kernel_size=1) 1D convolution for width-packed texture3d tensors. + * + * Each invocation computes one output texel containing 4 adjacent width + * positions. The reduction loops over input channels, multiplying a scalar + * weight by the full 4-wide input texel. + * + * Tensor layout: [N, C, L] stored as width-packed texture3d where + * texture x = L / 4 (packed width), y = C, z = N. + * + * Weight is [C_out, C_in/groups, 1] stored as a contiguous buffer with + * row-major layout: weight[c_out, ic] = t_weight[c_out * in_group_size + ic]. + */ +void main() { + const int out_l = int(gl_GlobalInvocationID.x); + const int out_c = int(gl_GlobalInvocationID.y); + const int n = int(gl_GlobalInvocationID.z); + + if (out_l >= out_limits.x || out_c >= out_limits.y || n >= out_limits.z) { + return; + } + + const int c_start = (out_c / out_group_size) * in_group_size; + const int w_base = out_c * in_group_size; + + VEC4_T sum = VEC4_T(0); + for (int ic = 0; ic < in_group_size; ic++) { + const VEC4_T in_texel = texelFetch(t_in, ivec3(out_l, c_start + ic, n), 0); + const T w = t_weight[w_base + ic]; + sum = fma(VEC4_T(w), in_texel, sum); + } + + sum += VEC4_T(T(t_bias[out_c])); + imageStore(t_out, ivec3(out_l, out_c, n), op(sum, VEC4_T(out_min), VEC4_T(out_max))); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw_texture.yaml new file mode 100644 index 00000000000..67c8dddd623 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_pw_texture.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv1d_pw_texture: + parameter_names_with_default_values: + OPERATOR: X + DTYPE: float + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: conv1d_pw_texture + - NAME: conv1d_pw_texture_clamp + OPERATOR: clamp(X, A, B) diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_texture.glsl b/backends/vulkan/runtime/graph/ops/glsl/conv1d_texture.glsl new file mode 100644 index 00000000000..d370561f9c8 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_texture.glsl @@ -0,0 +1,83 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +#define PRECISION ${PRECISION} + +#define VEC4_T ${texel_type(DTYPE)} + +#define op(X, A, B) ${OPERATOR} + +layout(std430) buffer; + +${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "kernel_in", DTYPE, STORAGE)} +${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)} + +${layout_declare_ubo(B, "ivec3", "out_limits")} +${layout_declare_ubo(B, "ivec4", "in_sizes")} + +${layout_declare_ubo(B,"int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation", "int", "in_group_size", "int", "out_group_size")} + +${layout_declare_ubo(B, "float", "out_min", "float", "out_max")} + +#include "indexing_utils.h" + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); + +${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); + +${layout_declare_spec_const(C, "int", "kernel_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 kernel_axis_map = unhash_axis_map(kernel_layout); + +${layout_declare_spec_const(C, "int", "bias_layout", "DEFAULT_LAYOUT")} +const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout); + +void main() { + const ivec3 lpos = ivec3(gl_GlobalInvocationID); + + if (any(greaterThanEqual(lpos, out_limits))) { + return; + } + + const int out_c = lpos.y; + + const int c_start = (out_c / out_group_size) * in_group_size; + const int c_end = c_start + in_group_size; + + const int out_l = lpos.x; + + const int N = lpos.z; + + const int in_l = out_l * stride - padding; + VEC4_T sum = VEC4_T(0); + + const int out_c_packed_index = out_c >> 2; + const int out_c_packed_lane = out_c & 0x3; + + for (int in_c = c_start; in_c < c_end; ++in_c) { + for (int k = 0; k < kernel_size; k++) { + const ivec3 w_lpos = ivec3(k, in_c % in_group_size, out_c_packed_index); + const VEC4_T weight_texel = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map); + VEC4_T weight = VEC4_T(weight_texel[out_c_packed_lane]); + + const ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, N), in_axis_map); + sum = fma(weight, load_texel(t_in, in_pos), sum); + } + } + + const VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c_packed_index, 0, 0), bias_axis_map); + const ivec3 out_lpos = ivec3(out_l, out_c, N); + write_texel_lpos(t_out, out_lpos, op(sum + bias[out_c_packed_lane], out_min, out_max), out_axis_map); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/conv1d_texture.yaml b/backends/vulkan/runtime/graph/ops/glsl/conv1d_texture.yaml new file mode 100644 index 00000000000..b97b1dde63d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/conv1d_texture.yaml @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +conv1d_texture: + parameter_names_with_default_values: + OPERATOR: X + DTYPE: float + STORAGE: texture3d + generate_variant_forall: + DTYPE: + - VALUE: half + - VALUE: float + shader_variants: + - NAME: conv1d_texture + - NAME: conv1d_texture_clamp + OPERATOR: clamp(X, A, B) diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv1d.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv1d.cpp new file mode 100644 index 00000000000..21188d9e88f --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Conv1d.cpp @@ -0,0 +1,302 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include +#include + +#include + +#include +#include + +#include + +namespace vkcompute { + +enum class Conv1dMethod : uint8_t { + Depthwise, + Pointwise, + General, +}; + +void resize_conv1d_buf_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args) { + const ValueRef out = args.at(0).refs.at(0); + const ValueRef self = args.at(1).refs.at(0); + TensorRefPtr weight_ref = graph->get_tref(extra_args.at(0)); + + const int64_t stride_size = graph->get_int_list(extra_args.at(1))->at(0); + const int64_t padding_size = graph->get_int_list(extra_args.at(2))->at(0); + const int64_t dilation_size = graph->get_int_list(extra_args.at(3))->at(0); + + const std::vector& weight_sizes = weight_ref->sizes; + const std::vector in_sizes = graph->sizes_of(self); + const size_t ndim = in_sizes.size(); + std::vector new_out_sizes(ndim); + + const int64_t kernel_size = weight_sizes.at(2); + const int64_t in_length = in_sizes.at(2); + + new_out_sizes.at(0) = in_sizes.at(0); + new_out_sizes.at(1) = weight_sizes.at(0); + new_out_sizes.at(2) = calc_out_size( + in_length, kernel_size, stride_size, padding_size, dilation_size, false); + + graph->virtual_resize(out, new_out_sizes); +} + +ValueRef prepack_conv1d_bias( + ComputeGraph& graph, + const ValueRef vref, + const ValueRef weight_data, + const int64_t out_channels) { + ValueRef v = graph.add_tensor( + {out_channels}, + graph.dtype_of(weight_data), + utils::kBuffer, + utils::kWidthPacked); + + // Use staging dtype from weight (vref may be None for bias=None). + vkapi::ShaderInfo shader = get_nchw_to_tensor_shader( + graph, v, graph.get_staging_dtype_for(weight_data)); + + // Must match add_prepack_standard_node's bindings for buffer-backed tensors. + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + shader, + graph.create_global_wg_size(v), + graph.create_local_wg_size(v), + vref, + v, + // Parameter Buffers + {graph.buffer_meta_ubo(v)}, + // Specialization Constants: layout hash + transpose_hw=0 + {graph.hashed_layout_of(v), 0}, + // Push Constants: sizes, strides, numel + {graph.sizes_pc_of(v), graph.strides_pc_of(v), graph.numel_pc_of(v)})); + + return v; +} + +utils::uvec3 conv1d_buf_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args) { + (void)shader; + (void)resize_args; + const ValueRef out = args.at(0).refs.at(0); + return { + graph->size_at(-1, out), // L_out + graph->size_at(-2, out), // C_out + graph->size_at(-3, out), // N + }; +} + +static void add_conv1d_general_buf_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef arg_weight, + const ValueRef arg_bias, + const ValueRef out, + const Kernel1dParams& kernel_params, + const float out_min_val, + const float out_max_val, + const bool clamp_out, + const ValueRef weight_data, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation) { + struct OutputParams { + float out_min; + float out_max; + } out_params{out_min_val, out_max_val}; + + std::string kernel_name = clamp_out ? "conv1d_clamp" : "conv1d"; + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + conv1d_buf_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, + // Shader params buffers (UBOs) - must match conv1d.glsl binding order + { + graph.sizes_ubo(out), + graph.strides_ubo(out), + graph.sizes_ubo(in), + graph.strides_ubo(in), + graph.strides_ubo(arg_weight), + graph.create_params_buffer(kernel_params), + graph.create_params_buffer(out_params), + }, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {weight_data, stride, padding, dilation}, + // Resizing Logic + resize_conv1d_buf_node)); +} + +static Conv1dMethod get_conv1d_method( + const std::vector& weight_sizes, + const int64_t groups) { + if (weight_sizes.at(0) == groups && weight_sizes.at(1) == 1) { + return Conv1dMethod::Depthwise; + } + if (weight_sizes.at(2) == 1) { + return Conv1dMethod::Pointwise; + } + return Conv1dMethod::General; +} + +void add_conv1d_dw_texture_entry( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const float out_min_val, + const float out_max_val, + const ValueRef out, + const bool clamp_out) { + const int64_t groups_val = graph.get_int(groups); + const auto weight_sizes = graph.sizes_of(weight_data); + + ValueRef arg_weight = + prepack_standard(graph, weight_data, utils::kBuffer, utils::kWidthPacked); + + const int64_t out_channels = weight_sizes.at(0); + ValueRef arg_bias = + prepack_conv1d_bias(graph, bias, weight_data, out_channels); + + const Kernel1dParams kernel_params = { + static_cast(weight_sizes.at(2)), // kernel_size + static_cast(graph.get_int_list(stride)->at(0)), + static_cast(graph.get_int_list(padding)->at(0)), + static_cast(graph.get_int_list(dilation)->at(0)), + static_cast(weight_sizes.at(1)), // in_group_size = C_in/groups (=1) + static_cast(out_channels / groups_val), // out_group_size (=1) + }; + + add_conv1d_dw_texture_node( + graph, + in, + arg_weight, + arg_bias, + out, + kernel_params, + out_min_val, + out_max_val, + clamp_out, + weight_data, + stride, + padding, + dilation); +} + +void add_conv1d_buf_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const float out_min_val, + const float out_max_val, + const ValueRef out, + const bool clamp_out) { + const int64_t groups_val = graph.get_int(groups); + const auto weight_sizes = graph.sizes_of(weight_data); + + ValueRef arg_weight = + prepack_standard(graph, weight_data, utils::kBuffer, utils::kWidthPacked); + + const int64_t out_channels = weight_sizes.at(0); + ValueRef arg_bias = + prepack_conv1d_bias(graph, bias, weight_data, out_channels); + + const Kernel1dParams kernel_params = { + static_cast(weight_sizes.at(2)), // kernel_size + static_cast(graph.get_int_list(stride)->at(0)), + static_cast(graph.get_int_list(padding)->at(0)), + static_cast(graph.get_int_list(dilation)->at(0)), + static_cast(weight_sizes.at(1)), // in_group_size = C_in/groups + static_cast(out_channels / groups_val), // out_group_size + }; + + const Conv1dMethod method = get_conv1d_method(weight_sizes, groups_val); + + switch (method) { + case Conv1dMethod::Depthwise: + add_conv1d_dw_buf_node( + graph, + in, + arg_weight, + arg_bias, + out, + kernel_params, + out_min_val, + out_max_val, + clamp_out, + weight_data, + stride, + padding, + dilation); + break; + case Conv1dMethod::Pointwise: + add_conv1d_pw_buf_node( + graph, + in, + arg_weight, + arg_bias, + out, + kernel_params, + out_min_val, + out_max_val, + clamp_out, + weight_data, + stride, + padding, + dilation); + break; + case Conv1dMethod::General: + add_conv1d_general_buf_node( + graph, + in, + arg_weight, + arg_bias, + out, + kernel_params, + out_min_val, + out_max_val, + clamp_out, + weight_data, + stride, + padding, + dilation); + break; + } +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv1d.h b/backends/vulkan/runtime/graph/ops/impl/Conv1d.h new file mode 100644 index 00000000000..e5034c1364d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Conv1d.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace vkcompute { + +// Resize function shared by all buffer conv1d dispatch nodes. +void resize_conv1d_buf_node( + ComputeGraph* graph, + const std::vector& args, + const std::vector& extra_args); + +// Global workgroup size function shared by all buffer conv1d dispatch nodes. +// Returns (L_out, C_out, N) from the output tensor dimensions. +utils::uvec3 conv1d_buf_global_wg_size( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& resize_args); + +// Prepack a 1D bias tensor into a width-packed buffer. Uses weight_data's dtype +// and staging dtype so that bias=None (where vref has no dtype) is handled. +ValueRef prepack_conv1d_bias( + ComputeGraph& graph, + const ValueRef vref, + const ValueRef weight_data, + const int64_t out_channels); + +// Dispatch a depthwise 1D convolution node using width-packed buffer tensors. +// arg_weight and arg_bias must already be prepacked. +void add_conv1d_dw_buf_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef arg_weight, + const ValueRef arg_bias, + const ValueRef out, + const Kernel1dParams& kernel_params, + const float out_min_val, + const float out_max_val, + const bool clamp_out, + const ValueRef weight_data, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation); + +// Dispatch a depthwise 1D convolution node using width-packed TEXTURE_3D +// tensors. arg_weight (buffer) and arg_bias (buffer) must already be prepacked. +void add_conv1d_dw_texture_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef arg_weight, + const ValueRef arg_bias, + const ValueRef out, + const Kernel1dParams& kernel_params, + const float out_min_val, + const float out_max_val, + const bool clamp_out, + const ValueRef weight_data, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation); + +// Dispatch a pointwise (kernel_size=1) 1D convolution node using width-packed +// buffer tensors. arg_weight and arg_bias must already be prepacked. +void add_conv1d_pw_buf_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef arg_weight, + const ValueRef arg_bias, + const ValueRef out, + const Kernel1dParams& kernel_params, + const float out_min_val, + const float out_max_val, + const bool clamp_out, + const ValueRef weight_data, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation); + +// Top-level entry point. Determines whether the convolution is depthwise, +// pointwise, or general, prepacks weight/bias, and dispatches accordingly. +// Requires that `in` is a width-packed buffer tensor. +void add_conv1d_buf_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const float out_min_val, + const float out_max_val, + const ValueRef out, + const bool clamp_out); + +// Entry point for depthwise 1D convolution using width-packed texture3d +// input/output with buffer weight/bias. Handles prepacking internally. +void add_conv1d_dw_texture_entry( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const float out_min_val, + const float out_max_val, + const ValueRef out, + const bool clamp_out); + +// Entry point for pointwise (kernel_size=1) 1D convolution using width-packed +// texture3d input/output with buffer weight/bias. Handles prepacking +// internally. +void add_conv1d_pw_texture_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const float out_min_val, + const float out_max_val, + const ValueRef out, + const bool clamp_out); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv1dDW.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv1dDW.cpp new file mode 100644 index 00000000000..cd123867002 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Conv1dDW.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +#include + +namespace vkcompute { + +void add_conv1d_dw_buf_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef arg_weight, + const ValueRef arg_bias, + const ValueRef out, + const Kernel1dParams& kernel_params, + const float out_min_val, + const float out_max_val, + const bool clamp_out, + const ValueRef weight_data, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation) { + struct OutputParams { + float out_min; + float out_max; + } out_params{out_min_val, out_max_val}; + + std::string kernel_name = clamp_out ? "conv1d_dw_clamp" : "conv1d_dw"; + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + conv1d_buf_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, + // Shader params buffers (UBOs) - must match conv1d_dw.glsl binding order + { + graph.sizes_ubo(out), + graph.strides_ubo(out), + graph.sizes_ubo(in), + graph.strides_ubo(in), + graph.strides_ubo(arg_weight), + graph.create_params_buffer(kernel_params), + graph.create_params_buffer(out_params), + }, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {weight_data, stride, padding, dilation}, + // Resizing Logic + resize_conv1d_buf_node)); +} + +void add_conv1d_dw_texture_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef arg_weight, + const ValueRef arg_bias, + const ValueRef out, + const Kernel1dParams& kernel_params, + const float out_min_val, + const float out_max_val, + const bool clamp_out, + const ValueRef weight_data, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation) { + struct ConvDWParams { + int32_t kernel_size; + int32_t stride; + int32_t padding; + int32_t dilation; + } conv_dw_params{ + kernel_params.kernel_size, + kernel_params.stride, + kernel_params.padding, + kernel_params.dilation, + }; + + struct InLengthParams { + int32_t in_length; + } in_length_params{ + graph.size_at(-1, in), + }; + + struct OutputParams { + float out_min; + float out_max; + } out_params{out_min_val, out_max_val}; + + std::string kernel_name = + clamp_out ? "conv1d_dw_texture_clamp" : "conv1d_dw_texture"; + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + // Global workgroup size: (L_out_texels, C_out, N) + [](ComputeGraph* g, + const vkapi::ShaderInfo&, + const std::vector& a, + const std::vector&) -> utils::uvec3 { + const ValueRef o = a.at(0).refs.at(0); + const auto limits = g->logical_limits_of(o); + return { + static_cast(limits[0]), + static_cast(limits[1]), + static_cast(limits[2])}; + }, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, + // UBOs - must match conv1d_dw_texture.glsl binding order + { + graph.logical_limits_ubo(out), + graph.create_params_buffer(conv_dw_params), + graph.create_params_buffer(in_length_params), + graph.create_params_buffer(out_params), + }, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {weight_data, stride, padding, dilation}, + // Resizing Logic + resize_conv1d_buf_node)); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp b/backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp new file mode 100644 index 00000000000..28971707bfa --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Conv1dPW.cpp @@ -0,0 +1,169 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +#include + +#include + +namespace vkcompute { + +void add_conv1d_pw_buf_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef arg_weight, + const ValueRef arg_bias, + const ValueRef out, + const Kernel1dParams& kernel_params, + const float out_min_val, + const float out_max_val, + const bool clamp_out, + const ValueRef weight_data, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation) { + struct OutputParams { + float out_min; + float out_max; + } out_params{out_min_val, out_max_val}; + + std::string kernel_name = clamp_out ? "conv1d_pw_clamp" : "conv1d_pw"; + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + conv1d_buf_global_wg_size, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, + // Shader params buffers (UBOs) - must match conv1d_pw.glsl binding order + { + graph.sizes_ubo(out), + graph.strides_ubo(out), + graph.sizes_ubo(in), + graph.strides_ubo(in), + graph.strides_ubo(arg_weight), + graph.create_params_buffer(kernel_params), + graph.create_params_buffer(out_params), + }, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {weight_data, stride, padding, dilation}, + // Resizing Logic + resize_conv1d_buf_node)); +} + +static ValueRef prepack_conv1d_pw_weight_buffer( + ComputeGraph& graph, + const ValueRef vref) { + const auto sizes = graph.sizes_of(vref); + // Weight [C_out, C_in/g, 1] -> flatten to [C_out * C_in/g] buffer + const int64_t numel = sizes.at(0) * sizes.at(1); + ValueRef v = graph.add_tensor( + {numel}, graph.dtype_of(vref), utils::kBuffer, utils::kWidthPacked); + + vkapi::ShaderInfo shader = + get_nchw_to_tensor_shader(graph, v, graph.get_staging_dtype_for(vref)); + + graph.prepack_nodes().emplace_back(new PrepackNode( + graph, + shader, + graph.create_global_wg_size(v), + graph.create_local_wg_size(v), + vref, + v, + {graph.buffer_meta_ubo(v)}, + {graph.hashed_layout_of(v), 0}, + {graph.sizes_pc_of(v), graph.strides_pc_of(v), graph.numel_pc_of(v)})); + + return v; +} + +void add_conv1d_pw_texture_node( + ComputeGraph& graph, + const ValueRef in, + const ValueRef weight_data, + const ValueRef bias, + const ValueRef stride, + const ValueRef padding, + const ValueRef dilation, + const ValueRef groups, + const float out_min_val, + const float out_max_val, + const ValueRef out, + const bool clamp_out) { + const int64_t groups_val = graph.get_int(groups); + const auto weight_sizes = graph.sizes_of(weight_data); + const int64_t out_channels = weight_sizes.at(0); + const int64_t in_group_size = weight_sizes.at(1); + const int64_t out_group_size = out_channels / groups_val; + + ValueRef arg_weight = prepack_conv1d_pw_weight_buffer(graph, weight_data); + ValueRef arg_bias = + prepack_conv1d_bias(graph, bias, weight_data, out_channels); + + struct ConvParams { + int32_t in_group_size; + int32_t out_group_size; + } conv_params{ + static_cast(in_group_size), + static_cast(out_group_size), + }; + + struct OutputParams { + float out_min; + float out_max; + } out_params{out_min_val, out_max_val}; + + std::string kernel_name = + clamp_out ? "conv1d_pw_texture_clamp" : "conv1d_pw_texture"; + add_dtype_suffix(kernel_name, graph.dtype_of(out)); + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + // Global workgroup size: (L_out_texels, C_out, N) + [](ComputeGraph* g, + const vkapi::ShaderInfo&, + const std::vector& a, + const std::vector&) -> utils::uvec3 { + const ValueRef o = a.at(0).refs.at(0); + const auto limits = g->logical_limits_of(o); + return { + static_cast(limits[0]), + static_cast(limits[1]), + static_cast(limits[2])}; + }, + default_pick_local_wg_size, + // Inputs and Outputs + {{out, vkapi::kWrite}, {{in, arg_weight, arg_bias}, vkapi::kRead}}, + // UBOs - must match conv1d_pw_texture.glsl binding order + { + graph.logical_limits_ubo(out), + graph.create_params_buffer(conv_params), + graph.create_params_buffer(out_params), + }, + // Push Constants + {}, + // Specialization Constants + {}, + // Resize Args + {weight_data, stride, padding, dilation}, + // Resizing Logic + resize_conv1d_buf_node)); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp index 2bf3f8f726d..e62ffb52551 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Convolution.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -442,12 +443,13 @@ utils::uvec3 conv1d_global_wg_size( (void)resize_args; const ValueRef out = args.at(0).refs.at(0); - return {// out length - graph->size_at(-1, out), - // out channels - static_cast(graph->size_at(-2, out)), - // out batches - utils::div_up_4(graph->size_at(-3, out))}; + return { + // out length + graph->size_at(-1, out), + // out channels + static_cast(graph->size_at(-2, out)), + // out batches + utils::div_up_4(graph->size_at(-3, out))}; } void add_conv2d_node( @@ -709,7 +711,7 @@ void add_conv1d_node( const OutputParams out_params = {out_min_val, out_max_val}; - std::string kernel_name("conv1d"); + std::string kernel_name("conv1d_texture"); if (clamp_out) { kernel_name += "_clamp"; } @@ -783,36 +785,142 @@ void conv(ComputeGraph& graph, const std::vector& args) { true); } } else { - if (args.size() == 10) { - // ordinary conv1d - return add_conv1d_node( - graph, - args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[8], - /*out_min = */ kDummyValueRef, - /*out_max = */ kDummyValueRef, - args[9], - false); + const bool use_buf_path = graph.is_buffer_storage(args[0]) && + graph.packed_dim_of(args[0]) == WHCN::kWidthDim; + // Width-packed texture pointwise conv1d: avoid texture->buffer transition + const bool use_texture_pw_path = !use_buf_path && + !graph.is_buffer_storage(args[0]) && + graph.packed_dim_of(args[0]) == WHCN::kWidthDim && + graph.sizes_of(args[1]).at(2) == 1; + // Width-packed texture depthwise conv1d + const auto weight_sizes_1d = graph.sizes_of(args[1]); + const int64_t groups_1d = graph.get_int(args[8]); + const bool use_texture_dw_path = !use_buf_path && !use_texture_pw_path && + !graph.is_buffer_storage(args[0]) && + graph.packed_dim_of(args[0]) == WHCN::kWidthDim && + weight_sizes_1d.at(0) == groups_1d && weight_sizes_1d.at(1) == 1; + if (use_texture_dw_path) { + if (args.size() == 10) { + return add_conv1d_dw_texture_entry( + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[8], + /*out_min_val = */ 0.0f, + /*out_max_val = */ 0.0f, + args[9], + false); + } else { + return add_conv1d_dw_texture_entry( + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[8], + graph.extract_scalar(args[9]), + graph.extract_scalar(args[10]), + args[11], + true); + } + } else if (use_texture_pw_path) { + if (args.size() == 10) { + return add_conv1d_pw_texture_node( + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[8], + /*out_min_val = */ 0.0f, + /*out_max_val = */ 0.0f, + args[9], + false); + } else { + return add_conv1d_pw_texture_node( + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[8], + graph.extract_scalar(args[9]), + graph.extract_scalar(args[10]), + args[11], + true); + } + } else if (use_buf_path) { + if (args.size() == 10) { + return add_conv1d_buf_node( + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[8], + /*out_min_val = */ 0.0f, + /*out_max_val = */ 0.0f, + args[9], + false); + } else { + return add_conv1d_buf_node( + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[8], + graph.extract_scalar(args[9]), + graph.extract_scalar(args[10]), + args[11], + true); + } } else { - // conv1d with clamp - return add_conv1d_node( - graph, - args[0], - args[1], - args[2], - args[3], - args[4], - args[5], - args[8], - args[9], - args[10], - args[11], - true); + if (args.size() == 10) { + // ordinary conv1d (texture path) + return add_conv1d_node( + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[8], + /*out_min = */ kDummyValueRef, + /*out_max = */ kDummyValueRef, + args[9], + false); + } else { + // conv1d with clamp (texture path) + return add_conv1d_node( + graph, + args[0], + args[1], + args[2], + args[3], + args[4], + args[5], + args[8], + args[9], + args[10], + args[11], + true); + } } } } diff --git a/backends/vulkan/test/op_tests/cases.py b/backends/vulkan/test/op_tests/cases.py index 7a18a1282ec..b632bfffbc8 100644 --- a/backends/vulkan/test/op_tests/cases.py +++ b/backends/vulkan/test/op_tests/cases.py @@ -704,7 +704,118 @@ def get_conv_inputs(): "utils::kChannelsPacked", ] test_suite_dw.test_name_suffix = "dw" - return [test_suite, test_suite_pw, test_suite_dw] + + # Extract 1D conv cases (3D input tensors) from test_cases for buffer path + test_cases_1d = [tc for tc in test_cases if len(tc.self) == 3] + test_cases_1d_dw = [ + Test( + self=(1, 6, 7), + weight=(6, 1, 3), + bias=(6,), + stride=[1], + padding=[0], + dilation=[1], + transposed=False, + output_padding=[0], + groups=6, + ), + Test( + self=(2, 20, 30), + weight=(10, 4, 6), + bias=(10,), + stride=[5], + padding=[5], + dilation=[3], + transposed=False, + output_padding=[0], + groups=5, + ), + Test( + self=(1, 9, 11), + weight=(9, 1, 3), + bias=None, + stride=[1], + padding=[0], + dilation=[1], + transposed=False, + output_padding=[0], + groups=9, + ), + Test( + self=(5, 15, 30), + weight=(20, 3, 3), + bias=None, + stride=[3], + padding=[5], + dilation=[7], + transposed=False, + output_padding=[0], + groups=5, + ), + ] + test_cases_1d_pw = [ + Test( + self=(1, 16, 64), + weight=(8, 16, 1), + bias=(8,), + stride=[1], + padding=[0], + dilation=[1], + transposed=False, + output_padding=[0], + groups=1, + ), + Test( + self=(2, 8, 32), + weight=(16, 8, 1), + bias=(16,), + stride=[1], + padding=[0], + dilation=[1], + transposed=False, + output_padding=[0], + groups=1, + ), + ] + + # Buffer path for non-pointwise 1D convolution + test_suite_1d_buf = VkTestSuite(test_cases_1d_dw) + test_suite_1d_buf.layouts = ["utils::kWidthPacked"] + test_suite_1d_buf.storage_types = ["utils::kBuffer"] + test_suite_1d_buf.test_name_suffix = "1d_buf" + + # Buffer path for pointwise 1D convolution + test_suite_1d_pw_buf = VkTestSuite(test_cases_1d_pw) + test_suite_1d_pw_buf.layouts = ["utils::kWidthPacked"] + test_suite_1d_pw_buf.storage_types = ["utils::kBuffer"] + test_suite_1d_pw_buf.test_name_suffix = "1d_pw_buf" + + # Texture path for pointwise 1D convolution + test_suite_1d_pw_tex = VkTestSuite(test_cases_1d_pw) + test_suite_1d_pw_tex.layouts = ["utils::kWidthPacked"] + test_suite_1d_pw_tex.storage_types = ["utils::kTexture3D"] + test_suite_1d_pw_tex.test_name_suffix = "1d_pw_tex" + + # Depthwise-only cases for 1D depthwise convolution + test_cases_1d_dw_only = [ + tc for tc in test_cases_1d_dw if tc.weight[1] == 1 and tc.weight[0] == tc.groups + ] + + # Texture path for depthwise 1D convolution + test_suite_1d_dw_tex = VkTestSuite(test_cases_1d_dw_only) + test_suite_1d_dw_tex.layouts = ["utils::kWidthPacked"] + test_suite_1d_dw_tex.storage_types = ["utils::kTexture3D"] + test_suite_1d_dw_tex.test_name_suffix = "1d_dw_tex" + + return [ + test_suite, + test_suite_pw, + test_suite_dw, + test_suite_1d_buf, + test_suite_1d_pw_buf, + test_suite_1d_pw_tex, + test_suite_1d_dw_tex, + ] @register_test_suite("aten.native_layer_norm.default")