diff --git a/native-engine/datafusion-ext-plans/src/agg/collect.rs b/native-engine/datafusion-ext-plans/src/agg/collect.rs index b20acafdd..467e1f772 100644 --- a/native-engine/datafusion-ext-plans/src/agg/collect.rs +++ b/native-engine/datafusion-ext-plans/src/agg/collect.rs @@ -561,10 +561,6 @@ impl AccSet { } pub fn merge(&mut self, other: &mut Self) { - if self.set.len() < other.set.len() { - // ensure the probed set is smaller - std::mem::swap(self, other); - } for pos_len in std::mem::take(&mut other.set).into_iter() { self.append_raw(other.list.ref_raw(pos_len)); } @@ -707,6 +703,26 @@ mod tests { assert_eq!(acc_set1.list.raw.len(), 12); // 4 bytes for each int32 assert_eq!(acc_set1.set.len(), 3); + let values: Vec = acc_set1.into_values(DataType::Int32, false).collect(); + assert_eq!(values, vec![value1, value2, value3]); + } + + #[test] + fn test_acc_set_merge_preserves_first_occurrence_order_when_rhs_is_larger() { + let mut acc_set1 = AccSet::default(); + let mut acc_set2 = AccSet::default(); + let value1 = ScalarValue::Int32(Some(1)); + let value2 = ScalarValue::Int32(Some(2)); + let value3 = ScalarValue::Int32(Some(3)); + + acc_set1.append(&value1, false); + acc_set2.append(&value2, false); + acc_set2.append(&value3, false); + + acc_set1.merge(&mut acc_set2); + + let values: Vec = acc_set1.into_values(DataType::Int32, false).collect(); + assert_eq!(values, vec![value1, value2, value3]); } #[test] @@ -759,4 +775,33 @@ mod tests { assert_eq!(acc_col.take_values(2), acc_col_unspill.take_values(2)); Ok(()) } + + #[test] + fn test_acc_set_merge_preserves_first_occurrence_order_after_rhs_spill() -> Result<()> { + let value1 = ScalarValue::Int32(Some(1)); + let value2 = ScalarValue::Int32(Some(2)); + let value3 = ScalarValue::Int32(Some(3)); + + let mut lhs = AccSetColumn::empty(DataType::Int32); + lhs.resize(1); + lhs.append_item(0, &value1); + + let mut rhs = AccSetColumn::empty(DataType::Int32); + rhs.resize(1); + rhs.append_item(0, &value2); + rhs.append_item(0, &value3); + + let mut spill: Box = Box::new(vec![]); + let mut spill_writer = spill.get_compressed_writer(); + rhs.spill(IdxSelection::Range(0, 1), &mut spill_writer)?; + spill_writer.finish()?; + + let mut rhs_unspill = AccSetColumn::empty(DataType::Int32); + rhs_unspill.unspill(1, &mut spill.get_compressed_reader())?; + + lhs.merge_items(0, &mut rhs_unspill, 0); + + assert_eq!(lhs.take_values(0), vec![value1, value2, value3]); + Ok(()) + } }