-
Notifications
You must be signed in to change notification settings - Fork 29.2k
[SPARK-56594][SQL] Add time_bucket scalar function #55535
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -13125,6 +13125,74 @@ def timestamp_add(unit: str, quantity: "ColumnOrName", ts: "ColumnOrName") -> Co | |||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @_try_remote_functions | ||||||||||||||||||||||||
| def time_bucket( | ||||||||||||||||||||||||
| bucket_size: "ColumnOrName", | ||||||||||||||||||||||||
| ts: "ColumnOrName", | ||||||||||||||||||||||||
| origin: Optional["ColumnOrName"] = None, | ||||||||||||||||||||||||
| ) -> Column: | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| Aligns a timestamp to the start of a fixed-size interval bucket. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Returns the start of the bucket that ``ts`` falls into, where buckets are defined by | ||||||||||||||||||||||||
| the given ``bucket_size`` interval aligned to ``origin``. All bucketing is performed on | ||||||||||||||||||||||||
| UTC micros, the session time zone does not affect bucket alignment. For local wall-clock | ||||||||||||||||||||||||
| alignment in a DST zone, cast the TIMESTAMP to TIMESTAMP_NTZ. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| .. versionadded:: 4.2.0 | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||||
| bucket_size : :class:`~pyspark.sql.Column` or column name | ||||||||||||||||||||||||
| A day-time or year-month interval defining the bucket size. Must be positive | ||||||||||||||||||||||||
| and foldable. | ||||||||||||||||||||||||
| ts : :class:`~pyspark.sql.Column` or column name | ||||||||||||||||||||||||
| A TIMESTAMP or TIMESTAMP_NTZ value to bucket. | ||||||||||||||||||||||||
| origin : :class:`~pyspark.sql.Column` or column name, optional | ||||||||||||||||||||||||
| Alignment anchor. Defaults to 1970-01-01 00:00:00 (UTC for TIMESTAMP). Must be | ||||||||||||||||||||||||
| the same type as ``ts`` and must be foldable. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Returns | ||||||||||||||||||||||||
| ------- | ||||||||||||||||||||||||
| :class:`~pyspark.sql.Column` | ||||||||||||||||||||||||
| The start of the bucket containing ``ts``, as the same type as ``ts``. | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Examples | ||||||||||||||||||||||||
| -------- | ||||||||||||||||||||||||
| >>> spark.conf.set("spark.sql.session.timeZone", "UTC") | ||||||||||||||||||||||||
| >>> import datetime | ||||||||||||||||||||||||
| >>> from pyspark.sql import functions as sf | ||||||||||||||||||||||||
| >>> df = spark.createDataFrame( | ||||||||||||||||||||||||
| ... [(datetime.datetime(2024, 1, 1, 11, 27, 0),)], ['ts']) | ||||||||||||||||||||||||
| >>> df.select( | ||||||||||||||||||||||||
| ... sf.time_bucket(sf.expr("INTERVAL '15' MINUTE"), 'ts').alias("bucket") | ||||||||||||||||||||||||
| ... ).collect() | ||||||||||||||||||||||||
| [Row(bucket=datetime.datetime(2024, 1, 1, 11, 15))] | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Shift the grid with an explicit origin: buckets run at :05, :20, :35, :50: | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| >>> df.select( | ||||||||||||||||||||||||
| ... sf.time_bucket( | ||||||||||||||||||||||||
| ... sf.expr("INTERVAL '15' MINUTE"), | ||||||||||||||||||||||||
| ... 'ts', | ||||||||||||||||||||||||
| ... sf.expr("TIMESTAMP '1970-01-01 00:05:00'") | ||||||||||||||||||||||||
| ... ).alias("bucket") | ||||||||||||||||||||||||
| ... ).collect() | ||||||||||||||||||||||||
| [Row(bucket=datetime.datetime(2024, 1, 1, 11, 20))] | ||||||||||||||||||||||||
| >>> spark.conf.unset("spark.sql.session.timeZone") | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
| from pyspark.sql.classic.column import _to_java_column | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if origin is None: | ||||||||||||||||||||||||
| return _invoke_function("time_bucket", _to_java_column(bucket_size), _to_java_column(ts)) | ||||||||||||||||||||||||
| return _invoke_function( | ||||||||||||||||||||||||
| "time_bucket", | ||||||||||||||||||||||||
| _to_java_column(bucket_size), | ||||||||||||||||||||||||
| _to_java_column(ts), | ||||||||||||||||||||||||
| _to_java_column(origin), | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
Comment on lines
+13186
to
+13193
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Recent additions in this file (
Suggested change
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| @_try_remote_functions | ||||||||||||||||||||||||
| def window( | ||||||||||||||||||||||||
| timeColumn: "ColumnOrName", | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -24,9 +24,11 @@ import java.util.Locale | |||||||||
|
|
||||||||||
| import org.apache.commons.text.StringEscapeUtils | ||||||||||
|
|
||||||||||
| import org.apache.spark.{SparkDateTimeException, SparkIllegalArgumentException} | ||||||||||
| import org.apache.spark.{SparkDateTimeException, SparkException, SparkIllegalArgumentException} | ||||||||||
| import org.apache.spark.sql.catalyst.InternalRow | ||||||||||
| import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry} | ||||||||||
| import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} | ||||||||||
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} | ||||||||||
| import org.apache.spark.sql.catalyst.expressions.Cast.{ordinalNumber, toSQLExpr, toSQLId, toSQLType, toSQLValue} | ||||||||||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||||||||||
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||||||||||
| import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke | ||||||||||
|
|
@@ -3897,3 +3899,178 @@ case class TimestampDiff( | |||||||||
| copy(startTimestamp = newLeft, endTimestamp = newRight) | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| /** | ||||||||||
| * Aligns a timestamp to the start of a fixed-size interval bucket. | ||||||||||
| * | ||||||||||
| * Returns the start of the half-open bucket [start, start + bucketSize) containing ts. | ||||||||||
| * All computation is performed on UTC values. | ||||||||||
| */ | ||||||||||
| case class TimeBucket( | ||||||||||
| bucketSize: Expression, | ||||||||||
| ts: Expression, | ||||||||||
| originTs: Expression) | ||||||||||
| extends TernaryExpression with ExpectsInputTypes { | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this extend SELECT date_trunc('MONTH', TIMESTAMP '2024-03-15 10:00:00');
-- 2024-03-01 00:00:00
SELECT time_bucket(INTERVAL '1' MONTH, TIMESTAMP '2024-03-15 10:00:00');
-- 2024-02-29 16:00:00 (March 1 UTC seen in PDT)Proposal: extend |
||||||||||
|
|
||||||||||
| override def nullIntolerant: Boolean = true | ||||||||||
|
|
||||||||||
| override def first: Expression = bucketSize | ||||||||||
| override def second: Expression = ts | ||||||||||
| override def third: Expression = originTs | ||||||||||
|
|
||||||||||
| override def inputTypes: Seq[AbstractDataType] = Seq( | ||||||||||
| TypeCollection(DayTimeIntervalType, YearMonthIntervalType), | ||||||||||
| AnyTimestampType, | ||||||||||
| AnyTimestampType) | ||||||||||
|
|
||||||||||
| override def dataType: DataType = ts.dataType | ||||||||||
|
|
||||||||||
| override def checkInputDataTypes(): TypeCheckResult = { | ||||||||||
| val defaultCheck = super.checkInputDataTypes() | ||||||||||
| if (defaultCheck.isFailure) return defaultCheck | ||||||||||
|
|
||||||||||
| if (!bucketSize.foldable) { | ||||||||||
| return DataTypeMismatch( | ||||||||||
| errorSubClass = "NON_FOLDABLE_INPUT", | ||||||||||
| messageParameters = Map( | ||||||||||
| "inputName" -> toSQLId("bucketSize"), | ||||||||||
| "inputType" -> toSQLType(bucketSize.dataType), | ||||||||||
| "inputExpr" -> toSQLExpr(bucketSize))) | ||||||||||
| } | ||||||||||
|
|
||||||||||
| val bucketSizeValue = bucketSize.eval() | ||||||||||
| if (bucketSizeValue != null) { | ||||||||||
| val isNonPositive = bucketSize.dataType match { | ||||||||||
| case _: DayTimeIntervalType => bucketSizeValue.asInstanceOf[Long] <= 0 | ||||||||||
| case _: YearMonthIntervalType => bucketSizeValue.asInstanceOf[Int] <= 0 | ||||||||||
| case other => throw SparkException.internalError( | ||||||||||
| s"Unexpected bucketSize type: $other") | ||||||||||
| } | ||||||||||
| if (isNonPositive) { | ||||||||||
| return DataTypeMismatch( | ||||||||||
| errorSubClass = "VALUE_OUT_OF_RANGE", | ||||||||||
| messageParameters = Map( | ||||||||||
| "exprName" -> "time_bucket", | ||||||||||
| "valueRange" -> "(0, inf)", | ||||||||||
|
Comment on lines
+3953
to
+3954
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two issues here:
Suggested change
|
||||||||||
| "currentValue" -> toSQLValue(bucketSizeValue, bucketSize.dataType))) | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| if (!originTs.foldable) { | ||||||||||
| return DataTypeMismatch( | ||||||||||
| errorSubClass = "NON_FOLDABLE_INPUT", | ||||||||||
| messageParameters = Map( | ||||||||||
| "inputName" -> toSQLId("origin"), | ||||||||||
| "inputType" -> toSQLType(originTs.dataType), | ||||||||||
| "inputExpr" -> toSQLExpr(originTs))) | ||||||||||
| } | ||||||||||
|
|
||||||||||
| if (ts.dataType != originTs.dataType) { | ||||||||||
| return DataTypeMismatch( | ||||||||||
| errorSubClass = "UNEXPECTED_INPUT_TYPE", | ||||||||||
| messageParameters = Map( | ||||||||||
| "paramIndex" -> ordinalNumber(2), | ||||||||||
| "requiredType" -> toSQLType(ts.dataType), | ||||||||||
| "inputSql" -> toSQLExpr(originTs), | ||||||||||
| "inputType" -> toSQLType(originTs.dataType))) | ||||||||||
| } | ||||||||||
|
|
||||||||||
| TypeCheckSuccess | ||||||||||
| } | ||||||||||
|
|
||||||||||
| override def nullSafeEval(bucketSizeVal: Any, tsVal: Any, originVal: Any): Any = { | ||||||||||
| first.dataType match { | ||||||||||
| case _: DayTimeIntervalType => | ||||||||||
| DateTimeUtils.timeBucketDTInterval( | ||||||||||
| bucketSizeVal.asInstanceOf[Long], tsVal.asInstanceOf[Long], | ||||||||||
| originVal.asInstanceOf[Long]) | ||||||||||
| case _: YearMonthIntervalType => | ||||||||||
| DateTimeUtils.timeBucketYMInterval( | ||||||||||
| bucketSizeVal.asInstanceOf[Int], tsVal.asInstanceOf[Long], | ||||||||||
| originVal.asInstanceOf[Long]) | ||||||||||
| case other => throw SparkException.internalError( | ||||||||||
| s"Unexpected bucketSize type: $other") | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||||||||||
| val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") | ||||||||||
| first.dataType match { | ||||||||||
| case _: DayTimeIntervalType => | ||||||||||
| defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) => | ||||||||||
| s"$dtu.timeBucketDTInterval($bucketSizeCode, $tsCode, $originCode)") | ||||||||||
| case _: YearMonthIntervalType => | ||||||||||
| defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) => | ||||||||||
| s"$dtu.timeBucketYMInterval($bucketSizeCode, $tsCode, $originCode)") | ||||||||||
| case other => throw SparkException.internalError( | ||||||||||
| s"Unexpected bucketSize type: $other") | ||||||||||
| } | ||||||||||
| } | ||||||||||
|
|
||||||||||
| override def prettyName: String = "time_bucket" | ||||||||||
|
|
||||||||||
| override protected def withNewChildrenInternal( | ||||||||||
| newFirst: Expression, newSecond: Expression, newThird: Expression): TimeBucket = | ||||||||||
| copy(bucketSize = newFirst, ts = newSecond, originTs = newThird) | ||||||||||
| } | ||||||||||
|
|
||||||||||
| // scalastyle:off line.size.limit | ||||||||||
| @ExpressionDescription( | ||||||||||
| usage = """ | ||||||||||
| _FUNC_(bucketSize, ts[, origin]) - Returns the start of the bucket that `ts` falls into, | ||||||||||
| where buckets are defined by the given `bucketSize` interval aligned to `origin`. All | ||||||||||
| bucketing is performed on UTC micros, the session time zone does not affect bucket | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Minor: comma splice — two independent clauses joined by a comma.
Suggested change
|
||||||||||
| alignment. For local wall-clock alignment in a DST zone, cast the TIMESTAMP to | ||||||||||
| TIMESTAMP_NTZ. | ||||||||||
| """, | ||||||||||
| arguments = """ | ||||||||||
| Arguments: | ||||||||||
| * bucketSize - A day-time or year-month interval defining the bucket size. Must be positive and foldable. | ||||||||||
| * ts - A TIMESTAMP or TIMESTAMP_NTZ value to bucket. | ||||||||||
| * origin - Optional TIMESTAMP or TIMESTAMP_NTZ alignment anchor. Defaults to 1970-01-01 00:00:00 (UTC for TIMESTAMP). Must be the same type as ts and must be foldable. | ||||||||||
| """, | ||||||||||
| examples = """ | ||||||||||
| Examples: | ||||||||||
| > SELECT _FUNC_(INTERVAL '15' MINUTE, TIMESTAMP '2024-01-01 11:27:00', TIMESTAMP '1970-01-01 00:00:00'); | ||||||||||
| 2024-01-01 11:15:00 | ||||||||||
| > SELECT _FUNC_(INTERVAL '1' HOUR, TIMESTAMP '2024-01-01 11:27:00'); | ||||||||||
| 2024-01-01 11:00:00 | ||||||||||
| > SELECT _FUNC_(INTERVAL '1' MONTH, TIMESTAMP '2024-07-20 14:30:00', TIMESTAMP '2024-06-15 09:00:00'); | ||||||||||
| 2024-07-15 09:00:00 | ||||||||||
| """, | ||||||||||
| since = "4.2.0", | ||||||||||
| group = "datetime_funcs") | ||||||||||
| // scalastyle:on line.size.limit | ||||||||||
| object TimeBucketExpressionBuilder extends ExpressionBuilder { | ||||||||||
| private def retypeNull(e: Expression, dt: DataType): Expression = e match { | ||||||||||
| case Literal(null, NullType) => Literal(null, dt) | ||||||||||
| case _ => e | ||||||||||
| } | ||||||||||
|
|
||||||||||
| override def build(funcName: String, expressions: Seq[Expression]): Expression = { | ||||||||||
| expressions match { | ||||||||||
| case Seq(rawBucketSize, rawTs) => | ||||||||||
| val bucketSize = retypeNull(rawBucketSize, DayTimeIntervalType()) | ||||||||||
| // Fall back to TimestampType for bad ts types; ExpectsInputTypes will report it. | ||||||||||
| val tsType = rawTs.dataType match { | ||||||||||
| case t if AnyTimestampType.acceptsType(t) => t | ||||||||||
| case _ => TimestampType | ||||||||||
| } | ||||||||||
| val ts = retypeNull(rawTs, tsType) | ||||||||||
| TimeBucket(bucketSize, ts, Literal(0L, tsType)) | ||||||||||
| case Seq(rawBucketSize, rawTs, rawOrigin) => | ||||||||||
| val bucketSize = retypeNull(rawBucketSize, DayTimeIntervalType()) | ||||||||||
| val tsType = (rawTs.dataType, rawOrigin.dataType) match { | ||||||||||
| case (NullType, t) if AnyTimestampType.acceptsType(t) => t | ||||||||||
| case (NullType, _) => TimestampType | ||||||||||
| case (t, _) => t | ||||||||||
| } | ||||||||||
| val ts = retypeNull(rawTs, tsType) | ||||||||||
| val originTs = retypeNull(rawOrigin, tsType) | ||||||||||
| TimeBucket(bucketSize, ts, originTs) | ||||||||||
| case _ => | ||||||||||
| throw QueryCompilationErrors.wrongNumArgsError( | ||||||||||
| funcName, Seq(2, 3), expressions.length) | ||||||||||
| } | ||||||||||
| } | ||||||||||
| } | ||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1059,4 +1059,59 @@ object DateTimeUtils extends SparkDateTimeUtils { | |
| time, timePrecision, interval, intervalEndField) | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * DayTimeInterval bucketing: microsecond floor division against `originMicros`. | ||
| * Returns `originMicros + floorDiv(tsMicros - originMicros, bucketMicros) * bucketMicros`. | ||
| * | ||
| * `bucketMicros` must be positive; `TimeBucket.checkInputDataTypes` enforces | ||
| * this at analysis time. | ||
| * | ||
| * @param bucketMicros bucket size in microseconds. | ||
| * @param tsMicros timestamp to bucket, in microseconds since the epoch (UTC). | ||
| * @param originMicros grid alignment anchor, in microseconds since the epoch (UTC). | ||
| */ | ||
| def timeBucketDTInterval(bucketMicros: Long, tsMicros: Long, originMicros: Long): Long = { | ||
| val diff = Math.subtractExact(tsMicros, originMicros) | ||
| val bucketOffset = Math.multiplyExact(Math.floorDiv(diff, bucketMicros), bucketMicros) | ||
| Math.addExact(originMicros, bucketOffset) | ||
|
Comment on lines
+1075
to
+1077
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Spark's convention for arithmetic with overflow checks is |
||
| } | ||
|
|
||
| /** | ||
| * YearMonthInterval bucketing: month arithmetic with end-of-month capping and step-back. | ||
| * The origin's day-of-month and time-of-day determine the bucket boundaries. | ||
| * | ||
| * `bucketMonths` must be positive; `TimeBucket.checkInputDataTypes` enforces | ||
| * this at analysis time. | ||
| * | ||
| * @param bucketMonths bucket size in months. | ||
| * @param tsMicros timestamp to bucket, in microseconds since the epoch (UTC). | ||
| * @param originMicros grid alignment anchor, in microseconds since the epoch (UTC). | ||
| */ | ||
| def timeBucketYMInterval(bucketMonths: Int, tsMicros: Long, originMicros: Long): Long = { | ||
| val tsDays = microsToDays(tsMicros, ZoneOffset.UTC) | ||
| val originDays = microsToDays(originMicros, ZoneOffset.UTC) | ||
| val originTodMicros = | ||
| Math.subtractExact(originMicros, daysToMicros(originDays, ZoneOffset.UTC)) | ||
|
|
||
| val tsDate = daysToLocalDate(tsDays) | ||
| val originDate = daysToLocalDate(originDays) | ||
| val rawMonthDiff = (tsDate.getYear.toLong * 12 + tsDate.getMonthValue) - | ||
| (originDate.getYear.toLong * 12 + originDate.getMonthValue) | ||
|
|
||
| var k = Math.floorDiv(rawMonthDiff, bucketMonths.toLong) | ||
| var candidateDays = dateAddMonths(originDays, | ||
| Math.toIntExact(Math.multiplyExact(k, bucketMonths.toLong))) | ||
| var candidate = Math.addExact(daysToMicros(candidateDays, ZoneOffset.UTC), originTodMicros) | ||
|
|
||
| // End-of-month capping in dateAddMonths can overshoot; step back one bucket if so. | ||
| if (candidate > tsMicros) { | ||
| k -= 1 | ||
| candidateDays = dateAddMonths(originDays, | ||
| Math.toIntExact(Math.multiplyExact(k, bucketMonths.toLong))) | ||
| candidate = Math.addExact(daysToMicros(candidateDays, ZoneOffset.UTC), originTodMicros) | ||
| } | ||
|
|
||
| candidate | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comma splice as in the Scala
@ExpressionDescription.