diff --git a/.github/workflows/ci-platform-generic.yml b/.github/workflows/ci-platform-generic.yml index 321bed6998..fb39a9bd53 100644 --- a/.github/workflows/ci-platform-generic.yml +++ b/.github/workflows/ci-platform-generic.yml @@ -73,6 +73,10 @@ jobs: testFloatSoftmax testFloatTranspose testFloatMul + testFloatPowScalar + testFloatPowVector + testFloatSqrt + testFloatRMSNorm Quant Dequant QuantizedLinear diff --git a/.gitignore b/.gitignore index 1115bfa5ba..e0e99b33ba 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,8 @@ package-lock.json .mypy_cache node_modules +.venv/* + compile_commands.json docs/_autosummary diff --git a/CHANGELOG.md b/CHANGELOG.md index a35473c98e..821dbaec51 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid ## Unreleased (Planned Release Target: v0.2.1) ### List of Pull Requests +- Support for RMSNorm (Pow and Sqrt operators) [#136](https://github.com/pulp-platform/Deeploy/pull/136) - Demo TinyViT compatibility with tiled Siracusa [#124](https://github.com/pulp-platform/Deeploy/pull/124) - TinyViT on non-tiled Siracusa [#117](https://github.com/pulp-platform/Deeploy/pull/117) - Support Fully Asynchronous DMAs [#114](https://github.com/pulp-platform/Deeploy/pull/114) @@ -26,6 +27,8 @@ This file contains the changelog for the Deeploy project. The changelog is divid - Fix bias hoisting in generic GEMM with no bias [#126](https://github.com/pulp-platform/Deeploy/pull/126) ### Added +- Support for RMSNorm operation via operator decomposition. +- Added `Pow` (Power) and `Sqrt` (Square Root) operation support (Parsers, Layers, Bindings, Templates, and FP32 Kernels) for the Generic platform. - Support for input tiling for PULP FP regular and DW conv 2D. - CI tests for tiled Siracusa FP regular and DW conv 2D, with and without bias, for skip connections, and for the demo version of TinyViT. - Documentation for PULP FP regular and DW conv 2D and MatMul tile constraints. diff --git a/Deeploy/Targets/Generic/Bindings.py b/Deeploy/Targets/Generic/Bindings.py index 6bfe805b39..1807864dfc 100644 --- a/Deeploy/Targets/Generic/Bindings.py +++ b/Deeploy/Targets/Generic/Bindings.py @@ -15,11 +15,11 @@ ConvTransposeTemplate, DebugPrintTemplate, DequantTemplate, DummyTemplate, DWConvTemplate, FloatAddTemplate, \ FloatConvTemplate, FloatDivTemplate, FloatDWConvTemplate, FloatGELUTemplate, FloatGemmTemplate, \ FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, FloatPadTemplate, \ - FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, GatherTemplate, GemmTemplate, \ - IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, MaxPoolTemplate, MulTemplate, \ - PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, RequantShiftTemplate, ReshapeTemplate, \ - RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, iGELUTemplate, iLayernormTemplate, \ - iRMSNormTemplate, iSoftmaxTemplate + FloatPowTemplate, FloatReduceMeanTemplate, FloatReluTemplate, FloatSoftmaxTemplate, FloatSqrtTemplate, \ + GatherTemplate, GemmTemplate, IntegerDivTemplate, ITAMaxTemplate, ITAPartialMaxTemplate, MatMulTemplate, \ + MaxPoolTemplate, MulTemplate, PadTemplate, QuantTemplate, ReduceMeanTemplate, ReduceSumTemplate, \ + RequantShiftTemplate, ReshapeTemplate, RQIntegerDivTemplate, RQSiGELUTemplate, SliceTemplate, TransposeTemplate, \ + iGELUTemplate, iLayernormTemplate, iRMSNormTemplate, iSoftmaxTemplate from Deeploy.Targets.Generic.TypeCheckers import AddChecker, BatchNormChecker, ConcatChecker, ConvChecker, \ DebugPrintChecker, DequantChecker, DivChecker, DummyChecker, GatherChecker, GELUChecker, GEMMChecker, \ LayerNormChecker, MatMulChecker, MaxPoolChecker, MulChecker, PadChecker, QuantChecker, ReduceMeanChecker, \ @@ -118,6 +118,16 @@ BasicTransformer) ] +BasicPowBindings = [ + NodeBinding(DummyChecker([PointerClass(float32_t), PointerClass(float32_t)], [PointerClass(float32_t)]), + FloatPowTemplate.referenceTemplate, BasicTransformer), +] + +BasicSqrtBindings = [ + NodeBinding(DummyChecker([PointerClass(float32_t)], [PointerClass(float32_t)]), FloatSqrtTemplate.referenceTemplate, + BasicTransformer), +] + BasicDivBindings = [ NodeBinding(DivChecker([PointerClass(int32_t), PointerClass(int32_t)], [PointerClass(int32_t)]), IntegerDivTemplate.referenceTemplate, BasicTransformer) diff --git a/Deeploy/Targets/Generic/Layers.py b/Deeploy/Targets/Generic/Layers.py index c924895c13..c61b3eb9a1 100644 --- a/Deeploy/Targets/Generic/Layers.py +++ b/Deeploy/Targets/Generic/Layers.py @@ -227,6 +227,18 @@ def computeOps(self): return matmul + rqs +class PowLayer(ONNXLayer): + + def __init__(self, maps: List[NodeMapper]): + super().__init__(maps) + + +class SqrtLayer(ONNXLayer): + + def __init__(self, maps: List[NodeMapper]): + super().__init__(maps) + + class DivLayer(ONNXLayer): def __init__(self, maps: List[NodeMapper]): diff --git a/Deeploy/Targets/Generic/Parsers.py b/Deeploy/Targets/Generic/Parsers.py index f63bb5411d..bc69e64dae 100644 --- a/Deeploy/Targets/Generic/Parsers.py +++ b/Deeploy/Targets/Generic/Parsers.py @@ -8,7 +8,7 @@ import numpy as np import onnx_graphsurgeon as gs -from Deeploy.DeeployTypes import NetworkContext, NodeParser, VariableBuffer +from Deeploy.DeeployTypes import ConstantBuffer, NetworkContext, NodeParser, VariableBuffer class ConcatParser(NodeParser): @@ -1964,6 +1964,32 @@ def parseNodeCtxt(self, return ctxt, True +class PowParser(NodeParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + return node.op == 'Pow' and len(node.inputs) == 2 and len(node.outputs) == 1 + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + + # Lookup both inputs (data and exponent) + data_in = ctxt.lookup(node.inputs[0].name) + exponent_tensor = ctxt.lookup(node.inputs[1].name) + data_out = ctxt.lookup(node.outputs[0].name) + + self.operatorRepresentation['data_in'] = data_in.name + self.operatorRepresentation['exponent'] = exponent_tensor.name + self.operatorRepresentation['data_out'] = data_out.name + self.operatorRepresentation['size'] = int(np.prod(data_in.shape)) + + return ctxt, True + + class DivParser(NodeParser): def __init__(self): @@ -2747,3 +2773,26 @@ def parseNodeCtxt(self, "ch_im_out"] * self.operatorRepresentation["dim_im_out_y"] return newCtxt, True return ctxt, False + + +class SqrtParser(NodeParser): + + def __init__(self): + super().__init__() + + def parseNode(self, node: gs.Node) -> bool: + return node.op == 'Sqrt' and len(node.inputs) == 1 and len(node.outputs) == 1 + + def parseNodeCtxt(self, + ctxt: NetworkContext, + node: gs.Node, + channels_first: bool = True) -> Tuple[NetworkContext, bool]: + + data_in = ctxt.lookup(node.inputs[0].name) + data_out = ctxt.lookup(node.outputs[0].name) + + self.operatorRepresentation['data_in'] = data_in.name + self.operatorRepresentation['data_out'] = data_out.name + self.operatorRepresentation['size'] = int(np.prod(data_in.shape)) + + return ctxt, True diff --git a/Deeploy/Targets/Generic/Platform.py b/Deeploy/Targets/Generic/Platform.py index a15b3db2e6..785d932776 100644 --- a/Deeploy/Targets/Generic/Platform.py +++ b/Deeploy/Targets/Generic/Platform.py @@ -11,21 +11,22 @@ BasicDequantBindings, BasicDivBindings, BasicDWConv1DBinding, BasicDWConv2DBindings, BasicGatherBindings, \ BasicGELUBindings, BasicGEMMBindings, BasicITAPartialSoftmaxBinding, BasicITASoftmaxBinding, \ BasicLayerNormBindings, BasicMatMulBindings, BasicMaxPool1DBindings, BasicMaxPool2DBindings, BasicMulBindings, \ - BasicPad1DBindings, BasicPad2DBindings, BasicQuantBindings, BasicReduceMeanBindings, BasicReduceSumBindings, \ - BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, BasicRQSGELUBinding, \ - BasicSliceBindings, BasicSoftmaxBindings, BasicTransposeBindings, DummyBinding + BasicPad1DBindings, BasicPad2DBindings, BasicPowBindings, BasicQuantBindings, BasicReduceMeanBindings, \ + BasicReduceSumBindings, BasicReluBinding, BasicReshapeBindings, BasicRQIntegerDivBinding, BasicRQSBindings, \ + BasicRQSGELUBinding, BasicSliceBindings, BasicSoftmaxBindings, BasicSqrtBindings, BasicTransposeBindings, \ + DummyBinding from Deeploy.Targets.Generic.Layers import AddLayer, BatchNormalizationLayer, ConcatLayer, ConvLayer, \ ConvTransposeLayer, DebugPrintLayer, DequantLayer, DivLayer, GatherLayer, GELULayer, GEMMLayer, ITAMaxLayer, \ - LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, QuantLayer, ReduceMeanLayer, ReduceSumLayer, \ - ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, SoftmaxLayer, \ - TransposeLayer + LayerNormLayer, MatMulLayer, MaxPoolLayer, MulLayer, PadLayer, PowLayer, QuantLayer, ReduceMeanLayer, \ + ReduceSumLayer, ReluLayer, RequantShiftLayer, ReshapeLayer, RQIntegerDivLayer, RQSiGELULayer, SliceLayer, \ + SoftmaxLayer, SqrtLayer, TransposeLayer from Deeploy.Targets.Generic.Parsers import AddParser, BatchNormParser, ConcatParser, ConvTranspose1DParser, \ DebugParser, DequantParser, DivParser, DummyParser, FlattenParser, GatherParser, GELUParser, GenericConv1DParser, \ GenericConv2DParser, GenericDWConv1DParser, GenericDWConv2DParser, GenericGEMMParser, GenericMaxPool2DParser, \ IntegerDivParser, ITAMaxParser, ITAPartialMaxParser, LayerNormParser, MatMulParser, MaxPool1DParser, MulParser, \ - Pad1DParser, Pad2DParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, RequantShiftParser, \ - ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, SoftmaxParser, TransposeParser, UnsqueezeParser, \ - iLayerNormParser, iSoftmaxParser + Pad1DParser, Pad2DParser, PowParser, QuantParser, ReduceMeanParser, ReduceSumParser, ReluParser, \ + RequantShiftParser, ReshapeParser, RQIntegerDivParser, RQSiGELUParser, SliceParser, SoftmaxParser, SqrtParser, \ + TransposeParser, UnsqueezeParser, iLayerNormParser, iSoftmaxParser from Deeploy.Targets.Generic.Templates import AllocateTemplate, FreeTemplate from Deeploy.Targets.Generic.TopologyOptimizationPasses.Passes import DequantPatternPass, ExtractPaddingFromConvPass, \ ExtractPaddingFromPoolPass, MatMulAddMergePass, MergeConstAddAndRequantPass, QuantPatternPass, \ @@ -52,6 +53,8 @@ MaxPoolMapper = NodeMapper(GenericMaxPool2DParser(), BasicMaxPool2DBindings) MaxPool1DMapper = NodeMapper(MaxPool1DParser(), BasicMaxPool1DBindings) MulMapper = NodeMapper(MulParser(), BasicMulBindings) +PowMapper = NodeMapper(PowParser(), BasicPowBindings) +SqrtMapper = NodeMapper(SqrtParser(), BasicSqrtBindings) Pad1DMapper = NodeMapper(Pad1DParser(), BasicPad1DBindings) Pad2DMapper = NodeMapper(Pad2DParser(), BasicPad2DBindings) ReduceMeanMapper = NodeMapper(ReduceMeanParser(), BasicReduceMeanBindings) @@ -98,6 +101,8 @@ 'MatMulInteger': MatMulLayer([MatMulMapper]), 'MaxPool': MaxPoolLayer([MaxPool1DMapper, MaxPoolMapper]), 'Mul': MulLayer([MulMapper]), + 'Pow': PowLayer([PowMapper]), + 'Sqrt': SqrtLayer([SqrtMapper]), 'Pad': PadLayer([Pad1DMapper, Pad2DMapper]), 'ReduceMean': ReduceMeanLayer([ReduceMeanMapper]), 'ReduceSum': ReduceSumLayer([ReduceSumMapper]), diff --git a/Deeploy/Targets/Generic/Templates/FloatPowTemplate.py b/Deeploy/Targets/Generic/Templates/FloatPowTemplate.py new file mode 100644 index 0000000000..83d177cc39 --- /dev/null +++ b/Deeploy/Targets/Generic/Templates/FloatPowTemplate.py @@ -0,0 +1,59 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, List, Tuple + +import numpy as np + +from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation + + +class _PowTemplate(NodeTemplate): + + def alignToContext(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + # Get input and output tensors + data_in = ctxt.lookup(operatorRepresentation['data_in']) + exponent = ctxt.lookup(operatorRepresentation['exponent']) + data_out = ctxt.lookup(operatorRepresentation['data_out']) + + # Get data type (fp32) + data_type = data_in._type.typeName + operatorRepresentation['data_type'] = data_type + + # Get type width dynamically (e.g., 32, 64) + type_width = data_in._type.referencedType.typeWidth + operatorRepresentation['type_width'] = type_width + + # Calculate size + input_size = int(np.prod(data_in.shape)) + exponent_size = int(np.prod(exponent.shape)) + operatorRepresentation['size'] = input_size + + # Check if exponent is scalar (broadcasting) + if exponent_size == 1: + operatorRepresentation['is_scalar'] = True + # Get the full variable name with prefix + exponent_name = operatorRepresentation['exponent'] + operatorRepresentation['exponent_scalar'] = f"DeeployNetwork_{exponent_name}[0]" + else: + # Since currently the kernel only supports equally sized base-exponent data, + # for non-scalar, let's add a size check here (length of data_in should be equal to exponent length). + if input_size != exponent_size: + raise ValueError(f"Pow operator mismatch: input size ({input_size}) " + f"must equal exponent size ({exponent_size}) for non-scalar exponents.") + + operatorRepresentation['is_scalar'] = False + operatorRepresentation['exponent_scalar'] = "NULL" + + return ctxt, operatorRepresentation, [] + + +referenceTemplate = _PowTemplate(""" +// Pow (Name: ${nodeName}, Op: ${nodeOp}) +% if is_scalar: +Pow_fp${type_width}_scalar_fp${type_width}(${data_in}, ${exponent_scalar}, ${data_out}, ${size}); +% else: +Pow_fp${type_width}_fp${type_width}_fp${type_width}(${data_in}, ${exponent}, ${data_out}, ${size}); +% endif +""") diff --git a/Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py b/Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py new file mode 100644 index 0000000000..99d7ba0475 --- /dev/null +++ b/Deeploy/Targets/Generic/Templates/FloatSqrtTemplate.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, List, Tuple + +import numpy as np + +from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation + + +class _SqrtTemplate(NodeTemplate): + + def alignToContext(self, ctxt: NetworkContext, + operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]: + # Get input and output tensors + data_in = ctxt.lookup(operatorRepresentation['data_in']) + data_out = ctxt.lookup(operatorRepresentation['data_out']) + + # Get data type (fp32) + data_type = data_in._type.typeName + operatorRepresentation['data_type'] = data_type + + type_width = data_in._type.referencedType.typeWidth + operatorRepresentation['type_width'] = type_width + + # Calculate size + operatorRepresentation['size'] = int(np.prod(data_in.shape)) + + return ctxt, operatorRepresentation, [] + + +referenceTemplate = _SqrtTemplate(""" +// Sqrt (Name: ${nodeName}, Op: ${nodeOp}) +Sqrt_fp${type_width}_fp${type_width}(${data_in}, ${data_out}, ${size}); +""") diff --git a/DeeployTest/Tests/testFloatPowScalar/inputs.npz b/DeeployTest/Tests/testFloatPowScalar/inputs.npz new file mode 100644 index 0000000000..1a1fbacabc Binary files /dev/null and b/DeeployTest/Tests/testFloatPowScalar/inputs.npz differ diff --git a/DeeployTest/Tests/testFloatPowScalar/network.onnx b/DeeployTest/Tests/testFloatPowScalar/network.onnx new file mode 100644 index 0000000000..50701ff540 Binary files /dev/null and b/DeeployTest/Tests/testFloatPowScalar/network.onnx differ diff --git a/DeeployTest/Tests/testFloatPowScalar/outputs.npz b/DeeployTest/Tests/testFloatPowScalar/outputs.npz new file mode 100644 index 0000000000..8e99e82420 Binary files /dev/null and b/DeeployTest/Tests/testFloatPowScalar/outputs.npz differ diff --git a/DeeployTest/Tests/testFloatPowVector/inputs.npz b/DeeployTest/Tests/testFloatPowVector/inputs.npz new file mode 100644 index 0000000000..fe8942ac67 Binary files /dev/null and b/DeeployTest/Tests/testFloatPowVector/inputs.npz differ diff --git a/DeeployTest/Tests/testFloatPowVector/network.onnx b/DeeployTest/Tests/testFloatPowVector/network.onnx new file mode 100644 index 0000000000..91be88483c --- /dev/null +++ b/DeeployTest/Tests/testFloatPowVector/network.onnx @@ -0,0 +1,23 @@ + +deeploy_test_generator:· +3 +data_in +exponentdata_outPow_Vector_Test"Powtest_float_pow_vectorZ! +data_in + + + + +Z" +exponent + + + + +b" +data_out + + + + +B \ No newline at end of file diff --git a/DeeployTest/Tests/testFloatPowVector/outputs.npz b/DeeployTest/Tests/testFloatPowVector/outputs.npz new file mode 100644 index 0000000000..ebe9468d52 Binary files /dev/null and b/DeeployTest/Tests/testFloatPowVector/outputs.npz differ diff --git a/DeeployTest/Tests/testFloatRMSNorm/inputs.npz b/DeeployTest/Tests/testFloatRMSNorm/inputs.npz new file mode 100644 index 0000000000..60df101e2e Binary files /dev/null and b/DeeployTest/Tests/testFloatRMSNorm/inputs.npz differ diff --git a/DeeployTest/Tests/testFloatRMSNorm/network.onnx b/DeeployTest/Tests/testFloatRMSNorm/network.onnx new file mode 100644 index 0000000000..906e25d254 Binary files /dev/null and b/DeeployTest/Tests/testFloatRMSNorm/network.onnx differ diff --git a/DeeployTest/Tests/testFloatRMSNorm/outputs.npz b/DeeployTest/Tests/testFloatRMSNorm/outputs.npz new file mode 100644 index 0000000000..eb8c1c4942 Binary files /dev/null and b/DeeployTest/Tests/testFloatRMSNorm/outputs.npz differ diff --git a/DeeployTest/Tests/testFloatSqrt/inputs.npz b/DeeployTest/Tests/testFloatSqrt/inputs.npz new file mode 100644 index 0000000000..c54577cb24 Binary files /dev/null and b/DeeployTest/Tests/testFloatSqrt/inputs.npz differ diff --git a/DeeployTest/Tests/testFloatSqrt/network.onnx b/DeeployTest/Tests/testFloatSqrt/network.onnx new file mode 100644 index 0000000000..c2f27907fa Binary files /dev/null and b/DeeployTest/Tests/testFloatSqrt/network.onnx differ diff --git a/DeeployTest/Tests/testFloatSqrt/outputs.npz b/DeeployTest/Tests/testFloatSqrt/outputs.npz new file mode 100644 index 0000000000..f6d42c73a1 Binary files /dev/null and b/DeeployTest/Tests/testFloatSqrt/outputs.npz differ diff --git a/TargetLibraries/Generic/inc/DeeployBasicMath.h b/TargetLibraries/Generic/inc/DeeployBasicMath.h index 288cb419ac..4fbbd00bf8 100644 --- a/TargetLibraries/Generic/inc/DeeployBasicMath.h +++ b/TargetLibraries/Generic/inc/DeeployBasicMath.h @@ -44,6 +44,7 @@ #include "kernel/MatMul.h" #include "kernel/MaxPool.h" #include "kernel/MaxPool1d.h" +#include "kernel/Pow.h" #include "kernel/RMSNorm.h" #include "kernel/RQDiv.h" #include "kernel/RQGELU.h" @@ -51,5 +52,6 @@ #include "kernel/Relu.h" #include "kernel/RequantShift.h" #include "kernel/Softmax.h" +#include "kernel/Sqrt.h" #endif //__DEEPLOY_BASIC_MATH_HEADER_ diff --git a/TargetLibraries/Generic/inc/kernel/Pow.h b/TargetLibraries/Generic/inc/kernel/Pow.h new file mode 100644 index 0000000000..f1d64859ed --- /dev/null +++ b/TargetLibraries/Generic/inc/kernel/Pow.h @@ -0,0 +1,24 @@ +/* + * SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna + * + * SPDX-License-Identifier: Apache-2.0 + */ + +/* + * This file implements the element-wise binary power operation. + */ + +#ifndef __DEEPLOY_MATH_POW_KERNEL_HEADER_ +#define __DEEPLOY_MATH_POW_KERNEL_HEADER_ + +#include "DeeployBasicMath.h" + +void Pow_fp32_fp32_fp32(const float32_t *__restrict__ data_in, + const float32_t *__restrict__ exponent, + float32_t *__restrict__ data_out, int32_t size); + +void Pow_fp32_scalar_fp32(const float32_t *__restrict__ data_in, + float32_t exponent, float32_t *__restrict__ data_out, + int32_t size); + +#endif diff --git a/TargetLibraries/Generic/inc/kernel/Sqrt.h b/TargetLibraries/Generic/inc/kernel/Sqrt.h new file mode 100644 index 0000000000..2c14e43bd3 --- /dev/null +++ b/TargetLibraries/Generic/inc/kernel/Sqrt.h @@ -0,0 +1,22 @@ +/* + * SPDX-FileCopyrightText: 2020 ETH Zurich and University of Bologna + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#ifndef __DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_ +#define __DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_ + +#include "DeeployBasicMath.h" + +/* + * Square root operation - computes sqrt for each element + */ + +/******************************************************************************/ +/* Sqrt */ +/******************************************************************************/ + +void Sqrt_fp32_fp32(float32_t *data_in, float32_t *data_out, int32_t size); + +#endif //__DEEPLOY_BASIC_MATH_SQRT_KERNEL_HEADER_ diff --git a/TargetLibraries/Generic/src/Pow_fp32.c b/TargetLibraries/Generic/src/Pow_fp32.c new file mode 100644 index 0000000000..89c07c6bda --- /dev/null +++ b/TargetLibraries/Generic/src/Pow_fp32.c @@ -0,0 +1,24 @@ +/* + * SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "DeeployBasicMath.h" +#include + +void Pow_fp32_fp32_fp32(const float32_t *__restrict__ data_in, + const float32_t *__restrict__ exponent, + float32_t *__restrict__ data_out, int32_t size) { + for (int i = 0; i < size; i++) { + data_out[i] = powf(data_in[i], exponent[i]); + } +} + +void Pow_fp32_scalar_fp32(const float32_t *__restrict__ data_in, + float32_t exponent, float32_t *__restrict__ data_out, + int32_t size) { + for (int i = 0; i < size; i++) { + data_out[i] = powf(data_in[i], exponent); + } +} diff --git a/TargetLibraries/Generic/src/Sqrt_fp32.c b/TargetLibraries/Generic/src/Sqrt_fp32.c new file mode 100644 index 0000000000..06327fda4e --- /dev/null +++ b/TargetLibraries/Generic/src/Sqrt_fp32.c @@ -0,0 +1,13 @@ +/* + * SPDX-FileCopyrightText: 2025 ETH Zurich and University of Bologna + * + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "DeeployBasicMath.h" + +void Sqrt_fp32_fp32(float32_t *data_in, float32_t *data_out, int32_t size) { + for (int i = 0; i < size; i++) { + data_out[i] = sqrtf(data_in[i]); + } +}