Skip to content

Commit 2378e6e

Browse files
committed
Use memory pool for DistinctCountAccumulator as well
1 parent af8b94f commit 2378e6e

File tree

1 file changed

+68
-10
lines changed
  • datafusion/functions-aggregate/src

1 file changed

+68
-10
lines changed

datafusion/functions-aggregate/src/count.rs

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use crate::utils::claim_buffers_recursive;
1819
use ahash::RandomState;
1920
use arrow::{
2021
array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray},
@@ -773,15 +774,25 @@ impl DistinctCountAccumulator {
773774

774775
// calculates the size as accurately as possible. Note that calling this
775776
// method is expensive
776-
fn full_size(&self) -> usize {
777-
size_of_val(self)
777+
fn full_size(&self, pool: Option<&dyn MemoryPool>) -> usize {
778+
let mut total = size_of_val(self)
778779
+ (size_of::<ScalarValue>() * self.values.capacity())
779-
+ self
780-
.values
781-
.iter()
782-
.map(|vals| ScalarValue::size(vals) - size_of_val(vals))
783-
.sum::<usize>()
784-
+ size_of::<DataType>()
780+
+ size_of::<DataType>();
781+
782+
for scalar in &self.values {
783+
if let Some(array) = scalar.get_array_ref() {
784+
total += size_of::<Arc<dyn Array>>();
785+
if let Some(pool) = pool {
786+
claim_buffers_recursive(&array.to_data(), pool);
787+
} else {
788+
total += scalar.size() - size_of_val(scalar);
789+
}
790+
} else {
791+
total += scalar.size() - size_of_val(scalar);
792+
}
793+
}
794+
795+
total
785796
}
786797
}
787798

@@ -840,11 +851,11 @@ impl Accumulator for DistinctCountAccumulator {
840851
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
841852
}
842853

843-
fn size(&self, _pool: Option<&dyn MemoryPool>) -> usize {
854+
fn size(&self, pool: Option<&dyn MemoryPool>) -> usize {
844855
match &self.state_data_type {
845856
DataType::Boolean | DataType::Null => self.fixed_size(),
846857
d if d.is_primitive() => self.fixed_size(),
847-
_ => self.full_size(),
858+
_ => self.full_size(pool),
848859
}
849860
}
850861
}
@@ -1046,4 +1057,51 @@ mod tests {
10461057
assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3)));
10471058
Ok(())
10481059
}
1060+
1061+
#[test]
1062+
fn distinct_count_does_not_over_account_memory() -> Result<()> {
1063+
use arrow::array::ListArray;
1064+
use arrow_buffer::TrackingMemoryPool;
1065+
1066+
// Create a DistinctCountAccumulator for List<Int32> (array type)
1067+
let mut acc = DistinctCountAccumulator {
1068+
values: HashSet::default(),
1069+
state_data_type: DataType::List(Arc::new(Field::new_list_field(
1070+
DataType::Int32,
1071+
true,
1072+
))),
1073+
};
1074+
1075+
// Create list arrays with shared buffers (slices of the same underlying data)
1076+
let list_array = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
1077+
Some(vec![Some(1), Some(2), Some(3)]),
1078+
Some(vec![Some(4), Some(5), Some(6)]),
1079+
Some(vec![Some(1), Some(2), Some(3)]), // duplicate
1080+
Some(vec![Some(7), Some(8), Some(9)]),
1081+
Some(vec![Some(4), Some(5), Some(6)]), // duplicate
1082+
]);
1083+
1084+
acc.update_batch(&[Arc::new(list_array)])?;
1085+
1086+
// Should have 3 distinct arrays
1087+
assert_eq!(acc.values.len(), 3);
1088+
1089+
// Test with memory pool - should not over-account shared buffers
1090+
let pool = TrackingMemoryPool::default();
1091+
let structural_size = acc.size(Some(&pool));
1092+
let total_size_with_pool = structural_size + pool.used();
1093+
1094+
// Test without pool - uses scalar.size() which may over-account
1095+
let size_without_pool = acc.size(None);
1096+
1097+
// With pool should be much smaller than without pool due to deduplication
1098+
// The pool tracks actual physical buffers, avoiding double-counting
1099+
// With the pool we get 13544 while when using the pool we get 4728
1100+
assert!(
1101+
total_size_with_pool < size_without_pool,
1102+
"Pool-based size ({total_size_with_pool}) should be less than non-pool size ({size_without_pool}) due to buffer deduplication"
1103+
);
1104+
1105+
Ok(())
1106+
}
10491107
}

0 commit comments

Comments
 (0)