Skip to content

Commit c4ccf0f

Browse files
ChangRui-Ryanti-chi-bot
authored andcommitted
This is an automated cherry-pick of #10364
Signed-off-by: ti-chi-bot <ti-community-prow-bot@tidb.io>
1 parent c9d93fc commit c4ccf0f

File tree

4 files changed

+700
-5
lines changed

4 files changed

+700
-5
lines changed

dbms/src/Functions/FunctionsRound.h

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ enum class RoundingMode
206206
#endif
207207
};
208208

209+
<<<<<<< HEAD
209210
/** Rounding functions for decimal values
210211
*/
211212

@@ -276,6 +277,8 @@ struct DecimalRoundingComputation
276277
};
277278

278279

280+
=======
281+
>>>>>>> f5f30b8ffc (Fix decimal floor/ceil (#10365) (#10364))
279282
/** Rounding functions for integer values.
280283
*/
281284
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
@@ -336,12 +339,74 @@ struct IntegerRoundingComputation
336339
}
337340
}
338341

339-
static ALWAYS_INLINE void compute(const T * __restrict in, size_t scale, T * __restrict out)
342+
static ALWAYS_INLINE void compute(const T * __restrict in, T scale, T * __restrict out)
340343
{
341344
*out = compute(*in, scale);
342345
}
343346
};
344347

348+
/** Rounding functions for decimal values
349+
*/
350+
351+
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
352+
struct DecimalRoundingComputation
353+
{
354+
static_assert(IsDecimal<T>);
355+
using NativeType = T::NativeType;
356+
static const size_t data_count = 1;
357+
static size_t prepare(size_t scale) { return scale; }
358+
// compute need decimal_scale to interpret decimals
359+
static inline void compute(
360+
const T * __restrict in,
361+
size_t scale,
362+
OutputType * __restrict out,
363+
NativeType decimal_scale)
364+
{
365+
static_assert(std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
366+
// Currently, we only use DecimalRoundingComputation for floor/ceil.
367+
// As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
368+
// So, we only handle ScaleMode::Zero here.
369+
if constexpr (scale_mode == ScaleMode::Zero)
370+
{
371+
try
372+
{
373+
if constexpr (rounding_mode == RoundingMode::Floor)
374+
{
375+
auto x = in->value;
376+
if (x < 0)
377+
x -= decimal_scale - 1;
378+
*out = static_cast<OutputType>(x / decimal_scale);
379+
}
380+
else if constexpr (rounding_mode == RoundingMode::Ceil)
381+
{
382+
auto x = in->value;
383+
if (x >= 0)
384+
x += decimal_scale - 1;
385+
*out = static_cast<OutputType>(x / decimal_scale);
386+
}
387+
else
388+
{
389+
throw Exception(
390+
"Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation",
391+
ErrorCodes::LOGICAL_ERROR);
392+
}
393+
}
394+
catch (const std::overflow_error & e)
395+
{
396+
throw Exception(
397+
"Logical error: unexpected overflow in DecimalRoundingComputation",
398+
ErrorCodes::LOGICAL_ERROR);
399+
}
400+
}
401+
else
402+
{
403+
throw Exception(
404+
"Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
405+
+ toString(scale),
406+
ErrorCodes::LOGICAL_ERROR);
407+
}
408+
}
409+
};
345410

346411
#if __SSE4_1__
347412

@@ -554,7 +619,7 @@ struct IntegerRoundingImpl
554619

555620
while (p_in < end_in)
556621
{
557-
Op::compute(p_in, scale, p_out);
622+
Op::compute(p_in, static_cast<T>(scale), p_out);
558623
++p_in;
559624
++p_out;
560625
}
@@ -620,14 +685,18 @@ struct DecimalRoundingImpl;
620685
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
621686
struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
622687
{
688+
static_assert(IsDecimal<T>);
689+
using NativeType = typename T::NativeType;
690+
623691
private:
624692
using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, Int64>;
625693
using Data = T;
626694

627695
public:
628696
static NO_INLINE void apply(const DecimalPaddedPODArray<T> & in, size_t scale, typename ColumnVector<Int64>::Container & out)
629697
{
630-
ScaleType decimal_scale = in.getScale();
698+
ScaleType in_scale = in.getScale();
699+
auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
631700
const T * end_in = in.data() + in.size();
632701

633702
const T * __restrict p_in = in.data();
@@ -645,14 +714,18 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
645714
template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
646715
struct DecimalRoundingImpl<T, rounding_mode, scale_mode, T>
647716
{
717+
static_assert(IsDecimal<T>);
718+
using NativeType = typename T::NativeType;
719+
648720
private:
649721
using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, T>;
650722
using Data = T;
651723

652724
public:
653725
static NO_INLINE void apply(const DecimalPaddedPODArray<T> & in, size_t scale, typename ColumnDecimal<T>::Container & out)
654726
{
655-
ScaleType decimal_scale = in.getScale();
727+
ScaleType in_scale = in.getScale();
728+
auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
656729
const T * end_in = in.data() + in.size();
657730

658731
const T * __restrict p_in = in.data();
@@ -705,7 +778,12 @@ struct Dispatcher
705778

706779
if constexpr (IsDecimal<OutputType>)
707780
{
708-
auto col_res = ColumnDecimal<OutputType>::create(col->getData().size(), col->getData().getScale());
781+
UInt32 res_scale = 0;
782+
if constexpr (rounding_mode == RoundingMode::Round || rounding_mode == RoundingMode::Trunc)
783+
{
784+
res_scale = col->getData().getScale();
785+
}
786+
auto col_res = ColumnDecimal<OutputType>::create(col->getData().size(), res_scale);
709787
typename ColumnDecimal<OutputType>::Container & vec_res = col_res->getData();
710788
applyInternal(col, vec_res, col_res, block, scale_arg, result);
711789
}
@@ -808,6 +886,20 @@ class FunctionRounding : public IFunction
808886
fmt::format("Illegal type {} of argument of function {}", arguments[0]->getName(), getName()),
809887
ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
810888

889+
if constexpr (rounding_mode == RoundingMode::Ceil || rounding_mode == RoundingMode::Floor)
890+
{
891+
if (arguments[0]->isDecimal())
892+
{
893+
if (const auto * decimal_type32 = checkAndGetDataType<DataTypeDecimal32>(arguments[0].get()))
894+
return std::make_shared<DataTypeDecimal32>(decimal_type32->getPrec(), 0);
895+
else if (const auto * decimal_type64 = checkAndGetDataType<DataTypeDecimal64>(arguments[0].get()))
896+
return std::make_shared<DataTypeDecimal64>(decimal_type64->getPrec(), 0);
897+
else if (const auto * decimal_type128 = checkAndGetDataType<DataTypeDecimal128>(arguments[0].get()))
898+
return std::make_shared<DataTypeDecimal128>(decimal_type128->getPrec(), 0);
899+
else if (const auto * decimal_type256 = checkAndGetDataType<DataTypeDecimal256>(arguments[0].get()))
900+
return std::make_shared<DataTypeDecimal256>(decimal_type256->getPrec(), 0);
901+
}
902+
}
811903
return arguments[0];
812904
}
813905

0 commit comments

Comments
 (0)