Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3728,6 +3728,19 @@ def timestamp_add(unit: str, quantity: "ColumnOrName", ts: "ColumnOrName") -> Co
timestamp_add.__doc__ = pysparkfuncs.timestamp_add.__doc__


def time_bucket(
bucket_size: "ColumnOrName",
ts: "ColumnOrName",
origin: Optional["ColumnOrName"] = None,
) -> Column:
if origin is None:
return _invoke_function_over_columns("time_bucket", bucket_size, ts)
return _invoke_function_over_columns("time_bucket", bucket_size, ts, origin)


time_bucket.__doc__ = pysparkfuncs.time_bucket.__doc__


def window(
timeColumn: "ColumnOrName",
windowDuration: str,
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/sql/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@
"timestamp_micros",
"timestamp_millis",
"timestamp_seconds",
"time_bucket",
"time_diff",
"time_from_micros",
"time_from_millis",
Expand Down
68 changes: 68 additions & 0 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13125,6 +13125,74 @@ def timestamp_add(unit: str, quantity: "ColumnOrName", ts: "ColumnOrName") -> Co
)


@_try_remote_functions
def time_bucket(
bucket_size: "ColumnOrName",
ts: "ColumnOrName",
origin: Optional["ColumnOrName"] = None,
) -> Column:
"""
Aligns a timestamp to the start of a fixed-size interval bucket.

Returns the start of the bucket that ``ts`` falls into, where buckets are defined by
the given ``bucket_size`` interval aligned to ``origin``. All bucketing is performed on
UTC micros, the session time zone does not affect bucket alignment. For local wall-clock
Comment on lines +13138 to +13139
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comma splice as in the Scala @ExpressionDescription.

Suggested change
the given ``bucket_size`` interval aligned to ``origin``. All bucketing is performed on
UTC micros, the session time zone does not affect bucket alignment. For local wall-clock
the given ``bucket_size`` interval aligned to ``origin``. All bucketing is performed on
UTC micros; the session time zone does not affect bucket alignment. For local wall-clock

alignment in a DST zone, cast the TIMESTAMP to TIMESTAMP_NTZ.

.. versionadded:: 4.2.0

Parameters
----------
bucket_size : :class:`~pyspark.sql.Column` or column name
A day-time or year-month interval defining the bucket size. Must be positive
and foldable.
ts : :class:`~pyspark.sql.Column` or column name
A TIMESTAMP or TIMESTAMP_NTZ value to bucket.
origin : :class:`~pyspark.sql.Column` or column name, optional
Alignment anchor. Defaults to 1970-01-01 00:00:00 (UTC for TIMESTAMP). Must be
the same type as ``ts`` and must be foldable.

Returns
-------
:class:`~pyspark.sql.Column`
The start of the bucket containing ``ts``, as the same type as ``ts``.

Examples
--------
>>> spark.conf.set("spark.sql.session.timeZone", "UTC")
>>> import datetime
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame(
... [(datetime.datetime(2024, 1, 1, 11, 27, 0),)], ['ts'])
>>> df.select(
... sf.time_bucket(sf.expr("INTERVAL '15' MINUTE"), 'ts').alias("bucket")
... ).collect()
[Row(bucket=datetime.datetime(2024, 1, 1, 11, 15))]

Shift the grid with an explicit origin: buckets run at :05, :20, :35, :50:

>>> df.select(
... sf.time_bucket(
... sf.expr("INTERVAL '15' MINUTE"),
... 'ts',
... sf.expr("TIMESTAMP '1970-01-01 00:05:00'")
... ).alias("bucket")
... ).collect()
[Row(bucket=datetime.datetime(2024, 1, 1, 11, 20))]
>>> spark.conf.unset("spark.sql.session.timeZone")
"""
from pyspark.sql.classic.column import _to_java_column

if origin is None:
return _invoke_function("time_bucket", _to_java_column(bucket_size), _to_java_column(ts))
return _invoke_function(
"time_bucket",
_to_java_column(bucket_size),
_to_java_column(ts),
_to_java_column(origin),
)
Comment on lines +13186 to +13193
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recent additions in this file (time_diff, time_trunc, to_timestamp_ltz) use _invoke_function_over_columns, which already maps _to_java_column for you internally — equivalent here, but easier to read and consistent with the neighbors:

Suggested change
if origin is None:
return _invoke_function("time_bucket", _to_java_column(bucket_size), _to_java_column(ts))
return _invoke_function(
"time_bucket",
_to_java_column(bucket_size),
_to_java_column(ts),
_to_java_column(origin),
)
if origin is None:
return _invoke_function_over_columns("time_bucket", bucket_size, ts)
return _invoke_function_over_columns("time_bucket", bucket_size, ts, origin)



@_try_remote_functions
def window(
timeColumn: "ColumnOrName",
Expand Down
20 changes: 20 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8486,6 +8486,26 @@ object functions {
def timestamp_add(unit: String, quantity: Column, ts: Column): Column =
Column.internalFn("timestampadd", lit(unit), quantity, ts)

/**
* Returns the start of the fixed-size bucket of `bucketSize` that contains `ts`, with buckets
* aligned to the epoch (1970-01-01 00:00:00). All computation is in UTC.
*
* @group datetime_funcs
* @since 4.2.0
*/
def time_bucket(bucketSize: Column, ts: Column): Column =
Column.fn("time_bucket", bucketSize, ts)

/**
* Returns the start of the fixed-size bucket of `bucketSize` that contains `ts`, with buckets
* aligned to `origin`. All computation is in UTC.
*
* @group datetime_funcs
* @since 4.2.0
*/
def time_bucket(bucketSize: Column, ts: Column, origin: Column): Column =
Column.fn("time_bucket", bucketSize, ts, origin)

/**
* Returns the difference between two times, measured in specified units. Throws a
* SparkIllegalArgumentException, in case the specified unit is not supported.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ object FunctionRegistry {
expression[UnixMillis]("unix_millis"),
expression[UnixMicros]("unix_micros"),
expression[ConvertTimezone]("convert_timezone"),
expressionBuilder("time_bucket", TimeBucketExpressionBuilder),

// collection functions
expression[CreateArray]("array"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ import java.util.Locale

import org.apache.commons.text.StringEscapeUtils

import org.apache.spark.{SparkDateTimeException, SparkIllegalArgumentException}
import org.apache.spark.{SparkDateTimeException, SparkException, SparkIllegalArgumentException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry}
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.Cast.{ordinalNumber, toSQLExpr, toSQLId, toSQLType, toSQLValue}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
Expand Down Expand Up @@ -3897,3 +3899,178 @@ case class TimestampDiff(
copy(startTimestamp = newLeft, endTimestamp = newRight)
}
}

/**
* Aligns a timestamp to the start of a fixed-size interval bucket.
*
* Returns the start of the half-open bucket [start, start + bucketSize) containing ts.
* All computation is performed on UTC values.
*/
case class TimeBucket(
bucketSize: Expression,
ts: Expression,
originTs: Expression)
extends TernaryExpression with ExpectsInputTypes {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this extend TimeZoneAwareExpression? Every other expression in this file that touches TimestampType does — TruncTimestamp (2467), TimestampAddInterval (1662), TimestampAddYMInterval (1979), ConvertTimezone, ParseToTimestamp, etc. TimeBucket accepts TIMESTAMP (LTZ) but always bucketizes in UTC, so e.g. in America/Los_Angeles:

SELECT date_trunc('MONTH', TIMESTAMP '2024-03-15 10:00:00');
-- 2024-03-01 00:00:00
SELECT time_bucket(INTERVAL '1' MONTH, TIMESTAMP '2024-03-15 10:00:00');
-- 2024-02-29 16:00:00   (March 1 UTC seen in PDT)

Proposal: extend TimeZoneAwareExpression, thread zoneId into timeBucketYMInterval, and replace ZoneOffset.UTC with zoneId at the four microsToDays / daysToMicros sites in that helper (mirrors how TruncTimestamp calls truncTimestamp(t, level, zoneId)). For NTZ, keep passing ZoneOffset.UTC. DT can stay as-is (fixed-length intervals are inherently zone-free, like TimeWindow), but the docstring should call that out explicitly rather than imply UTC for both flavors. Hard to walk this back once 4.2.0 ships.


override def nullIntolerant: Boolean = true

override def first: Expression = bucketSize
override def second: Expression = ts
override def third: Expression = originTs

override def inputTypes: Seq[AbstractDataType] = Seq(
TypeCollection(DayTimeIntervalType, YearMonthIntervalType),
AnyTimestampType,
AnyTimestampType)

override def dataType: DataType = ts.dataType

override def checkInputDataTypes(): TypeCheckResult = {
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) return defaultCheck

if (!bucketSize.foldable) {
return DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("bucketSize"),
"inputType" -> toSQLType(bucketSize.dataType),
"inputExpr" -> toSQLExpr(bucketSize)))
}

val bucketSizeValue = bucketSize.eval()
if (bucketSizeValue != null) {
val isNonPositive = bucketSize.dataType match {
case _: DayTimeIntervalType => bucketSizeValue.asInstanceOf[Long] <= 0
case _: YearMonthIntervalType => bucketSizeValue.asInstanceOf[Int] <= 0
case other => throw SparkException.internalError(
s"Unexpected bucketSize type: $other")
}
if (isNonPositive) {
return DataTypeMismatch(
errorSubClass = "VALUE_OUT_OF_RANGE",
messageParameters = Map(
"exprName" -> "time_bucket",
"valueRange" -> "(0, inf)",
Comment on lines +3953 to +3954
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two issues here:

  1. exprName should be the parameter name (the thing being constrained), not the function name. Convention (TimeWindow.scala:138-142, timeExpressions.scala:491-495, max_by's k arg → "exprName" : "\k`") is toSQLId(""). As written, the rendered message is "The time_bucket must be between (0, inf) ...", which reads as if time_bucket` is itself the value being checked.
  2. (0, inf) reads better as a concrete bound — TimeWindow.scala:141 uses s"(0, ${Long.MaxValue}]".
Suggested change
"exprName" -> "time_bucket",
"valueRange" -> "(0, inf)",
"exprName" -> toSQLId("bucketSize"),
"valueRange" -> s"(0, $${Long.MaxValue}]",

"currentValue" -> toSQLValue(bucketSizeValue, bucketSize.dataType)))
}
}

if (!originTs.foldable) {
return DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("origin"),
"inputType" -> toSQLType(originTs.dataType),
"inputExpr" -> toSQLExpr(originTs)))
}

if (ts.dataType != originTs.dataType) {
return DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(2),
"requiredType" -> toSQLType(ts.dataType),
"inputSql" -> toSQLExpr(originTs),
"inputType" -> toSQLType(originTs.dataType)))
}

TypeCheckSuccess
}

override def nullSafeEval(bucketSizeVal: Any, tsVal: Any, originVal: Any): Any = {
first.dataType match {
case _: DayTimeIntervalType =>
DateTimeUtils.timeBucketDTInterval(
bucketSizeVal.asInstanceOf[Long], tsVal.asInstanceOf[Long],
originVal.asInstanceOf[Long])
case _: YearMonthIntervalType =>
DateTimeUtils.timeBucketYMInterval(
bucketSizeVal.asInstanceOf[Int], tsVal.asInstanceOf[Long],
originVal.asInstanceOf[Long])
case other => throw SparkException.internalError(
s"Unexpected bucketSize type: $other")
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
first.dataType match {
case _: DayTimeIntervalType =>
defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) =>
s"$dtu.timeBucketDTInterval($bucketSizeCode, $tsCode, $originCode)")
case _: YearMonthIntervalType =>
defineCodeGen(ctx, ev, (bucketSizeCode, tsCode, originCode) =>
s"$dtu.timeBucketYMInterval($bucketSizeCode, $tsCode, $originCode)")
case other => throw SparkException.internalError(
s"Unexpected bucketSize type: $other")
}
}

override def prettyName: String = "time_bucket"

override protected def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): TimeBucket =
copy(bucketSize = newFirst, ts = newSecond, originTs = newThird)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(bucketSize, ts[, origin]) - Returns the start of the bucket that `ts` falls into,
where buckets are defined by the given `bucketSize` interval aligned to `origin`. All
bucketing is performed on UTC micros, the session time zone does not affect bucket
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: comma splice — two independent clauses joined by a comma.

Suggested change
bucketing is performed on UTC micros, the session time zone does not affect bucket
bucketing is performed on UTC micros; the session time zone does not affect bucket

alignment. For local wall-clock alignment in a DST zone, cast the TIMESTAMP to
TIMESTAMP_NTZ.
""",
arguments = """
Arguments:
* bucketSize - A day-time or year-month interval defining the bucket size. Must be positive and foldable.
* ts - A TIMESTAMP or TIMESTAMP_NTZ value to bucket.
* origin - Optional TIMESTAMP or TIMESTAMP_NTZ alignment anchor. Defaults to 1970-01-01 00:00:00 (UTC for TIMESTAMP). Must be the same type as ts and must be foldable.
""",
examples = """
Examples:
> SELECT _FUNC_(INTERVAL '15' MINUTE, TIMESTAMP '2024-01-01 11:27:00', TIMESTAMP '1970-01-01 00:00:00');
2024-01-01 11:15:00
> SELECT _FUNC_(INTERVAL '1' HOUR, TIMESTAMP '2024-01-01 11:27:00');
2024-01-01 11:00:00
> SELECT _FUNC_(INTERVAL '1' MONTH, TIMESTAMP '2024-07-20 14:30:00', TIMESTAMP '2024-06-15 09:00:00');
2024-07-15 09:00:00
""",
since = "4.2.0",
group = "datetime_funcs")
// scalastyle:on line.size.limit
object TimeBucketExpressionBuilder extends ExpressionBuilder {
private def retypeNull(e: Expression, dt: DataType): Expression = e match {
case Literal(null, NullType) => Literal(null, dt)
case _ => e
}

override def build(funcName: String, expressions: Seq[Expression]): Expression = {
expressions match {
case Seq(rawBucketSize, rawTs) =>
val bucketSize = retypeNull(rawBucketSize, DayTimeIntervalType())
// Fall back to TimestampType for bad ts types; ExpectsInputTypes will report it.
val tsType = rawTs.dataType match {
case t if AnyTimestampType.acceptsType(t) => t
case _ => TimestampType
}
val ts = retypeNull(rawTs, tsType)
TimeBucket(bucketSize, ts, Literal(0L, tsType))
case Seq(rawBucketSize, rawTs, rawOrigin) =>
val bucketSize = retypeNull(rawBucketSize, DayTimeIntervalType())
val tsType = (rawTs.dataType, rawOrigin.dataType) match {
case (NullType, t) if AnyTimestampType.acceptsType(t) => t
case (NullType, _) => TimestampType
case (t, _) => t
}
val ts = retypeNull(rawTs, tsType)
val originTs = retypeNull(rawOrigin, tsType)
TimeBucket(bucketSize, ts, originTs)
case _ =>
throw QueryCompilationErrors.wrongNumArgsError(
funcName, Seq(2, 3), expressions.length)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1059,4 +1059,59 @@ object DateTimeUtils extends SparkDateTimeUtils {
time, timePrecision, interval, intervalEndField)
}
}

/**
* DayTimeInterval bucketing: microsecond floor division against `originMicros`.
* Returns `originMicros + floorDiv(tsMicros - originMicros, bucketMicros) * bucketMicros`.
*
* `bucketMicros` must be positive; `TimeBucket.checkInputDataTypes` enforces
* this at analysis time.
*
* @param bucketMicros bucket size in microseconds.
* @param tsMicros timestamp to bucket, in microseconds since the epoch (UTC).
* @param originMicros grid alignment anchor, in microseconds since the epoch (UTC).
*/
def timeBucketDTInterval(bucketMicros: Long, tsMicros: Long, originMicros: Long): Long = {
val diff = Math.subtractExact(tsMicros, originMicros)
val bucketOffset = Math.multiplyExact(Math.floorDiv(diff, bucketMicros), bucketMicros)
Math.addExact(originMicros, bucketOffset)
Comment on lines +1075 to +1077
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark's convention for arithmetic with overflow checks is MathUtils.subtractExact / multiplyExact / addExact (sql/api MathUtils.scala), which wraps java.lang.ArithmeticException into SparkArithmeticException with the ARITHMETIC_OVERFLOW error class. Using Math.*Exact directly here means a user hitting an extreme tsMicros / originMicros / bucketMicros combination sees a raw, non-i18n'd java.lang.ArithmeticException — the unit test intercept[ArithmeticException] confirms this is what propagates. Same comment applies to Math.subtractExact / multiplyExact / addExact / toIntExact in timeBucketYMInterval below.

}

/**
* YearMonthInterval bucketing: month arithmetic with end-of-month capping and step-back.
* The origin's day-of-month and time-of-day determine the bucket boundaries.
*
* `bucketMonths` must be positive; `TimeBucket.checkInputDataTypes` enforces
* this at analysis time.
*
* @param bucketMonths bucket size in months.
* @param tsMicros timestamp to bucket, in microseconds since the epoch (UTC).
* @param originMicros grid alignment anchor, in microseconds since the epoch (UTC).
*/
def timeBucketYMInterval(bucketMonths: Int, tsMicros: Long, originMicros: Long): Long = {
val tsDays = microsToDays(tsMicros, ZoneOffset.UTC)
val originDays = microsToDays(originMicros, ZoneOffset.UTC)
val originTodMicros =
Math.subtractExact(originMicros, daysToMicros(originDays, ZoneOffset.UTC))

val tsDate = daysToLocalDate(tsDays)
val originDate = daysToLocalDate(originDays)
val rawMonthDiff = (tsDate.getYear.toLong * 12 + tsDate.getMonthValue) -
(originDate.getYear.toLong * 12 + originDate.getMonthValue)

var k = Math.floorDiv(rawMonthDiff, bucketMonths.toLong)
var candidateDays = dateAddMonths(originDays,
Math.toIntExact(Math.multiplyExact(k, bucketMonths.toLong)))
var candidate = Math.addExact(daysToMicros(candidateDays, ZoneOffset.UTC), originTodMicros)

// End-of-month capping in dateAddMonths can overshoot; step back one bucket if so.
if (candidate > tsMicros) {
k -= 1
candidateDays = dateAddMonths(originDays,
Math.toIntExact(Math.multiplyExact(k, bucketMonths.toLong)))
candidate = Math.addExact(daysToMicros(candidateDays, ZoneOffset.UTC), originTodMicros)
}

candidate
}
}
Loading