Skip to content

Commit 9eddf47

Browse files
authored
fix: CteWorkTable: properly apply TableProvider::scan projection argument (#18993)
It was previously ignored ## Which issue does this PR close? - Closes #18992. ## Rationale for this change All `TableProvider` implementations must support the `projection` argument of the `scan` method. This was not the case in `CteWorkTable`. ## What changes are included in this PR? Minimal implementation of the projection support. The projection applied before the plan node return results. It might be nice to push it further inside of the recursion implementation to reduce memory consumption but I preferred to keep the fix minimal. ## Are these changes tested? I have not figured out yet a nice SQL query to trigger an error without this change. Some existing queries in `cte.slt` have set projection (i.e. not `None`) so the code is very likely working. There is also a test on the projection itself in `WorkTableExec`
1 parent 6ac7b89 commit 9eddf47

File tree

2 files changed

+106
-23
lines changed

2 files changed

+106
-23
lines changed

datafusion/catalog/src/cte_worktable.rs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,18 @@
1717

1818
//! CteWorkTable implementation used for recursive queries
1919
20+
use std::any::Any;
21+
use std::borrow::Cow;
2022
use std::sync::Arc;
21-
use std::{any::Any, borrow::Cow};
2223

23-
use crate::Session;
2424
use arrow::datatypes::SchemaRef;
2525
use async_trait::async_trait;
26-
use datafusion_physical_plan::work_table::WorkTableExec;
27-
28-
use datafusion_physical_plan::ExecutionPlan;
29-
3026
use datafusion_common::error::Result;
3127
use datafusion_expr::{Expr, LogicalPlan, TableProviderFilterPushDown, TableType};
28+
use datafusion_physical_plan::ExecutionPlan;
29+
use datafusion_physical_plan::work_table::WorkTableExec;
3230

33-
use crate::TableProvider;
31+
use crate::{ScanArgs, ScanResult, Session, TableProvider};
3432

3533
/// The temporary working table where the previous iteration of a recursive query is stored
3634
/// Naming is based on PostgreSQL's implementation.
@@ -85,16 +83,28 @@ impl TableProvider for CteWorkTable {
8583

8684
async fn scan(
8785
&self,
88-
_state: &dyn Session,
89-
_projection: Option<&Vec<usize>>,
90-
_filters: &[Expr],
91-
_limit: Option<usize>,
86+
state: &dyn Session,
87+
projection: Option<&Vec<usize>>,
88+
filters: &[Expr],
89+
limit: Option<usize>,
9290
) -> Result<Arc<dyn ExecutionPlan>> {
93-
// TODO: pushdown filters and limits
94-
Ok(Arc::new(WorkTableExec::new(
91+
let options = ScanArgs::default()
92+
.with_projection(projection.map(|p| p.as_slice()))
93+
.with_filters(Some(filters))
94+
.with_limit(limit);
95+
Ok(self.scan_with_args(state, options).await?.into_inner())
96+
}
97+
98+
async fn scan_with_args<'a>(
99+
&self,
100+
_state: &dyn Session,
101+
args: ScanArgs<'a>,
102+
) -> Result<ScanResult> {
103+
Ok(ScanResult::new(Arc::new(WorkTableExec::new(
95104
self.name.clone(),
96105
Arc::clone(&self.table_schema),
97-
)))
106+
args.projection().map(|p| p.to_vec()),
107+
)?)))
98108
}
99109

100110
fn supports_filters_pushdown(

datafusion/physical-plan/src/work_table.rs

Lines changed: 82 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,8 @@ pub struct WorkTableExec {
102102
name: String,
103103
/// The schema of the stream
104104
schema: SchemaRef,
105+
/// Projection to apply to build the output stream from the recursion state
106+
projection: Option<Vec<usize>>,
105107
/// The work table
106108
work_table: Arc<WorkTable>,
107109
/// Execution metrics
@@ -112,15 +114,23 @@ pub struct WorkTableExec {
112114

113115
impl WorkTableExec {
114116
/// Create a new execution plan for a worktable exec.
115-
pub fn new(name: String, schema: SchemaRef) -> Self {
117+
pub fn new(
118+
name: String,
119+
mut schema: SchemaRef,
120+
projection: Option<Vec<usize>>,
121+
) -> Result<Self> {
122+
if let Some(projection) = &projection {
123+
schema = Arc::new(schema.project(projection)?);
124+
}
116125
let cache = Self::compute_properties(Arc::clone(&schema));
117-
Self {
126+
Ok(Self {
118127
name: name.clone(),
119128
schema,
120-
metrics: ExecutionPlanMetricsSet::new(),
129+
projection,
121130
work_table: Arc::new(WorkTable::new(name)),
131+
metrics: ExecutionPlanMetricsSet::new(),
122132
cache,
123-
}
133+
})
124134
}
125135

126136
/// Ref to name
@@ -198,11 +208,22 @@ impl ExecutionPlan for WorkTableExec {
198208
0,
199209
"WorkTableExec got an invalid partition {partition} (expected 0)"
200210
);
201-
let batch = self.work_table.take()?;
211+
let ReservedBatches {
212+
mut batches,
213+
reservation,
214+
} = self.work_table.take()?;
215+
if let Some(projection) = &self.projection {
216+
// We apply the projection
217+
// TODO: it would be better to apply it as soon as possible and not only here
218+
// TODO: an aggressive projection makes the memory reservation smaller, even if we do not edit it
219+
batches = batches
220+
.into_iter()
221+
.map(|b| b.project(projection))
222+
.collect::<Result<Vec<_>, _>>()?;
223+
}
202224

203-
let stream =
204-
MemoryStream::try_new(batch.batches, Arc::clone(&self.schema), None)?
205-
.with_reservation(batch.reservation);
225+
let stream = MemoryStream::try_new(batches, Arc::clone(&self.schema), None)?
226+
.with_reservation(reservation);
206227
Ok(Box::pin(cooperative(stream)))
207228
}
208229

@@ -239,6 +260,7 @@ impl ExecutionPlan for WorkTableExec {
239260
Some(Arc::new(Self {
240261
name: self.name.clone(),
241262
schema: Arc::clone(&self.schema),
263+
projection: self.projection.clone(),
242264
metrics: ExecutionPlanMetricsSet::new(),
243265
work_table,
244266
cache: self.cache.clone(),
@@ -249,8 +271,10 @@ impl ExecutionPlan for WorkTableExec {
249271
#[cfg(test)]
250272
mod tests {
251273
use super::*;
252-
use arrow::array::{ArrayRef, Int32Array};
274+
use arrow::array::{ArrayRef, Int16Array, Int32Array, Int64Array};
275+
use arrow_schema::{DataType, Field, Schema};
253276
use datafusion_execution::memory_pool::{MemoryConsumer, UnboundedMemoryPool};
277+
use futures::StreamExt;
254278

255279
#[test]
256280
fn test_work_table() {
@@ -283,4 +307,53 @@ mod tests {
283307
drop(memory_stream);
284308
assert_eq!(pool.reserved(), 0);
285309
}
310+
311+
#[tokio::test]
312+
async fn test_work_table_exec() {
313+
let schema = Arc::new(Schema::new(vec![
314+
Field::new("a", DataType::Int64, false),
315+
Field::new("b", DataType::Int32, false),
316+
Field::new("c", DataType::Int16, false),
317+
]));
318+
let work_table_exec =
319+
WorkTableExec::new("wt".into(), Arc::clone(&schema), Some(vec![2, 1]))
320+
.unwrap();
321+
322+
// We inject the work table
323+
let work_table = Arc::new(WorkTable::new("wt".into()));
324+
let work_table_exec = work_table_exec
325+
.with_new_state(Arc::clone(&work_table) as _)
326+
.unwrap();
327+
328+
// We update the work table
329+
let pool = Arc::new(UnboundedMemoryPool::default()) as _;
330+
let reservation = MemoryConsumer::new("test_work_table").register(&pool);
331+
let batch = RecordBatch::try_new(
332+
Arc::clone(&schema),
333+
vec![
334+
Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5])),
335+
Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])),
336+
Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])),
337+
],
338+
)
339+
.unwrap();
340+
work_table.update(ReservedBatches::new(vec![batch], reservation));
341+
342+
// We get back the batch from the work table
343+
let returned_batch = work_table_exec
344+
.execute(0, Arc::new(TaskContext::default()))
345+
.unwrap()
346+
.next()
347+
.await
348+
.unwrap()
349+
.unwrap();
350+
assert_eq!(
351+
returned_batch,
352+
RecordBatch::try_from_iter(vec![
353+
("c", Arc::new(Int16Array::from(vec![1, 2, 3, 4, 5])) as _),
354+
("b", Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as _),
355+
])
356+
.unwrap()
357+
);
358+
}
286359
}

0 commit comments

Comments
 (0)