Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -598,7 +598,7 @@ impl Alias {
}
}

/// Binary expression
/// Binary expression for [`Expr::BinaryExpr`]
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
pub struct BinaryExpr {
/// Left-hand side of the expression
Expand Down Expand Up @@ -901,7 +901,7 @@ impl<'a> TreeNodeContainer<'a, Expr> for Sort {
}
}

/// Aggregate function
/// Aggregate Function
///
/// See also [`ExprFunctionExt`] to set these fields on `Expr`
///
Expand Down
36 changes: 16 additions & 20 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub use datafusion_functions_aggregate_common::accumulator::{
AccumulatorArgs, AccumulatorFactoryFunction, StateFieldsArgs,
};

use crate::expr::{AggregateFunction, WindowFunction};
use crate::simplify::SimplifyContext;
pub use datafusion_functions_window_common::expr::ExpressionArgs;
pub use datafusion_functions_window_common::field::WindowUDFFieldArgs;
pub use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
Expand Down Expand Up @@ -64,28 +66,22 @@ pub type PartitionEvaluatorFactory =
pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;

/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure
/// A closure with two arguments:
/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked
/// * 'info': [crate::simplify::SimplifyContext]
/// Return type for [crate::udaf::AggregateUDFImpl::simplify]
///
/// This closure is invoked with two arguments:
/// * 'aggregate_function': [AggregateFunction] with already simplified arguments
/// * 'info': [SimplifyContext]
///
/// Closure returns simplified [Expr] or an error.
pub type AggregateFunctionSimplification = Box<
dyn Fn(
crate::expr::AggregateFunction,
&crate::simplify::SimplifyContext,
) -> Result<Expr>,
>;
pub type AggregateFunctionSimplification =
Box<dyn Fn(AggregateFunction, &SimplifyContext) -> Result<Expr>>;

/// [crate::udwf::WindowUDFImpl::simplify] simplifier closure
/// A closure with two arguments:
/// * 'window_function': [crate::expr::WindowFunction] for which simplified has been invoked
/// * 'info': [crate::simplify::SimplifyContext]
/// Return type for [crate::udwf::WindowUDFImpl::simplify]
///
/// This closure is invoked with two arguments:
/// * 'window_function': [WindowFunction] for which simplified has been invoked
/// * 'info': [SimplifyContext]
///
/// Closure returns simplified [Expr] or an error.
pub type WindowFunctionSimplification = Box<
dyn Fn(
crate::expr::WindowFunction,
&crate::simplify::SimplifyContext,
) -> Result<Expr>,
>;
pub type WindowFunctionSimplification =
Box<dyn Fn(WindowFunction, &SimplifyContext) -> Result<Expr>>;
9 changes: 5 additions & 4 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -651,10 +651,10 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {
AggregateOrderSensitivity::HardRequirement
}

/// Optionally apply per-UDaF simplification / rewrite rules.
/// Return a closure for simplifying a user defined aggregate.
///
/// This can be used to apply function specific simplification rules during
/// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default
/// optimization (e.g. `percentile_cont(` --> `Min`). The default
/// implementation does nothing.
///
/// Note that DataFusion handles simplifying arguments and "constant
Expand All @@ -664,10 +664,11 @@ pub trait AggregateUDFImpl: Debug + DynEq + DynHash + Send + Sync {
///
/// # Returns
///
/// [None] if simplify is not defined or,
/// [None] if simplify is not defined
///
/// Or, a closure with two arguments:
/// * 'aggregate_function': [AggregateFunction] for which simplified has been invoked
/// * 'aggregate_function': [AggregateFunction], which includes already simplified
/// arguments
/// * 'info': [crate::simplify::SimplifyContext]
///
/// closure returns simplified [Expr] or an error.
Expand Down
96 changes: 91 additions & 5 deletions datafusion/functions-aggregate/src/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@ use datafusion_common::types::{
logical_int64, logical_uint8, logical_uint16, logical_uint32, logical_uint64,
};
use datafusion_common::{HashMap, Result, ScalarValue, exec_err, not_impl_err};
use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs};
use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams};
use datafusion_expr::function::{
AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs,
};
use datafusion_expr::simplify::SimplifyContext;
use datafusion_expr::utils::{AggregateOrderSensitivity, format_state_name};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Coercion, Documentation, Expr, GroupsAccumulator,
ReversedUDAF, SetMonotonicity, Signature, TypeSignature, TypeSignatureClass,
Volatility,
Accumulator, AggregateUDFImpl, BinaryExpr, Coercion, Documentation, Expr,
GroupsAccumulator, Operator, ReversedUDAF, SetMonotonicity, Signature, TypeSignature,
TypeSignatureClass, Volatility,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
use datafusion_functions_aggregate_common::aggregate::sum_distinct::DistinctSumAccumulator;
Expand All @@ -54,7 +58,7 @@ make_udaf_expr_and_func!(
);

pub fn sum_distinct(expr: Expr) -> Expr {
Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf(
Expr::AggregateFunction(AggregateFunction::new_udf(
sum_udaf(),
vec![expr],
true,
Expand Down Expand Up @@ -346,6 +350,88 @@ impl AggregateUDFImpl for Sum {
_ => SetMonotonicity::NotMonotonic,
}
}

/// Simplification Rules
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
Some(Box::new(sum_simplifier))
}
}

/// Implement ClickBench Q29 specific optimization:
/// `SUM(arg + constant)` --> `SUM(arg) + constant * COUNT(arg)`
///
/// Backstory: TODO
///
fn sum_simplifier(mut agg: AggregateFunction, _info: &SimplifyContext) -> Result<Expr> {
// Explicitly destructure to ensure we check all relevant fields
let AggregateFunctionParams {
args,
distinct,
filter,
order_by,
null_treatment,
} = &agg.params;

if *distinct
|| filter.is_some()
|| !order_by.is_empty()
|| null_treatment.is_some()
|| args.len() != 1
{
return Ok(Expr::AggregateFunction(agg));
}

// otherwise check the arguments if they are <col> <op> scalar
let (arg, lit) = match SplitResult::new(agg.params.args.swap_remove(0)) {
SplitResult::Original(expr) => {
agg.params.args.push(expr); // put it back
return Ok(Expr::AggregateFunction(agg));
}
SplitResult::Split { arg, lit } => (arg, lit),
};

// Rewrite to SUM(arg)
agg.params.args.push(arg.clone());
let sum_agg = Expr::AggregateFunction(agg);

// sum(arg) + scalar * COUNT(arg)
Ok(sum_agg + (lit * crate::count::count(arg)))
}

/// Result of trying to split an expression into an arg and constant
#[derive(Debug, Clone)]
enum SplitResult {
/// if the expression is either of
/// * `<arg> <op> <lit>`
/// * `<lit> <op> <arg>`
///
/// When `op` is `+`
Split { arg: Expr, lit: Expr },
/// If the expression is something else
Original(Expr),
}

impl SplitResult {
fn new(expr: Expr) -> Self {
let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr else {
return Self::Original(expr);
};
if op != Operator::Plus {
return Self::Original(Expr::BinaryExpr(BinaryExpr { left, op, right }));
}

match (left.as_ref(), right.as_ref()) {
(Expr::Literal(..), _) => Self::Split {
arg: *right,
lit: *left,
},
(_, Expr::Literal(..)) => Self::Split {
arg: *left,
lit: *right,
},
_ => Self::Original(Expr::BinaryExpr(BinaryExpr { left, op, right })),
}
}
}

/// This accumulator computes SUM incrementally
Expand Down
97 changes: 97 additions & 0 deletions datafusion/sqllogictest/test_files/aggregates_simplify.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

#######
# Tests for aggregate optimizations / simplifications
#######

statement ok
CREATE TABLE sum_simplify_t AS VALUES (1), (2), (NULL);

#######
# Positive EXPLAIN cases for SUM(arg + literal) simplification
#######

# Expect to see one COUNT and one SUM in each query below
query TT
EXPLAIN SELECT SUM(column1 + 1), SUM(column1 + 2) FROM sum_simplify_t;
----
logical_plan
01)Projection: __common_expr_1 + __common_expr_2 AS sum(sum_simplify_t.column1 + Int64(1)), __common_expr_1 + Int64(2) * __common_expr_2 AS sum(sum_simplify_t.column1 + Int64(2))
02)--Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1) AS __common_expr_1, count(sum_simplify_t.column1) AS __common_expr_2]]
03)----TableScan: sum_simplify_t projection=[column1]
physical_plan
01)ProjectionExec: expr=[__common_expr_1@0 + __common_expr_2@1 as sum(sum_simplify_t.column1 + Int64(1)), __common_expr_1@0 + 2 * __common_expr_2@1 as sum(sum_simplify_t.column1 + Int64(2))]
02)--AggregateExec: mode=Single, gby=[], aggr=[__common_expr_1, __common_expr_2]
03)----DataSourceExec: partitions=1, partition_sizes=[1]

query TT
EXPLAIN SELECT SUM(1 + column1), SUM(column1 + 2) FROM sum_simplify_t;
----
logical_plan
01)Projection: __common_expr_1 + __common_expr_2 AS sum(Int64(1) + sum_simplify_t.column1), __common_expr_1 + Int64(2) * __common_expr_2 AS sum(sum_simplify_t.column1 + Int64(2))
02)--Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1) AS __common_expr_1, count(sum_simplify_t.column1) AS __common_expr_2]]
03)----TableScan: sum_simplify_t projection=[column1]
physical_plan
01)ProjectionExec: expr=[__common_expr_1@0 + __common_expr_2@1 as sum(Int64(1) + sum_simplify_t.column1), __common_expr_1@0 + 2 * __common_expr_2@1 as sum(sum_simplify_t.column1 + Int64(2))]
02)--AggregateExec: mode=Single, gby=[], aggr=[__common_expr_1, __common_expr_2]
03)----DataSourceExec: partitions=1, partition_sizes=[1]

#######
# Cases where rewrite should not apply
#######

query TT
EXPLAIN SELECT SUM(DISTINCT column1 + 1), SUM(DISTINCT column1 + 2) FROM sum_simplify_t;
----
logical_plan
01)Aggregate: groupBy=[[]], aggr=[[sum(DISTINCT sum_simplify_t.column1 + Int64(1)), sum(DISTINCT sum_simplify_t.column1 + Int64(2))]]
02)--TableScan: sum_simplify_t projection=[column1]
physical_plan
01)AggregateExec: mode=Single, gby=[], aggr=[sum(DISTINCT sum_simplify_t.column1 + Int64(1)), sum(DISTINCT sum_simplify_t.column1 + Int64(2))]
02)--DataSourceExec: partitions=1, partition_sizes=[1]

query TT
EXPLAIN SELECT SUM(column1 + 1) FILTER (WHERE column1 > 1), SUM(column1 + 2) FILTER (WHERE column1 > 2 ) FROM sum_simplify_t;
----
logical_plan
01)Aggregate: groupBy=[[]], aggr=[[sum(sum_simplify_t.column1 + Int64(1)) FILTER (WHERE sum_simplify_t.column1 > Int64(1)), sum(sum_simplify_t.column1 + Int64(2)) FILTER (WHERE sum_simplify_t.column1 > Int64(2))]]
02)--TableScan: sum_simplify_t projection=[column1]
physical_plan
01)AggregateExec: mode=Single, gby=[], aggr=[sum(sum_simplify_t.column1 + Int64(1)) FILTER (WHERE sum_simplify_t.column1 > Int64(1)), sum(sum_simplify_t.column1 + Int64(2)) FILTER (WHERE sum_simplify_t.column1 > Int64(2))]
02)--DataSourceExec: partitions=1, partition_sizes=[1]

# This test should work
query error
SELECT SUM(random() + 1), SUM(random() + 2) FROM sum_simplify_t;


#######
# Reproducers for known issues
#######

# Blocking: single rewritten SUM fails with "Invalid aggregate expression"
query error
SELECT SUM(column1 + 1) FROM sum_simplify_t;

# Blocking: CSE can fail with "No field named ... Valid fields are __common_expr_1"
query error
SELECT SUM(column1), SUM(column1 + 1) FROM sum_simplify_t;


statement ok
DROP TABLE sum_simplify_t;
Loading