Skip to content

Commit 6fa2daa

Browse files
committed
fix
1 parent ac0fb76 commit 6fa2daa

12 files changed

+283
-219
lines changed

src/query/functions/src/aggregates/adaptors/aggregate_combinator_distinct.rs

Lines changed: 101 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,35 @@ use super::AggregateCountFunction;
4343
use super::AggregateFunction;
4444
use super::AggregateFunctionCreator;
4545
use super::AggregateFunctionDescription;
46+
use super::AggregateFunctionFeatures;
4647
use super::AggregateFunctionSortDesc;
4748
use super::CombinatorDescription;
4849
use super::StateAddr;
50+
use crate::aggregates::adaptors::AggregateFunctionCombinatorNull;
4951

50-
#[derive(Clone)]
5152
pub struct AggregateDistinctCombinator<State> {
5253
name: String,
5354

5455
nested_name: String,
5556
arguments: Vec<DataType>,
57+
check_null: bool,
5658
nested: Arc<dyn AggregateFunction>,
5759
_s: PhantomData<fn(State)>,
5860
}
5961

62+
impl<State> Clone for AggregateDistinctCombinator<State> {
63+
fn clone(&self) -> Self {
64+
Self {
65+
name: self.name.clone(),
66+
nested_name: self.nested_name.clone(),
67+
arguments: self.arguments.clone(),
68+
check_null: self.check_null,
69+
nested: self.nested.clone(),
70+
_s: PhantomData,
71+
}
72+
}
73+
}
74+
6075
impl<State> AggregateDistinctCombinator<State>
6176
where State: Send + 'static
6277
{
@@ -104,12 +119,12 @@ where State: DistinctStateFunc
104119
input_rows: usize,
105120
) -> Result<()> {
106121
let state = Self::get_state(place);
107-
state.batch_add(columns, validity, input_rows)
122+
state.batch_add(columns, validity, input_rows, self.check_null)
108123
}
109124

110125
fn accumulate_row(&self, place: AggrState, columns: ProjectedBlock, row: usize) -> Result<()> {
111126
let state = Self::get_state(place);
112-
state.add(columns, row)
127+
state.add(columns, row, self.check_null)
113128
}
114129

115130
fn serialize_type(&self) -> Vec<StateSerdeItem> {
@@ -202,32 +217,55 @@ pub fn aggregate_combinator_distinct_desc() -> CombinatorDescription {
202217
CombinatorDescription::creator(Box::new(try_create))
203218
}
204219

205-
pub fn aggregate_combinator_uniq_desc() -> AggregateFunctionDescription {
206-
let features = super::AggregateFunctionFeatures {
220+
pub fn aggregate_uniq_desc() -> AggregateFunctionDescription {
221+
let features = AggregateFunctionFeatures {
207222
returns_default_when_only_null: true,
208223
..Default::default()
209224
};
210-
AggregateFunctionDescription::creator_with_features(Box::new(try_create_uniq), features)
225+
AggregateFunctionDescription::creator_with_features(
226+
Box::new(|nested_name, params, arguments, sort_descs| {
227+
let creator = Box::new(AggregateCountFunction::try_create) as _;
228+
try_create(nested_name, params, arguments, sort_descs, &creator)
229+
}),
230+
features,
231+
)
211232
}
212233

213-
pub fn try_create_uniq(
214-
nested_name: &str,
215-
params: Vec<Scalar>,
216-
arguments: Vec<DataType>,
217-
sort_descs: Vec<AggregateFunctionSortDesc>,
218-
) -> Result<Arc<dyn AggregateFunction>> {
219-
let creator: AggregateFunctionCreator = Box::new(AggregateCountFunction::try_create);
220-
try_create(nested_name, params, arguments, sort_descs, &creator)
234+
pub fn aggregate_count_distinct_desc() -> AggregateFunctionDescription {
235+
AggregateFunctionDescription::creator_with_features(
236+
Box::new(|_, params, arguments, _| {
237+
let count_creator = Box::new(AggregateCountFunction::try_create) as _;
238+
if matches!(*arguments, [DataType::Nullable(_)]) {
239+
let new_arguments =
240+
AggregateFunctionCombinatorNull::transform_arguments(&arguments)?;
241+
let nested = try_create(
242+
"count",
243+
params.clone(),
244+
new_arguments,
245+
vec![],
246+
&count_creator,
247+
)?;
248+
AggregateFunctionCombinatorNull::try_create(params, arguments, nested, true)
249+
} else {
250+
try_create("count", params, arguments, vec![], &count_creator)
251+
}
252+
}),
253+
AggregateFunctionFeatures {
254+
returns_default_when_only_null: true,
255+
keep_nullable: true,
256+
..Default::default()
257+
},
258+
)
221259
}
222260

223-
pub fn try_create(
261+
fn try_create(
224262
nested_name: &str,
225263
params: Vec<Scalar>,
226264
arguments: Vec<DataType>,
227265
sort_descs: Vec<AggregateFunctionSortDesc>,
228266
nested_creator: &AggregateFunctionCreator,
229267
) -> Result<Arc<dyn AggregateFunction>> {
230-
let name = format!("DistinctCombinator({})", nested_name);
268+
let name = format!("DistinctCombinator({nested_name})");
231269
assert_variadic_arguments(&name, arguments.len(), (1, 32))?;
232270

233271
let nested_arguments = match nested_name {
@@ -236,53 +274,54 @@ pub fn try_create(
236274
};
237275
let nested = nested_creator(nested_name, params, nested_arguments, sort_descs)?;
238276

239-
if arguments.len() == 1 {
240-
match &arguments[0] {
241-
DataType::Number(ty) => with_number_mapped_type!(|NUM_TYPE| match ty {
242-
NumberDataType::NUM_TYPE => {
243-
return Ok(Arc::new(AggregateDistinctCombinator::<
244-
AggregateDistinctNumberState<NUM_TYPE>,
245-
> {
246-
nested_name: nested_name.to_owned(),
247-
arguments,
248-
nested,
249-
name,
250-
_s: PhantomData,
251-
}));
252-
}
253-
}),
254-
DataType::String => {
255-
return match nested_name {
256-
"count" | "uniq" => Ok(Arc::new(AggregateDistinctCombinator::<
257-
AggregateUniqStringState,
258-
> {
259-
name,
260-
arguments,
261-
nested,
262-
nested_name: nested_name.to_owned(),
263-
_s: PhantomData,
264-
})),
265-
_ => Ok(Arc::new(AggregateDistinctCombinator::<
266-
AggregateDistinctStringState,
267-
> {
268-
nested_name: nested_name.to_owned(),
269-
arguments,
270-
nested,
271-
name,
272-
_s: PhantomData,
273-
})),
274-
};
277+
match *arguments {
278+
[DataType::Number(ty)] => with_number_mapped_type!(|NUM_TYPE| match ty {
279+
NumberDataType::NUM_TYPE => {
280+
Ok(Arc::new(AggregateDistinctCombinator::<
281+
AggregateDistinctNumberState<NUM_TYPE>,
282+
> {
283+
nested_name: nested_name.to_owned(),
284+
arguments,
285+
check_null: false,
286+
nested,
287+
name,
288+
_s: PhantomData,
289+
}))
275290
}
276-
_ => {}
291+
}),
292+
[DataType::String] if matches!(nested_name, "count" | "uniq") => {
293+
Ok(Arc::new(AggregateDistinctCombinator::<
294+
AggregateUniqStringState,
295+
> {
296+
name,
297+
arguments,
298+
check_null: false,
299+
nested,
300+
nested_name: nested_name.to_owned(),
301+
_s: PhantomData,
302+
}))
277303
}
304+
[DataType::String] => Ok(Arc::new(AggregateDistinctCombinator::<
305+
AggregateDistinctStringState,
306+
> {
307+
nested_name: nested_name.to_owned(),
308+
arguments,
309+
check_null: false,
310+
nested,
311+
name,
312+
_s: PhantomData,
313+
})),
314+
_ => Ok(Arc::new(AggregateDistinctCombinator::<
315+
AggregateDistinctState,
316+
> {
317+
nested_name: nested_name.to_owned(),
318+
check_null: nested_name == "count"
319+
&& arguments.len() > 1
320+
&& arguments.iter().any(DataType::is_nullable_or_null),
321+
arguments,
322+
nested,
323+
name,
324+
_s: PhantomData,
325+
})),
278326
}
279-
Ok(Arc::new(AggregateDistinctCombinator::<
280-
AggregateDistinctState,
281-
> {
282-
nested_name: nested_name.to_owned(),
283-
arguments,
284-
nested,
285-
name,
286-
_s: PhantomData,
287-
}))
288327
}

src/query/functions/src/aggregates/adaptors/aggregate_null_adaptor.rs

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ use super::AggrStateLoc;
3232
use super::AggrStateRegistry;
3333
use super::AggrStateType;
3434
use super::AggregateFunction;
35-
use super::AggregateFunctionFeatures;
3635
use super::AggregateFunctionRef;
3736
use super::AggregateNullResultFunction;
3837
use super::StateAddr;
@@ -57,28 +56,22 @@ impl AggregateFunctionCombinatorNull {
5756
Ok(results)
5857
}
5958

60-
pub fn transform_params(params: &[Scalar]) -> Result<Vec<Scalar>> {
61-
Ok(params.to_owned())
62-
}
63-
6459
pub fn try_create(
65-
_name: &str,
6660
params: Vec<Scalar>,
6761
arguments: Vec<DataType>,
6862
nested: AggregateFunctionRef,
69-
properties: AggregateFunctionFeatures,
63+
returns_default_when_only_null: bool,
7064
) -> Result<AggregateFunctionRef> {
7165
// has_null_types
7266
if arguments.iter().any(|f| f == &DataType::Null) {
73-
if properties.returns_default_when_only_null {
67+
if returns_default_when_only_null {
7468
return AggregateNullResultFunction::try_create(DataType::Number(
7569
NumberDataType::UInt64,
7670
));
7771
} else {
7872
return AggregateNullResultFunction::try_create(DataType::Null);
7973
}
8074
}
81-
let params = Self::transform_params(&params)?;
8275
let arguments = Self::transform_arguments(&arguments)?;
8376
let size = arguments.len();
8477

@@ -90,8 +83,7 @@ impl AggregateFunctionCombinatorNull {
9083
}
9184

9285
let return_type = nested.return_type()?;
93-
let result_is_null =
94-
!properties.returns_default_when_only_null && return_type.can_inside_nullable();
86+
let result_is_null = !returns_default_when_only_null && return_type.can_inside_nullable();
9587

9688
match size {
9789
1 => match result_is_null {

src/query/functions/src/aggregates/adaptors/aggregate_ornull_adaptor.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ use databend_common_expression::StateSerdeItem;
3030
use super::AggrState;
3131
use super::AggrStateLoc;
3232
use super::AggregateFunction;
33-
use super::AggregateFunctionFeatures;
3433
use super::AggregateFunctionRef;
3534
use super::StateAddr;
3635

@@ -44,13 +43,10 @@ pub struct AggregateFunctionOrNullAdaptor {
4443
}
4544

4645
impl AggregateFunctionOrNullAdaptor {
47-
pub fn create(
48-
nested: AggregateFunctionRef,
49-
features: AggregateFunctionFeatures,
50-
) -> Result<AggregateFunctionRef> {
46+
pub fn create(nested: AggregateFunctionRef) -> Result<AggregateFunctionRef> {
5147
// count/count distinct should not be nullable for empty set, just return zero
5248
let inner_return_type = nested.return_type()?;
53-
if features.returns_default_when_only_null || inner_return_type == DataType::Null {
49+
if inner_return_type == DataType::Null {
5450
return Ok(nested);
5551
}
5652

src/query/functions/src/aggregates/aggregate_array_agg.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ use super::AggrState;
6262
use super::AggrStateLoc;
6363
use super::AggregateFunction;
6464
use super::AggregateFunctionDescription;
65+
use super::AggregateFunctionFeatures;
6566
use super::AggregateFunctionSortDesc;
6667
use super::SerializeInfo;
6768
use super::StateAddr;
@@ -809,5 +810,12 @@ fn try_create_aggregate_array_agg_function(
809810
}
810811

811812
pub fn aggregate_array_agg_function_desc() -> AggregateFunctionDescription {
812-
AggregateFunctionDescription::creator(Box::new(try_create_aggregate_array_agg_function))
813+
AggregateFunctionDescription::creator_with_features(
814+
Box::new(try_create_aggregate_array_agg_function),
815+
AggregateFunctionFeatures {
816+
allow_sort: true,
817+
keep_nullable: true,
818+
..Default::default()
819+
},
820+
)
813821
}

src/query/functions/src/aggregates/aggregate_array_moving.rs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ use super::AggrState;
4444
use super::AggrStateLoc;
4545
use super::AggregateFunction;
4646
use super::AggregateFunctionDescription;
47+
use super::AggregateFunctionFeatures;
4748
use super::AggregateFunctionRef;
4849
use super::AggregateFunctionSortDesc;
4950
use super::SerializeInfo;
@@ -678,7 +679,13 @@ pub fn try_create_aggregate_array_moving_avg_function(
678679
}
679680

680681
pub fn aggregate_array_moving_avg_function_desc() -> AggregateFunctionDescription {
681-
AggregateFunctionDescription::creator(Box::new(try_create_aggregate_array_moving_avg_function))
682+
AggregateFunctionDescription::creator_with_features(
683+
Box::new(try_create_aggregate_array_moving_avg_function),
684+
AggregateFunctionFeatures {
685+
keep_nullable: true,
686+
..Default::default()
687+
},
688+
)
682689
}
683690

684691
#[derive(Clone)]
@@ -859,5 +866,11 @@ pub fn try_create_aggregate_array_moving_sum_function(
859866
}
860867

861868
pub fn aggregate_array_moving_sum_function_desc() -> AggregateFunctionDescription {
862-
AggregateFunctionDescription::creator(Box::new(try_create_aggregate_array_moving_sum_function))
869+
AggregateFunctionDescription::creator_with_features(
870+
Box::new(try_create_aggregate_array_moving_sum_function),
871+
AggregateFunctionFeatures {
872+
keep_nullable: true,
873+
..Default::default()
874+
},
875+
)
863876
}

0 commit comments

Comments
 (0)