diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala index b3c55da54..cc12a176a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/AuronConverters.scala @@ -1102,6 +1102,8 @@ object AuronConverters extends Logging { case _: NativeParquetScanBase | _: NativeOrcScanBase | _: NativeHiveTableScanBase | _: NativeUnionBase => true + case exec if exec.nodeName == "NativeIcebergTableScan" => + true case _: ConvertToNativeBase => needRenameColumns(plan.children.head) case exec if NativeHelper.isNative(exec) => NativeHelper.getUnderlyingNativePlan(exec).output != plan.output diff --git a/thirdparty/auron-iceberg/src/main/scala/org/apache/spark/sql/auron/iceberg/IcebergScanSupport.scala b/thirdparty/auron-iceberg/src/main/scala/org/apache/spark/sql/auron/iceberg/IcebergScanSupport.scala index 7ac671638..39e14f3e2 100644 --- a/thirdparty/auron-iceberg/src/main/scala/org/apache/spark/sql/auron/iceberg/IcebergScanSupport.scala +++ b/thirdparty/auron-iceberg/src/main/scala/org/apache/spark/sql/auron/iceberg/IcebergScanSupport.scala @@ -20,17 +20,22 @@ import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.iceberg.{FileFormat, FileScanTask, MetadataColumns} -import org.apache.iceberg.expressions.Expressions +import org.apache.iceberg.expressions.{And => IcebergAnd, BoundPredicate, Expression => IcebergExpression, Not => IcebergNot, Or => IcebergOr, UnboundPredicate} import org.apache.spark.internal.Logging import org.apache.spark.sql.auron.NativeConverters +import org.apache.spark.sql.catalyst.expressions.{And => SparkAnd, AttributeReference, EqualTo, Expression => SparkExpression, GreaterThan, GreaterThanOrEqual, In, IsNaN, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not => SparkNot, Or => SparkOr} import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.datasources.v2.BatchScanExec -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{BinaryType, DataType, DecimalType, StringType, StructField, StructType} + +import org.apache.auron.{protobuf => pb} final case class IcebergScanPlan( fileTasks: Seq[FileScanTask], fileFormat: FileFormat, - readSchema: StructType) + readSchema: StructType, + pruningPredicates: Seq[pb.PhysicalExprNode]) object IcebergScanSupport extends Logging { @@ -61,7 +66,7 @@ object IcebergScanSupport extends Logging { // Empty scan (e.g. empty table) should still build a plan to return no rows. if (partitions.isEmpty) { logWarning(s"Native Iceberg scan planned with empty partitions for $scanClassName.") - return Some(IcebergScanPlan(Seq.empty, FileFormat.PARQUET, readSchema)) + return Some(IcebergScanPlan(Seq.empty, FileFormat.PARQUET, readSchema, Seq.empty)) } val icebergPartitions = partitions.flatMap(icebergPartition) @@ -77,11 +82,6 @@ object IcebergScanSupport extends Logging { return None } - // Residual filters require row-level evaluation, not supported in native scan. - if (!fileTasks.forall(task => Expressions.alwaysTrue().equals(task.residual()))) { - return None - } - // Native scan handles a single file format; mixed formats must fallback. val formats = fileTasks.map(_.file().format()).distinct if (formats.size > 1) { @@ -93,7 +93,9 @@ object IcebergScanSupport extends Logging { return None } - Some(IcebergScanPlan(fileTasks, format, readSchema)) + val pruningPredicates = collectPruningPredicates(scan.asInstanceOf[AnyRef], readSchema) + + Some(IcebergScanPlan(fileTasks, format, readSchema, pruningPredicates)) } private def hasMetadataColumns(schema: StructType): Boolean = @@ -188,4 +190,240 @@ object IcebergScanSupport extends Logging { None } } + + private def collectPruningPredicates( + scan: AnyRef, + readSchema: StructType): Seq[pb.PhysicalExprNode] = { + scanFilterExpressions(scan).flatMap { expr => + convertIcebergFilterExpression(expr, readSchema) match { + case Some(converted) => + Some(NativeConverters.convertScanPruningExpr(converted)) + case None => + logDebug(s"Skip unsupported Iceberg pruning expression: $expr") + None + } + } + } + + private def scanFilterExpressions(scan: AnyRef): Seq[IcebergExpression] = { + invokeDeclaredMethod(scan, "filterExpressions") match { + case Some(values: java.util.Collection[_]) => + values.asScala.collect { case expr: IcebergExpression => expr }.toSeq + case Some(values: Seq[_]) => + values.collect { case expr: IcebergExpression => expr } + case _ => + Seq.empty + } + } + + private def invokeDeclaredMethod(target: AnyRef, methodName: String): Option[Any] = { + try { + var cls: Class[_] = target.getClass + while (cls != null) { + cls.getDeclaredMethods.find(_.getName == methodName) match { + case Some(method) => + method.setAccessible(true) + return Some(method.invoke(target)) + case None => + cls = cls.getSuperclass + } + } + None + } catch { + case NonFatal(t) => + logDebug(s"Failed to invoke $methodName on ${target.getClass.getName}.", t) + None + } + } + + private def convertIcebergFilterExpression( + expr: IcebergExpression, + readSchema: StructType): Option[SparkExpression] = { + expr match { + case and: IcebergAnd => + for { + left <- convertIcebergFilterExpression(and.left(), readSchema) + right <- convertIcebergFilterExpression(and.right(), readSchema) + } yield SparkAnd(left, right) + case or: IcebergOr => + for { + left <- convertIcebergFilterExpression(or.left(), readSchema) + right <- convertIcebergFilterExpression(or.right(), readSchema) + } yield SparkOr(left, right) + case not: IcebergNot => + convertIcebergFilterExpression(not.child(), readSchema).map(SparkNot) + case predicate: UnboundPredicate[_] => + convertUnboundPredicate(predicate, readSchema) + case predicate: BoundPredicate[_] => + convertBoundPredicate(predicate, readSchema) + case _ => + expr.op() match { + case org.apache.iceberg.expressions.Expression.Operation.TRUE => + Some(Literal(true)) + case org.apache.iceberg.expressions.Expression.Operation.FALSE => + Some(Literal(false)) + case _ => + None + } + } + } + + private def convertUnboundPredicate( + predicate: UnboundPredicate[_], + readSchema: StructType): Option[SparkExpression] = { + findField(predicate.ref().name(), readSchema).flatMap { field => + val attr = toAttribute(field) + val op = predicate.op() + + op match { + case org.apache.iceberg.expressions.Expression.Operation.IS_NULL => + Some(IsNull(attr)) + case org.apache.iceberg.expressions.Expression.Operation.NOT_NULL => + Some(IsNotNull(attr)) + case org.apache.iceberg.expressions.Expression.Operation.IS_NAN => + Some(IsNaN(attr)) + case org.apache.iceberg.expressions.Expression.Operation.NOT_NAN => + Some(SparkNot(IsNaN(attr))) + case org.apache.iceberg.expressions.Expression.Operation.IN => + convertInPredicate( + attr, + field.dataType, + predicate.literals().asScala.map(_.value()).toSeq) + case org.apache.iceberg.expressions.Expression.Operation.NOT_IN => + convertInPredicate( + attr, + field.dataType, + predicate.literals().asScala.map(_.value()).toSeq).map(SparkNot) + case _ => + convertBinaryPredicate(attr, field.dataType, op, predicate.literal().value()) + } + } + } + + private def convertBoundPredicate( + predicate: BoundPredicate[_], + readSchema: StructType): Option[SparkExpression] = { + findField(predicate.ref().name(), readSchema).flatMap { field => + val attr = toAttribute(field) + val op = predicate.op() + + if (predicate.isUnaryPredicate()) { + op match { + case org.apache.iceberg.expressions.Expression.Operation.IS_NULL => + Some(IsNull(attr)) + case org.apache.iceberg.expressions.Expression.Operation.NOT_NULL => + Some(IsNotNull(attr)) + case org.apache.iceberg.expressions.Expression.Operation.IS_NAN => + Some(IsNaN(attr)) + case org.apache.iceberg.expressions.Expression.Operation.NOT_NAN => + Some(SparkNot(IsNaN(attr))) + case _ => + None + } + } else if (predicate.isLiteralPredicate()) { + val literalValue = predicate.asLiteralPredicate().literal().value() + op match { + case _ => + convertBinaryPredicate(attr, field.dataType, op, literalValue) + } + } else if (predicate.isSetPredicate()) { + val values = predicate.asSetPredicate().literalSet().asScala.toSeq + op match { + case org.apache.iceberg.expressions.Expression.Operation.IN => + convertInPredicate(attr, field.dataType, values) + case org.apache.iceberg.expressions.Expression.Operation.NOT_IN => + convertInPredicate(attr, field.dataType, values).map(SparkNot) + case _ => + None + } + } else { + None + } + } + } + + private def convertBinaryPredicate( + attr: AttributeReference, + dataType: DataType, + op: org.apache.iceberg.expressions.Expression.Operation, + literalValue: Any): Option[SparkExpression] = { + if (!supportsScanPruningLiteralType(dataType)) { + return None + } + toLiteral(literalValue, dataType).flatMap { literal => + op match { + case org.apache.iceberg.expressions.Expression.Operation.EQ => + Some(EqualTo(attr, literal)) + case org.apache.iceberg.expressions.Expression.Operation.NOT_EQ => + Some(SparkNot(EqualTo(attr, literal))) + case org.apache.iceberg.expressions.Expression.Operation.LT => + Some(LessThan(attr, literal)) + case org.apache.iceberg.expressions.Expression.Operation.LT_EQ => + Some(LessThanOrEqual(attr, literal)) + case org.apache.iceberg.expressions.Expression.Operation.GT => + Some(GreaterThan(attr, literal)) + case org.apache.iceberg.expressions.Expression.Operation.GT_EQ => + Some(GreaterThanOrEqual(attr, literal)) + case _ => + None + } + } + } + + private def convertInPredicate( + attr: AttributeReference, + dataType: DataType, + values: Seq[Any]): Option[SparkExpression] = { + if (!supportsScanPruningLiteralType(dataType)) { + return None + } + val literals = values.map(toLiteral(_, dataType)) + if (literals.forall(_.nonEmpty)) { + Some(In(attr, literals.flatten)) + } else { + None + } + } + + private def supportsScanPruningLiteralType(dataType: DataType): Boolean = { + dataType match { + case StringType | BinaryType => false + case _: DecimalType => false + case _ => true + } + } + + private def toLiteral(value: Any, dataType: DataType): Option[Literal] = { + if (value == null) { + return Some(Literal.create(null, dataType)) + } + dataType match { + case _: DecimalType => + None + case BinaryType => + value match { + case bytes: Array[Byte] => + Some(Literal(bytes, BinaryType)) + case byteBuffer: java.nio.ByteBuffer => + val duplicated = byteBuffer.duplicate() + val bytes = new Array[Byte](duplicated.remaining()) + duplicated.get(bytes) + Some(Literal(bytes, BinaryType)) + case _ => + None + } + case StringType => + Some(Literal.create(value.toString, StringType)) + case _ => + Some(Literal.create(value, dataType)) + } + } + + private def toAttribute(field: StructField): AttributeReference = + AttributeReference(field.name, field.dataType, nullable = true)() + + private def findField(name: String, readSchema: StructType): Option[StructField] = { + val resolver = SQLConf.get.resolver + readSchema.fields.find(field => resolver(field.name, name)) + } } diff --git a/thirdparty/auron-iceberg/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeIcebergTableScanExec.scala b/thirdparty/auron-iceberg/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeIcebergTableScanExec.scala index 63927e0b0..3cf5cf562 100644 --- a/thirdparty/auron-iceberg/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeIcebergTableScanExec.scala +++ b/thirdparty/auron-iceberg/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeIcebergTableScanExec.scala @@ -59,6 +59,7 @@ case class NativeIcebergTableScanExec(basedScan: BatchScanExec, plan: IcebergSca private lazy val readSchema: StructType = plan.readSchema private lazy val fileTasks: Seq[FileScanTask] = plan.fileTasks + private lazy val pruningPredicates: Seq[pb.PhysicalExprNode] = plan.pruningPredicates private lazy val partitions: Array[FilePartition] = buildFilePartitions() private lazy val fileSizes: Map[String, Long] = buildFileSizes() @@ -166,8 +167,7 @@ case class NativeIcebergTableScanExec(basedScan: BatchScanExec, plan: IcebergSca .newBuilder() .setBaseConf(nativeFileScanConf) .setFsResourceId(resourceId) - // No pruning predicates are pushed down in the native scan yet. - .addAllPruningPredicates(new java.util.ArrayList()) + .addAllPruningPredicates(pruningPredicates.asJava) pb.PhysicalPlanNode .newBuilder() @@ -178,8 +178,7 @@ case class NativeIcebergTableScanExec(basedScan: BatchScanExec, plan: IcebergSca .newBuilder() .setBaseConf(nativeFileScanConf) .setFsResourceId(resourceId) - // No pruning predicates are pushed down in the native scan yet. - .addAllPruningPredicates(new java.util.ArrayList()) + .addAllPruningPredicates(pruningPredicates.asJava) pb.PhysicalPlanNode .newBuilder() diff --git a/thirdparty/auron-iceberg/src/test/scala/org/apache/auron/iceberg/AuronIcebergIntegrationSuite.scala b/thirdparty/auron-iceberg/src/test/scala/org/apache/auron/iceberg/AuronIcebergIntegrationSuite.scala index 6472fcd86..d4d1c19ac 100644 --- a/thirdparty/auron-iceberg/src/test/scala/org/apache/auron/iceberg/AuronIcebergIntegrationSuite.scala +++ b/thirdparty/auron-iceberg/src/test/scala/org/apache/auron/iceberg/AuronIcebergIntegrationSuite.scala @@ -24,7 +24,9 @@ import org.apache.iceberg.{FileFormat, FileScanTask} import org.apache.iceberg.data.{GenericAppenderFactory, Record} import org.apache.iceberg.deletes.PositionDelete import org.apache.iceberg.spark.Spark3Util -import org.apache.spark.sql.Row +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.auron.iceberg.IcebergScanSupport +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec class AuronIcebergIntegrationSuite extends org.apache.spark.sql.QueryTest @@ -132,14 +134,77 @@ class AuronIcebergIntegrationSuite } } - test("iceberg scan falls back for residual filters on data columns") { + test("iceberg scan pushes residual filters into native scan pruning predicates") { withTable("local.db.t_residual") { sql("create table local.db.t_residual (id int, v string) using iceberg") sql("insert into local.db.t_residual values (1, 'a'), (2, 'b')") - val df = sql("select * from local.db.t_residual where v = 'a'") + val df = sql("select * from local.db.t_residual where id = 1") checkAnswer(df, Seq(Row(1, "a"))) + val nativeScanPlan = icebergScanPlan(df) + assert(nativeScanPlan.nonEmpty) + assert(nativeScanPlan.get.pruningPredicates.nonEmpty) val plan = df.queryExecution.executedPlan.toString() - assert(!plan.contains("NativeIcebergTableScan")) + assert(plan.contains("NativeIcebergTableScan")) + assert(plan.contains("NativeFilter")) + } + } + + test("iceberg scan pushes supported IN filters into native scan pruning predicates") { + withTable("local.db.t_residual_supported") { + sql("create table local.db.t_residual_supported (id int, v string) using iceberg") + sql( + "insert into local.db.t_residual_supported values (1, 'alpha'), (2, 'beta'), (3, 'atom')") + val df = sql(""" + |select * from local.db.t_residual_supported + |where id in (1, 3) + |""".stripMargin) + checkAnswer(df, Seq(Row(1, "alpha"), Row(3, "atom"))) + val nativeScanPlan = icebergScanPlan(df) + assert(nativeScanPlan.nonEmpty) + assert(nativeScanPlan.get.pruningPredicates.nonEmpty) + val plan = df.queryExecution.executedPlan.toString() + assert(plan.contains("NativeIcebergTableScan")) + assert(plan.contains("NativeFilter")) + } + } + + test("iceberg scan keeps native post-scan filter when only part of the predicate is pushed") { + withTable("local.db.t_residual_partial_pushdown") { + sql("create table local.db.t_residual_partial_pushdown (id int, v string) using iceberg") + sql( + "insert into local.db.t_residual_partial_pushdown values (1, 'alpha'), (2, 'beta'), (3, 'atom'), (4, 'delta')") + val df = sql(""" + |select * from local.db.t_residual_partial_pushdown + |where id in (1, 2, 3) and id % 2 = 1 + |""".stripMargin) + checkAnswer(df, Seq(Row(1, "alpha"), Row(3, "atom"))) + val nativeScanPlan = icebergScanPlan(df) + assert(nativeScanPlan.nonEmpty) + assert(nativeScanPlan.get.pruningPredicates.nonEmpty) + val plan = df.queryExecution.executedPlan.toString() + assert(plan.contains("NativeIcebergTableScan")) + assert(plan.contains("NativeFilter")) + } + } + + test("iceberg scan keeps native string filter outside scan pruning") { + withTable("local.db.t_residual_string") { + sql("create table local.db.t_residual_string (id int, v string) using iceberg") + sql("insert into local.db.t_residual_string values (1, 'a'), (2, 'b'), (3, null)") + val df = sql(""" + |select * from local.db.t_residual_string + |where v = 'a' + |""".stripMargin) + checkAnswer(df, Seq(Row(1, "a"))) + val nativeScanPlan = icebergScanPlan(df) + assert(nativeScanPlan.nonEmpty) + assert(nativeScanPlan.get.pruningPredicates.nonEmpty) + assert( + !nativeScanPlan.get.pruningPredicates.exists(_.toString.contains("binary_expr")), + "string equality should remain on the post-scan native filter path") + val plan = df.queryExecution.executedPlan.toString() + assert(plan.contains("NativeIcebergTableScan")) + assert(plan.contains("NativeFilter")) } } @@ -238,4 +303,9 @@ class AuronIcebergIntegrationSuite taskIterable.close() } } + + private def icebergScanPlan(df: DataFrame) = + df.queryExecution.sparkPlan.collectFirst { case scan: BatchScanExec => + IcebergScanSupport.plan(scan) + }.flatten }