Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 71 additions & 50 deletions datafusion/substrait/src/logical_plan/consumer/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
// specific language governing permissions and limitations
// under the License.

use super::utils::{make_renamed_schema, rename_expressions};
use super::utils::{alias_expressions, make_renamed_schema, rename_expressions};
use super::{DefaultSubstraitConsumer, SubstraitConsumer};
use crate::extensions::Extensions;
use datafusion::common::{not_impl_err, plan_err};
use datafusion::common::{DFSchema, not_impl_err, plan_err};
use datafusion::execution::SessionState;
use datafusion::logical_expr::{Aggregate, LogicalPlan, Projection, col};
use std::sync::Arc;
Expand Down Expand Up @@ -56,7 +56,6 @@ pub async fn from_substrait_plan_with_consumer(
let plan =
consumer.consume_rel(root.input.as_ref().unwrap()).await?;
if root.names.is_empty() {
// Backwards compatibility for plans missing names
return Ok(plan);
}
let renamed_schema =
Expand All @@ -65,55 +64,9 @@ pub async fn from_substrait_plan_with_consumer(
.has_equivalent_names_and_types(plan.schema())
.is_ok()
{
// Nothing to do if the schema is already equivalent
return Ok(plan);
}
match plan {
// If the last node of the plan produces expressions, bake the renames into those expressions.
// This isn't necessary for correctness, but helps with roundtrip tests.
LogicalPlan::Projection(p) => {
Ok(LogicalPlan::Projection(Projection::try_new(
rename_expressions(
p.expr,
p.input.schema(),
renamed_schema.fields(),
)?,
p.input,
)?))
}
LogicalPlan::Aggregate(a) => {
let (group_fields, expr_fields) =
renamed_schema.fields().split_at(a.group_expr.len());
let new_group_exprs = rename_expressions(
a.group_expr,
a.input.schema(),
group_fields,
)?;
let new_aggr_exprs = rename_expressions(
a.aggr_expr,
a.input.schema(),
expr_fields,
)?;
Ok(LogicalPlan::Aggregate(Aggregate::try_new(
a.input,
new_group_exprs,
new_aggr_exprs,
)?))
}
// There are probably more plans where we could bake things in, can add them later as needed.
// Otherwise, add a new Project to handle the renaming.
_ => Ok(LogicalPlan::Projection(Projection::try_new(
rename_expressions(
plan.schema()
.columns()
.iter()
.map(|c| col(c.to_owned())),
plan.schema(),
renamed_schema.fields(),
)?,
Arc::new(plan),
)?)),
}
apply_renames(plan, &renamed_schema)
}
},
None => plan_err!("Cannot parse plan relation: None"),
Expand All @@ -125,3 +78,71 @@ pub async fn from_substrait_plan_with_consumer(
),
}
}

/// Apply the root-level schema renames to the given plan.
///
/// The strategy depends on the plan type:
/// - **Projection**: renames (aliases + casts) are baked directly into the
/// projection expressions.
/// - **Aggregate**: only safe aliases are applied to the aggregate's
/// expressions. If struct-field casts are also needed, a wrapping Projection
/// is added on top (the physical planner rejects Cast-wrapped aggregates).
/// - **Other nodes**: a new Projection is added on top to carry the renames.
fn apply_renames(
plan: LogicalPlan,
renamed_schema: &DFSchema,
) -> datafusion::common::Result<LogicalPlan> {
match plan {
LogicalPlan::Projection(p) => {
Ok(LogicalPlan::Projection(Projection::try_new(
rename_expressions(
p.expr,
p.input.schema(),
renamed_schema.fields(),
)?,
p.input,
)?))
}
LogicalPlan::Aggregate(a) => {
let (group_fields, expr_fields) =
renamed_schema.fields().split_at(a.group_expr.len());
let agg = LogicalPlan::Aggregate(Aggregate::try_new(
a.input,
alias_expressions(a.group_expr, group_fields)?,
alias_expressions(a.aggr_expr, expr_fields)?,
)?);
// If aliasing alone didn't satisfy the target schema
// (e.g. nested struct field renames require casts), wrap
// in a Projection that can safely carry those casts.
if renamed_schema
.has_equivalent_names_and_types(agg.schema())
.is_ok()
{
Ok(agg)
} else {
Ok(LogicalPlan::Projection(Projection::try_new(
rename_expressions(
agg.schema()
.columns()
.iter()
.map(|c| col(c.to_owned())),
agg.schema(),
renamed_schema.fields(),
)?,
Arc::new(agg),
)?))
}
}
_ => Ok(LogicalPlan::Projection(Projection::try_new(
rename_expressions(
plan.schema()
.columns()
.iter()
.map(|c| col(c.to_owned())),
plan.schema(),
renamed_schema.fields(),
)?,
Arc::new(plan),
)?)),
}
}
18 changes: 18 additions & 0 deletions datafusion/substrait/src/logical_plan/consumer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,24 @@ pub(super) fn make_renamed_schema(
)
}

/// Apply only top-level column name aliases to expressions, without casting.
/// Unlike `rename_expressions`, this never injects `Expr::Cast` for nested type
/// differences, making it safe for aggregate expressions which the physical
/// planner requires to be pure AggregateFunctions (optionally aliased).
pub(super) fn alias_expressions(
exprs: impl IntoIterator<Item = Expr>,
new_schema_fields: &[Arc<Field>],
) -> datafusion::common::Result<Vec<Expr>> {
exprs
.into_iter()
.zip(new_schema_fields)
.map(|(old_expr, new_field)| match &old_expr {
Expr::Column(c) if &c.name == new_field.name() => Ok(old_expr),
_ => old_expr.alias_if_changed(new_field.name().to_owned()),
})
.collect()
}

/// Ensure the expressions have the right name(s) according to the new schema.
/// This includes the top-level (column) name, which will be renamed through aliasing if needed,
/// as well as nested names (if the expression produces any struct types), which will be renamed
Expand Down
180 changes: 180 additions & 0 deletions datafusion/substrait/tests/cases/aggregation_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,19 @@ mod tests {
use datafusion::prelude::SessionContext;
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use insta::assert_snapshot;
use substrait::proto::aggregate_rel::Measure;
use substrait::proto::expression::field_reference::ReferenceType;
use substrait::proto::expression::reference_segment::ReferenceType as SegRefType;
use substrait::proto::expression::{FieldReference, ReferenceSegment, RexType};
use substrait::proto::extensions::simple_extension_declaration::MappingType;
use substrait::proto::extensions::SimpleExtensionDeclaration;
use substrait::proto::function_argument::ArgType;
use substrait::proto::read_rel::{ReadType, VirtualTable};
use substrait::proto::rel::RelType;
use substrait::proto::{
AggregateFunction, Expression, FunctionArgument, NamedStruct, Plan, PlanRel,
ReadRel, Rel, RelRoot, Type, r#type,
};

#[tokio::test]
async fn no_grouping_set() -> Result<()> {
Expand Down Expand Up @@ -68,4 +81,171 @@ mod tests {

Ok(())
}

/// When root names rename struct fields inside an aggregate measure's
/// return type (e.g. `List<Struct{c0,c1}>` → `List<Struct{one,two}>`),
/// `rename_expressions` injects `Expr::Cast` around the aggregate
/// function. The physical planner rejects Cast-wrapped aggregates.
/// This test verifies that the consumer wraps the Aggregate in a
/// Projection instead.
#[tokio::test]
#[expect(deprecated)]
async fn aggregate_with_struct_field_rename() -> Result<()> {
// Build a Substrait plan:
// ReadRel(VirtualTable) with one column: List<Struct{c0: Utf8, c1: Utf8}>
// AggregateRel with array_agg on that column
// Root names rename struct fields: c0 → one, c1 → two

let utf8_nullable = Type {
kind: Some(r#type::Kind::String(r#type::String {
type_variation_reference: 0,
nullability: r#type::Nullability::Nullable as i32,
})),
};

let struct_type = r#type::Struct {
types: vec![utf8_nullable.clone(), utf8_nullable.clone()],
type_variation_reference: 0,
nullability: r#type::Nullability::Nullable as i32,
};

let list_of_struct = Type {
kind: Some(r#type::Kind::List(Box::new(r#type::List {
r#type: Some(Box::new(Type {
kind: Some(r#type::Kind::Struct(struct_type.clone())),
})),
type_variation_reference: 0,
nullability: r#type::Nullability::Nullable as i32,
}))),
};

// ReadRel with VirtualTable (empty) and base_schema
let read_rel = Rel {
rel_type: Some(RelType::Read(Box::new(ReadRel {
common: None,
base_schema: Some(NamedStruct {
names: vec![
"col0".to_string(),
"c0".to_string(),
"c1".to_string(),
],
r#struct: Some(r#type::Struct {
types: vec![list_of_struct.clone()],
type_variation_reference: 0,
nullability: r#type::Nullability::Required as i32,
}),
}),
filter: None,
best_effort_filter: None,
projection: None,
advanced_extension: None,
read_type: Some(ReadType::VirtualTable(VirtualTable {
values: vec![],
expressions: vec![],
})),
}))),
};

// AggregateRel with array_agg(col0)
let field_ref = Expression {
rex_type: Some(RexType::Selection(Box::new(FieldReference {
reference_type: Some(ReferenceType::DirectReference(
ReferenceSegment {
reference_type: Some(SegRefType::StructField(Box::new(
substrait::proto::expression::reference_segment::StructField {
field: 0,
child: None,
},
))),
},
)),
root_type: Some(
substrait::proto::expression::field_reference::RootType::RootReference(
substrait::proto::expression::field_reference::RootReference {},
),
),
}))),
};

let aggregate_rel = Rel {
rel_type: Some(RelType::Aggregate(Box::new(
substrait::proto::AggregateRel {
common: None,
input: Some(Box::new(read_rel)),
grouping_expressions: vec![],
groupings: vec![],
measures: vec![Measure {
measure: Some(AggregateFunction {
function_reference: 1,
arguments: vec![FunctionArgument {
arg_type: Some(ArgType::Value(field_ref)),
}],
sorts: vec![],
output_type: Some(list_of_struct),
invocation: 1, // AGGREGATION_INVOCATION_ALL
phase: 3, // AGGREGATION_PHASE_INITIAL_TO_RESULT
args: vec![],
options: vec![],
}),
filter: None,
}],
advanced_extension: None,
},
))),
};

// Root names: rename struct fields c0 → one, c1 → two
let proto_plan = Plan {
version: None,
extension_uris: vec![substrait::proto::extensions::SimpleExtensionUri {
extension_uri_anchor: 1,
uri: "/functions_aggregate.yaml".to_string(),
}],
extension_urns: vec![],
extensions: vec![SimpleExtensionDeclaration {
mapping_type: Some(MappingType::ExtensionFunction(
substrait::proto::extensions::simple_extension_declaration::ExtensionFunction {
extension_uri_reference: 1,
extension_urn_reference: 0,
function_anchor: 1,
name: "array_agg:list".to_string(),
},
)),
}],
relations: vec![PlanRel {
rel_type: Some(substrait::proto::plan_rel::RelType::Root(RelRoot {
input: Some(aggregate_rel),
names: vec![
"result".to_string(),
"one".to_string(),
"two".to_string(),
],
})),
}],
advanced_extensions: None,
expected_type_urls: vec![],
parameter_bindings: vec![],
type_aliases: vec![],
};

let ctx = SessionContext::new();
let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?;

// The plan should contain a Projection wrapping the Aggregate,
// because rename_expressions injects Cast for the struct field rename.
let plan_str = format!("{plan}");
assert!(
plan_str.contains("Projection:"),
"Expected Projection wrapper but got:\n{plan_str}"
);
assert!(
plan_str.contains("Aggregate:"),
"Expected Aggregate in plan but got:\n{plan_str}"
);

// Execute to confirm the physical planner accepts this plan
DataFrame::new(ctx.state(), plan).show().await?;

Ok(())
}
}
Loading