Skip to content

Commit fb9278c

Browse files
committed
Fix decimal floor/ceil (#10365)
1 parent c6dc5fc commit fb9278c

File tree

4 files changed

+688
-74
lines changed

4 files changed

+688
-74
lines changed

dbms/src/Functions/FunctionsRound.h

Lines changed: 85 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -199,77 +199,6 @@ enum class RoundingMode
199199
#endif
200200
};
201201

202-
/** Rounding functions for decimal values
203-
*/
204-
205-
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
206-
struct DecimalRoundingComputation
207-
{
208-
static_assert(IsDecimal<T>);
209-
static const size_t data_count = 1;
210-
static size_t prepare(size_t scale) { return scale; }
211-
// compute need decimal_scale to interpret decimals
212-
static inline void compute(
213-
const T * __restrict in,
214-
size_t scale,
215-
OutputType * __restrict out,
216-
ScaleType decimal_scale)
217-
{
218-
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
219-
Float64 val = in->template toFloat<Float64>(decimal_scale);
220-
221-
if constexpr (scale_mode == ScaleMode::Positive)
222-
{
223-
val = val * scale;
224-
}
225-
else if constexpr (scale_mode == ScaleMode::Negative)
226-
{
227-
val = val / scale;
228-
}
229-
230-
if constexpr (rounding_mode == RoundingMode::Round)
231-
{
232-
val = round(val);
233-
}
234-
else if constexpr (rounding_mode == RoundingMode::Floor)
235-
{
236-
val = floor(val);
237-
}
238-
else if constexpr (rounding_mode == RoundingMode::Ceil)
239-
{
240-
val = ceil(val);
241-
}
242-
else if constexpr (rounding_mode == RoundingMode::Trunc)
243-
{
244-
val = trunc(val);
245-
}
246-
247-
248-
if constexpr (scale_mode == ScaleMode::Positive)
249-
{
250-
val = val / scale;
251-
}
252-
else if constexpr (scale_mode == ScaleMode::Negative)
253-
{
254-
val = val * scale;
255-
}
256-
257-
if constexpr (std::is_same_v<T, OutputType>)
258-
{
259-
*out = ToDecimal<Float64, T>(val, decimal_scale);
260-
}
261-
else if constexpr (std::is_same_v<OutputType, Int64>)
262-
{
263-
*out = static_cast<Int64>(val);
264-
}
265-
else
266-
{
267-
; // never arrived here
268-
}
269-
}
270-
};
271-
272-
273202
/** Rounding functions for integer values.
274203
*/
275204
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
@@ -327,12 +256,75 @@ struct IntegerRoundingComputation
327256
}
328257
}
329258

330-
static ALWAYS_INLINE void compute(const T * __restrict in, size_t scale, T * __restrict out)
259+
static ALWAYS_INLINE void compute(const T * __restrict in, T scale, T * __restrict out)
331260
{
332261
*out = compute(*in, scale);
333262
}
334263
};
335264

265+
/** Rounding functions for decimal values
266+
*/
267+
268+
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
269+
struct DecimalRoundingComputation
270+
{
271+
static_assert(IsDecimal<T>);
272+
static const size_t data_count = 1;
273+
static size_t prepare(size_t scale) { return scale; }
274+
// compute need decimal_scale to interpret decimals
275+
static inline void compute(
276+
const T * __restrict in,
277+
size_t scale,
278+
OutputType * __restrict out,
279+
ScaleType decimal_scale)
280+
{
281+
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
282+
using NativeType = T::NativeType;
283+
// Currently, we only use DecimalRoundingComputation for floor/ceil.
284+
// As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
285+
// So, we only handle ScaleMode::Zero here.
286+
if constexpr (scale_mode == ScaleMode::Zero)
287+
{
288+
auto scale_factor = intExp10OfSize<NativeType>(decimal_scale);
289+
try
290+
{
291+
if constexpr (rounding_mode == RoundingMode::Floor)
292+
{
293+
auto x = in->value;
294+
if (x < 0)
295+
x -= scale_factor - 1;
296+
*out = static_cast<OutputType>(x / scale_factor);
297+
}
298+
else if constexpr (rounding_mode == RoundingMode::Ceil)
299+
{
300+
auto x = in->value;
301+
if (x >= 0)
302+
x += scale_factor - 1;
303+
*out = static_cast<OutputType>(x / scale_factor);
304+
}
305+
else
306+
{
307+
throw Exception(
308+
"Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation",
309+
ErrorCodes::LOGICAL_ERROR);
310+
}
311+
}
312+
catch (const std::overflow_error & e)
313+
{
314+
throw Exception(
315+
"Logical error: unexpected overflow in DecimalRoundingComputation",
316+
ErrorCodes::LOGICAL_ERROR);
317+
}
318+
}
319+
else
320+
{
321+
throw Exception(
322+
"Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
323+
+ toString(scale),
324+
ErrorCodes::LOGICAL_ERROR);
325+
}
326+
}
327+
};
336328

337329
#if __SSE4_1__
338330

@@ -540,7 +532,7 @@ struct IntegerRoundingImpl
540532

541533
while (p_in < end_in)
542534
{
543-
Op::compute(p_in, scale, p_out);
535+
Op::compute(p_in, static_cast<T>(scale), p_out);
544536
++p_in;
545537
++p_out;
546538
}
@@ -698,7 +690,12 @@ struct Dispatcher
698690

699691
if constexpr (IsDecimal<OutputType>)
700692
{
701-
auto col_res = ColumnDecimal<OutputType>::create(col->getData().size(), col->getData().getScale());
693+
UInt32 res_scale = 0;
694+
if constexpr (rounding_mode == RoundingMode::Round || rounding_mode == RoundingMode::Trunc)
695+
{
696+
res_scale = col->getData().getScale();
697+
}
698+
auto col_res = ColumnDecimal<OutputType>::create(col->getData().size(), res_scale);
702699
typename ColumnDecimal<OutputType>::Container & vec_res = col_res->getData();
703700
applyInternal(col, vec_res, col_res, block, scale_arg, result);
704701
}
@@ -813,6 +810,20 @@ class FunctionRounding : public IFunction
813810
fmt::format("Illegal type {} of argument of function {}", arguments[0]->getName(), getName()),
814811
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
815812

813+
if constexpr (rounding_mode == RoundingMode::Ceil || rounding_mode == RoundingMode::Floor)
814+
{
815+
if (arguments[0]->isDecimal())
816+
{
817+
if (const auto * decimal_type32 = checkAndGetDataType<DataTypeDecimal32>(arguments[0].get()))
818+
return std::make_shared<DataTypeDecimal32>(decimal_type32->getPrec(), 0);
819+
else if (const auto * decimal_type64 = checkAndGetDataType<DataTypeDecimal64>(arguments[0].get()))
820+
return std::make_shared<DataTypeDecimal64>(decimal_type64->getPrec(), 0);
821+
else if (const auto * decimal_type128 = checkAndGetDataType<DataTypeDecimal128>(arguments[0].get()))
822+
return std::make_shared<DataTypeDecimal128>(decimal_type128->getPrec(), 0);
823+
else if (const auto * decimal_type256 = checkAndGetDataType<DataTypeDecimal256>(arguments[0].get()))
824+
return std::make_shared<DataTypeDecimal256>(decimal_type256->getPrec(), 0);
825+
}
826+
}
816827
return arguments[0];
817828
}
818829

0 commit comments

Comments
 (0)