From fa8c826835e44165436102852cbacf9f8632d73d Mon Sep 17 00:00:00 2001 From: Siddhartha Menon Date: Tue, 12 May 2026 11:46:08 +0100 Subject: [PATCH] feat: Expose accumulation mode flag via Conv2dInfo Users of the functional and experimental operator convolution APIs, e.g, arm_compute::NEGEMMConv2d, or arm_compute::experimental::op::CpuGemmDirectConv2d, can make use of fp32 accumulation by setting this flag in Conv2dInfo during the validate() and configure() steps. Commit 5e40456e changed the default behaviour of CpuGemmDirectConv2d to accumulate in f32 unless enable_fast_math was set. However, this can produce regressions for users expecting the old behaviour. This change exposes the flag to user directly, making fp32 accumulation opt-in. Change-Id: I3203bdbbfa5152a64438941dd138bab6feb1cec2 Signed-off-by: Siddhartha Menon --- arm_compute/runtime/FunctionDescriptors.h | 9 ++++++--- src/cpu/operators/CpuGemmDirectConv2d.cpp | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/arm_compute/runtime/FunctionDescriptors.h b/arm_compute/runtime/FunctionDescriptors.h index 4691b059fb..2eccb01fa6 100644 --- a/arm_compute/runtime/FunctionDescriptors.h +++ b/arm_compute/runtime/FunctionDescriptors.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2023, 2025 Arm Limited. + * Copyright (c) 2019-2023, 2025-2026 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -67,13 +67,15 @@ struct Conv2dInfo const ActivationLayerInfo &act_info, bool enable_fast_math, unsigned int num_groups, - const WeightsInfo &weights_info = WeightsInfo()) + const WeightsInfo &weights_info = WeightsInfo(), + bool use_fp32_acc = false) : conv_info(conv_info), dilation(dilation), act_info(act_info), enable_fast_math(enable_fast_math), num_groups(num_groups), - weights_info(weights_info) + weights_info(weights_info), + use_fp32_acc(use_fp32_acc) { } @@ -83,6 +85,7 @@ struct Conv2dInfo bool enable_fast_math{false}; unsigned int num_groups{1}; WeightsInfo weights_info{}; + bool use_fp32_acc{false}; }; /** Descriptor used by the 3d Convolution function */ diff --git a/src/cpu/operators/CpuGemmDirectConv2d.cpp b/src/cpu/operators/CpuGemmDirectConv2d.cpp index e1f8225a41..caac7761e2 100644 --- a/src/cpu/operators/CpuGemmDirectConv2d.cpp +++ b/src/cpu/operators/CpuGemmDirectConv2d.cpp @@ -93,7 +93,7 @@ cpu::AsmGemmInfo init_assembly_metadata(const Conv2dInfo &info, bool is_indirect asm_info.fast_mode = info.enable_fast_math; asm_info.fixed_format = info.weights_info.weight_format() != WeightFormat::UNSPECIFIED; asm_info.weight_format = info.weights_info.weight_format(); - asm_info.use_fp32_acc = !info.enable_fast_math; + asm_info.use_fp32_acc = info.use_fp32_acc; return asm_info; } } // namespace