From c2cbb795200734b68c2bf1a1d1b2afdb5877f4c6 Mon Sep 17 00:00:00 2001 From: fys Date: Wed, 3 Jun 2026 22:56:28 +0800 Subject: [PATCH] fix: timestamp comparisons to coerce to finer unit --- datafusion/common/src/scalar/mod.rs | 46 +++++++++- datafusion/expr-common/src/columnar_value.rs | 86 +++++++++++++++++-- .../expr-common/src/type_coercion/binary.rs | 20 +---- .../type_coercion/binary/tests/arithmetic.rs | 4 +- .../type_coercion/binary/tests/comparison.rs | 36 ++++++++ .../test_files/datetime/timestamps.slt | 53 +++++++++++- 6 files changed, 214 insertions(+), 31 deletions(-) diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index 3e154b491eda7..ec8b5dd00ed42 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -148,6 +148,30 @@ pub fn date_to_timestamp_multiplier( } } +/// Returns the multiplier that converts the input timestamp representation into +/// the desired timestamp unit, if the conversion requires a multiplication that +/// can overflow an `i64`. +pub fn timestamp_to_timestamp_multiplier( + source_type: &DataType, + target_type: &DataType, +) -> Option { + let (DataType::Timestamp(source_unit, _), DataType::Timestamp(target_unit, _)) = + (source_type, target_type) + else { + return None; + }; + + match (source_unit, target_unit) { + (TimeUnit::Second, TimeUnit::Millisecond) => Some(1_000), + (TimeUnit::Second, TimeUnit::Microsecond) => Some(1_000_000), + (TimeUnit::Second, TimeUnit::Nanosecond) => Some(1_000_000_000), + (TimeUnit::Millisecond, TimeUnit::Microsecond) => Some(1_000), + (TimeUnit::Millisecond, TimeUnit::Nanosecond) => Some(1_000_000), + (TimeUnit::Microsecond, TimeUnit::Nanosecond) => Some(1_000), + _ => None, + } +} + /// Ensures the provided value can be represented as a timestamp with the given /// multiplier. Returns an [`DataFusionError::Execution`] when the converted /// value would overflow the timestamp range. @@ -4265,7 +4289,8 @@ impl ScalarValue { } if let Some(multiplier) = date_to_timestamp_multiplier(&source_type, target_type) - && let Some(value) = self.date_scalar_value_as_i64() + .or_else(|| timestamp_to_timestamp_multiplier(&source_type, target_type)) + && let Some(value) = self.temporal_scalar_value_as_i64() { ensure_timestamp_in_bounds(value, multiplier, &source_type, target_type)?; } @@ -4287,10 +4312,14 @@ impl ScalarValue { ScalarValue::try_from_array(&cast_arr, 0) } - fn date_scalar_value_as_i64(&self) -> Option { + fn temporal_scalar_value_as_i64(&self) -> Option { match self { ScalarValue::Date32(Some(value)) => Some(i64::from(*value)), ScalarValue::Date64(Some(value)) => Some(*value), + ScalarValue::TimestampSecond(Some(value), _) + | ScalarValue::TimestampMillisecond(Some(value), _) + | ScalarValue::TimestampMicrosecond(Some(value), _) + | ScalarValue::TimestampNanosecond(Some(value), _) => Some(*value), _ => None, } } @@ -10149,6 +10178,19 @@ mod tests { ); } + #[test] + fn cast_timestamp_to_timestamp_overflow_returns_error() { + let scalar = ScalarValue::TimestampSecond(Some(i64::MAX), None); + let err = scalar + .cast_to(&DataType::Timestamp(TimeUnit::Nanosecond, None)) + .expect_err("expected cast to fail"); + assert!( + err.to_string() + .contains("converted value exceeds the representable i64 range"), + "unexpected error: {err}" + ); + } + #[test] fn null_dictionary_scalar_produces_null_dictionary_array() { let dictionary_scalar = ScalarValue::Dictionary( diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index bc6b8177ab3cf..caeb3f10da752 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -18,9 +18,12 @@ //! [`ColumnarValue`] represents the result of evaluating an expression. use arrow::{ - array::{Array, ArrayRef, Date32Array, Date64Array, NullArray}, + array::{ + Array, ArrayRef, Date32Array, Date64Array, NullArray, TimestampMicrosecondArray, + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + }, compute::{CastOptions, kernels, max, min}, - datatypes::DataType, + datatypes::{DataType, TimeUnit}, util::pretty::pretty_format_columns, }; use datafusion_common::internal_datafusion_err; @@ -28,7 +31,10 @@ use datafusion_common::{ Result, ScalarValue, format::DEFAULT_CAST_OPTIONS, internal_err, - scalar::{date_to_timestamp_multiplier, ensure_timestamp_in_bounds}, + scalar::{ + date_to_timestamp_multiplier, ensure_timestamp_in_bounds, + timestamp_to_timestamp_multiplier, + }, }; use std::fmt; use std::sync::Arc; @@ -319,7 +325,7 @@ fn cast_array_by_name( ) { datafusion_common::nested_struct::cast_column(array, cast_type, cast_options) } else { - ensure_date_array_timestamp_bounds(array, cast_type)?; + ensure_temporal_array_timestamp_bounds(array, cast_type)?; Ok(kernels::cast::cast_with_options( array, cast_type, @@ -328,12 +334,14 @@ fn cast_array_by_name( } } -fn ensure_date_array_timestamp_bounds( +fn ensure_temporal_array_timestamp_bounds( array: &ArrayRef, cast_type: &DataType, ) -> Result<()> { let source_type = array.data_type().clone(); - let Some(multiplier) = date_to_timestamp_multiplier(&source_type, cast_type) else { + let Some(multiplier) = date_to_timestamp_multiplier(&source_type, cast_type) + .or_else(|| timestamp_to_timestamp_multiplier(&source_type, cast_type)) + else { return Ok(()); }; @@ -367,7 +375,55 @@ fn ensure_date_array_timestamp_bounds( })?; (min(arr), max(arr)) } - _ => return Ok(()), // Not a date type, nothing to do + DataType::Timestamp(TimeUnit::Second, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!( + "Expected TimestampSecondArray but found {}", + array.data_type() + ) + })?; + (min(arr), max(arr)) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!( + "Expected TimestampMillisecondArray but found {}", + array.data_type() + ) + })?; + (min(arr), max(arr)) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!( + "Expected TimestampMicrosecondArray but found {}", + array.data_type() + ) + })?; + (min(arr), max(arr)) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + internal_datafusion_err!( + "Expected TimestampNanosecondArray but found {}", + array.data_type() + ) + })?; + (min(arr), max(arr)) + } + _ => return Ok(()), // Not a temporal type that needs checking. }; // Only validate the min and max values instead of all elements @@ -694,4 +750,20 @@ mod tests { "unexpected error: {err}" ); } + + #[test] + fn cast_timestamp_array_to_timestamp_overflow() { + let overflow_value = i64::MAX / 1_000_000_000 + 1; + let array: ArrayRef = + Arc::new(TimestampSecondArray::from(vec![Some(overflow_value)])); + let value = ColumnarValue::Array(array); + let result = + value.cast_to(&DataType::Timestamp(TimeUnit::Nanosecond, None), None); + let err = result.expect_err("expected overflow to be detected"); + assert!( + err.to_string() + .contains("converted value exceeds the representable i64 range"), + "unexpected error: {err}" + ); + } } diff --git a/datafusion/expr-common/src/type_coercion/binary.rs b/datafusion/expr-common/src/type_coercion/binary.rs index aec87ec5ff853..4581745ccbb8c 100644 --- a/datafusion/expr-common/src/type_coercion/binary.rs +++ b/datafusion/expr-common/src/type_coercion/binary.rs @@ -2048,22 +2048,10 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option TimeUnit { use arrow::datatypes::TimeUnit::*; match (lhs_unit, rhs_unit) { - (Second, Millisecond) => Second, - (Second, Microsecond) => Second, - (Second, Nanosecond) => Second, - (Millisecond, Second) => Second, - (Millisecond, Microsecond) => Millisecond, - (Millisecond, Nanosecond) => Millisecond, - (Microsecond, Second) => Second, - (Microsecond, Millisecond) => Millisecond, - (Microsecond, Nanosecond) => Microsecond, - (Nanosecond, Second) => Second, - (Nanosecond, Millisecond) => Millisecond, - (Nanosecond, Microsecond) => Microsecond, - (l, r) => { - assert_eq!(l, r); - *l - } + (Second, Second) => Second, + (Nanosecond, _) | (_, Nanosecond) => Nanosecond, + (Microsecond, _) | (_, Microsecond) => Microsecond, + (Millisecond, _) | (_, Millisecond) => Millisecond, } } diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs index eb5622fedb8aa..70a8fc0e35a15 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/arithmetic.rs @@ -40,8 +40,8 @@ fn test_date_timestamp_arithmetic_error() -> Result<()> { &DataType::Timestamp(Millisecond, None), ) .get_input_types()?; - assert_eq!(lhs, DataType::Timestamp(Millisecond, None)); - assert_eq!(rhs, DataType::Timestamp(Millisecond, None)); + assert_eq!(lhs, DataType::Timestamp(Nanosecond, None)); + assert_eq!(rhs, DataType::Timestamp(Nanosecond, None)); let err = BinaryTypeCoercer::new(&DataType::Date32, &Operator::Plus, &DataType::Date64) diff --git a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs index f8bff3ca90ecf..5f6b7dfcc1d4f 100644 --- a/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs +++ b/datafusion/expr-common/src/type_coercion/binary/tests/comparison.rs @@ -575,6 +575,24 @@ fn test_type_coercion_compare() -> Result<()> { Operator::Eq, DataType::Timestamp(Second, Some("Europe/Brussels".into())) ); + test_coercion_binary_rule!( + DataType::Timestamp(Second, None), + DataType::Timestamp(Millisecond, None), + Operator::Eq, + DataType::Timestamp(Millisecond, None) + ); + test_coercion_binary_rule!( + DataType::Timestamp(Second, Some("America/New_York".into())), + DataType::Timestamp(Nanosecond, Some("Europe/Brussels".into())), + Operator::Lt, + DataType::Timestamp(Nanosecond, Some("America/New_York".into())) + ); + test_coercion_binary_rule!( + DataType::Timestamp(Microsecond, None), + DataType::Timestamp(Nanosecond, None), + Operator::GtEq, + DataType::Timestamp(Nanosecond, None) + ); // list let inner_field = Arc::new(Field::new_list_field(DataType::Int64, true)); @@ -872,6 +890,24 @@ fn test_type_union_coercion_prefers_string() { ); } +#[test] +fn test_type_union_coercion_prefers_finer_timestamp_unit() { + assert_eq!( + type_union_coercion( + &DataType::Timestamp(Second, None), + &DataType::Timestamp(Millisecond, None), + ), + Some(DataType::Timestamp(Millisecond, None)) + ); + assert_eq!( + type_union_resolution(&[ + DataType::Timestamp(Second, None), + DataType::Timestamp(Nanosecond, None), + ]), + Some(DataType::Timestamp(Nanosecond, None)) + ); +} + /// Tests that comparison operators coerce to numeric when comparing /// numeric and string types. #[test] diff --git a/datafusion/sqllogictest/test_files/datetime/timestamps.slt b/datafusion/sqllogictest/test_files/datetime/timestamps.slt index 958ff86b4fb4d..89c6f0a12139e 100644 --- a/datafusion/sqllogictest/test_files/datetime/timestamps.slt +++ b/datafusion/sqllogictest/test_files/datetime/timestamps.slt @@ -2480,6 +2480,51 @@ SELECT TIMESTAMPTZ '2020-01-01 00:00:00Z' = TIMESTAMP '2020-01-01' ---- true +query BBB +SELECT + arrow_cast(TIMESTAMP '2024-01-01 00:00:00', 'Timestamp(Second, None)') = + arrow_cast(TIMESTAMP '2024-01-01 00:00:00.123', 'Timestamp(Millisecond, None)'), + arrow_cast(TIMESTAMP '2024-01-01 00:00:00', 'Timestamp(Second, None)') = + arrow_cast(TIMESTAMP '2024-01-01 00:00:00.000', 'Timestamp(Millisecond, None)'), + arrow_cast(TIMESTAMP '2024-01-01 00:00:00', 'Timestamp(Second, None)') < + arrow_cast(TIMESTAMP '2024-01-01 00:00:00.123', 'Timestamp(Millisecond, None)') +---- +false true true + +query ? +SELECT + arrow_cast(TIMESTAMP '2024-01-01 00:00:00.123', 'Timestamp(Millisecond, None)') - + arrow_cast(TIMESTAMP '2024-01-01 00:00:00', 'Timestamp(Second, None)') +---- +0 days 0 hours 0 mins 0.123 secs + +query TP +SELECT arrow_typeof(ts), ts +FROM ( + SELECT arrow_cast(TIMESTAMP '2024-01-01 00:00:00', 'Timestamp(Second, None)') AS ts + UNION ALL + SELECT arrow_cast(TIMESTAMP '2024-01-01 00:00:00.123', 'Timestamp(Millisecond, None)') AS ts +) +ORDER BY ts +---- +Timestamp(ms) 2024-01-01T00:00:00 +Timestamp(ms) 2024-01-01T00:00:00.123 + +query TP +SELECT + arrow_typeof( + coalesce( + arrow_cast(NULL, 'Timestamp(Second, None)'), + arrow_cast(TIMESTAMP '2024-01-01 00:00:00.123', 'Timestamp(Millisecond, None)') + ) + ), + coalesce( + arrow_cast(NULL, 'Timestamp(Second, None)'), + arrow_cast(TIMESTAMP '2024-01-01 00:00:00.123', 'Timestamp(Millisecond, None)') + ) +---- +Timestamp(ms) 2024-01-01T00:00:00.123 + # verify timestamp cast with integer input query PPPPPP SELECT to_timestamp(null), to_timestamp(0), to_timestamp(1926632005), to_timestamp(1), to_timestamp(-1), to_timestamp(0-1) @@ -3959,17 +4004,17 @@ true query ? select arrow_cast('2024-06-17T11:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("UTC"))'); ---- -0 days -1 hours 0 mins 0.000000 secs +0 days -1 hours 0 mins 0.000000000 secs query ? select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("+00:00"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("UTC"))'); ---- -0 days 1 hours 0 mins 0.000000 secs +0 days 1 hours 0 mins 0.000000000 secs query ? select arrow_cast('2024-06-17T13:00:00', 'Timestamp(Nanosecond, Some("UTC"))') - arrow_cast('2024-06-17T12:00:00', 'Timestamp(Microsecond, Some("+00:00"))'); ---- -0 days 1 hours 0 mins 0.000000 secs +0 days 1 hours 0 mins 0.000000000 secs # not supported: coercion across timezones query error @@ -5331,7 +5376,7 @@ SELECT to_timestamp(arrow_cast(-9223372036, 'Int64')); 1677-09-21T00:12:44 # Overflow error when value exceeds valid range -query error Arithmetic overflow +query error converted value exceeds the representable i64 range SELECT to_timestamp(arrow_cast(9223372037, 'Int64')); # Float truncation behavior