diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index 33926436e..9555112d9 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -82,6 +82,7 @@ pub fn create_auron_ext_function( "Spark_Hour" => Arc::new(spark_dates::spark_hour), "Spark_Minute" => Arc::new(spark_dates::spark_minute), "Spark_Second" => Arc::new(spark_dates::spark_second), + "Spark_MonthsBetween" => Arc::new(spark_dates::spark_months_between), "Spark_BrickhouseArrayUnion" => Arc::new(brickhouse::array_union::array_union), "Spark_Round" => Arc::new(spark_round::spark_round), "Spark_BRound" => Arc::new(spark_bround::spark_bround), diff --git a/native-engine/datafusion-ext-functions/src/spark_dates.rs b/native-engine/datafusion-ext-functions/src/spark_dates.rs index 70d1a0a22..6693d8d8e 100644 --- a/native-engine/datafusion-ext-functions/src/spark_dates.rs +++ b/native-engine/datafusion-ext-functions/src/spark_dates.rs @@ -16,11 +16,13 @@ use std::sync::Arc; use arrow::{ - array::{ArrayRef, Date32Array, Int32Array, TimestampMillisecondArray}, + array::{ + ArrayRef, BooleanArray, Date32Array, Float64Array, Int32Array, TimestampMillisecondArray, + }, compute::{DatePart, date_part}, datatypes::{DataType, TimeUnit}, }; -use chrono::{Duration, TimeZone, Utc, prelude::*}; +use chrono::{Duration, LocalResult, NaiveDate, TimeZone, Utc, prelude::*}; use chrono_tz::Tz; use datafusion::{ common::{DataFusionError, Result, ScalarValue}, @@ -179,15 +181,114 @@ pub fn spark_quarter(args: &[ColumnarValue]) -> Result { /// Parse optional timezone (2nd argument) into `Option`. fn parse_tz(args: &[ColumnarValue]) -> Option { - if args.len() < 2 { - return None; - } - match &args[1] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.parse::().ok(), + parse_tz_value(args.get(1)) +} + +fn parse_tz_value(arg: Option<&ColumnarValue>) -> Option { + match arg { + Some(ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))) => s.parse::().ok(), _ => None, } } +fn local_datetime(epoch_ms: i64, tz_opt: Option) -> Option { + let dt_utc = Utc.timestamp_millis_opt(epoch_ms).single()?; + Some(match tz_opt { + Some(tz) => dt_utc.with_timezone(&tz).naive_local(), + None => dt_utc.naive_utc(), + }) +} + +fn start_of_local_day_ms(local_date: NaiveDate, tz_opt: Option) -> Option { + let local_midnight = local_date.and_hms_opt(0, 0, 0)?; + + match tz_opt { + Some(tz) => match tz.from_local_datetime(&local_midnight) { + LocalResult::Single(dt) => Some(dt.with_timezone(&Utc).timestamp_millis()), + LocalResult::Ambiguous(dt1, dt2) => { + Some(dt1.min(dt2).with_timezone(&Utc).timestamp_millis()) + } + LocalResult::None => { + // Align with Java's LocalDate.atStartOfDay(zone): choose the first valid + // local time on that date if midnight itself falls in a gap. + for minute in 1..=(24 * 60) { + let candidate = local_midnight + chrono::Duration::minutes(minute); + match tz.from_local_datetime(&candidate) { + LocalResult::Single(dt) => { + return Some(dt.with_timezone(&Utc).timestamp_millis()); + } + LocalResult::Ambiguous(dt1, dt2) => { + return Some(dt1.min(dt2).with_timezone(&Utc).timestamp_millis()); + } + LocalResult::None => continue, + } + } + None + } + }, + None => Some(local_midnight.and_utc().timestamp_millis()), + } +} + +fn days_in_month(year: i32, month: u32) -> u32 { + match month { + 1 | 3 | 5 | 7 | 8 | 10 | 12 => 31, + 4 | 6 | 9 | 11 => 30, + 2 => { + let leap_year = (year % 4 == 0 && year % 100 != 0) || year % 400 == 0; + if leap_year { 29 } else { 28 } + } + _ => unreachable!("month must be in 1..=12"), + } +} + +fn round_to_8_digits(value: f64) -> f64 { + const SCALE: f64 = 1.0e8; + ((value * SCALE) + 0.5).floor() / SCALE +} + +fn months_between_value( + timestamp1_ms: i64, + timestamp2_ms: i64, + round_off: bool, + tz_opt: Option, +) -> Option { + const SECONDS_PER_DAY: i64 = 86_400; + const SECONDS_PER_MONTH: i64 = 31 * SECONDS_PER_DAY; + + let local_dt1 = local_datetime(timestamp1_ms, tz_opt)?; + let local_dt2 = local_datetime(timestamp2_ms, tz_opt)?; + let date1 = local_dt1.date(); + let date2 = local_dt2.date(); + + let months1 = date1.year() * 12 + date1.month() as i32; + let months2 = date2.year() * 12 + date2.month() as i32; + let month_diff = (months1 - months2) as f64; + + let day1 = date1.day(); + let day2 = date2.day(); + let days_to_month_end1 = days_in_month(date1.year(), date1.month()) - day1; + let days_to_month_end2 = days_in_month(date2.year(), date2.month()) - day2; + + if day1 == day2 || (days_to_month_end1 == 0 && days_to_month_end2 == 0) { + return Some(month_diff); + } + + let start_of_day1_ms = start_of_local_day_ms(date1, tz_opt)?; + let start_of_day2_ms = start_of_local_day_ms(date2, tz_opt)?; + let seconds_in_day1 = (timestamp1_ms - start_of_day1_ms) / 1000; + let seconds_in_day2 = (timestamp2_ms - start_of_day2_ms) / 1000; + let seconds_diff = + (day1 as i64 - day2 as i64) * SECONDS_PER_DAY + seconds_in_day1 - seconds_in_day2; + + let result = month_diff + seconds_diff as f64 / SECONDS_PER_MONTH as f64; + Some(if round_off { + round_to_8_digits(result) + } else { + result + }) +} + /// Return the UTC offset in **seconds** for `epoch_ms` at the given `tz` /// (DST-aware). fn offset_seconds_at(tz: Tz, epoch_ms: i64) -> i32 { @@ -299,11 +400,75 @@ pub fn spark_second(args: &[ColumnarValue]) -> Result { )))) } +/// Compute Spark-compatible `months_between(timestamp1, timestamp2, roundOff)`. +/// +/// The first two arguments are timestamps in physical UTC milliseconds and the +/// fourth argument is an optional session timezone string used for local date +/// boundary calculations. +pub fn spark_months_between(args: &[ColumnarValue]) -> Result { + if args.len() != 4 { + return Err(DataFusionError::Execution( + "spark_months_between() requires four arguments".to_string(), + )); + } + + let num_rows = args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(arr) => Some(arr.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + + let timestamp1 = cast( + &args[0].clone().into_array(num_rows)?, + &DataType::Timestamp(TimeUnit::Millisecond, None), + )?; + let timestamp2 = cast( + &args[1].clone().into_array(num_rows)?, + &DataType::Timestamp(TimeUnit::Millisecond, None), + )?; + let round_off = cast(&args[2].clone().into_array(num_rows)?, &DataType::Boolean)?; + + let timestamp1 = timestamp1 + .as_any() + .downcast_ref::() + .expect("internal cast to Timestamp(Millisecond, None) must succeed"); + let timestamp2 = timestamp2 + .as_any() + .downcast_ref::() + .expect("internal cast to Timestamp(Millisecond, None) must succeed"); + let round_off = round_off + .as_any() + .downcast_ref::() + .expect("internal cast to Boolean must succeed"); + let tz = parse_tz_value(args.get(3)); + + let result = Float64Array::from_iter( + timestamp1 + .iter() + .zip(timestamp2.iter()) + .zip(round_off.iter()) + .map(|((timestamp1_ms, timestamp2_ms), round_off)| { + match (timestamp1_ms, timestamp2_ms, round_off) { + (Some(timestamp1_ms), Some(timestamp2_ms), Some(round_off)) => { + months_between_value(timestamp1_ms, timestamp2_ms, round_off, tz) + } + _ => None, + } + }), + ); + + Ok(ColumnarValue::Array(Arc::new(result))) +} + #[cfg(test)] mod tests { use std::sync::Arc; - use arrow::array::{ArrayRef, Date32Array, Int32Array, TimestampMillisecondArray}; + use arrow::array::{ + ArrayRef, Date32Array, Float64Array, Int32Array, TimestampMillisecondArray, + }; use super::*; @@ -775,4 +940,141 @@ mod tests { Ok(()) } + + fn months_between_args( + timestamp1_ms: Option, + timestamp2_ms: Option, + round_off: Option, + timezone: Option<&str>, + ) -> [ColumnarValue; 4] { + [ + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(timestamp1_ms, None)), + ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(timestamp2_ms, None)), + ColumnarValue::Scalar(ScalarValue::Boolean(round_off)), + ColumnarValue::Scalar(ScalarValue::Utf8(timezone.map(str::to_string))), + ] + } + + #[test] + fn test_spark_months_between_ignores_time_for_same_day() -> Result<()> { + let out = spark_months_between(&months_between_args( + Some(utc_ms(2024, 3, 15, 23, 59, 59)), + Some(utc_ms(2024, 1, 15, 0, 0, 0)), + Some(true), + Some("UTC"), + ))? + .into_array(1)?; + + let expected: ArrayRef = Arc::new(Float64Array::from(vec![Some(2.0)])); + assert_eq!(&out, &expected); + Ok(()) + } + + #[test] + fn test_spark_months_between_ignores_time_for_last_day_of_month() -> Result<()> { + let out = spark_months_between(&months_between_args( + Some(utc_ms(2024, 2, 29, 12, 0, 0)), + Some(utc_ms(2024, 1, 31, 0, 0, 0)), + Some(true), + Some("UTC"), + ))? + .into_array(1)?; + + let expected: ArrayRef = Arc::new(Float64Array::from(vec![Some(1.0)])); + assert_eq!(&out, &expected); + Ok(()) + } + + #[test] + fn test_spark_months_between_fractional_rounding() -> Result<()> { + let rounded = spark_months_between(&months_between_args( + Some(utc_ms(2024, 3, 2, 12, 0, 0)), + Some(utc_ms(2024, 1, 1, 0, 0, 0)), + Some(true), + Some("UTC"), + ))? + .into_array(1)?; + let expected_rounded: ArrayRef = Arc::new(Float64Array::from(vec![Some(2.0483871)])); + assert_eq!(&rounded, &expected_rounded); + + let unrounded = spark_months_between(&months_between_args( + Some(utc_ms(2024, 3, 2, 12, 0, 0)), + Some(utc_ms(2024, 1, 1, 0, 0, 0)), + Some(false), + Some("UTC"), + ))? + .into_array(1)?; + let unrounded = unrounded + .as_any() + .downcast_ref::() + .expect("months_between should return Float64Array"); + assert!((unrounded.value(0) - 2.0483870967741935).abs() < 1e-12); + Ok(()) + } + + #[test] + fn test_spark_months_between_respects_dst_gap_in_session_timezone() -> Result<()> { + let out = spark_months_between(&months_between_args( + // 2024-03-10 07:30:00 UTC -> 2024-03-10 03:30:00 local in America/New_York. + // The 02:00-02:59 hour does not exist on this day due to spring-forward DST. + Some(utc_ms(2024, 3, 10, 7, 30, 0)), + // 2024-02-09 06:30:00 UTC -> 2024-02-09 01:30:00 local. + Some(utc_ms(2024, 2, 9, 6, 30, 0)), + Some(false), + Some("America/New_York"), + ))? + .into_array(1)?; + + let out = out + .as_any() + .downcast_ref::() + .expect("months_between should return Float64Array"); + assert!((out.value(0) - 1.0336021505376345).abs() < 1e-12); + Ok(()) + } + + #[test] + fn test_spark_months_between_negative_when_timestamp1_is_earlier() -> Result<()> { + let out = spark_months_between(&months_between_args( + Some(utc_ms(2024, 1, 15, 0, 0, 0)), + Some(utc_ms(2024, 3, 15, 23, 59, 59)), + Some(true), + Some("UTC"), + ))? + .into_array(1)?; + + let expected: ArrayRef = Arc::new(Float64Array::from(vec![Some(-2.0)])); + assert_eq!(&out, &expected); + Ok(()) + } + + #[test] + fn test_spark_months_between_fractional_rounding_keeps_negative_sign() -> Result<()> { + let out = spark_months_between(&months_between_args( + Some(utc_ms(2024, 1, 1, 0, 0, 0)), + Some(utc_ms(2024, 3, 2, 12, 0, 0)), + Some(true), + Some("UTC"), + ))? + .into_array(1)?; + + let expected: ArrayRef = Arc::new(Float64Array::from(vec![Some(-2.0483871)])); + assert_eq!(&out, &expected); + Ok(()) + } + + #[test] + fn test_spark_months_between_null_propagation() -> Result<()> { + let out = spark_months_between(&months_between_args( + None, + Some(utc_ms(2024, 1, 1, 0, 0, 0)), + Some(true), + Some("UTC"), + ))? + .into_array(1)?; + + let expected: ArrayRef = Arc::new(Float64Array::from(vec![None])); + assert_eq!(&out, &expected); + Ok(()) + } } diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala index 0df644787..df0708815 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronFunctionSuite.scala @@ -728,4 +728,53 @@ class AuronFunctionSuite extends AuronQueryTest with BaseAuronSQLSuite { } } } + + test("months_between function") { + withSQLConf("spark.sql.session.timeZone" -> "UTC") { + withTable("t1") { + sql(""" + |create table t1( + | same_day_of_month_later_ts timestamp, + | same_day_of_month_earlier_ts timestamp, + | last_day_of_month_later_dt date, + | last_day_of_month_earlier_dt date, + | fractional_later_ts timestamp, + | fractional_earlier_ts timestamp, + | null_ts timestamp + |) using parquet + |""".stripMargin) + sql(""" + |insert into t1 values ( + | timestamp'2024-03-15 23:59:59', + | timestamp'2024-01-15 00:00:00', + | date'2024-02-29', + | date'2024-01-31', + | timestamp'2024-03-02 12:00:00', + | timestamp'2024-01-01 00:00:00', + | null + |) + |""".stripMargin) + val query = + """ + |select + | months_between(same_day_of_month_later_ts, same_day_of_month_earlier_ts) + | as same_day_of_month_ignores_time_positive, + | months_between(last_day_of_month_later_dt, last_day_of_month_earlier_dt) + | as last_day_of_month_ignores_time, + | months_between(fractional_later_ts, fractional_earlier_ts) + | as rounded_fractional_positive, + | months_between(same_day_of_month_earlier_ts, same_day_of_month_later_ts) + | as same_day_of_month_ignores_time_negative, + | months_between(fractional_earlier_ts, fractional_later_ts) + | as rounded_fractional_negative, + | months_between(fractional_later_ts, fractional_earlier_ts, false) + | as unrounded_fractional_positive, + | months_between(null_ts, fractional_earlier_ts) + | as null_propagation + |from t1 + |""".stripMargin + checkSparkAnswerAndOperator(query) + } + } + } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 89839da1e..7db4374bd 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -955,6 +955,8 @@ object NativeConverters extends Logging { buildTimePartExt("Spark_Minute", e.children.head, isPruningExpr, fallback) case e: Second if datetimeExtractEnabled => buildTimePartExt("Spark_Second", e.children.head, isPruningExpr, fallback) + case e: MonthsBetween => + buildMonthsBetweenExt("Spark_MonthsBetween", e, isPruningExpr, fallback) // startswith is converted to scalar function in pruning-expr mode case StartsWith(expr, Literal(prefix, StringType)) if isPruningExpr => @@ -1391,6 +1393,24 @@ object NativeConverters extends Logging { buildExtScalarFunctionNode(name, Seq(child, tzArg), IntegerType, isPruningExpr, fallback) } + private def buildMonthsBetweenExt( + name: String, + expr: MonthsBetween, + isPruningExpr: Boolean, + fallback: Expression => pb.PhysicalExprNode): pb.PhysicalExprNode = { + val tzArg: Expression = if (Seq(expr.date1, expr.date2).exists(_.dataType == TimestampType)) { + Literal.create(SQLConf.get.sessionLocalTimeZone, StringType) + } else { + Literal.create(null, StringType) + } + buildExtScalarFunctionNode( + name, + Seq(expr.date1, expr.date2, expr.roundOff, tzArg), + DoubleType, + isPruningExpr, + fallback) + } + def castIfNecessary(expr: Expression, dataType: DataType): Expression = { if (expr.dataType == dataType) { return expr