diff --git a/kernels/optimized/CMakeLists.txt b/kernels/optimized/CMakeLists.txt index 69dae952255..210a489f38a 100644 --- a/kernels/optimized/CMakeLists.txt +++ b/kernels/optimized/CMakeLists.txt @@ -75,6 +75,22 @@ target_link_libraries( kernels_util_all_deps ) target_compile_options(optimized_kernels PUBLIC ${_common_compile_options}) + +# op_grid_sampler_2d_fp16_hw.cpp uses hardware fp16 NEON intrinsics +# (vcvt_f32_f16 / vld1_f16). Those are part of the ARMv8.2-a+fp16 extension and +# raise SIGILL on chips without it. Scope the `-march` flag to just that +# translation unit. The main op_grid_sampler_2d.cpp (which hosts the runtime +# dispatcher via cpuinfo_has_arm_neon_fp16) and the fp16 software-convert path +# stay on plain ARMv8 so they can run on any chip. +if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64" OR ANDROID_ABI STREQUAL + "arm64-v8a" +) + set_source_files_properties( + ${EXECUTORCH_ROOT}/kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.cpp + PROPERTIES COMPILE_OPTIONS "-march=armv8.2-a+fp16" + ) +endif() + # Build a library for _optimized_kernels_srcs # # optimized_ops_lib: Register optimized ops kernels into Executorch runtime diff --git a/kernels/optimized/cpu/op_grid_sampler_2d.cpp b/kernels/optimized/cpu/op_grid_sampler_2d.cpp new file mode 100644 index 00000000000..f4738040f39 --- /dev/null +++ b/kernels/optimized/cpu/op_grid_sampler_2d.cpp @@ -0,0 +1,434 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Optimized grid_sampler_2d.out for CPU. On aarch64 this is a NEON-vectorized +// implementation for the common (bilinear + zeros padding) case. fp16 inputs +// are promoted to fp32 for weight computation and accumulation and cast back +// on store — this avoids fp16 catastrophic cancellation on `ix_se - ix`-style +// weight subtractions in the portable kernel. +// +// fp16 comes in two flavors to avoid SIGILL on ARMv8 chips without the +// +fp16 extension: +// +// * Hardware path (op_grid_sampler_2d_fp16_hw.cpp) — compiled with +// `-march=armv8.2-a+fp16`. Uses hardware fp16 NEON instructions +// (vld1_f16 / vcvt_f32_f16 / ...). Fast on capable chips; illegal +// instructions on older ones. +// +// * Software path (below) — plain ARMv8 NEON. Converts fp16<->fp32 in +// software via `c10::Half`'s portable conversion. Slower per +// conversion but safe on any ARMv8 CPU. +// +// A runtime cpuinfo_has_arm_neon_fp16() check picks the right one. Non-aarch64 +// targets, and any unsupported interpolation/padding/layout combination, +// delegate to the portable kernel. + +#include + +#ifdef __aarch64__ +#include +#include +#endif + +#include + +#include + +namespace torch { +namespace executor { +namespace native { + +using executorch::aten::ScalarType; +using executorch::aten::Tensor; + +// Portable kernel (same-op fallback). Both libs link into the same binary. +Tensor& grid_sampler_2d_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + Tensor& out); + +#ifdef __aarch64__ +namespace opt_grid_sampler_2d_internal { +// Declared in op_grid_sampler_2d_fp16_hw.cpp, compiled separately with +// `-march=armv8.2-a+fp16`. Only safe to call when +// cpuinfo_has_arm_neon_fp16() is true. +void grid_sampler_2d_bilinear_fp16_hw( + const void* input, + const void* grid, + void* output, + int N, + int C, + int H_in, + int W_in, + int H_out, + int W_out, + bool align_corners); +} // namespace opt_grid_sampler_2d_internal +#endif + +#ifdef __aarch64__ +namespace { + +// -------------------- fp32 (plain ARMv8 NEON) -------------------- + +inline void bilinear_all_channels_f32( + const float* input_n, + float* output_n, + int C, + int H_in, + int W_in, + int H_out, + int W_out, + int h_out, + int w_out, + float gx, + float gy) { + const int x0 = static_cast(std::floor(gx)); + const int y0 = static_cast(std::floor(gy)); + const int x1 = x0 + 1; + const int y1 = y0 + 1; + const float fx = gx - static_cast(x0); + const float fy = gy - static_cast(y0); + + const bool tl_v = static_cast(x0) < static_cast(W_in) && + static_cast(y0) < static_cast(H_in); + const bool tr_v = static_cast(x1) < static_cast(W_in) && + static_cast(y0) < static_cast(H_in); + const bool bl_v = static_cast(x0) < static_cast(W_in) && + static_cast(y1) < static_cast(H_in); + const bool br_v = static_cast(x1) < static_cast(W_in) && + static_cast(y1) < static_cast(H_in); + + const int off_tl = y0 * W_in + x0; + const int off_tr = y0 * W_in + x1; + const int off_bl = y1 * W_in + x0; + const int off_br = y1 * W_in + x1; + const int spatial_in = H_in * W_in; + const int spatial_out = H_out * W_out; + const int out_off = h_out * W_out + w_out; + + const float32x4_t vw_tl = vdupq_n_f32((1.0f - fx) * (1.0f - fy)); + const float32x4_t vw_tr = vdupq_n_f32(fx * (1.0f - fy)); + const float32x4_t vw_bl = vdupq_n_f32((1.0f - fx) * fy); + const float32x4_t vw_br = vdupq_n_f32(fx * fy); + + int c = 0; + for (; c + 3 < C; c += 4) { + const float* p0 = input_n + (c + 0) * spatial_in; + const float* p1 = input_n + (c + 1) * spatial_in; + const float* p2 = input_n + (c + 2) * spatial_in; + const float* p3 = input_n + (c + 3) * spatial_in; + + float tl[4] = {0}, tr[4] = {0}, bl[4] = {0}, br[4] = {0}; + if (tl_v) { + tl[0] = p0[off_tl]; + tl[1] = p1[off_tl]; + tl[2] = p2[off_tl]; + tl[3] = p3[off_tl]; + } + if (tr_v) { + tr[0] = p0[off_tr]; + tr[1] = p1[off_tr]; + tr[2] = p2[off_tr]; + tr[3] = p3[off_tr]; + } + if (bl_v) { + bl[0] = p0[off_bl]; + bl[1] = p1[off_bl]; + bl[2] = p2[off_bl]; + bl[3] = p3[off_bl]; + } + if (br_v) { + br[0] = p0[off_br]; + br[1] = p1[off_br]; + br[2] = p2[off_br]; + br[3] = p3[off_br]; + } + + float32x4_t result = vmulq_f32(vw_tl, vld1q_f32(tl)); + result = vfmaq_f32(result, vw_tr, vld1q_f32(tr)); + result = vfmaq_f32(result, vw_bl, vld1q_f32(bl)); + result = vfmaq_f32(result, vw_br, vld1q_f32(br)); + + float res[4]; + vst1q_f32(res, result); + output_n[(c + 0) * spatial_out + out_off] = res[0]; + output_n[(c + 1) * spatial_out + out_off] = res[1]; + output_n[(c + 2) * spatial_out + out_off] = res[2]; + output_n[(c + 3) * spatial_out + out_off] = res[3]; + } + + const float w_tl = (1.0f - fx) * (1.0f - fy); + const float w_tr = fx * (1.0f - fy); + const float w_bl = (1.0f - fx) * fy; + const float w_br = fx * fy; + for (; c < C; ++c) { + const float* p = input_n + c * spatial_in; + float v = 0.0f; + if (tl_v) + v += w_tl * p[off_tl]; + if (tr_v) + v += w_tr * p[off_tr]; + if (bl_v) + v += w_bl * p[off_bl]; + if (br_v) + v += w_br * p[off_br]; + output_n[c * spatial_out + out_off] = v; + } +} + +// -------------------- fp16 software-convert path -------------------- +// +// Uses only plain ARMv8 NEON. fp16 <-> fp32 conversion goes through +// c10::Half's portable `operator float()` / constructor, which is a +// software conversion on chips that lack the +fp16 extension. + +inline void bilinear_all_channels_f16_sw( + const c10::Half* input_n, + c10::Half* output_n, + int C, + int H_in, + int W_in, + int H_out, + int W_out, + int h_out, + int w_out, + float gx, + float gy) { + const int x0 = static_cast(std::floor(gx)); + const int y0 = static_cast(std::floor(gy)); + const int x1 = x0 + 1; + const int y1 = y0 + 1; + const float fx = gx - static_cast(x0); + const float fy = gy - static_cast(y0); + + const bool tl_v = static_cast(x0) < static_cast(W_in) && + static_cast(y0) < static_cast(H_in); + const bool tr_v = static_cast(x1) < static_cast(W_in) && + static_cast(y0) < static_cast(H_in); + const bool bl_v = static_cast(x0) < static_cast(W_in) && + static_cast(y1) < static_cast(H_in); + const bool br_v = static_cast(x1) < static_cast(W_in) && + static_cast(y1) < static_cast(H_in); + + const int off_tl = y0 * W_in + x0; + const int off_tr = y0 * W_in + x1; + const int off_bl = y1 * W_in + x0; + const int off_br = y1 * W_in + x1; + const int spatial_in = H_in * W_in; + const int spatial_out = H_out * W_out; + const int out_off = h_out * W_out + w_out; + + const float32x4_t vw_tl = vdupq_n_f32((1.0f - fx) * (1.0f - fy)); + const float32x4_t vw_tr = vdupq_n_f32(fx * (1.0f - fy)); + const float32x4_t vw_bl = vdupq_n_f32((1.0f - fx) * fy); + const float32x4_t vw_br = vdupq_n_f32(fx * fy); + + int c = 0; + for (; c + 3 < C; c += 4) { + const c10::Half* p0 = input_n + (c + 0) * spatial_in; + const c10::Half* p1 = input_n + (c + 1) * spatial_in; + const c10::Half* p2 = input_n + (c + 2) * spatial_in; + const c10::Half* p3 = input_n + (c + 3) * spatial_in; + + // SW fp16 -> fp32: use c10::Half's portable conversion on each lane. + float tl[4] = {0}, tr[4] = {0}, bl[4] = {0}, br[4] = {0}; + if (tl_v) { + tl[0] = static_cast(p0[off_tl]); + tl[1] = static_cast(p1[off_tl]); + tl[2] = static_cast(p2[off_tl]); + tl[3] = static_cast(p3[off_tl]); + } + if (tr_v) { + tr[0] = static_cast(p0[off_tr]); + tr[1] = static_cast(p1[off_tr]); + tr[2] = static_cast(p2[off_tr]); + tr[3] = static_cast(p3[off_tr]); + } + if (bl_v) { + bl[0] = static_cast(p0[off_bl]); + bl[1] = static_cast(p1[off_bl]); + bl[2] = static_cast(p2[off_bl]); + bl[3] = static_cast(p3[off_bl]); + } + if (br_v) { + br[0] = static_cast(p0[off_br]); + br[1] = static_cast(p1[off_br]); + br[2] = static_cast(p2[off_br]); + br[3] = static_cast(p3[off_br]); + } + + float32x4_t result = vmulq_f32(vw_tl, vld1q_f32(tl)); + result = vfmaq_f32(result, vw_tr, vld1q_f32(tr)); + result = vfmaq_f32(result, vw_bl, vld1q_f32(bl)); + result = vfmaq_f32(result, vw_br, vld1q_f32(br)); + + float res[4]; + vst1q_f32(res, result); + // SW fp32 -> fp16 on store. + output_n[(c + 0) * spatial_out + out_off] = c10::Half(res[0]); + output_n[(c + 1) * spatial_out + out_off] = c10::Half(res[1]); + output_n[(c + 2) * spatial_out + out_off] = c10::Half(res[2]); + output_n[(c + 3) * spatial_out + out_off] = c10::Half(res[3]); + } + + const float w_tl = (1.0f - fx) * (1.0f - fy); + const float w_tr = fx * (1.0f - fy); + const float w_bl = (1.0f - fx) * fy; + const float w_br = fx * fy; + for (; c < C; ++c) { + const c10::Half* p = input_n + c * spatial_in; + float v = 0.0f; + if (tl_v) + v += w_tl * static_cast(p[off_tl]); + if (tr_v) + v += w_tr * static_cast(p[off_tr]); + if (bl_v) + v += w_bl * static_cast(p[off_bl]); + if (br_v) + v += w_br * static_cast(p[off_br]); + output_n[c * spatial_out + out_off] = c10::Half(v); + } +} + +template +void grid_sampler_2d_neon( + const SCALAR* input, + const SCALAR* grid, + SCALAR* output, + int N, + int C, + int H_in, + int W_in, + int H_out, + int W_out, + bool align_corners, + SampleFn sample_fn) { + const int spatial_in = H_in * W_in; + const int spatial_out = H_out * W_out; + + for (int n = 0; n < N; ++n) { + const SCALAR* input_n = input + n * C * spatial_in; + SCALAR* output_n = output + n * C * spatial_out; + const SCALAR* grid_n = grid + n * H_out * W_out * 2; + + for (int h = 0; h < H_out; ++h) { + if (h + 1 < H_out) { + __builtin_prefetch(grid_n + (h + 1) * W_out * 2, 0, 1); + } + for (int w = 0; w < W_out; ++w) { + const int grid_off = (h * W_out + w) * 2; + float gx = static_cast(grid_n[grid_off]); + float gy = static_cast(grid_n[grid_off + 1]); + if (align_corners) { + gx = (gx + 1.0f) * (W_in - 1) * 0.5f; + gy = (gy + 1.0f) * (H_in - 1) * 0.5f; + } else { + gx = (gx + 1.0f) * W_in * 0.5f - 0.5f; + gy = (gy + 1.0f) * H_in * 0.5f - 0.5f; + } + sample_fn(input_n, output_n, C, H_in, W_in, H_out, W_out, h, w, gx, gy); + } + } + } +} + +} // namespace +#endif // __aarch64__ + +Tensor& opt_grid_sampler_2d_out( + KernelRuntimeContext& ctx, + const Tensor& input, + const Tensor& grid, + int64_t interpolation_mode, + int64_t padding_mode, + bool align_corners, + Tensor& out) { + // The NEON paths index input/grid/out directly assuming a contiguous NCHW + // default-dim-order layout — no use of .strides() or .dim_order(). Fall + // back to portable for anything else. + const bool fast_eligible = tensor_is_default_dim_order(input) && + tensor_is_default_dim_order(grid) && tensor_is_default_dim_order(out) && + tensor_is_contiguous(input) && tensor_is_contiguous(grid) && + tensor_is_contiguous(out); + + if (interpolation_mode != 0 || padding_mode != 0 || !fast_eligible) { + return grid_sampler_2d_out( + ctx, input, grid, interpolation_mode, padding_mode, align_corners, out); + } +#ifndef __aarch64__ + return grid_sampler_2d_out( + ctx, input, grid, interpolation_mode, padding_mode, align_corners, out); +#else + const int N = static_cast(input.size(0)); + const int C = static_cast(input.size(1)); + const int H_in = static_cast(input.size(2)); + const int W_in = static_cast(input.size(3)); + const int H_out = static_cast(grid.size(1)); + const int W_out = static_cast(grid.size(2)); + + if (input.scalar_type() == ScalarType::Float) { + grid_sampler_2d_neon( + input.const_data_ptr(), + grid.const_data_ptr(), + out.mutable_data_ptr(), + N, + C, + H_in, + W_in, + H_out, + W_out, + align_corners, + bilinear_all_channels_f32); + return out; + } + if (input.scalar_type() == ScalarType::Half) { + if (cpuinfo_initialize() && cpuinfo_has_arm_neon_fp16()) { + // Hardware fp16 path — safe because the CPU supports the +fp16 + // extension. Declared in op_grid_sampler_2d_fp16_hw.cpp. + opt_grid_sampler_2d_internal::grid_sampler_2d_bilinear_fp16_hw( + input.const_data_ptr(), + grid.const_data_ptr(), + out.mutable_data_ptr(), + N, + C, + H_in, + W_in, + H_out, + W_out, + align_corners); + return out; + } + // Software fp16<->fp32 conversion path. Works on any ARMv8. + grid_sampler_2d_neon( + input.const_data_ptr(), + grid.const_data_ptr(), + out.mutable_data_ptr(), + N, + C, + H_in, + W_in, + H_out, + W_out, + align_corners, + bilinear_all_channels_f16_sw); + return out; + } + // Any other dtype: let portable handle it. + return grid_sampler_2d_out( + ctx, input, grid, interpolation_mode, padding_mode, align_corners, out); +#endif +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.cpp b/kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.cpp new file mode 100644 index 00000000000..92051cd5c9e --- /dev/null +++ b/kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.cpp @@ -0,0 +1,205 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Hardware-fp16 variant of the NEON grid_sampler_2d.out bilinear + zeros- +// padding fast path. This translation unit is compiled with +// `-march=armv8.2-a+fp16`, which lets the compiler emit hardware fp16 +// load/store/convert intrinsics (vld1_f16 / vcvt_f32_f16 / vst1_f16 / +// vcvt_f16_f32). Those instructions are undefined on ARMv8.0 and ARMv8.1 +// chips without the fp16 extension, so this entry point must only be +// invoked after a runtime CPU-feature check — see the dispatcher in +// op_grid_sampler_2d.cpp (cpuinfo_has_arm_neon_fp16). +// +// Math happens in fp32 regardless: we load fp16 from memory, convert to +// fp32 via the hardware instruction, do the weighted-sum FMA chain in +// fp32, convert back to fp16 on store. This matches the precision of +// the portable kernel once #19117 lands. + +#ifdef __aarch64__ + +#include +#include + +namespace torch { +namespace executor { +namespace native { +namespace opt_grid_sampler_2d_internal { + +namespace { + +// One output spatial location, all channels. +inline void bilinear_all_channels_fp16_hw_sample( + const __fp16* input_n, + __fp16* output_n, + int C, + int H_in, + int W_in, + int H_out, + int W_out, + int h_out, + int w_out, + float gx, + float gy) { + const int x0 = static_cast(std::floor(gx)); + const int y0 = static_cast(std::floor(gy)); + const int x1 = x0 + 1; + const int y1 = y0 + 1; + const float fx = gx - static_cast(x0); + const float fy = gy - static_cast(y0); + + const bool tl_v = static_cast(x0) < static_cast(W_in) && + static_cast(y0) < static_cast(H_in); + const bool tr_v = static_cast(x1) < static_cast(W_in) && + static_cast(y0) < static_cast(H_in); + const bool bl_v = static_cast(x0) < static_cast(W_in) && + static_cast(y1) < static_cast(H_in); + const bool br_v = static_cast(x1) < static_cast(W_in) && + static_cast(y1) < static_cast(H_in); + + const int off_tl = y0 * W_in + x0; + const int off_tr = y0 * W_in + x1; + const int off_bl = y1 * W_in + x0; + const int off_br = y1 * W_in + x1; + const int spatial_in = H_in * W_in; + const int spatial_out = H_out * W_out; + const int out_off = h_out * W_out + w_out; + + const float32x4_t vw_tl = vdupq_n_f32((1.0f - fx) * (1.0f - fy)); + const float32x4_t vw_tr = vdupq_n_f32(fx * (1.0f - fy)); + const float32x4_t vw_bl = vdupq_n_f32((1.0f - fx) * fy); + const float32x4_t vw_br = vdupq_n_f32(fx * fy); + + int c = 0; + for (; c + 3 < C; c += 4) { + const __fp16* p0 = input_n + (c + 0) * spatial_in; + const __fp16* p1 = input_n + (c + 1) * spatial_in; + const __fp16* p2 = input_n + (c + 2) * spatial_in; + const __fp16* p3 = input_n + (c + 3) * spatial_in; + + __fp16 tl[4] = {0}, tr[4] = {0}, bl[4] = {0}, br[4] = {0}; + if (tl_v) { + tl[0] = p0[off_tl]; + tl[1] = p1[off_tl]; + tl[2] = p2[off_tl]; + tl[3] = p3[off_tl]; + } + if (tr_v) { + tr[0] = p0[off_tr]; + tr[1] = p1[off_tr]; + tr[2] = p2[off_tr]; + tr[3] = p3[off_tr]; + } + if (bl_v) { + bl[0] = p0[off_bl]; + bl[1] = p1[off_bl]; + bl[2] = p2[off_bl]; + bl[3] = p3[off_bl]; + } + if (br_v) { + br[0] = p0[off_br]; + br[1] = p1[off_br]; + br[2] = p2[off_br]; + br[3] = p3[off_br]; + } + + // Hardware fp16 -> fp32 conversion (requires +fp16 extension). + const float32x4_t v_tl = vcvt_f32_f16(vld1_f16(tl)); + const float32x4_t v_tr = vcvt_f32_f16(vld1_f16(tr)); + const float32x4_t v_bl = vcvt_f32_f16(vld1_f16(bl)); + const float32x4_t v_br = vcvt_f32_f16(vld1_f16(br)); + + float32x4_t result = vmulq_f32(vw_tl, v_tl); + result = vfmaq_f32(result, vw_tr, v_tr); + result = vfmaq_f32(result, vw_bl, v_bl); + result = vfmaq_f32(result, vw_br, v_br); + + __fp16 res[4]; + vst1_f16(res, vcvt_f16_f32(result)); + output_n[(c + 0) * spatial_out + out_off] = res[0]; + output_n[(c + 1) * spatial_out + out_off] = res[1]; + output_n[(c + 2) * spatial_out + out_off] = res[2]; + output_n[(c + 3) * spatial_out + out_off] = res[3]; + } + + // Scalar tail. + const float w_tl = (1.0f - fx) * (1.0f - fy); + const float w_tr = fx * (1.0f - fy); + const float w_bl = (1.0f - fx) * fy; + const float w_br = fx * fy; + for (; c < C; ++c) { + const __fp16* p = input_n + c * spatial_in; + float v = 0.0f; + if (tl_v) + v += w_tl * static_cast(p[off_tl]); + if (tr_v) + v += w_tr * static_cast(p[off_tr]); + if (bl_v) + v += w_bl * static_cast(p[off_bl]); + if (br_v) + v += w_br * static_cast(p[off_br]); + output_n[c * spatial_out + out_off] = static_cast<__fp16>(v); + } +} + +} // namespace + +// Exposed entry point. Called by op_grid_sampler_2d.cpp's dispatcher only +// when cpuinfo_has_arm_neon_fp16() reports true. Input/output data are +// raw uint16_t buffers interpreted as __fp16; N/C/H/W/grid come pre- +// computed from the dispatcher. +void grid_sampler_2d_bilinear_fp16_hw( + const void* input, + const void* grid, + void* output, + int N, + int C, + int H_in, + int W_in, + int H_out, + int W_out, + bool align_corners) { + const __fp16* in = reinterpret_cast(input); + const __fp16* gd = reinterpret_cast(grid); + __fp16* out = reinterpret_cast<__fp16*>(output); + + const int spatial_in = H_in * W_in; + const int spatial_out = H_out * W_out; + + for (int n = 0; n < N; ++n) { + const __fp16* input_n = in + n * C * spatial_in; + __fp16* output_n = out + n * C * spatial_out; + const __fp16* grid_n = gd + n * H_out * W_out * 2; + + for (int h = 0; h < H_out; ++h) { + if (h + 1 < H_out) { + __builtin_prefetch(grid_n + (h + 1) * W_out * 2, 0, 1); + } + for (int w = 0; w < W_out; ++w) { + const int grid_off = (h * W_out + w) * 2; + float gx = static_cast(grid_n[grid_off]); + float gy = static_cast(grid_n[grid_off + 1]); + if (align_corners) { + gx = (gx + 1.0f) * (W_in - 1) * 0.5f; + gy = (gy + 1.0f) * (H_in - 1) * 0.5f; + } else { + gx = (gx + 1.0f) * W_in * 0.5f - 0.5f; + gy = (gy + 1.0f) * H_in * 0.5f - 0.5f; + } + bilinear_all_channels_fp16_hw_sample( + input_n, output_n, C, H_in, W_in, H_out, W_out, h, w, gx, gy); + } + } + } +} + +} // namespace opt_grid_sampler_2d_internal +} // namespace native +} // namespace executor +} // namespace torch + +#endif // __aarch64__ diff --git a/kernels/optimized/cpu/op_sum.cpp b/kernels/optimized/cpu/op_sum.cpp new file mode 100644 index 00000000000..059153120ab --- /dev/null +++ b/kernels/optimized/cpu/op_sum.cpp @@ -0,0 +1,203 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { + +using executorch::aten::ArrayRef; +using executorch::aten::ScalarType; +using executorch::aten::Tensor; + +// Forward decl of the portable kernel — used as a fallback for shapes and +// dtype combinations the optimized path doesn't specialize. Both libraries +// live in the same binary, so direct call is fine. +Tensor& sum_dim_out( + KernelRuntimeContext& ctx, + const Tensor& in, + std::optional> dim_list, + bool keepdim, + std::optional dtype, + Tensor& out); + +namespace { + +// Contiguous innermost reduction: sum each row of the inner axis into one +// scalar. fp16/bf16 accumulate in fp32 for precision; fp32 accumulates in +// fp32 directly. Uses at::vec::Vectorized for cross-arch SIMD. +template +inline void sum_innermost( + const CTYPE* in, + CTYPE* out, + int64_t outer_size, + int64_t reduce_size) { + using Vec = at::vec::Vectorized; + constexpr int64_t kVecSize = static_cast(Vec::size()); + for (int64_t i = 0; i < outer_size; ++i) { + const CTYPE* row = in + i * reduce_size; + Vec acc(0.0f); + int64_t j = 0; + for (; j + kVecSize - 1 < reduce_size; j += kVecSize) { + if constexpr (std::is_same_v) { + acc = acc + Vec::loadu(row + j); + } else { + // Half / BFloat16: load N elements, convert to float, add. + float tmp[kVecSize]; + for (int64_t k = 0; k < kVecSize; ++k) { + tmp[k] = static_cast(row[j + k]); + } + acc = acc + Vec::loadu(tmp); + } + } + float sum = + at::vec::vec_reduce_all([](Vec a, Vec b) { return a + b; }, acc); + for (; j < reduce_size; ++j) { + sum += static_cast(row[j]); + } + out[i] = static_cast(sum); + } +} + +// Non-innermost (strided) single-dim reduction. For each (outer, inner) pair, +// sum over reduce_size elements spaced `inner_size` apart. Vectorize across +// the contiguous inner axis (so each add-step processes kVecSize output +// positions at once). +template +inline void sum_strided( + const CTYPE* in, + CTYPE* out, + int64_t outer_size, + int64_t reduce_size, + int64_t inner_size) { + using Vec = at::vec::Vectorized; + constexpr int64_t kVecSize = static_cast(Vec::size()); + const int64_t outer_stride = reduce_size * inner_size; + for (int64_t o = 0; o < outer_size; ++o) { + const CTYPE* in_o = in + o * outer_stride; + CTYPE* out_o = out + o * inner_size; + int64_t j = 0; + for (; j + kVecSize - 1 < inner_size; j += kVecSize) { + Vec acc(0.0f); + for (int64_t k = 0; k < reduce_size; ++k) { + const CTYPE* p = in_o + k * inner_size + j; + if constexpr (std::is_same_v) { + acc = acc + Vec::loadu(p); + } else { + float tmp[kVecSize]; + for (int64_t m = 0; m < kVecSize; ++m) { + tmp[m] = static_cast(p[m]); + } + acc = acc + Vec::loadu(tmp); + } + } + if constexpr (std::is_same_v) { + acc.store(out_o + j); + } else { + float tmp[kVecSize]; + acc.store(tmp); + for (int64_t m = 0; m < kVecSize; ++m) { + out_o[j + m] = static_cast(tmp[m]); + } + } + } + for (; j < inner_size; ++j) { + float sum = 0.0f; + for (int64_t k = 0; k < reduce_size; ++k) { + sum += static_cast(in_o[k * inner_size + j]); + } + out_o[j] = static_cast(sum); + } + } +} + +} // namespace + +Tensor& opt_sum_dim_out( + KernelRuntimeContext& ctx, + const Tensor& in, + std::optional> dim_list, + bool keepdim, + std::optional dtype, + Tensor& out) { + ET_KERNEL_CHECK( + ctx, + check_reduction_args(in, dim_list, keepdim, dtype, out), + InvalidArgument, + out); + ET_KERNEL_CHECK( + ctx, + resize_reduction_out(in, dim_list, keepdim, out) == Error::Ok, + InvalidArgument, + out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + + if (in.numel() == 0) { + if (out.numel() > 0) { + std::memset(out.mutable_data_ptr(), 0, out.nbytes()); + } + return out; + } + + // Fast path: single reduction dim, matching dtype, non-complex, contiguous. + // Anything else falls through to the portable kernel. + const bool fast_eligible = dim_list.has_value() && + dim_list.value().size() == 1 && in.scalar_type() == out.scalar_type() && + !executorch::runtime::isComplexType(in.scalar_type()) && + tensor_is_contiguous(in); + + if (fast_eligible) { + const int64_t d = dim_list.value()[0] < 0 ? dim_list.value()[0] + in.dim() + : dim_list.value()[0]; + int64_t outer_size = 1, reduce_size = in.size(d), inner_size = 1; + for (int64_t i = 0; i < d; ++i) { + outer_size *= in.size(i); + } + for (int64_t i = d + 1; i < in.dim(); ++i) { + inner_size *= in.size(i); + } + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "sum.IntList_out"; + bool handled = false; + ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] { + const CTYPE* ip = in.const_data_ptr(); + CTYPE* op = out.mutable_data_ptr(); + if (inner_size == 1) { + sum_innermost(ip, op, outer_size, reduce_size); + handled = true; + } else { + sum_strided(ip, op, outer_size, reduce_size, inner_size); + handled = true; + } + }); + if (handled) { + return out; + } + } + + // Fallback. + return sum_dim_out(ctx, in, dim_list, keepdim, dtype, out); +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index 78bbecd9e2c..9da8d67ab38 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -76,6 +76,26 @@ def define_common_targets(): ], ) + # Hardware fp16 variant of grid_sampler_2d. Needs ARMv8.2-a+fp16 so it + # must be a separate translation unit — op_grid_sampler_2d.cpp (the + # runtime dispatcher) remains on plain ARMv8 and only calls into this + # after cpuinfo_has_arm_neon_fp16() reports true. Scoped compile flag + # stays local to this library. Named without the "op_" prefix so the + # op_registration_util dependency check (which forbids op_target -> + # op_target edges) still lets op_grid_sampler_2d depend on it. + runtime.cxx_library( + name = "grid_sampler_2d_fp16_hw_impl", + srcs = ["op_grid_sampler_2d_fp16_hw.cpp"], + visibility = ["PUBLIC"], + compiler_flags = select({ + "DEFAULT": [], + "ovr_config//cpu:arm64": ["-march=armv8.2-a+fp16"], + }), + exported_deps = [ + "//executorch/runtime/kernel:kernel_includes", + ], + ) + # Used for dtype selective build. Collect source and header files. runtime.filegroup( name = "optimized_source_files", diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index 58121549ea5..5a001afc7a0 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -57,6 +57,11 @@ - arg_meta: null kernel_name: torch::executor::opt_gelu_out +- op: grid_sampler_2d.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_grid_sampler_2d_out + - op: le.Scalar_out kernels: - arg_meta: null @@ -97,6 +102,11 @@ - arg_meta: null kernel_name: torch::executor::opt_sub_out +- op: sum.IntList_out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_sum_dim_out + - op: sub.Scalar_out kernels: - arg_meta: null diff --git a/shim_et/xplat/executorch/build/build_variables.bzl b/shim_et/xplat/executorch/build/build_variables.bzl index edddc1da916..a4329155355 100644 --- a/shim_et/xplat/executorch/build/build_variables.bzl +++ b/shim_et/xplat/executorch/build/build_variables.bzl @@ -267,6 +267,8 @@ OPTIMIZED_KERNELS_SRCS = [ "kernels/optimized/cpu/op_fft_c2r.cpp", "kernels/optimized/cpu/op_fft_r2c.cpp", "kernels/optimized/cpu/op_gelu.cpp", + "kernels/optimized/cpu/op_grid_sampler_2d.cpp", + "kernels/optimized/cpu/op_grid_sampler_2d_fp16_hw.cpp", "kernels/optimized/cpu/op_le.cpp", "kernels/optimized/cpu/op_linear.cpp", "kernels/optimized/cpu/op_log_softmax.cpp", @@ -274,6 +276,7 @@ OPTIMIZED_KERNELS_SRCS = [ "kernels/optimized/cpu/op_mul.cpp", "kernels/optimized/cpu/op_native_layer_norm.cpp", "kernels/optimized/cpu/op_sub.cpp", + "kernels/optimized/cpu/op_sum.cpp", "kernels/optimized/cpu/op_where.cpp", ] diff --git a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl index bc43688b04e..fe77affcf36 100644 --- a/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl +++ b/shim_et/xplat/executorch/kernels/optimized/op_registration_util.bzl @@ -217,6 +217,21 @@ OPTIMIZED_ATEN_OPS = ( "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", ], ), + op_target( + name = "op_grid_sampler_2d", + deps = [ + "//executorch/kernels/portable/cpu:op_grid_sampler_2d", + # Hardware fp16 path lives in a separate translation unit so the + # ARMv8.2-a+fp16 compile flag can be scoped locally. A runtime + # cpuinfo_has_arm_neon_fp16() check in op_grid_sampler_2d.cpp + # picks between it and the software-convert fp16 path. Named + # without the "op_" prefix so _enforce_deps doesn't reject it + # as an op_target-to-op_target edge. + ":grid_sampler_2d_fp16_hw_impl", + "fbsource//third-party/cpuinfo:cpuinfo", + "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", + ], + ), op_target( name = "op_le", deps = [ @@ -282,6 +297,13 @@ OPTIMIZED_ATEN_OPS = ( "//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch", ], ), + op_target( + name = "op_sum", + deps = [ + "//executorch/kernels/portable/cpu:op_sum", + "//executorch/kernels/portable/cpu/util:reduce_util", + ], + ), op_target( name = "op_where", deps = [