Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions native-engine/auron-planner/proto/auron.proto
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,8 @@ message Field {
bool nullable = 3;
// for complex data types like structs, unions
repeated Field children = 4;
// Iceberg/Parquet field id. Zero means unset.
int32 field_id = 5;
}

message FixedSizeBinary {
Expand Down
51 changes: 28 additions & 23 deletions native-engine/auron-planner/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};

use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema, TimeUnit};
use datafusion::{common::JoinSide, logical_expr::Operator, scalar::ScalarValue};
use datafusion::{
common::JoinSide, logical_expr::Operator, parquet::arrow::PARQUET_FIELD_ID_META_KEY,
scalar::ScalarValue,
};
use datafusion_ext_plans::{agg::AggFunction, joins::join_utils::JoinType};

use crate::error::PlanSerDeError;
Expand Down Expand Up @@ -406,17 +409,29 @@ impl TryInto<DataType> for &Box<protobuf::List> {
impl TryInto<Field> for &protobuf::Field {
type Error = PlanSerDeError;
fn try_into(self) -> Result<Field, Self::Error> {
let pb_datatype = self.arrow_type.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: Field message missing required field 'arrow_type'",
)
})?;
build_arrow_field(self)
}
}

fn build_arrow_field(field: &protobuf::Field) -> Result<Field, PlanSerDeError> {
let pb_datatype = field.arrow_type.as_ref().ok_or_else(|| {
proto_error(
"Protobuf deserialization error: Field message missing required field 'arrow_type'",
)
})?;
let arrow_field = Field::new(
field.name.as_str(),
pb_datatype.as_ref().try_into()?,
field.nullable,
);

Ok(Field::new(
self.name.as_str(),
pb_datatype.as_ref().try_into()?,
self.nullable,
))
if field.field_id == 0 {
Ok(arrow_field)
} else {
Ok(arrow_field.with_metadata(HashMap::from([(
PARQUET_FIELD_ID_META_KEY.to_string(),
field.field_id.to_string(),
)])))
}
}

Expand All @@ -427,17 +442,7 @@ impl TryInto<Schema> for &protobuf::Schema {
let fields = self
.columns
.iter()
.map(|c| {
let pb_arrow_type_res = c
.arrow_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: Field message was missing required field 'arrow_type'"));
let pb_arrow_type: &protobuf::ArrowType = match pb_arrow_type_res {
Ok(res) => res,
Err(e) => return Err(e),
};
Ok(Field::new(&c.name, pb_arrow_type.try_into()?, c.nullable))
})
.map(build_arrow_field)
.collect::<Result<Vec<_>, _>>()?;
Ok(Schema::new(fields))
}
Expand Down
16 changes: 13 additions & 3 deletions native-engine/datafusion-ext-plans/src/scan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion::{
datasource::schema_adapter::{
SchemaAdapter, SchemaAdapterFactory, SchemaMapper, SchemaMapping,
},
parquet::arrow::PARQUET_FIELD_ID_META_KEY,
};
use datafusion_ext_commons::df_execution_err;

Expand Down Expand Up @@ -57,11 +58,10 @@ impl SchemaAdapter for AuronSchemaAdapter {
fn map_column_index(&self, index: usize, file_schema: &Schema) -> Option<usize> {
let field = self.table_schema.field(index);

// use case insensitive matching
file_schema
.fields()
.iter()
.position(|f| f.name().eq_ignore_ascii_case(field.name()))
.position(|file_field| fields_match(field, file_field))
}

fn map_schema(&self, file_schema: &Schema) -> Result<(Arc<dyn SchemaMapper>, Vec<usize>)> {
Expand All @@ -73,7 +73,7 @@ impl SchemaAdapter for AuronSchemaAdapter {
.table_schema
.fields()
.iter()
.position(|f| f.name().eq_ignore_ascii_case(file_field.name()))
.position(|table_field| fields_match(table_field, file_field))
{
field_mappings[table_idx] = Some(projection.len());
projection.push(file_idx);
Expand All @@ -89,6 +89,16 @@ impl SchemaAdapter for AuronSchemaAdapter {
}
}

fn fields_match(table_field: &Field, file_field: &Field) -> bool {
match table_field.metadata().get(PARQUET_FIELD_ID_META_KEY) {
Some(table_field_id) => file_field
.metadata()
.get(PARQUET_FIELD_ID_META_KEY)
.is_some_and(|file_field_id| file_field_id == table_field_id),
None => table_field.name().eq_ignore_ascii_case(file_field.name()),
}
}

pub fn create_auron_schema_mapper(
table_schema: &SchemaRef,
field_mappings: &[Option<usize>],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,18 +216,22 @@ object NativeConverters extends Logging {
arrowTypeBuilder.build()
}

def convertField(sparkField: StructField): pb.Field = {
pb.Field
def convertField(sparkField: StructField, fieldId: Option[Int] = None): pb.Field = {
val fieldBuilder = pb.Field
.newBuilder()
.setName(sparkField.name)
.setNullable(sparkField.nullable)
.setArrowType(convertDataType(sparkField.dataType))
.build()
fieldId.foreach(fieldBuilder.setFieldId)
fieldBuilder.build()
}

def convertSchema(sparkSchema: StructType): pb.Schema = {
def convertSchema(
sparkSchema: StructType,
fieldIdsByName: Map[String, Int] = Map.empty): pb.Schema = {
val schemaBuilder = pb.Schema.newBuilder()
sparkSchema.foreach(sparkField => schemaBuilder.addColumns(convertField(sparkField)))
sparkSchema.foreach(sparkField =>
schemaBuilder.addColumns(convertField(sparkField, fieldIdsByName.get(sparkField.name))))
schemaBuilder.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ abstract class NativeGenerateBase(
}

private def nativeGeneratorOutput =
Util.getSchema(generatorOutput).map(NativeConverters.convertField)
Util.getSchema(generatorOutput).map(field => NativeConverters.convertField(field))

private def nativeRequiredChildOutput =
Util.getSchema(requiredChildOutput).map(_.name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,77 @@
*/
package org.apache.iceberg.spark.source

import scala.collection.JavaConverters._

object AuronIcebergSourceUtil {

final case class RenameOrDrop(topLevel: Boolean, nested: Boolean)

def getClassOfSparkBatchQueryScan(): Class[SparkBatchQueryScan] = {
classOf[SparkBatchQueryScan]
}

def getClassOfSparkInputPartition(): Class[SparkInputPartition] = {
classOf[SparkInputPartition]
}

def expectedFieldIds(scan: AnyRef): Map[String, Int] = {
val expectedSchema = asBatchQueryScan(scan).expectedSchema()
expectedSchema.columns().asScala.map(field => field.name() -> field.fieldId()).toMap
}

def detectRenameOrDrop(scan: AnyRef): RenameOrDrop = {
val table = asBatchQueryScan(scan).table()
val currentFields = collectFieldIdToName(table.schema())

table
.schemas()
.values()
.asScala
.foldLeft(RenameOrDrop(topLevel = false, nested = false)) { (result, schema) =>
collectFieldIdToName(schema).foldLeft(result) {
case (currentResult, (fieldId, historicalField)) =>
currentFields.get(fieldId) match {
case Some(currentField) if currentField.name != historicalField.name =>
if (historicalField.topLevel || currentField.topLevel) {
currentResult.copy(topLevel = true)
} else {
currentResult.copy(nested = true)
}
case None =>
if (historicalField.topLevel) {
currentResult.copy(topLevel = true)
} else {
currentResult.copy(nested = true)
}
case _ =>
currentResult
}
}
}
}

final private case class FieldIdentity(name: String, topLevel: Boolean)

private def collectFieldIdToName(schema: org.apache.iceberg.Schema): Map[Int, FieldIdentity] = {
def collect(
fields: Seq[org.apache.iceberg.types.Types.NestedField],
topLevel: Boolean): Seq[(Int, FieldIdentity)] = {
fields.flatMap { field =>
val current = field.fieldId() -> FieldIdentity(field.name(), topLevel)
val nested =
if (field.`type`().isNestedType) {
collect(field.`type`().asNestedType().fields().asScala.toSeq, topLevel = false)
} else {
Seq.empty
}
current +: nested
}
}

collect(schema.columns().asScala.toSeq, topLevel = true).toMap
}

private def asBatchQueryScan(scan: AnyRef): SparkBatchQueryScan =
scan.asInstanceOf[SparkBatchQueryScan]
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ final case class IcebergScanPlan(
readSchema: StructType,
fileSchema: StructType,
partitionSchema: StructType,
pruningPredicates: Seq[pb.PhysicalExprNode])
pruningPredicates: Seq[pb.PhysicalExprNode],
fieldIdsByName: Map[String, Int])

object IcebergScanSupport extends Logging {

Expand Down Expand Up @@ -75,6 +76,27 @@ object IcebergScanSupport extends Logging {
partitionSchema.fields.forall(field => NativeConverters.isTypeSupported(field.dataType)),
"Has unsupported schema type.")

val inspected: Option[(AuronIcebergSourceUtil.RenameOrDrop, Map[String, Int])] =
try {
Some(
(
AuronIcebergSourceUtil.detectRenameOrDrop(scan.asInstanceOf[AnyRef]),
AuronIcebergSourceUtil.expectedFieldIds(scan.asInstanceOf[AnyRef])))
} catch {
case NonFatal(t) =>
logWarning(s"Failed to inspect Iceberg field ids for $scanClassName.", t)
None
}
if (inspected.isEmpty) {
return None
}
val (renameOrDrop, fieldIdsByName) = inspected.get
assert(!renameOrDrop.nested, "Nested Iceberg rename or drop is not supported.")

assert(
fileSchema.fields.forall(field => fieldIdsByName.contains(field.name)),
"Failed to find field ids for all Iceberg data columns.")

val partitions = inputPartitions(exec)
// Empty scan (e.g. empty table) should still build a plan to return no rows.
if (partitions.isEmpty) {
Expand All @@ -86,7 +108,8 @@ object IcebergScanSupport extends Logging {
readSchema,
fileSchema,
partitionSchema,
Seq.empty))
Seq.empty,
fieldIdsByName))
}

val icebergPartitions = partitions.flatMap(icebergPartition)
Expand All @@ -110,6 +133,9 @@ object IcebergScanSupport extends Logging {
assert(
!(format != FileFormat.PARQUET && format != FileFormat.ORC),
"Only support parquet or orc.")
assert(
!(format == FileFormat.ORC && renameOrDrop.topLevel),
"Iceberg ORC rename or drop is not supported.")

val pruningPredicates = collectPruningPredicates(scan.asInstanceOf[AnyRef], readSchema)
Some(
Expand All @@ -119,7 +145,8 @@ object IcebergScanSupport extends Logging {
readSchema,
fileSchema,
partitionSchema,
pruningPredicates))
pruningPredicates,
fieldIdsByName))
}

private def collectUnsupportedMetadataColumns(schema: StructType): Seq[String] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ case class NativeIcebergTableScanExec(basedScan: BatchScanExec, plan: IcebergSca
private lazy val fileSizes: Map[String, Long] = buildFileSizes()
private lazy val fileSpecIds: Map[String, Int] = buildFileSpecIds()

private lazy val nativeFileSchema: pb.Schema = NativeConverters.convertSchema(fileSchema)
private lazy val nativeFileSchema: pb.Schema =
NativeConverters.convertSchema(fileSchema, plan.fieldIdsByName)
private lazy val nativePartitionSchema: pb.Schema =
NativeConverters.convertSchema(partitionSchema)

Expand Down
Loading
Loading