@@ -43,20 +43,35 @@ use super::AggregateCountFunction;
4343use super :: AggregateFunction ;
4444use super :: AggregateFunctionCreator ;
4545use super :: AggregateFunctionDescription ;
46+ use super :: AggregateFunctionFeatures ;
4647use super :: AggregateFunctionSortDesc ;
4748use super :: CombinatorDescription ;
4849use super :: StateAddr ;
50+ use crate :: aggregates:: adaptors:: AggregateFunctionCombinatorNull ;
4951
50- #[ derive( Clone ) ]
5152pub struct AggregateDistinctCombinator < State > {
5253 name : String ,
5354
5455 nested_name : String ,
5556 arguments : Vec < DataType > ,
57+ check_null : bool ,
5658 nested : Arc < dyn AggregateFunction > ,
5759 _s : PhantomData < fn ( State ) > ,
5860}
5961
62+ impl < State > Clone for AggregateDistinctCombinator < State > {
63+ fn clone ( & self ) -> Self {
64+ Self {
65+ name : self . name . clone ( ) ,
66+ nested_name : self . nested_name . clone ( ) ,
67+ arguments : self . arguments . clone ( ) ,
68+ check_null : self . check_null ,
69+ nested : self . nested . clone ( ) ,
70+ _s : PhantomData ,
71+ }
72+ }
73+ }
74+
6075impl < State > AggregateDistinctCombinator < State >
6176where State : Send + ' static
6277{
@@ -104,12 +119,12 @@ where State: DistinctStateFunc
104119 input_rows : usize ,
105120 ) -> Result < ( ) > {
106121 let state = Self :: get_state ( place) ;
107- state. batch_add ( columns, validity, input_rows)
122+ state. batch_add ( columns, validity, input_rows, self . check_null )
108123 }
109124
110125 fn accumulate_row ( & self , place : AggrState , columns : ProjectedBlock , row : usize ) -> Result < ( ) > {
111126 let state = Self :: get_state ( place) ;
112- state. add ( columns, row)
127+ state. add ( columns, row, self . check_null )
113128 }
114129
115130 fn serialize_type ( & self ) -> Vec < StateSerdeItem > {
@@ -202,32 +217,55 @@ pub fn aggregate_combinator_distinct_desc() -> CombinatorDescription {
202217 CombinatorDescription :: creator ( Box :: new ( try_create) )
203218}
204219
205- pub fn aggregate_combinator_uniq_desc ( ) -> AggregateFunctionDescription {
206- let features = super :: AggregateFunctionFeatures {
220+ pub fn aggregate_uniq_desc ( ) -> AggregateFunctionDescription {
221+ let features = AggregateFunctionFeatures {
207222 returns_default_when_only_null : true ,
208223 ..Default :: default ( )
209224 } ;
210- AggregateFunctionDescription :: creator_with_features ( Box :: new ( try_create_uniq) , features)
225+ AggregateFunctionDescription :: creator_with_features (
226+ Box :: new ( |nested_name, params, arguments, sort_descs| {
227+ let creator = Box :: new ( AggregateCountFunction :: try_create) as _ ;
228+ try_create ( nested_name, params, arguments, sort_descs, & creator)
229+ } ) ,
230+ features,
231+ )
211232}
212233
213- pub fn try_create_uniq (
214- nested_name : & str ,
215- params : Vec < Scalar > ,
216- arguments : Vec < DataType > ,
217- sort_descs : Vec < AggregateFunctionSortDesc > ,
218- ) -> Result < Arc < dyn AggregateFunction > > {
219- let creator: AggregateFunctionCreator = Box :: new ( AggregateCountFunction :: try_create) ;
220- try_create ( nested_name, params, arguments, sort_descs, & creator)
234+ pub fn aggregate_count_distinct_desc ( ) -> AggregateFunctionDescription {
235+ AggregateFunctionDescription :: creator_with_features (
236+ Box :: new ( |_, params, arguments, _| {
237+ let count_creator = Box :: new ( AggregateCountFunction :: try_create) as _ ;
238+ if matches ! ( * arguments, [ DataType :: Nullable ( _) ] ) {
239+ let new_arguments =
240+ AggregateFunctionCombinatorNull :: transform_arguments ( & arguments) ?;
241+ let nested = try_create (
242+ "count" ,
243+ params. clone ( ) ,
244+ new_arguments,
245+ vec ! [ ] ,
246+ & count_creator,
247+ ) ?;
248+ AggregateFunctionCombinatorNull :: try_create ( params, arguments, nested, true )
249+ } else {
250+ try_create ( "count" , params, arguments, vec ! [ ] , & count_creator)
251+ }
252+ } ) ,
253+ AggregateFunctionFeatures {
254+ returns_default_when_only_null : true ,
255+ keep_nullable : true ,
256+ ..Default :: default ( )
257+ } ,
258+ )
221259}
222260
223- pub fn try_create (
261+ fn try_create (
224262 nested_name : & str ,
225263 params : Vec < Scalar > ,
226264 arguments : Vec < DataType > ,
227265 sort_descs : Vec < AggregateFunctionSortDesc > ,
228266 nested_creator : & AggregateFunctionCreator ,
229267) -> Result < Arc < dyn AggregateFunction > > {
230- let name = format ! ( "DistinctCombinator({})" , nested_name ) ;
268+ let name = format ! ( "DistinctCombinator({nested_name })" ) ;
231269 assert_variadic_arguments ( & name, arguments. len ( ) , ( 1 , 32 ) ) ?;
232270
233271 let nested_arguments = match nested_name {
@@ -236,53 +274,54 @@ pub fn try_create(
236274 } ;
237275 let nested = nested_creator ( nested_name, params, nested_arguments, sort_descs) ?;
238276
239- if arguments. len ( ) == 1 {
240- match & arguments[ 0 ] {
241- DataType :: Number ( ty) => with_number_mapped_type ! ( |NUM_TYPE | match ty {
242- NumberDataType :: NUM_TYPE => {
243- return Ok ( Arc :: new( AggregateDistinctCombinator :: <
244- AggregateDistinctNumberState <NUM_TYPE >,
245- > {
246- nested_name: nested_name. to_owned( ) ,
247- arguments,
248- nested,
249- name,
250- _s: PhantomData ,
251- } ) ) ;
252- }
253- } ) ,
254- DataType :: String => {
255- return match nested_name {
256- "count" | "uniq" => Ok ( Arc :: new ( AggregateDistinctCombinator :: <
257- AggregateUniqStringState ,
258- > {
259- name,
260- arguments,
261- nested,
262- nested_name : nested_name. to_owned ( ) ,
263- _s : PhantomData ,
264- } ) ) ,
265- _ => Ok ( Arc :: new ( AggregateDistinctCombinator :: <
266- AggregateDistinctStringState ,
267- > {
268- nested_name : nested_name. to_owned ( ) ,
269- arguments,
270- nested,
271- name,
272- _s : PhantomData ,
273- } ) ) ,
274- } ;
277+ match * arguments {
278+ [ DataType :: Number ( ty) ] => with_number_mapped_type ! ( |NUM_TYPE | match ty {
279+ NumberDataType :: NUM_TYPE => {
280+ Ok ( Arc :: new( AggregateDistinctCombinator :: <
281+ AggregateDistinctNumberState <NUM_TYPE >,
282+ > {
283+ nested_name: nested_name. to_owned( ) ,
284+ arguments,
285+ check_null: false ,
286+ nested,
287+ name,
288+ _s: PhantomData ,
289+ } ) )
275290 }
276- _ => { }
291+ } ) ,
292+ [ DataType :: String ] if matches ! ( nested_name, "count" | "uniq" ) => {
293+ Ok ( Arc :: new ( AggregateDistinctCombinator :: <
294+ AggregateUniqStringState ,
295+ > {
296+ name,
297+ arguments,
298+ check_null : false ,
299+ nested,
300+ nested_name : nested_name. to_owned ( ) ,
301+ _s : PhantomData ,
302+ } ) )
277303 }
304+ [ DataType :: String ] => Ok ( Arc :: new ( AggregateDistinctCombinator :: <
305+ AggregateDistinctStringState ,
306+ > {
307+ nested_name : nested_name. to_owned ( ) ,
308+ arguments,
309+ check_null : false ,
310+ nested,
311+ name,
312+ _s : PhantomData ,
313+ } ) ) ,
314+ _ => Ok ( Arc :: new ( AggregateDistinctCombinator :: <
315+ AggregateDistinctState ,
316+ > {
317+ nested_name : nested_name. to_owned ( ) ,
318+ check_null : nested_name == "count"
319+ && arguments. len ( ) > 1
320+ && arguments. iter ( ) . any ( DataType :: is_nullable_or_null) ,
321+ arguments,
322+ nested,
323+ name,
324+ _s : PhantomData ,
325+ } ) ) ,
278326 }
279- Ok ( Arc :: new ( AggregateDistinctCombinator :: <
280- AggregateDistinctState ,
281- > {
282- nested_name : nested_name. to_owned ( ) ,
283- arguments,
284- nested,
285- name,
286- _s : PhantomData ,
287- } ) )
288327}
0 commit comments