diff --git a/datafusion/substrait/src/logical_plan/consumer/plan.rs b/datafusion/substrait/src/logical_plan/consumer/plan.rs index d5e10fb604017..82ef80892bf1c 100644 --- a/datafusion/substrait/src/logical_plan/consumer/plan.rs +++ b/datafusion/substrait/src/logical_plan/consumer/plan.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use super::utils::{make_renamed_schema, rename_expressions}; +use super::utils::{alias_expressions, make_renamed_schema, rename_expressions}; use super::{DefaultSubstraitConsumer, SubstraitConsumer}; use crate::extensions::Extensions; -use datafusion::common::{not_impl_err, plan_err}; +use datafusion::common::{DFSchema, not_impl_err, plan_err}; use datafusion::execution::SessionState; use datafusion::logical_expr::{Aggregate, LogicalPlan, Projection, col}; use std::sync::Arc; @@ -56,7 +56,6 @@ pub async fn from_substrait_plan_with_consumer( let plan = consumer.consume_rel(root.input.as_ref().unwrap()).await?; if root.names.is_empty() { - // Backwards compatibility for plans missing names return Ok(plan); } let renamed_schema = @@ -65,55 +64,9 @@ pub async fn from_substrait_plan_with_consumer( .has_equivalent_names_and_types(plan.schema()) .is_ok() { - // Nothing to do if the schema is already equivalent return Ok(plan); } - match plan { - // If the last node of the plan produces expressions, bake the renames into those expressions. - // This isn't necessary for correctness, but helps with roundtrip tests. - LogicalPlan::Projection(p) => { - Ok(LogicalPlan::Projection(Projection::try_new( - rename_expressions( - p.expr, - p.input.schema(), - renamed_schema.fields(), - )?, - p.input, - )?)) - } - LogicalPlan::Aggregate(a) => { - let (group_fields, expr_fields) = - renamed_schema.fields().split_at(a.group_expr.len()); - let new_group_exprs = rename_expressions( - a.group_expr, - a.input.schema(), - group_fields, - )?; - let new_aggr_exprs = rename_expressions( - a.aggr_expr, - a.input.schema(), - expr_fields, - )?; - Ok(LogicalPlan::Aggregate(Aggregate::try_new( - a.input, - new_group_exprs, - new_aggr_exprs, - )?)) - } - // There are probably more plans where we could bake things in, can add them later as needed. - // Otherwise, add a new Project to handle the renaming. - _ => Ok(LogicalPlan::Projection(Projection::try_new( - rename_expressions( - plan.schema() - .columns() - .iter() - .map(|c| col(c.to_owned())), - plan.schema(), - renamed_schema.fields(), - )?, - Arc::new(plan), - )?)), - } + apply_renames(plan, &renamed_schema) } }, None => plan_err!("Cannot parse plan relation: None"), @@ -125,3 +78,71 @@ pub async fn from_substrait_plan_with_consumer( ), } } + +/// Apply the root-level schema renames to the given plan. +/// +/// The strategy depends on the plan type: +/// - **Projection**: renames (aliases + casts) are baked directly into the +/// projection expressions. +/// - **Aggregate**: only safe aliases are applied to the aggregate's +/// expressions. If struct-field casts are also needed, a wrapping Projection +/// is added on top (the physical planner rejects Cast-wrapped aggregates). +/// - **Other nodes**: a new Projection is added on top to carry the renames. +fn apply_renames( + plan: LogicalPlan, + renamed_schema: &DFSchema, +) -> datafusion::common::Result { + match plan { + LogicalPlan::Projection(p) => { + Ok(LogicalPlan::Projection(Projection::try_new( + rename_expressions( + p.expr, + p.input.schema(), + renamed_schema.fields(), + )?, + p.input, + )?)) + } + LogicalPlan::Aggregate(a) => { + let (group_fields, expr_fields) = + renamed_schema.fields().split_at(a.group_expr.len()); + let agg = LogicalPlan::Aggregate(Aggregate::try_new( + a.input, + alias_expressions(a.group_expr, group_fields)?, + alias_expressions(a.aggr_expr, expr_fields)?, + )?); + // If aliasing alone didn't satisfy the target schema + // (e.g. nested struct field renames require casts), wrap + // in a Projection that can safely carry those casts. + if renamed_schema + .has_equivalent_names_and_types(agg.schema()) + .is_ok() + { + Ok(agg) + } else { + Ok(LogicalPlan::Projection(Projection::try_new( + rename_expressions( + agg.schema() + .columns() + .iter() + .map(|c| col(c.to_owned())), + agg.schema(), + renamed_schema.fields(), + )?, + Arc::new(agg), + )?)) + } + } + _ => Ok(LogicalPlan::Projection(Projection::try_new( + rename_expressions( + plan.schema() + .columns() + .iter() + .map(|c| col(c.to_owned())), + plan.schema(), + renamed_schema.fields(), + )?, + Arc::new(plan), + )?)), + } +} diff --git a/datafusion/substrait/src/logical_plan/consumer/utils.rs b/datafusion/substrait/src/logical_plan/consumer/utils.rs index 59cdf4a8fc93f..d81d6fc63a83f 100644 --- a/datafusion/substrait/src/logical_plan/consumer/utils.rs +++ b/datafusion/substrait/src/logical_plan/consumer/utils.rs @@ -257,6 +257,24 @@ pub(super) fn make_renamed_schema( ) } +/// Apply only top-level column name aliases to expressions, without casting. +/// Unlike `rename_expressions`, this never injects `Expr::Cast` for nested type +/// differences, making it safe for aggregate expressions which the physical +/// planner requires to be pure AggregateFunctions (optionally aliased). +pub(super) fn alias_expressions( + exprs: impl IntoIterator, + new_schema_fields: &[Arc], +) -> datafusion::common::Result> { + exprs + .into_iter() + .zip(new_schema_fields) + .map(|(old_expr, new_field)| match &old_expr { + Expr::Column(c) if &c.name == new_field.name() => Ok(old_expr), + _ => old_expr.alias_if_changed(new_field.name().to_owned()), + }) + .collect() +} + /// Ensure the expressions have the right name(s) according to the new schema. /// This includes the top-level (column) name, which will be renamed through aliasing if needed, /// as well as nested names (if the expression produces any struct types), which will be renamed diff --git a/datafusion/substrait/tests/cases/aggregation_tests.rs b/datafusion/substrait/tests/cases/aggregation_tests.rs index 92a41850b208d..f48397d433a96 100644 --- a/datafusion/substrait/tests/cases/aggregation_tests.rs +++ b/datafusion/substrait/tests/cases/aggregation_tests.rs @@ -25,6 +25,19 @@ mod tests { use datafusion::prelude::SessionContext; use datafusion_substrait::logical_plan::consumer::from_substrait_plan; use insta::assert_snapshot; + use substrait::proto::aggregate_rel::Measure; + use substrait::proto::expression::field_reference::ReferenceType; + use substrait::proto::expression::reference_segment::ReferenceType as SegRefType; + use substrait::proto::expression::{FieldReference, ReferenceSegment, RexType}; + use substrait::proto::extensions::simple_extension_declaration::MappingType; + use substrait::proto::extensions::SimpleExtensionDeclaration; + use substrait::proto::function_argument::ArgType; + use substrait::proto::read_rel::{ReadType, VirtualTable}; + use substrait::proto::rel::RelType; + use substrait::proto::{ + AggregateFunction, Expression, FunctionArgument, NamedStruct, Plan, PlanRel, + ReadRel, Rel, RelRoot, Type, r#type, + }; #[tokio::test] async fn no_grouping_set() -> Result<()> { @@ -68,4 +81,171 @@ mod tests { Ok(()) } + + /// When root names rename struct fields inside an aggregate measure's + /// return type (e.g. `List` → `List`), + /// `rename_expressions` injects `Expr::Cast` around the aggregate + /// function. The physical planner rejects Cast-wrapped aggregates. + /// This test verifies that the consumer wraps the Aggregate in a + /// Projection instead. + #[tokio::test] + #[expect(deprecated)] + async fn aggregate_with_struct_field_rename() -> Result<()> { + // Build a Substrait plan: + // ReadRel(VirtualTable) with one column: List + // AggregateRel with array_agg on that column + // Root names rename struct fields: c0 → one, c1 → two + + let utf8_nullable = Type { + kind: Some(r#type::Kind::String(r#type::String { + type_variation_reference: 0, + nullability: r#type::Nullability::Nullable as i32, + })), + }; + + let struct_type = r#type::Struct { + types: vec![utf8_nullable.clone(), utf8_nullable.clone()], + type_variation_reference: 0, + nullability: r#type::Nullability::Nullable as i32, + }; + + let list_of_struct = Type { + kind: Some(r#type::Kind::List(Box::new(r#type::List { + r#type: Some(Box::new(Type { + kind: Some(r#type::Kind::Struct(struct_type.clone())), + })), + type_variation_reference: 0, + nullability: r#type::Nullability::Nullable as i32, + }))), + }; + + // ReadRel with VirtualTable (empty) and base_schema + let read_rel = Rel { + rel_type: Some(RelType::Read(Box::new(ReadRel { + common: None, + base_schema: Some(NamedStruct { + names: vec![ + "col0".to_string(), + "c0".to_string(), + "c1".to_string(), + ], + r#struct: Some(r#type::Struct { + types: vec![list_of_struct.clone()], + type_variation_reference: 0, + nullability: r#type::Nullability::Required as i32, + }), + }), + filter: None, + best_effort_filter: None, + projection: None, + advanced_extension: None, + read_type: Some(ReadType::VirtualTable(VirtualTable { + values: vec![], + expressions: vec![], + })), + }))), + }; + + // AggregateRel with array_agg(col0) + let field_ref = Expression { + rex_type: Some(RexType::Selection(Box::new(FieldReference { + reference_type: Some(ReferenceType::DirectReference( + ReferenceSegment { + reference_type: Some(SegRefType::StructField(Box::new( + substrait::proto::expression::reference_segment::StructField { + field: 0, + child: None, + }, + ))), + }, + )), + root_type: Some( + substrait::proto::expression::field_reference::RootType::RootReference( + substrait::proto::expression::field_reference::RootReference {}, + ), + ), + }))), + }; + + let aggregate_rel = Rel { + rel_type: Some(RelType::Aggregate(Box::new( + substrait::proto::AggregateRel { + common: None, + input: Some(Box::new(read_rel)), + grouping_expressions: vec![], + groupings: vec![], + measures: vec![Measure { + measure: Some(AggregateFunction { + function_reference: 1, + arguments: vec![FunctionArgument { + arg_type: Some(ArgType::Value(field_ref)), + }], + sorts: vec![], + output_type: Some(list_of_struct), + invocation: 1, // AGGREGATION_INVOCATION_ALL + phase: 3, // AGGREGATION_PHASE_INITIAL_TO_RESULT + args: vec![], + options: vec![], + }), + filter: None, + }], + advanced_extension: None, + }, + ))), + }; + + // Root names: rename struct fields c0 → one, c1 → two + let proto_plan = Plan { + version: None, + extension_uris: vec![substrait::proto::extensions::SimpleExtensionUri { + extension_uri_anchor: 1, + uri: "/functions_aggregate.yaml".to_string(), + }], + extension_urns: vec![], + extensions: vec![SimpleExtensionDeclaration { + mapping_type: Some(MappingType::ExtensionFunction( + substrait::proto::extensions::simple_extension_declaration::ExtensionFunction { + extension_uri_reference: 1, + extension_urn_reference: 0, + function_anchor: 1, + name: "array_agg:list".to_string(), + }, + )), + }], + relations: vec![PlanRel { + rel_type: Some(substrait::proto::plan_rel::RelType::Root(RelRoot { + input: Some(aggregate_rel), + names: vec![ + "result".to_string(), + "one".to_string(), + "two".to_string(), + ], + })), + }], + advanced_extensions: None, + expected_type_urls: vec![], + parameter_bindings: vec![], + type_aliases: vec![], + }; + + let ctx = SessionContext::new(); + let plan = from_substrait_plan(&ctx.state(), &proto_plan).await?; + + // The plan should contain a Projection wrapping the Aggregate, + // because rename_expressions injects Cast for the struct field rename. + let plan_str = format!("{plan}"); + assert!( + plan_str.contains("Projection:"), + "Expected Projection wrapper but got:\n{plan_str}" + ); + assert!( + plan_str.contains("Aggregate:"), + "Expected Aggregate in plan but got:\n{plan_str}" + ); + + // Execute to confirm the physical planner accepts this plan + DataFrame::new(ctx.state(), plan).show().await?; + + Ok(()) + } }