@@ -36,27 +36,43 @@ use super::aggregate_distinct_state::AggregateDistinctState;
3636use super :: aggregate_distinct_state:: AggregateDistinctStringState ;
3737use super :: aggregate_distinct_state:: AggregateUniqStringState ;
3838use super :: aggregate_distinct_state:: DistinctStateFunc ;
39+ use super :: aggregate_null_result:: AggregateNullResultFunction ;
3940use super :: assert_variadic_arguments;
4041use super :: AggrState ;
4142use super :: AggrStateLoc ;
4243use super :: AggregateCountFunction ;
4344use super :: AggregateFunction ;
45+ use super :: AggregateFunctionCombinatorNull ;
4446use super :: AggregateFunctionCreator ;
4547use super :: AggregateFunctionDescription ;
48+ use super :: AggregateFunctionFeatures ;
4649use super :: AggregateFunctionSortDesc ;
4750use super :: CombinatorDescription ;
4851use super :: StateAddr ;
4952
50- #[ derive( Clone ) ]
5153pub struct AggregateDistinctCombinator < State > {
5254 name : String ,
5355
5456 nested_name : String ,
5557 arguments : Vec < DataType > ,
58+ skip_null : bool ,
5659 nested : Arc < dyn AggregateFunction > ,
5760 _s : PhantomData < fn ( State ) > ,
5861}
5962
63+ impl < State > Clone for AggregateDistinctCombinator < State > {
64+ fn clone ( & self ) -> Self {
65+ Self {
66+ name : self . name . clone ( ) ,
67+ nested_name : self . nested_name . clone ( ) ,
68+ arguments : self . arguments . clone ( ) ,
69+ skip_null : self . skip_null ,
70+ nested : self . nested . clone ( ) ,
71+ _s : PhantomData ,
72+ }
73+ }
74+ }
75+
6076impl < State > AggregateDistinctCombinator < State >
6177where State : Send + ' static
6278{
@@ -104,12 +120,12 @@ where State: DistinctStateFunc
104120 input_rows : usize ,
105121 ) -> Result < ( ) > {
106122 let state = Self :: get_state ( place) ;
107- state. batch_add ( columns, validity, input_rows)
123+ state. batch_add ( columns, validity, input_rows, self . skip_null )
108124 }
109125
110126 fn accumulate_row ( & self , place : AggrState , columns : ProjectedBlock , row : usize ) -> Result < ( ) > {
111127 let state = Self :: get_state ( place) ;
112- state. add ( columns, row)
128+ state. add ( columns, row, self . skip_null )
113129 }
114130
115131 fn serialize_type ( & self ) -> Vec < StateSerdeItem > {
@@ -202,32 +218,63 @@ pub fn aggregate_combinator_distinct_desc() -> CombinatorDescription {
202218 CombinatorDescription :: creator ( Box :: new ( try_create) )
203219}
204220
205- pub fn aggregate_combinator_uniq_desc ( ) -> AggregateFunctionDescription {
206- let features = super :: AggregateFunctionFeatures {
221+ pub fn aggregate_uniq_desc ( ) -> AggregateFunctionDescription {
222+ let features = AggregateFunctionFeatures {
207223 returns_default_when_only_null : true ,
208224 ..Default :: default ( )
209225 } ;
210- AggregateFunctionDescription :: creator_with_features ( Box :: new ( try_create_uniq) , features)
226+ AggregateFunctionDescription :: creator_with_features (
227+ Box :: new ( |nested_name, params, arguments, sort_descs| {
228+ let creator = Box :: new ( AggregateCountFunction :: try_create) as _ ;
229+ try_create ( nested_name, params, arguments, sort_descs, & creator)
230+ } ) ,
231+ features,
232+ )
211233}
212234
213- pub fn try_create_uniq (
214- nested_name : & str ,
215- params : Vec < Scalar > ,
216- arguments : Vec < DataType > ,
217- sort_descs : Vec < AggregateFunctionSortDesc > ,
218- ) -> Result < Arc < dyn AggregateFunction > > {
219- let creator: AggregateFunctionCreator = Box :: new ( AggregateCountFunction :: try_create) ;
220- try_create ( nested_name, params, arguments, sort_descs, & creator)
235+ pub fn aggregate_count_distinct_desc ( ) -> AggregateFunctionDescription {
236+ AggregateFunctionDescription :: creator_with_features (
237+ Box :: new ( |_, params, arguments, _| {
238+ let count_creator = Box :: new ( AggregateCountFunction :: try_create) as _ ;
239+ match * arguments {
240+ [ DataType :: Nullable ( _) ] => {
241+ let new_arguments =
242+ AggregateFunctionCombinatorNull :: transform_arguments ( & arguments) ?;
243+ let nested = try_create (
244+ "count" ,
245+ params. clone ( ) ,
246+ new_arguments,
247+ vec ! [ ] ,
248+ & count_creator,
249+ ) ?;
250+ AggregateFunctionCombinatorNull :: try_create ( params, arguments, nested, true )
251+ }
252+ ref arguments
253+ if !arguments. is_empty ( ) && arguments. iter ( ) . all ( DataType :: is_null) =>
254+ {
255+ AggregateNullResultFunction :: try_create ( DataType :: Number (
256+ NumberDataType :: UInt64 ,
257+ ) )
258+ }
259+ _ => try_create ( "count" , params, arguments, vec ! [ ] , & count_creator) ,
260+ }
261+ } ) ,
262+ AggregateFunctionFeatures {
263+ returns_default_when_only_null : true ,
264+ keep_nullable : true ,
265+ ..Default :: default ( )
266+ } ,
267+ )
221268}
222269
223- pub fn try_create (
270+ fn try_create (
224271 nested_name : & str ,
225272 params : Vec < Scalar > ,
226273 arguments : Vec < DataType > ,
227274 sort_descs : Vec < AggregateFunctionSortDesc > ,
228275 nested_creator : & AggregateFunctionCreator ,
229276) -> Result < Arc < dyn AggregateFunction > > {
230- let name = format ! ( "DistinctCombinator({})" , nested_name ) ;
277+ let name = format ! ( "DistinctCombinator({nested_name })" ) ;
231278 assert_variadic_arguments ( & name, arguments. len ( ) , ( 1 , 32 ) ) ?;
232279
233280 let nested_arguments = match nested_name {
@@ -236,53 +283,54 @@ pub fn try_create(
236283 } ;
237284 let nested = nested_creator ( nested_name, params, nested_arguments, sort_descs) ?;
238285
239- if arguments. len ( ) == 1 {
240- match & arguments[ 0 ] {
241- DataType :: Number ( ty) => with_number_mapped_type ! ( |NUM_TYPE | match ty {
242- NumberDataType :: NUM_TYPE => {
243- return Ok ( Arc :: new( AggregateDistinctCombinator :: <
244- AggregateDistinctNumberState <NUM_TYPE >,
245- > {
246- nested_name: nested_name. to_owned( ) ,
247- arguments,
248- nested,
249- name,
250- _s: PhantomData ,
251- } ) ) ;
252- }
253- } ) ,
254- DataType :: String => {
255- return match nested_name {
256- "count" | "uniq" => Ok ( Arc :: new ( AggregateDistinctCombinator :: <
257- AggregateUniqStringState ,
258- > {
259- name,
260- arguments,
261- nested,
262- nested_name : nested_name. to_owned ( ) ,
263- _s : PhantomData ,
264- } ) ) ,
265- _ => Ok ( Arc :: new ( AggregateDistinctCombinator :: <
266- AggregateDistinctStringState ,
267- > {
268- nested_name : nested_name. to_owned ( ) ,
269- arguments,
270- nested,
271- name,
272- _s : PhantomData ,
273- } ) ) ,
274- } ;
286+ match * arguments {
287+ [ DataType :: Number ( ty) ] => with_number_mapped_type ! ( |NUM_TYPE | match ty {
288+ NumberDataType :: NUM_TYPE => {
289+ Ok ( Arc :: new( AggregateDistinctCombinator :: <
290+ AggregateDistinctNumberState <NUM_TYPE >,
291+ > {
292+ nested_name: nested_name. to_owned( ) ,
293+ arguments,
294+ skip_null: false ,
295+ nested,
296+ name,
297+ _s: PhantomData ,
298+ } ) )
275299 }
276- _ => { }
300+ } ) ,
301+ [ DataType :: String ] if matches ! ( nested_name, "count" | "uniq" ) => {
302+ Ok ( Arc :: new ( AggregateDistinctCombinator :: <
303+ AggregateUniqStringState ,
304+ > {
305+ name,
306+ arguments,
307+ skip_null : false ,
308+ nested,
309+ nested_name : nested_name. to_owned ( ) ,
310+ _s : PhantomData ,
311+ } ) )
277312 }
313+ [ DataType :: String ] => Ok ( Arc :: new ( AggregateDistinctCombinator :: <
314+ AggregateDistinctStringState ,
315+ > {
316+ nested_name : nested_name. to_owned ( ) ,
317+ arguments,
318+ skip_null : false ,
319+ nested,
320+ name,
321+ _s : PhantomData ,
322+ } ) ) ,
323+ _ => Ok ( Arc :: new ( AggregateDistinctCombinator :: <
324+ AggregateDistinctState ,
325+ > {
326+ nested_name : nested_name. to_owned ( ) ,
327+ skip_null : nested_name == "count"
328+ && arguments. len ( ) > 1
329+ && arguments. iter ( ) . any ( DataType :: is_nullable_or_null) ,
330+ arguments,
331+ nested,
332+ name,
333+ _s : PhantomData ,
334+ } ) ) ,
278335 }
279- Ok ( Arc :: new ( AggregateDistinctCombinator :: <
280- AggregateDistinctState ,
281- > {
282- nested_name : nested_name. to_owned ( ) ,
283- arguments,
284- nested,
285- name,
286- _s : PhantomData ,
287- } ) )
288336}
0 commit comments