@@ -206,6 +206,7 @@ enum class RoundingMode
206206#endif
207207};
208208
209+ <<<<<<< HEAD
209210/* * Rounding functions for decimal values
210211 */
211212
@@ -276,6 +277,8 @@ struct DecimalRoundingComputation
276277};
277278
278279
280+ =======
281+ >>>>>>> f5f30b8ffc (Fix decimal floor/ceil (#10365 ) (#10364 ))
279282/* * Rounding functions for integer values.
280283 */
281284template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
@@ -336,12 +339,74 @@ struct IntegerRoundingComputation
336339 }
337340 }
338341
339- static ALWAYS_INLINE void compute (const T * __restrict in, size_t scale, T * __restrict out)
342+ static ALWAYS_INLINE void compute (const T * __restrict in, T scale, T * __restrict out)
340343 {
341344 *out = compute (*in, scale);
342345 }
343346};
344347
348+ /* * Rounding functions for decimal values
349+ */
350+
351+ template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
352+ struct DecimalRoundingComputation
353+ {
354+ static_assert (IsDecimal<T>);
355+ using NativeType = T::NativeType;
356+ static const size_t data_count = 1 ;
357+ static size_t prepare (size_t scale) { return scale; }
358+ // compute need decimal_scale to interpret decimals
359+ static inline void compute (
360+ const T * __restrict in,
361+ size_t scale,
362+ OutputType * __restrict out,
363+ NativeType decimal_scale)
364+ {
365+ static_assert (std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
366+ // Currently, we only use DecimalRoundingComputation for floor/ceil.
367+ // As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
368+ // So, we only handle ScaleMode::Zero here.
369+ if constexpr (scale_mode == ScaleMode::Zero)
370+ {
371+ try
372+ {
373+ if constexpr (rounding_mode == RoundingMode::Floor)
374+ {
375+ auto x = in->value ;
376+ if (x < 0 )
377+ x -= decimal_scale - 1 ;
378+ *out = static_cast <OutputType>(x / decimal_scale);
379+ }
380+ else if constexpr (rounding_mode == RoundingMode::Ceil)
381+ {
382+ auto x = in->value ;
383+ if (x >= 0 )
384+ x += decimal_scale - 1 ;
385+ *out = static_cast <OutputType>(x / decimal_scale);
386+ }
387+ else
388+ {
389+ throw Exception (
390+ " Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation" ,
391+ ErrorCodes::LOGICAL_ERROR);
392+ }
393+ }
394+ catch (const std::overflow_error & e)
395+ {
396+ throw Exception (
397+ " Logical error: unexpected overflow in DecimalRoundingComputation" ,
398+ ErrorCodes::LOGICAL_ERROR);
399+ }
400+ }
401+ else
402+ {
403+ throw Exception (
404+ " Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
405+ + toString (scale),
406+ ErrorCodes::LOGICAL_ERROR);
407+ }
408+ }
409+ };
345410
346411#if __SSE4_1__
347412
@@ -554,7 +619,7 @@ struct IntegerRoundingImpl
554619
555620 while (p_in < end_in)
556621 {
557- Op::compute (p_in, scale, p_out);
622+ Op::compute (p_in, static_cast <T>( scale) , p_out);
558623 ++p_in;
559624 ++p_out;
560625 }
@@ -620,14 +685,18 @@ struct DecimalRoundingImpl;
620685template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
621686struct DecimalRoundingImpl <T, rounding_mode, scale_mode, Int64>
622687{
688+ static_assert (IsDecimal<T>);
689+ using NativeType = typename T::NativeType;
690+
623691private:
624692 using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, Int64>;
625693 using Data = T;
626694
627695public:
628696 static NO_INLINE void apply (const DecimalPaddedPODArray<T> & in, size_t scale, typename ColumnVector<Int64>::Container & out)
629697 {
630- ScaleType decimal_scale = in.getScale ();
698+ ScaleType in_scale = in.getScale ();
699+ auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
631700 const T * end_in = in.data () + in.size ();
632701
633702 const T * __restrict p_in = in.data ();
@@ -645,14 +714,18 @@ struct DecimalRoundingImpl<T, rounding_mode, scale_mode, Int64>
645714template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
646715struct DecimalRoundingImpl <T, rounding_mode, scale_mode, T>
647716{
717+ static_assert (IsDecimal<T>);
718+ using NativeType = typename T::NativeType;
719+
648720private:
649721 using Op = DecimalRoundingComputation<T, rounding_mode, scale_mode, T>;
650722 using Data = T;
651723
652724public:
653725 static NO_INLINE void apply (const DecimalPaddedPODArray<T> & in, size_t scale, typename ColumnDecimal<T>::Container & out)
654726 {
655- ScaleType decimal_scale = in.getScale ();
727+ ScaleType in_scale = in.getScale ();
728+ auto decimal_scale = intExp10OfSize<NativeType>(in_scale);
656729 const T * end_in = in.data () + in.size ();
657730
658731 const T * __restrict p_in = in.data ();
@@ -705,7 +778,12 @@ struct Dispatcher
705778
706779 if constexpr (IsDecimal<OutputType>)
707780 {
708- auto col_res = ColumnDecimal<OutputType>::create (col->getData ().size (), col->getData ().getScale ());
781+ UInt32 res_scale = 0 ;
782+ if constexpr (rounding_mode == RoundingMode::Round || rounding_mode == RoundingMode::Trunc)
783+ {
784+ res_scale = col->getData ().getScale ();
785+ }
786+ auto col_res = ColumnDecimal<OutputType>::create (col->getData ().size (), res_scale);
709787 typename ColumnDecimal<OutputType>::Container & vec_res = col_res->getData ();
710788 applyInternal (col, vec_res, col_res, block, scale_arg, result);
711789 }
@@ -808,6 +886,20 @@ class FunctionRounding : public IFunction
808886 fmt::format (" Illegal type {} of argument of function {}" , arguments[0 ]->getName (), getName ()),
809887 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
810888
889+ if constexpr (rounding_mode == RoundingMode::Ceil || rounding_mode == RoundingMode::Floor)
890+ {
891+ if (arguments[0 ]->isDecimal ())
892+ {
893+ if (const auto * decimal_type32 = checkAndGetDataType<DataTypeDecimal32>(arguments[0 ].get ()))
894+ return std::make_shared<DataTypeDecimal32>(decimal_type32->getPrec (), 0 );
895+ else if (const auto * decimal_type64 = checkAndGetDataType<DataTypeDecimal64>(arguments[0 ].get ()))
896+ return std::make_shared<DataTypeDecimal64>(decimal_type64->getPrec (), 0 );
897+ else if (const auto * decimal_type128 = checkAndGetDataType<DataTypeDecimal128>(arguments[0 ].get ()))
898+ return std::make_shared<DataTypeDecimal128>(decimal_type128->getPrec (), 0 );
899+ else if (const auto * decimal_type256 = checkAndGetDataType<DataTypeDecimal256>(arguments[0 ].get ()))
900+ return std::make_shared<DataTypeDecimal256>(decimal_type256->getPrec (), 0 );
901+ }
902+ }
811903 return arguments[0 ];
812904 }
813905
0 commit comments