|
15 | 15 | // specific language governing permissions and limitations |
16 | 16 | // under the License. |
17 | 17 |
|
| 18 | +use crate::utils::claim_buffers_recursive; |
18 | 19 | use ahash::RandomState; |
19 | 20 | use arrow::{ |
20 | 21 | array::{Array, ArrayRef, AsArray, BooleanArray, Int64Array, PrimitiveArray}, |
@@ -773,15 +774,25 @@ impl DistinctCountAccumulator { |
773 | 774 |
|
774 | 775 | // calculates the size as accurately as possible. Note that calling this |
775 | 776 | // method is expensive |
776 | | - fn full_size(&self) -> usize { |
777 | | - size_of_val(self) |
| 777 | + fn full_size(&self, pool: Option<&dyn MemoryPool>) -> usize { |
| 778 | + let mut total = size_of_val(self) |
778 | 779 | + (size_of::<ScalarValue>() * self.values.capacity()) |
779 | | - + self |
780 | | - .values |
781 | | - .iter() |
782 | | - .map(|vals| ScalarValue::size(vals) - size_of_val(vals)) |
783 | | - .sum::<usize>() |
784 | | - + size_of::<DataType>() |
| 780 | + + size_of::<DataType>(); |
| 781 | + |
| 782 | + for scalar in &self.values { |
| 783 | + if let Some(array) = scalar.get_array_ref() { |
| 784 | + total += size_of::<Arc<dyn Array>>(); |
| 785 | + if let Some(pool) = pool { |
| 786 | + claim_buffers_recursive(&array.to_data(), pool); |
| 787 | + } else { |
| 788 | + total += scalar.size() - size_of_val(scalar); |
| 789 | + } |
| 790 | + } else { |
| 791 | + total += scalar.size() - size_of_val(scalar); |
| 792 | + } |
| 793 | + } |
| 794 | + |
| 795 | + total |
785 | 796 | } |
786 | 797 | } |
787 | 798 |
|
@@ -840,11 +851,11 @@ impl Accumulator for DistinctCountAccumulator { |
840 | 851 | Ok(ScalarValue::Int64(Some(self.values.len() as i64))) |
841 | 852 | } |
842 | 853 |
|
843 | | - fn size(&self, _pool: Option<&dyn MemoryPool>) -> usize { |
| 854 | + fn size(&self, pool: Option<&dyn MemoryPool>) -> usize { |
844 | 855 | match &self.state_data_type { |
845 | 856 | DataType::Boolean | DataType::Null => self.fixed_size(), |
846 | 857 | d if d.is_primitive() => self.fixed_size(), |
847 | | - _ => self.full_size(), |
| 858 | + _ => self.full_size(pool), |
848 | 859 | } |
849 | 860 | } |
850 | 861 | } |
@@ -1046,4 +1057,51 @@ mod tests { |
1046 | 1057 | assert_eq!(merged.evaluate()?, ScalarValue::Int64(Some(3))); |
1047 | 1058 | Ok(()) |
1048 | 1059 | } |
| 1060 | + |
| 1061 | + #[test] |
| 1062 | + fn distinct_count_does_not_over_account_memory() -> Result<()> { |
| 1063 | + use arrow::array::ListArray; |
| 1064 | + use arrow_buffer::TrackingMemoryPool; |
| 1065 | + |
| 1066 | + // Create a DistinctCountAccumulator for List<Int32> (array type) |
| 1067 | + let mut acc = DistinctCountAccumulator { |
| 1068 | + values: HashSet::default(), |
| 1069 | + state_data_type: DataType::List(Arc::new(Field::new_list_field( |
| 1070 | + DataType::Int32, |
| 1071 | + true, |
| 1072 | + ))), |
| 1073 | + }; |
| 1074 | + |
| 1075 | + // Create list arrays with shared buffers (slices of the same underlying data) |
| 1076 | + let list_array = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![ |
| 1077 | + Some(vec![Some(1), Some(2), Some(3)]), |
| 1078 | + Some(vec![Some(4), Some(5), Some(6)]), |
| 1079 | + Some(vec![Some(1), Some(2), Some(3)]), // duplicate |
| 1080 | + Some(vec![Some(7), Some(8), Some(9)]), |
| 1081 | + Some(vec![Some(4), Some(5), Some(6)]), // duplicate |
| 1082 | + ]); |
| 1083 | + |
| 1084 | + acc.update_batch(&[Arc::new(list_array)])?; |
| 1085 | + |
| 1086 | + // Should have 3 distinct arrays |
| 1087 | + assert_eq!(acc.values.len(), 3); |
| 1088 | + |
| 1089 | + // Test with memory pool - should not over-account shared buffers |
| 1090 | + let pool = TrackingMemoryPool::default(); |
| 1091 | + let structural_size = acc.size(Some(&pool)); |
| 1092 | + let total_size_with_pool = structural_size + pool.used(); |
| 1093 | + |
| 1094 | + // Test without pool - uses scalar.size() which may over-account |
| 1095 | + let size_without_pool = acc.size(None); |
| 1096 | + |
| 1097 | + // With pool should be much smaller than without pool due to deduplication |
| 1098 | + // The pool tracks actual physical buffers, avoiding double-counting |
| 1099 | + // With the pool we get 13544 while when using the pool we get 4728 |
| 1100 | + assert!( |
| 1101 | + total_size_with_pool < size_without_pool, |
| 1102 | + "Pool-based size ({total_size_with_pool}) should be less than non-pool size ({size_without_pool}) due to buffer deduplication" |
| 1103 | + ); |
| 1104 | + |
| 1105 | + Ok(()) |
| 1106 | + } |
1049 | 1107 | } |
0 commit comments