Skip to content
255 changes: 234 additions & 21 deletions datafusion/optimizer/src/optimize_projections/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ fn optimize_projections(
config: &dyn OptimizerConfig,
indices: RequiredIndices,
) -> Result<Transformed<LogicalPlan>> {
let volatile_in_plan = plan.expressions().iter().any(Expr::is_volatile);

// Recursively rewrite any nodes that may be able to avoid computation given
// their parents' required indices.
match plan {
Expand All @@ -141,6 +143,7 @@ fn optimize_projections(
});
}
LogicalPlan::Aggregate(aggregate) => {
let has_volatile_ancestor = indices.has_volatile_ancestor();
// Split parent requirements to GROUP BY and aggregate sections:
let n_group_exprs = aggregate.group_expr_len()?;
// Offset aggregate indices so that they point to valid indices at
Expand Down Expand Up @@ -188,6 +191,14 @@ fn optimize_projections(
let necessary_indices =
RequiredIndices::new().with_exprs(schema, all_exprs_iter);
let necessary_exprs = necessary_indices.get_required_exprs(schema);
let mut necessary_indices = if new_aggr_expr.is_empty() {
necessary_indices.for_multiplicity_insensitive_child()
} else {
necessary_indices.for_multiplicity_sensitive_child()
};
necessary_indices = necessary_indices
.with_volatile_ancestor_if(has_volatile_ancestor)
.with_plan_volatile(volatile_in_plan);

return optimize_projections(
Arc::unwrap_or_clone(aggregate.input),
Expand All @@ -213,6 +224,7 @@ fn optimize_projections(
});
}
LogicalPlan::Window(window) => {
let has_volatile_ancestor = indices.has_volatile_ancestor();
let input_schema = Arc::clone(window.input.schema());
// Split parent requirements to child and window expression sections:
let n_input_fields = input_schema.fields().len();
Expand All @@ -227,6 +239,14 @@ fn optimize_projections(
// Get all the required column indices at the input, either by the
// parent or window expression requirements.
let required_indices = child_reqs.with_exprs(&input_schema, &new_window_expr);
let mut required_indices = if new_window_expr.is_empty() {
required_indices.for_multiplicity_insensitive_child()
} else {
required_indices.for_multiplicity_sensitive_child()
};
required_indices = required_indices
.with_volatile_ancestor_if(has_volatile_ancestor)
.with_plan_volatile(volatile_in_plan);

return optimize_projections(
Arc::unwrap_or_clone(window.input),
Expand Down Expand Up @@ -293,10 +313,11 @@ fn optimize_projections(
plan.inputs()
.into_iter()
.map(|input| {
indices
let required = indices
.clone()
.with_projection_beneficial()
.with_plan_exprs(&plan, input.schema())
.with_plan_exprs(&plan, input.schema())?;
Ok(required.with_plan_volatile(volatile_in_plan))
})
.collect::<Result<_>>()?
}
Expand All @@ -307,7 +328,11 @@ fn optimize_projections(
// flag is `false`.
plan.inputs()
.into_iter()
.map(|input| indices.clone().with_plan_exprs(&plan, input.schema()))
.map(|input| {
let required =
indices.clone().with_plan_exprs(&plan, input.schema())?;
Ok(required.with_plan_volatile(volatile_in_plan))
})
.collect::<Result<_>>()?
}
LogicalPlan::Copy(_)
Expand All @@ -316,17 +341,19 @@ fn optimize_projections(
| LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Distinct(Distinct::All(_)) => {
| LogicalPlan::Statement(_) => {
// These plans require all their fields, and their children should
// be treated as final plans -- otherwise, we may have schema a
// mismatch.
// TODO: For some subquery variants (e.g. a subquery arising from an
// EXISTS expression), we may not need to require all indices.
plan.inputs()
.into_iter()
.map(RequiredIndices::new_for_all_exprs)
.collect()
.map(|input| {
let required = RequiredIndices::new_for_all_exprs(input);
Ok(required.with_plan_volatile(volatile_in_plan))
})
.collect::<Result<_>>()?
}
LogicalPlan::Extension(extension) => {
let Some(necessary_children_indices) =
Expand All @@ -348,8 +375,9 @@ fn optimize_projections(
.into_iter()
.zip(necessary_children_indices)
.map(|(child, necessary_indices)| {
RequiredIndices::new_from_indices(necessary_indices)
.with_plan_exprs(&plan, child.schema())
let required = RequiredIndices::new_from_indices(necessary_indices)
.with_plan_exprs(&plan, child.schema())?;
Ok(required.with_plan_volatile(volatile_in_plan))
})
.collect::<Result<Vec<_>>>()?
}
Expand All @@ -376,10 +404,11 @@ fn optimize_projections(
plan.inputs()
.into_iter()
.map(|input| {
indices
let required = indices
.clone()
.with_projection_beneficial()
.with_plan_exprs(&plan, input.schema())
.with_plan_exprs(&plan, input.schema())?;
Ok(required.with_plan_volatile(volatile_in_plan))
})
.collect::<Result<Vec<_>>>()?
}
Expand All @@ -391,13 +420,28 @@ fn optimize_projections(
left_req_indices.with_plan_exprs(&plan, join.left.schema())?;
let right_indices =
right_req_indices.with_plan_exprs(&plan, join.right.schema())?;
let left_indices = left_indices
.for_multiplicity_sensitive_child()
.with_plan_volatile(volatile_in_plan);
let right_indices = right_indices
.for_multiplicity_sensitive_child()
.with_plan_volatile(volatile_in_plan);
// Joins benefit from "small" input tables (lower memory usage).
// Therefore, each child benefits from projection:
vec![
left_indices.with_projection_beneficial(),
right_indices.with_projection_beneficial(),
]
}
LogicalPlan::Distinct(Distinct::All(_)) => plan
.inputs()
.into_iter()
.map(|input| {
let required = RequiredIndices::new_for_all_exprs(input)
.for_multiplicity_insensitive_child();
Ok(required.with_plan_volatile(volatile_in_plan))
})
.collect::<Result<_>>()?,
// these nodes are explicitly rewritten in the match statement above
LogicalPlan::Projection(_)
| LogicalPlan::Aggregate(_)
Expand All @@ -407,19 +451,29 @@ fn optimize_projections(
"OptimizeProjection: should have handled in the match statement above"
);
}
LogicalPlan::Unnest(Unnest {
input,
dependency_indices,
..
}) => {
LogicalPlan::Unnest(unnest) => {
if can_eliminate_unnest(unnest, &indices) {
let child_required_indices =
build_unnest_child_requirements(unnest, &indices);
let transformed_input = optimize_projections(
Arc::unwrap_or_clone(Arc::clone(&unnest.input)),
config,
child_required_indices,
)?;
return Ok(Transformed::yes(transformed_input.data));
}
// at least provide the indices for the exec-columns as a starting point
let required_indices =
RequiredIndices::new().with_plan_exprs(&plan, input.schema())?;
let mut required_indices =
RequiredIndices::new().with_plan_exprs(&plan, unnest.input.schema())?;
required_indices = required_indices
.for_multiplicity_sensitive_child()
.with_volatile_ancestor_if(indices.has_volatile_ancestor())
.with_plan_volatile(volatile_in_plan);

// Add additional required indices from the parent
let mut additional_necessary_child_indices = Vec::new();
indices.indices().iter().for_each(|idx| {
if let Some(index) = dependency_indices.get(*idx) {
if let Some(index) = unnest.dependency_indices.get(*idx) {
additional_necessary_child_indices.push(*index);
}
});
Expand Down Expand Up @@ -837,8 +891,14 @@ fn rewrite_projection_given_requirements(

let exprs_used = indices.get_at_indices(&expr);

let required_indices =
let mut required_indices =
RequiredIndices::new().with_exprs(input.schema(), exprs_used.iter());
if !indices.multiplicity_sensitive() {
required_indices = required_indices.for_multiplicity_insensitive_child();
}
if indices.has_volatile_ancestor() {
required_indices = required_indices.with_volatile_ancestor();
}

// rewrite the children projection, and if they are changed rewrite the
// projection down
Expand Down Expand Up @@ -909,6 +969,62 @@ fn plan_contains_other_subqueries(plan: &LogicalPlan, cte_name: &str) -> bool {
.any(|child| plan_contains_other_subqueries(child, cte_name))
}

fn can_eliminate_unnest(unnest: &Unnest, indices: &RequiredIndices) -> bool {
if indices.multiplicity_sensitive() || indices.has_volatile_ancestor() {
return false;
}

// List unnest can drop rows for empty lists even when preserve_nulls=true.
// Without proving non-empty cardinality, keep UNNEST conservatively.
if !unnest.list_type_columns.is_empty() {
return false;
}

// preserve_nulls only affects list unnest semantics. For struct-only unnest,
// row cardinality is unchanged and this option is not semantically relevant.

indices
.indices()
.iter()
.all(|&output_idx| unnest_output_is_passthrough(unnest, output_idx))
}

fn unnest_output_is_passthrough(unnest: &Unnest, output_idx: usize) -> bool {
let Some(&dependency_idx) = unnest.dependency_indices.get(output_idx) else {
return false;
};

if dependency_idx >= unnest.input.schema().fields().len() {
return false;
}

unnest.schema.qualified_field(output_idx)
== unnest.input.schema().qualified_field(dependency_idx)
}

fn build_unnest_child_requirements(
unnest: &Unnest,
indices: &RequiredIndices,
) -> RequiredIndices {
let child_indices = indices
.indices()
.iter()
.filter_map(|&output_idx| unnest.dependency_indices.get(output_idx).copied())
.collect::<Vec<_>>();
let mut child_required_indices = RequiredIndices::new_from_indices(child_indices);
if indices.projection_beneficial() {
child_required_indices = child_required_indices.with_projection_beneficial();
}
if indices.has_volatile_ancestor() {
child_required_indices = child_required_indices.with_volatile_ancestor();
}
if !indices.multiplicity_sensitive() {
child_required_indices =
child_required_indices.for_multiplicity_insensitive_child();
}
child_required_indices
}

fn expr_contains_subquery(expr: &Expr) -> bool {
expr.exists(|e| match e {
Expr::ScalarSubquery(_) | Expr::Exists(_) | Expr::InSubquery(_) => Ok(true),
Expand Down Expand Up @@ -953,7 +1069,7 @@ mod tests {
use crate::{OptimizerContext, OptimizerRule};
use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{
Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference,
Column, DFSchema, DFSchemaRef, JoinType, Result, TableReference, UnnestOptions,
};
use datafusion_expr::ExprFunctionExt;
use datafusion_expr::{
Expand Down Expand Up @@ -2274,6 +2390,103 @@ mod tests {
)
}

#[test]
fn eliminate_struct_unnest_when_only_group_keys_are_required() -> Result<()> {
let schema = Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new(
"user",
DataType::Struct(
vec![
Field::new("name", DataType::Utf8, true),
Field::new("score", DataType::Int32, true),
]
.into(),
),
true,
),
]);
let plan = scan_empty(Some("test"), &schema, None)?
.unnest_column("user")?
.aggregate(vec![col("id")], Vec::<Expr>::new())?
.project(vec![col("id")])?
.build()?;

let optimized = optimize(plan)?;
let formatted = format!("{}", optimized.display_indent());
assert!(!formatted.contains("Unnest:"));
Ok(())
}

#[test]
fn keep_list_unnest_when_group_keys_are_only_required_outputs() -> Result<()> {
let schema = Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new(
"vals",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
),
]);
let plan = scan_empty(Some("test"), &schema, None)?
.unnest_column("vals")?
.aggregate(vec![col("id")], Vec::<Expr>::new())?
.project(vec![col("id")])?
.build()?;

let optimized = optimize(plan)?;
let formatted = format!("{}", optimized.display_indent());
assert!(formatted.contains("Unnest:"));
Ok(())
}

#[test]
fn keep_unnest_when_count_depends_on_row_multiplicity() -> Result<()> {
let schema = Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new(
"vals",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
),
]);
let plan = scan_empty(Some("test"), &schema, None)?
.unnest_column("vals")?
.aggregate(vec![col("id")], vec![count(lit(1)).alias("cnt")])?
.project(vec![col("id"), col("cnt")])?
.build()?;

let optimized = optimize(plan)?;
let formatted = format!("{}", optimized.display_indent());
assert!(formatted.contains("Unnest:"));
Ok(())
}

#[test]
fn keep_unnest_when_preserve_nulls_is_disabled() -> Result<()> {
let schema = Schema::new(vec![
Field::new("id", DataType::UInt32, false),
Field::new(
"vals",
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
true,
),
]);
let plan = scan_empty(Some("test"), &schema, None)?
.unnest_column_with_options(
"vals",
UnnestOptions::new().with_preserve_nulls(false),
)?
.aggregate(vec![col("id")], Vec::<Expr>::new())?
.project(vec![col("id")])?
.build()?;

let optimized = optimize(plan)?;
let formatted = format!("{}", optimized.display_indent());
assert!(formatted.contains("Unnest:"));
Ok(())
}

#[test]
fn test_window() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down
Loading
Loading