Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,43 @@ use super::aggregate_distinct_state::AggregateDistinctState;
use super::aggregate_distinct_state::AggregateDistinctStringState;
use super::aggregate_distinct_state::AggregateUniqStringState;
use super::aggregate_distinct_state::DistinctStateFunc;
use super::aggregate_null_result::AggregateNullResultFunction;
use super::assert_variadic_arguments;
use super::AggrState;
use super::AggrStateLoc;
use super::AggregateCountFunction;
use super::AggregateFunction;
use super::AggregateFunctionCombinatorNull;
use super::AggregateFunctionCreator;
use super::AggregateFunctionDescription;
use super::AggregateFunctionFeatures;
use super::AggregateFunctionSortDesc;
use super::CombinatorDescription;
use super::StateAddr;

#[derive(Clone)]
pub struct AggregateDistinctCombinator<State> {
name: String,

nested_name: String,
arguments: Vec<DataType>,
skip_null: bool,
nested: Arc<dyn AggregateFunction>,
_s: PhantomData<fn(State)>,
}

impl<State> Clone for AggregateDistinctCombinator<State> {
fn clone(&self) -> Self {
Self {
name: self.name.clone(),
nested_name: self.nested_name.clone(),
arguments: self.arguments.clone(),
skip_null: self.skip_null,
nested: self.nested.clone(),
_s: PhantomData,
}
}
}

impl<State> AggregateDistinctCombinator<State>
where State: Send + 'static
{
Expand Down Expand Up @@ -104,12 +120,12 @@ where State: DistinctStateFunc
input_rows: usize,
) -> Result<()> {
let state = Self::get_state(place);
state.batch_add(columns, validity, input_rows)
state.batch_add(columns, validity, input_rows, self.skip_null)
}

fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()> {
let state = Self::get_state(place);
state.add(columns, row)
state.add(columns, row, self.skip_null)
}

fn serialize_type(&self) -> Vec<StateSerdeItem> {
Expand Down Expand Up @@ -202,32 +218,63 @@ pub fn aggregate_combinator_distinct_desc() -> CombinatorDescription {
CombinatorDescription::creator(Box::new(try_create))
}

pub fn aggregate_combinator_uniq_desc() -> AggregateFunctionDescription {
let features = super::AggregateFunctionFeatures {
pub fn aggregate_uniq_desc() -> AggregateFunctionDescription {
let features = AggregateFunctionFeatures {
returns_default_when_only_null: true,
..Default::default()
};
AggregateFunctionDescription::creator_with_features(Box::new(try_create_uniq), features)
AggregateFunctionDescription::creator_with_features(
Box::new(|nested_name, params, arguments, sort_descs| {
let creator = Box::new(AggregateCountFunction::try_create) as _;
try_create(nested_name, params, arguments, sort_descs, &creator)
}),
features,
)
}

pub fn try_create_uniq(
nested_name: &str,
params: Vec<Scalar>,
arguments: Vec<DataType>,
sort_descs: Vec<AggregateFunctionSortDesc>,
) -> Result<Arc<dyn AggregateFunction>> {
let creator: AggregateFunctionCreator = Box::new(AggregateCountFunction::try_create);
try_create(nested_name, params, arguments, sort_descs, &creator)
pub fn aggregate_count_distinct_desc() -> AggregateFunctionDescription {
AggregateFunctionDescription::creator_with_features(
Box::new(|_, params, arguments, _| {
let count_creator = Box::new(AggregateCountFunction::try_create) as _;
match *arguments {
[DataType::Nullable(_)] => {
let new_arguments =
AggregateFunctionCombinatorNull::transform_arguments(&arguments)?;
let nested = try_create(
"count",
params.clone(),
new_arguments,
vec![],
&count_creator,
)?;
AggregateFunctionCombinatorNull::try_create(params, arguments, nested, true)
}
ref arguments
if !arguments.is_empty() && arguments.iter().all(DataType::is_null) =>
{
AggregateNullResultFunction::try_create(DataType::Number(
NumberDataType::UInt64,
))
}
_ => try_create("count", params, arguments, vec![], &count_creator),
}
}),
AggregateFunctionFeatures {
returns_default_when_only_null: true,
keep_nullable: true,
..Default::default()
},
)
}

pub fn try_create(
fn try_create(
nested_name: &str,
params: Vec<Scalar>,
arguments: Vec<DataType>,
sort_descs: Vec<AggregateFunctionSortDesc>,
nested_creator: &AggregateFunctionCreator,
) -> Result<Arc<dyn AggregateFunction>> {
let name = format!("DistinctCombinator({})", nested_name);
let name = format!("DistinctCombinator({nested_name})");
assert_variadic_arguments(&name, arguments.len(), (1, 32))?;

let nested_arguments = match nested_name {
Expand All @@ -236,53 +283,54 @@ pub fn try_create(
};
let nested = nested_creator(nested_name, params, nested_arguments, sort_descs)?;

if arguments.len() == 1 {
match &arguments[0] {
DataType::Number(ty) => with_number_mapped_type!(|NUM_TYPE| match ty {
NumberDataType::NUM_TYPE => {
return Ok(Arc::new(AggregateDistinctCombinator::<
AggregateDistinctNumberState<NUM_TYPE>,
> {
nested_name: nested_name.to_owned(),
arguments,
nested,
name,
_s: PhantomData,
}));
}
}),
DataType::String => {
return match nested_name {
"count" | "uniq" => Ok(Arc::new(AggregateDistinctCombinator::<
AggregateUniqStringState,
> {
name,
arguments,
nested,
nested_name: nested_name.to_owned(),
_s: PhantomData,
})),
_ => Ok(Arc::new(AggregateDistinctCombinator::<
AggregateDistinctStringState,
> {
nested_name: nested_name.to_owned(),
arguments,
nested,
name,
_s: PhantomData,
})),
};
match *arguments {
[DataType::Number(ty)] => with_number_mapped_type!(|NUM_TYPE| match ty {
NumberDataType::NUM_TYPE => {
Ok(Arc::new(AggregateDistinctCombinator::<
AggregateDistinctNumberState<NUM_TYPE>,
> {
nested_name: nested_name.to_owned(),
arguments,
skip_null: false,
nested,
name,
_s: PhantomData,
}))
}
_ => {}
}),
[DataType::String] if matches!(nested_name, "count" | "uniq") => {
Ok(Arc::new(AggregateDistinctCombinator::<
AggregateUniqStringState,
> {
name,
arguments,
skip_null: false,
nested,
nested_name: nested_name.to_owned(),
_s: PhantomData,
}))
}
[DataType::String] => Ok(Arc::new(AggregateDistinctCombinator::<
AggregateDistinctStringState,
> {
nested_name: nested_name.to_owned(),
arguments,
skip_null: false,
nested,
name,
_s: PhantomData,
})),
_ => Ok(Arc::new(AggregateDistinctCombinator::<
AggregateDistinctState,
> {
nested_name: nested_name.to_owned(),
skip_null: nested_name == "count"
&& arguments.len() > 1
&& arguments.iter().any(DataType::is_nullable_or_null),
arguments,
nested,
name,
_s: PhantomData,
})),
}
Ok(Arc::new(AggregateDistinctCombinator::<
AggregateDistinctState,
> {
nested_name: nested_name.to_owned(),
arguments,
nested,
name,
_s: PhantomData,
}))
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use super::AggrStateLoc;
use super::AggrStateRegistry;
use super::AggrStateType;
use super::AggregateFunction;
use super::AggregateFunctionFeatures;
use super::AggregateFunctionRef;
use super::AggregateNullResultFunction;
use super::StateAddr;
Expand All @@ -57,28 +56,22 @@ impl AggregateFunctionCombinatorNull {
Ok(results)
}

pub fn transform_params(params: &[Scalar]) -> Result<Vec<Scalar>> {
Ok(params.to_owned())
}

pub fn try_create(
_name: &str,
params: Vec<Scalar>,
arguments: Vec<DataType>,
nested: AggregateFunctionRef,
properties: AggregateFunctionFeatures,
returns_default_when_only_null: bool,
) -> Result<AggregateFunctionRef> {
// has_null_types
if arguments.iter().any(|f| f == &DataType::Null) {
if properties.returns_default_when_only_null {
if returns_default_when_only_null {
return AggregateNullResultFunction::try_create(DataType::Number(
NumberDataType::UInt64,
));
} else {
return AggregateNullResultFunction::try_create(DataType::Null);
}
}
let params = Self::transform_params(&params)?;
let arguments = Self::transform_arguments(&arguments)?;
let size = arguments.len();

Expand All @@ -90,8 +83,7 @@ impl AggregateFunctionCombinatorNull {
}

let return_type = nested.return_type()?;
let result_is_null =
!properties.returns_default_when_only_null && return_type.can_inside_nullable();
let result_is_null = !returns_default_when_only_null && return_type.can_inside_nullable();

match size {
1 => match result_is_null {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ use databend_common_expression::StateSerdeItem;
use super::AggrState;
use super::AggrStateLoc;
use super::AggregateFunction;
use super::AggregateFunctionFeatures;
use super::AggregateFunctionRef;
use super::StateAddr;

Expand All @@ -44,13 +43,10 @@ pub struct AggregateFunctionOrNullAdaptor {
}

impl AggregateFunctionOrNullAdaptor {
pub fn create(
nested: AggregateFunctionRef,
features: AggregateFunctionFeatures,
) -> Result<AggregateFunctionRef> {
pub fn create(nested: AggregateFunctionRef) -> Result<AggregateFunctionRef> {
// count/count distinct should not be nullable for empty set, just return zero
let inner_return_type = nested.return_type()?;
if features.returns_default_when_only_null || inner_return_type == DataType::Null {
if inner_return_type == DataType::Null {
return Ok(nested);
}

Expand Down
10 changes: 9 additions & 1 deletion src/query/functions/src/aggregates/aggregate_array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ use super::AggrState;
use super::AggrStateLoc;
use super::AggregateFunction;
use super::AggregateFunctionDescription;
use super::AggregateFunctionFeatures;
use super::AggregateFunctionSortDesc;
use super::SerializeInfo;
use super::StateAddr;
Expand Down Expand Up @@ -809,5 +810,12 @@ fn try_create_aggregate_array_agg_function(
}

pub fn aggregate_array_agg_function_desc() -> AggregateFunctionDescription {
AggregateFunctionDescription::creator(Box::new(try_create_aggregate_array_agg_function))
AggregateFunctionDescription::creator_with_features(
Box::new(try_create_aggregate_array_agg_function),
AggregateFunctionFeatures {
allow_sort: true,
keep_nullable: true,
..Default::default()
},
)
}
17 changes: 15 additions & 2 deletions src/query/functions/src/aggregates/aggregate_array_moving.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use super::AggrState;
use super::AggrStateLoc;
use super::AggregateFunction;
use super::AggregateFunctionDescription;
use super::AggregateFunctionFeatures;
use super::AggregateFunctionRef;
use super::AggregateFunctionSortDesc;
use super::SerializeInfo;
Expand Down Expand Up @@ -678,7 +679,13 @@ pub fn try_create_aggregate_array_moving_avg_function(
}

pub fn aggregate_array_moving_avg_function_desc() -> AggregateFunctionDescription {
AggregateFunctionDescription::creator(Box::new(try_create_aggregate_array_moving_avg_function))
AggregateFunctionDescription::creator_with_features(
Box::new(try_create_aggregate_array_moving_avg_function),
AggregateFunctionFeatures {
keep_nullable: true,
..Default::default()
},
)
}

#[derive(Clone)]
Expand Down Expand Up @@ -859,5 +866,11 @@ pub fn try_create_aggregate_array_moving_sum_function(
}

pub fn aggregate_array_moving_sum_function_desc() -> AggregateFunctionDescription {
AggregateFunctionDescription::creator(Box::new(try_create_aggregate_array_moving_sum_function))
AggregateFunctionDescription::creator_with_features(
Box::new(try_create_aggregate_array_moving_sum_function),
AggregateFunctionFeatures {
keep_nullable: true,
..Default::default()
},
)
}
Loading
Loading