Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,8 @@ object AuronConverters extends Logging {
case _: NativeParquetScanBase | _: NativeOrcScanBase | _: NativeHiveTableScanBase |
_: NativeUnionBase =>
true
case exec if exec.nodeName == "NativeIcebergTableScan" =>
true
case _: ConvertToNativeBase => needRenameColumns(plan.children.head)
case exec if NativeHelper.isNative(exec) =>
NativeHelper.getUnderlyingNativePlan(exec).output != plan.output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@ import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import org.apache.iceberg.{FileFormat, FileScanTask, MetadataColumns}
import org.apache.iceberg.expressions.Expressions
import org.apache.iceberg.expressions.{And => IcebergAnd, BoundPredicate, Expression => IcebergExpression, Not => IcebergNot, Or => IcebergOr, UnboundPredicate}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.auron.NativeConverters
import org.apache.spark.sql.catalyst.expressions.{And => SparkAnd, AttributeReference, EqualTo, Expression => SparkExpression, GreaterThan, GreaterThanOrEqual, In, IsNaN, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not => SparkNot, Or => SparkOr}
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{BinaryType, DataType, DecimalType, StringType, StructField, StructType}

import org.apache.auron.{protobuf => pb}

final case class IcebergScanPlan(
fileTasks: Seq[FileScanTask],
fileFormat: FileFormat,
readSchema: StructType)
readSchema: StructType,
pruningPredicates: Seq[pb.PhysicalExprNode])

object IcebergScanSupport extends Logging {

Expand Down Expand Up @@ -61,7 +66,7 @@ object IcebergScanSupport extends Logging {
// Empty scan (e.g. empty table) should still build a plan to return no rows.
if (partitions.isEmpty) {
logWarning(s"Native Iceberg scan planned with empty partitions for $scanClassName.")
return Some(IcebergScanPlan(Seq.empty, FileFormat.PARQUET, readSchema))
return Some(IcebergScanPlan(Seq.empty, FileFormat.PARQUET, readSchema, Seq.empty))
}

val icebergPartitions = partitions.flatMap(icebergPartition)
Expand All @@ -77,11 +82,6 @@ object IcebergScanSupport extends Logging {
return None
}

// Residual filters require row-level evaluation, not supported in native scan.
if (!fileTasks.forall(task => Expressions.alwaysTrue().equals(task.residual()))) {
return None
}

// Native scan handles a single file format; mixed formats must fallback.
val formats = fileTasks.map(_.file().format()).distinct
if (formats.size > 1) {
Expand All @@ -93,7 +93,9 @@ object IcebergScanSupport extends Logging {
return None
}

Some(IcebergScanPlan(fileTasks, format, readSchema))
val pruningPredicates = collectPruningPredicates(scan.asInstanceOf[AnyRef], readSchema)

Some(IcebergScanPlan(fileTasks, format, readSchema, pruningPredicates))
}

private def hasMetadataColumns(schema: StructType): Boolean =
Expand Down Expand Up @@ -188,4 +190,240 @@ object IcebergScanSupport extends Logging {
None
}
}

private def collectPruningPredicates(
scan: AnyRef,
readSchema: StructType): Seq[pb.PhysicalExprNode] = {
scanFilterExpressions(scan).flatMap { expr =>
convertIcebergFilterExpression(expr, readSchema) match {
case Some(converted) =>
Some(NativeConverters.convertScanPruningExpr(converted))
case None =>
logDebug(s"Skip unsupported Iceberg pruning expression: $expr")
None
}
}
}

private def scanFilterExpressions(scan: AnyRef): Seq[IcebergExpression] = {
invokeDeclaredMethod(scan, "filterExpressions") match {
case Some(values: java.util.Collection[_]) =>
values.asScala.collect { case expr: IcebergExpression => expr }.toSeq
case Some(values: Seq[_]) =>
values.collect { case expr: IcebergExpression => expr }
case _ =>
Seq.empty
}
}

private def invokeDeclaredMethod(target: AnyRef, methodName: String): Option[Any] = {
try {
var cls: Class[_] = target.getClass
while (cls != null) {
cls.getDeclaredMethods.find(_.getName == methodName) match {
case Some(method) =>
method.setAccessible(true)
return Some(method.invoke(target))
case None =>
cls = cls.getSuperclass
}
}
None
} catch {
case NonFatal(t) =>
logDebug(s"Failed to invoke $methodName on ${target.getClass.getName}.", t)
None
}
}

private def convertIcebergFilterExpression(
expr: IcebergExpression,
readSchema: StructType): Option[SparkExpression] = {
expr match {
case and: IcebergAnd =>
for {
left <- convertIcebergFilterExpression(and.left(), readSchema)
right <- convertIcebergFilterExpression(and.right(), readSchema)
} yield SparkAnd(left, right)
case or: IcebergOr =>
for {
left <- convertIcebergFilterExpression(or.left(), readSchema)
right <- convertIcebergFilterExpression(or.right(), readSchema)
} yield SparkOr(left, right)
case not: IcebergNot =>
convertIcebergFilterExpression(not.child(), readSchema).map(SparkNot)
case predicate: UnboundPredicate[_] =>
convertUnboundPredicate(predicate, readSchema)
case predicate: BoundPredicate[_] =>
convertBoundPredicate(predicate, readSchema)
case _ =>
expr.op() match {
case org.apache.iceberg.expressions.Expression.Operation.TRUE =>
Some(Literal(true))
case org.apache.iceberg.expressions.Expression.Operation.FALSE =>
Some(Literal(false))
case _ =>
None
}
}
}

private def convertUnboundPredicate(
predicate: UnboundPredicate[_],
readSchema: StructType): Option[SparkExpression] = {
findField(predicate.ref().name(), readSchema).flatMap { field =>
val attr = toAttribute(field)
val op = predicate.op()

op match {
case org.apache.iceberg.expressions.Expression.Operation.IS_NULL =>
Some(IsNull(attr))
case org.apache.iceberg.expressions.Expression.Operation.NOT_NULL =>
Some(IsNotNull(attr))
case org.apache.iceberg.expressions.Expression.Operation.IS_NAN =>
Some(IsNaN(attr))
case org.apache.iceberg.expressions.Expression.Operation.NOT_NAN =>
Some(SparkNot(IsNaN(attr)))
case org.apache.iceberg.expressions.Expression.Operation.IN =>
convertInPredicate(
attr,
field.dataType,
predicate.literals().asScala.map(_.value()).toSeq)
case org.apache.iceberg.expressions.Expression.Operation.NOT_IN =>
convertInPredicate(
attr,
field.dataType,
predicate.literals().asScala.map(_.value()).toSeq).map(SparkNot)
case _ =>
convertBinaryPredicate(attr, field.dataType, op, predicate.literal().value())
}
}
}

private def convertBoundPredicate(
predicate: BoundPredicate[_],
readSchema: StructType): Option[SparkExpression] = {
findField(predicate.ref().name(), readSchema).flatMap { field =>
val attr = toAttribute(field)
val op = predicate.op()

if (predicate.isUnaryPredicate()) {
op match {
case org.apache.iceberg.expressions.Expression.Operation.IS_NULL =>
Some(IsNull(attr))
case org.apache.iceberg.expressions.Expression.Operation.NOT_NULL =>
Some(IsNotNull(attr))
case org.apache.iceberg.expressions.Expression.Operation.IS_NAN =>
Some(IsNaN(attr))
case org.apache.iceberg.expressions.Expression.Operation.NOT_NAN =>
Some(SparkNot(IsNaN(attr)))
case _ =>
None
}
} else if (predicate.isLiteralPredicate()) {
val literalValue = predicate.asLiteralPredicate().literal().value()
op match {
case _ =>
convertBinaryPredicate(attr, field.dataType, op, literalValue)
}
} else if (predicate.isSetPredicate()) {
val values = predicate.asSetPredicate().literalSet().asScala.toSeq
op match {
case org.apache.iceberg.expressions.Expression.Operation.IN =>
convertInPredicate(attr, field.dataType, values)
case org.apache.iceberg.expressions.Expression.Operation.NOT_IN =>
convertInPredicate(attr, field.dataType, values).map(SparkNot)
case _ =>
None
}
} else {
None
}
}
}

private def convertBinaryPredicate(
attr: AttributeReference,
dataType: DataType,
op: org.apache.iceberg.expressions.Expression.Operation,
literalValue: Any): Option[SparkExpression] = {
if (!supportsScanPruningLiteralType(dataType)) {
return None
}
toLiteral(literalValue, dataType).flatMap { literal =>
op match {
case org.apache.iceberg.expressions.Expression.Operation.EQ =>
Some(EqualTo(attr, literal))
case org.apache.iceberg.expressions.Expression.Operation.NOT_EQ =>
Some(SparkNot(EqualTo(attr, literal)))
case org.apache.iceberg.expressions.Expression.Operation.LT =>
Some(LessThan(attr, literal))
case org.apache.iceberg.expressions.Expression.Operation.LT_EQ =>
Some(LessThanOrEqual(attr, literal))
case org.apache.iceberg.expressions.Expression.Operation.GT =>
Some(GreaterThan(attr, literal))
case org.apache.iceberg.expressions.Expression.Operation.GT_EQ =>
Some(GreaterThanOrEqual(attr, literal))
case _ =>
None
}
}
}

private def convertInPredicate(
attr: AttributeReference,
dataType: DataType,
values: Seq[Any]): Option[SparkExpression] = {
if (!supportsScanPruningLiteralType(dataType)) {
return None
}
val literals = values.map(toLiteral(_, dataType))
if (literals.forall(_.nonEmpty)) {
Some(In(attr, literals.flatten))
} else {
None
}
}

private def supportsScanPruningLiteralType(dataType: DataType): Boolean = {
dataType match {
case StringType | BinaryType => false
case _: DecimalType => false
case _ => true
}
}

private def toLiteral(value: Any, dataType: DataType): Option[Literal] = {
if (value == null) {
return Some(Literal.create(null, dataType))
}
dataType match {
case _: DecimalType =>
None
case BinaryType =>
value match {
case bytes: Array[Byte] =>
Some(Literal(bytes, BinaryType))
case byteBuffer: java.nio.ByteBuffer =>
val duplicated = byteBuffer.duplicate()
val bytes = new Array[Byte](duplicated.remaining())
duplicated.get(bytes)
Some(Literal(bytes, BinaryType))
case _ =>
None
}
case StringType =>
Some(Literal.create(value.toString, StringType))
case _ =>
Some(Literal.create(value, dataType))
}
}

private def toAttribute(field: StructField): AttributeReference =
AttributeReference(field.name, field.dataType, nullable = true)()

private def findField(name: String, readSchema: StructType): Option[StructField] = {
val resolver = SQLConf.get.resolver
readSchema.fields.find(field => resolver(field.name, name))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ case class NativeIcebergTableScanExec(basedScan: BatchScanExec, plan: IcebergSca

private lazy val readSchema: StructType = plan.readSchema
private lazy val fileTasks: Seq[FileScanTask] = plan.fileTasks
private lazy val pruningPredicates: Seq[pb.PhysicalExprNode] = plan.pruningPredicates

private lazy val partitions: Array[FilePartition] = buildFilePartitions()
private lazy val fileSizes: Map[String, Long] = buildFileSizes()
Expand Down Expand Up @@ -166,8 +167,7 @@ case class NativeIcebergTableScanExec(basedScan: BatchScanExec, plan: IcebergSca
.newBuilder()
.setBaseConf(nativeFileScanConf)
.setFsResourceId(resourceId)
// No pruning predicates are pushed down in the native scan yet.
.addAllPruningPredicates(new java.util.ArrayList())
.addAllPruningPredicates(pruningPredicates.asJava)

pb.PhysicalPlanNode
.newBuilder()
Expand All @@ -178,8 +178,7 @@ case class NativeIcebergTableScanExec(basedScan: BatchScanExec, plan: IcebergSca
.newBuilder()
.setBaseConf(nativeFileScanConf)
.setFsResourceId(resourceId)
// No pruning predicates are pushed down in the native scan yet.
.addAllPruningPredicates(new java.util.ArrayList())
.addAllPruningPredicates(pruningPredicates.asJava)

pb.PhysicalPlanNode
.newBuilder()
Expand Down
Loading
Loading