Skip to content

Commit b16668c

Browse files
committed
Add a function with SYCL kernel for divmod
1 parent fc95d01 commit b16668c

File tree

1 file changed

+120
-0
lines changed
  • dpnp/backend/kernels/elementwise_functions

1 file changed

+120
-0
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
//*****************************************************************************
2+
// Copyright (c) 2025, Intel Corporation
3+
// All rights reserved.
4+
//
5+
// Redistribution and use in source and binary forms, with or without
6+
// modification, are permitted provided that the following conditions are met:
7+
// - Redistributions of source code must retain the above copyright notice,
8+
// this list of conditions and the following disclaimer.
9+
// - Redistributions in binary form must reproduce the above copyright notice,
10+
// this list of conditions and the following disclaimer in the documentation
11+
// and/or other materials provided with the distribution.
12+
// - Neither the name of the copyright holder nor the names of its contributors
13+
// may be used to endorse or promote products derived from this software
14+
// without specific prior written permission.
15+
//
16+
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
// THE POSSIBILITY OF SUCH DAMAGE.
27+
//*****************************************************************************
28+
29+
#pragma once
30+
31+
#include <sycl/sycl.hpp>
32+
33+
namespace dpnp::kernels::divmod
34+
{
35+
template <typename argT1, typename argT2, typename divT, typename modT>
36+
struct DivmodFunctor
37+
{
38+
using argT = argT1;
39+
40+
static_assert(std::is_same_v<argT, argT2>,
41+
"Input types are expected to be the same");
42+
static_assert(std::is_integral_v<argT> || std::is_floating_point_v<argT> ||
43+
std::is_same_v<argT, sycl::half>,
44+
"Input types are expected to be integral or floating");
45+
46+
using supports_vec = typename std::false_type;
47+
using supports_sg_loadstore = typename std::true_type;
48+
49+
divT operator()(const argT &in1, const argT &in2, modT &mod) const
50+
{
51+
if constexpr (std::is_integral_v<argT>) {
52+
if (in2 == argT(0)) {
53+
mod = modT(0);
54+
return divT(0);
55+
}
56+
57+
if constexpr (std::is_signed_v<argT>) {
58+
if ((in1 == std::numeric_limits<argT>::min()) &&
59+
(in2 == argT(-1))) {
60+
mod = modT(0);
61+
return std::numeric_limits<argT>::min();
62+
}
63+
}
64+
65+
divT div = in1 / in2;
66+
mod = in1 % in2;
67+
68+
if constexpr (std::is_signed_v<argT>) {
69+
if (l_xor(in1 > 0, in2 > 0) && (mod != 0)) {
70+
div -= divT(1);
71+
mod += in2;
72+
}
73+
}
74+
return div;
75+
}
76+
else {
77+
mod = sycl::fmod(in1, in2);
78+
if (!in2) {
79+
// in2 == 0 (not NaN): return result of fmod (for IEEE is nan)
80+
return in1 / in2;
81+
}
82+
83+
// (in1 - mod) should be very nearly an integer multiple of in2
84+
auto div = (in1 - mod) / in2;
85+
86+
// adjust fmod result to conform to Python convention of remainder
87+
if (mod) {
88+
if (l_xor(in2 < 0, mod < 0)) {
89+
mod += in2;
90+
div -= divT(1.0);
91+
}
92+
}
93+
else {
94+
// if mod is zero ensure correct sign
95+
mod = sycl::copysign(modT(0), in2);
96+
}
97+
98+
// snap quotient to nearest integral value
99+
if (div) {
100+
auto floordiv = sycl::floor(div);
101+
if (div - floordiv > divT(0.5)) {
102+
floordiv += divT(1.0);
103+
}
104+
div = floordiv;
105+
}
106+
else {
107+
// if div is zero ensure correct sign
108+
div = sycl::copysign(divT(0), in1 / in2);
109+
}
110+
return div;
111+
}
112+
}
113+
114+
private:
115+
bool l_xor(bool b1, bool b2) const
116+
{
117+
return (b1 != b2);
118+
}
119+
};
120+
} // namespace dpnp::kernels::divmod

0 commit comments

Comments
 (0)