@@ -199,77 +199,6 @@ enum class RoundingMode
199199#endif
200200};
201201
202- /* * Rounding functions for decimal values
203- */
204-
205- template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
206- struct DecimalRoundingComputation
207- {
208- static_assert (IsDecimal<T>);
209- static const size_t data_count = 1 ;
210- static size_t prepare (size_t scale) { return scale; }
211- // compute need decimal_scale to interpret decimals
212- static inline void compute (
213- const T * __restrict in,
214- size_t scale,
215- OutputType * __restrict out,
216- ScaleType decimal_scale)
217- {
218- static_assert (std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
219- Float64 val = in->template toFloat <Float64>(decimal_scale);
220-
221- if constexpr (scale_mode == ScaleMode::Positive)
222- {
223- val = val * scale;
224- }
225- else if constexpr (scale_mode == ScaleMode::Negative)
226- {
227- val = val / scale;
228- }
229-
230- if constexpr (rounding_mode == RoundingMode::Round)
231- {
232- val = round (val);
233- }
234- else if constexpr (rounding_mode == RoundingMode::Floor)
235- {
236- val = floor (val);
237- }
238- else if constexpr (rounding_mode == RoundingMode::Ceil)
239- {
240- val = ceil (val);
241- }
242- else if constexpr (rounding_mode == RoundingMode::Trunc)
243- {
244- val = trunc (val);
245- }
246-
247-
248- if constexpr (scale_mode == ScaleMode::Positive)
249- {
250- val = val / scale;
251- }
252- else if constexpr (scale_mode == ScaleMode::Negative)
253- {
254- val = val * scale;
255- }
256-
257- if constexpr (std::is_same_v<T, OutputType>)
258- {
259- *out = ToDecimal<Float64, T>(val, decimal_scale);
260- }
261- else if constexpr (std::is_same_v<OutputType, Int64>)
262- {
263- *out = static_cast <Int64>(val);
264- }
265- else
266- {
267- ; // never arrived here
268- }
269- }
270- };
271-
272-
273202/* * Rounding functions for integer values.
274203 */
275204template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode>
@@ -327,12 +256,75 @@ struct IntegerRoundingComputation
327256 }
328257 }
329258
330- static ALWAYS_INLINE void compute (const T * __restrict in, size_t scale, T * __restrict out)
259+ static ALWAYS_INLINE void compute (const T * __restrict in, T scale, T * __restrict out)
331260 {
332261 *out = compute (*in, scale);
333262 }
334263};
335264
265+ /* * Rounding functions for decimal values
266+ */
267+
268+ template <typename T, RoundingMode rounding_mode, ScaleMode scale_mode, typename OutputType>
269+ struct DecimalRoundingComputation
270+ {
271+ static_assert (IsDecimal<T>);
272+ static const size_t data_count = 1 ;
273+ static size_t prepare (size_t scale) { return scale; }
274+ // compute need decimal_scale to interpret decimals
275+ static inline void compute (
276+ const T * __restrict in,
277+ size_t scale,
278+ OutputType * __restrict out,
279+ ScaleType decimal_scale)
280+ {
281+ static_assert (std::is_same_v<T, OutputType> || std::is_same_v<OutputType, Int64>);
282+ using NativeType = T::NativeType;
283+ // Currently, we only use DecimalRoundingComputation for floor/ceil.
284+ // As for round/truncate, we always use tidbRoundWithFrac/tidbTruncateWithFrac.
285+ // So, we only handle ScaleMode::Zero here.
286+ if constexpr (scale_mode == ScaleMode::Zero)
287+ {
288+ auto scale_factor = intExp10OfSize<NativeType>(decimal_scale);
289+ try
290+ {
291+ if constexpr (rounding_mode == RoundingMode::Floor)
292+ {
293+ auto x = in->value ;
294+ if (x < 0 )
295+ x -= scale_factor - 1 ;
296+ *out = static_cast <OutputType>(x / scale_factor);
297+ }
298+ else if constexpr (rounding_mode == RoundingMode::Ceil)
299+ {
300+ auto x = in->value ;
301+ if (x >= 0 )
302+ x += scale_factor - 1 ;
303+ *out = static_cast <OutputType>(x / scale_factor);
304+ }
305+ else
306+ {
307+ throw Exception (
308+ " Logical error: unexpected 'rounding_mode' of DecimalRoundingComputation" ,
309+ ErrorCodes::LOGICAL_ERROR);
310+ }
311+ }
312+ catch (const std::overflow_error & e)
313+ {
314+ throw Exception (
315+ " Logical error: unexpected overflow in DecimalRoundingComputation" ,
316+ ErrorCodes::LOGICAL_ERROR);
317+ }
318+ }
319+ else
320+ {
321+ throw Exception (
322+ " Logical error: unexpected 'scale_mode' of DecimalRoundingComputation and unexpected scale: "
323+ + toString (scale),
324+ ErrorCodes::LOGICAL_ERROR);
325+ }
326+ }
327+ };
336328
337329#if __SSE4_1__
338330
@@ -540,7 +532,7 @@ struct IntegerRoundingImpl
540532
541533 while (p_in < end_in)
542534 {
543- Op::compute (p_in, scale, p_out);
535+ Op::compute (p_in, static_cast <T>( scale) , p_out);
544536 ++p_in;
545537 ++p_out;
546538 }
@@ -698,7 +690,12 @@ struct Dispatcher
698690
699691 if constexpr (IsDecimal<OutputType>)
700692 {
701- auto col_res = ColumnDecimal<OutputType>::create (col->getData ().size (), col->getData ().getScale ());
693+ UInt32 res_scale = 0 ;
694+ if constexpr (rounding_mode == RoundingMode::Round || rounding_mode == RoundingMode::Trunc)
695+ {
696+ res_scale = col->getData ().getScale ();
697+ }
698+ auto col_res = ColumnDecimal<OutputType>::create (col->getData ().size (), res_scale);
702699 typename ColumnDecimal<OutputType>::Container & vec_res = col_res->getData ();
703700 applyInternal (col, vec_res, col_res, block, scale_arg, result);
704701 }
@@ -813,6 +810,20 @@ class FunctionRounding : public IFunction
813810 fmt::format (" Illegal type {} of argument of function {}" , arguments[0 ]->getName (), getName ()),
814811 ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT);
815812
813+ if constexpr (rounding_mode == RoundingMode::Ceil || rounding_mode == RoundingMode::Floor)
814+ {
815+ if (arguments[0 ]->isDecimal ())
816+ {
817+ if (const auto * decimal_type32 = checkAndGetDataType<DataTypeDecimal32>(arguments[0 ].get ()))
818+ return std::make_shared<DataTypeDecimal32>(decimal_type32->getPrec (), 0 );
819+ else if (const auto * decimal_type64 = checkAndGetDataType<DataTypeDecimal64>(arguments[0 ].get ()))
820+ return std::make_shared<DataTypeDecimal64>(decimal_type64->getPrec (), 0 );
821+ else if (const auto * decimal_type128 = checkAndGetDataType<DataTypeDecimal128>(arguments[0 ].get ()))
822+ return std::make_shared<DataTypeDecimal128>(decimal_type128->getPrec (), 0 );
823+ else if (const auto * decimal_type256 = checkAndGetDataType<DataTypeDecimal256>(arguments[0 ].get ()))
824+ return std::make_shared<DataTypeDecimal256>(decimal_type256->getPrec (), 0 );
825+ }
826+ }
816827 return arguments[0 ];
817828 }
818829
0 commit comments