Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions native-engine/datafusion-ext-functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ pub fn create_auron_ext_function(
"Spark_Hour" => Arc::new(spark_dates::spark_hour),
"Spark_Minute" => Arc::new(spark_dates::spark_minute),
"Spark_Second" => Arc::new(spark_dates::spark_second),
"Spark_MonthsBetween" => Arc::new(spark_dates::spark_months_between),
"Spark_BrickhouseArrayUnion" => Arc::new(brickhouse::array_union::array_union),
"Spark_Round" => Arc::new(spark_round::spark_round),
"Spark_BRound" => Arc::new(spark_bround::spark_bround),
Expand Down
318 changes: 310 additions & 8 deletions native-engine/datafusion-ext-functions/src/spark_dates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
use std::sync::Arc;

use arrow::{
array::{ArrayRef, Date32Array, Int32Array, TimestampMillisecondArray},
array::{
ArrayRef, BooleanArray, Date32Array, Float64Array, Int32Array, TimestampMillisecondArray,
},
compute::{DatePart, date_part},
datatypes::{DataType, TimeUnit},
};
use chrono::{Duration, TimeZone, Utc, prelude::*};
use chrono::{Duration, LocalResult, NaiveDate, TimeZone, Utc, prelude::*};
use chrono_tz::Tz;
use datafusion::{
common::{DataFusionError, Result, ScalarValue},
Expand Down Expand Up @@ -179,15 +181,114 @@ pub fn spark_quarter(args: &[ColumnarValue]) -> Result<ColumnarValue> {

/// Parse optional timezone (2nd argument) into `Option<Tz>`.
fn parse_tz(args: &[ColumnarValue]) -> Option<Tz> {
if args.len() < 2 {
return None;
}
match &args[1] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) => s.parse::<Tz>().ok(),
parse_tz_value(args.get(1))
}

fn parse_tz_value(arg: Option<&ColumnarValue>) -> Option<Tz> {
match arg {
Some(ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))) => s.parse::<Tz>().ok(),
_ => None,
}
}

fn local_datetime(epoch_ms: i64, tz_opt: Option<Tz>) -> Option<NaiveDateTime> {
let dt_utc = Utc.timestamp_millis_opt(epoch_ms).single()?;
Some(match tz_opt {
Some(tz) => dt_utc.with_timezone(&tz).naive_local(),
None => dt_utc.naive_utc(),
})
}

fn start_of_local_day_ms(local_date: NaiveDate, tz_opt: Option<Tz>) -> Option<i64> {
let local_midnight = local_date.and_hms_opt(0, 0, 0)?;

match tz_opt {
Some(tz) => match tz.from_local_datetime(&local_midnight) {
LocalResult::Single(dt) => Some(dt.with_timezone(&Utc).timestamp_millis()),
LocalResult::Ambiguous(dt1, dt2) => {
Some(dt1.min(dt2).with_timezone(&Utc).timestamp_millis())
}
LocalResult::None => {
// Align with Java's LocalDate.atStartOfDay(zone): choose the first valid
// local time on that date if midnight itself falls in a gap.
for minute in 1..=(24 * 60) {
let candidate = local_midnight + chrono::Duration::minutes(minute);
match tz.from_local_datetime(&candidate) {
LocalResult::Single(dt) => {
return Some(dt.with_timezone(&Utc).timestamp_millis());
}
LocalResult::Ambiguous(dt1, dt2) => {
return Some(dt1.min(dt2).with_timezone(&Utc).timestamp_millis());
}
LocalResult::None => continue,
}
}
None
}
},
None => Some(local_midnight.and_utc().timestamp_millis()),
}
}

fn days_in_month(year: i32, month: u32) -> u32 {
match month {
1 | 3 | 5 | 7 | 8 | 10 | 12 => 31,
4 | 6 | 9 | 11 => 30,
2 => {
let leap_year = (year % 4 == 0 && year % 100 != 0) || year % 400 == 0;
if leap_year { 29 } else { 28 }
}
_ => unreachable!("month must be in 1..=12"),
}
}

fn round_to_8_digits(value: f64) -> f64 {
const SCALE: f64 = 1.0e8;
((value * SCALE) + 0.5).floor() / SCALE
}

fn months_between_value(
timestamp1_ms: i64,
timestamp2_ms: i64,
round_off: bool,
tz_opt: Option<Tz>,
) -> Option<f64> {
const SECONDS_PER_DAY: i64 = 86_400;
const SECONDS_PER_MONTH: i64 = 31 * SECONDS_PER_DAY;

let local_dt1 = local_datetime(timestamp1_ms, tz_opt)?;
let local_dt2 = local_datetime(timestamp2_ms, tz_opt)?;
let date1 = local_dt1.date();
let date2 = local_dt2.date();

let months1 = date1.year() * 12 + date1.month() as i32;
let months2 = date2.year() * 12 + date2.month() as i32;
let month_diff = (months1 - months2) as f64;

let day1 = date1.day();
let day2 = date2.day();
let days_to_month_end1 = days_in_month(date1.year(), date1.month()) - day1;
let days_to_month_end2 = days_in_month(date2.year(), date2.month()) - day2;

if day1 == day2 || (days_to_month_end1 == 0 && days_to_month_end2 == 0) {
return Some(month_diff);
}

let start_of_day1_ms = start_of_local_day_ms(date1, tz_opt)?;
let start_of_day2_ms = start_of_local_day_ms(date2, tz_opt)?;
let seconds_in_day1 = (timestamp1_ms - start_of_day1_ms) / 1000;
let seconds_in_day2 = (timestamp2_ms - start_of_day2_ms) / 1000;
let seconds_diff =
(day1 as i64 - day2 as i64) * SECONDS_PER_DAY + seconds_in_day1 - seconds_in_day2;

let result = month_diff + seconds_diff as f64 / SECONDS_PER_MONTH as f64;
Some(if round_off {
round_to_8_digits(result)
} else {
result
})
}

/// Return the UTC offset in **seconds** for `epoch_ms` at the given `tz`
/// (DST-aware).
fn offset_seconds_at(tz: Tz, epoch_ms: i64) -> i32 {
Expand Down Expand Up @@ -299,11 +400,75 @@ pub fn spark_second(args: &[ColumnarValue]) -> Result<ColumnarValue> {
))))
}

/// Compute Spark-compatible `months_between(timestamp1, timestamp2, roundOff)`.
///
/// The first two arguments are timestamps in physical UTC milliseconds and the
/// fourth argument is an optional session timezone string used for local date
/// boundary calculations.
pub fn spark_months_between(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 4 {
return Err(DataFusionError::Execution(
"spark_months_between() requires four arguments".to_string(),
));
}

let num_rows = args
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(arr) => Some(arr.len()),
ColumnarValue::Scalar(_) => None,
})
.unwrap_or(1);

let timestamp1 = cast(
&args[0].clone().into_array(num_rows)?,
&DataType::Timestamp(TimeUnit::Millisecond, None),
)?;
let timestamp2 = cast(
&args[1].clone().into_array(num_rows)?,
&DataType::Timestamp(TimeUnit::Millisecond, None),
)?;
let round_off = cast(&args[2].clone().into_array(num_rows)?, &DataType::Boolean)?;

let timestamp1 = timestamp1
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.expect("internal cast to Timestamp(Millisecond, None) must succeed");
let timestamp2 = timestamp2
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.expect("internal cast to Timestamp(Millisecond, None) must succeed");
let round_off = round_off
.as_any()
.downcast_ref::<BooleanArray>()
.expect("internal cast to Boolean must succeed");
let tz = parse_tz_value(args.get(3));

let result = Float64Array::from_iter(
timestamp1
.iter()
.zip(timestamp2.iter())
.zip(round_off.iter())
.map(|((timestamp1_ms, timestamp2_ms), round_off)| {
match (timestamp1_ms, timestamp2_ms, round_off) {
(Some(timestamp1_ms), Some(timestamp2_ms), Some(round_off)) => {
months_between_value(timestamp1_ms, timestamp2_ms, round_off, tz)
}
_ => None,
}
}),
);

Ok(ColumnarValue::Array(Arc::new(result)))
}

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow::array::{ArrayRef, Date32Array, Int32Array, TimestampMillisecondArray};
use arrow::array::{
ArrayRef, Date32Array, Float64Array, Int32Array, TimestampMillisecondArray,
};

use super::*;

Expand Down Expand Up @@ -775,4 +940,141 @@ mod tests {

Ok(())
}

fn months_between_args(
timestamp1_ms: Option<i64>,
timestamp2_ms: Option<i64>,
round_off: Option<bool>,
timezone: Option<&str>,
) -> [ColumnarValue; 4] {
[
ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(timestamp1_ms, None)),
ColumnarValue::Scalar(ScalarValue::TimestampMillisecond(timestamp2_ms, None)),
ColumnarValue::Scalar(ScalarValue::Boolean(round_off)),
ColumnarValue::Scalar(ScalarValue::Utf8(timezone.map(str::to_string))),
]
}

#[test]
fn test_spark_months_between_ignores_time_for_same_day() -> Result<()> {
let out = spark_months_between(&months_between_args(
Some(utc_ms(2024, 3, 15, 23, 59, 59)),
Some(utc_ms(2024, 1, 15, 0, 0, 0)),
Some(true),
Some("UTC"),
))?
Comment thread
weimingdiit marked this conversation as resolved.
.into_array(1)?;

let expected: ArrayRef = Arc::new(Float64Array::from(vec![Some(2.0)]));
assert_eq!(&out, &expected);
Ok(())
}

#[test]
fn test_spark_months_between_ignores_time_for_last_day_of_month() -> Result<()> {
let out = spark_months_between(&months_between_args(
Some(utc_ms(2024, 2, 29, 12, 0, 0)),
Some(utc_ms(2024, 1, 31, 0, 0, 0)),
Some(true),
Some("UTC"),
))?
.into_array(1)?;

let expected: ArrayRef = Arc::new(Float64Array::from(vec![Some(1.0)]));
assert_eq!(&out, &expected);
Ok(())
}

#[test]
fn test_spark_months_between_fractional_rounding() -> Result<()> {
let rounded = spark_months_between(&months_between_args(
Some(utc_ms(2024, 3, 2, 12, 0, 0)),
Some(utc_ms(2024, 1, 1, 0, 0, 0)),
Some(true),
Some("UTC"),
))?
.into_array(1)?;
let expected_rounded: ArrayRef = Arc::new(Float64Array::from(vec![Some(2.0483871)]));
assert_eq!(&rounded, &expected_rounded);

let unrounded = spark_months_between(&months_between_args(
Some(utc_ms(2024, 3, 2, 12, 0, 0)),
Some(utc_ms(2024, 1, 1, 0, 0, 0)),
Some(false),
Some("UTC"),
))?
.into_array(1)?;
let unrounded = unrounded
.as_any()
.downcast_ref::<Float64Array>()
.expect("months_between should return Float64Array");
assert!((unrounded.value(0) - 2.0483870967741935).abs() < 1e-12);
Ok(())
}

#[test]
fn test_spark_months_between_respects_dst_gap_in_session_timezone() -> Result<()> {
let out = spark_months_between(&months_between_args(
// 2024-03-10 07:30:00 UTC -> 2024-03-10 03:30:00 local in America/New_York.
// The 02:00-02:59 hour does not exist on this day due to spring-forward DST.
Some(utc_ms(2024, 3, 10, 7, 30, 0)),
// 2024-02-09 06:30:00 UTC -> 2024-02-09 01:30:00 local.
Some(utc_ms(2024, 2, 9, 6, 30, 0)),
Some(false),
Some("America/New_York"),
))?
.into_array(1)?;

let out = out
.as_any()
.downcast_ref::<Float64Array>()
.expect("months_between should return Float64Array");
assert!((out.value(0) - 1.0336021505376345).abs() < 1e-12);
Ok(())
}

#[test]
fn test_spark_months_between_negative_when_timestamp1_is_earlier() -> Result<()> {
let out = spark_months_between(&months_between_args(
Some(utc_ms(2024, 1, 15, 0, 0, 0)),
Some(utc_ms(2024, 3, 15, 23, 59, 59)),
Some(true),
Some("UTC"),
))?
.into_array(1)?;

let expected: ArrayRef = Arc::new(Float64Array::from(vec![Some(-2.0)]));
assert_eq!(&out, &expected);
Ok(())
}

#[test]
fn test_spark_months_between_fractional_rounding_keeps_negative_sign() -> Result<()> {
let out = spark_months_between(&months_between_args(
Some(utc_ms(2024, 1, 1, 0, 0, 0)),
Some(utc_ms(2024, 3, 2, 12, 0, 0)),
Some(true),
Some("UTC"),
))?
.into_array(1)?;

let expected: ArrayRef = Arc::new(Float64Array::from(vec![Some(-2.0483871)]));
assert_eq!(&out, &expected);
Ok(())
}

#[test]
fn test_spark_months_between_null_propagation() -> Result<()> {
let out = spark_months_between(&months_between_args(
None,
Some(utc_ms(2024, 1, 1, 0, 0, 0)),
Some(true),
Some("UTC"),
))?
.into_array(1)?;

let expected: ArrayRef = Arc::new(Float64Array::from(vec![None]));
assert_eq!(&out, &expected);
Ok(())
}
}
Loading
Loading