diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 72b466f5a0f9a..0d37a0b7ce7ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -22,6 +22,7 @@ import org.apache.spark.internal.LogKeys.EXPR import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke} +import org.apache.spark.sql.catalyst.expressions.variant.VariantGet import org.apache.spark.sql.catalyst.optimizer.ConstantFolding import org.apache.spark.sql.connector.catalog.functions.ScalarFunction import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, GetArrayItem => V2GetArrayItem, LiteralValue, NullOrdering, SortDirection, SortValue, UserDefinedScalarFunc} @@ -29,6 +30,7 @@ import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String /** * The builder to generate V2 expressions from catalyst expressions. @@ -333,6 +335,20 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) extends L case _ => None } + case v: VariantGet + if v.path.foldable + && v.child.isInstanceOf[Attribute] => + val colName = v.child.asInstanceOf[Attribute].name + val path = v.path.eval().toString + val typeName = v.dataType.catalogString + val colRef = FieldReference.column(colName) + val pathLit = LiteralValue(UTF8String.fromString(path), StringType) + val typeLit = LiteralValue(UTF8String.fromString(typeName), StringType) + val canonName = v.prettyName + Some(new UserDefinedScalarFunc( + canonName, + canonName, + Array[V2Expression](colRef, pathLit, typeLit))) // TODO supports other expressions case ApplyFunctionExpression(function, children) => val childrenExpressions = children.flatMap(generateExpression(_)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala index 5f89a618edd53..e7b8dce0dff98 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.variant.VariantGet import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, LongType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, LongType, StringType, StructField, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String class DataSourceV2StrategySuite extends SharedSparkSession { @@ -818,6 +819,57 @@ class DataSourceV2StrategySuite extends SharedSparkSession { FieldReference("cdouble")))) } + test("VariantGet serializes to UserDefinedScalarFunc") { + val ref = AttributeReference("v", VariantType)() + val path = Literal.create("$.city", StringType) + val expr = VariantGet(ref, path, StringType, failOnError = true) + val gt = GreaterThan(expr, Literal.create("NYC", StringType)) + val result = new V2ExpressionBuilder(gt, isPredicate = true).build() + result match { + case Some(v2pred: Predicate) if v2pred.name() == ">" => + v2pred.children()(0) match { + case udf: UserDefinedScalarFunc => + assert(udf.name() == "variant_get") + assert(udf.children().length == 3) + case _ => fail("expected UserDefinedScalarFunc") + } + case _ => fail("expected predicate with name '>'") + } + } + + test("VariantGet predicate is translated by translateFilterV2") { + val ref = AttributeReference("v", VariantType)() + val path = Literal.create("$.city", StringType) + val expr = VariantGet(ref, path, StringType, failOnError = true) + val gt = GreaterThan(expr, Literal.create("NYC", StringType)) + val result = DataSourceV2Strategy.translateFilterV2(gt) + assert(result.isDefined) + result.get.children()(0) match { + case udf: UserDefinedScalarFunc => + assert(udf.name() == "variant_get") + assert(udf.children().length == 3) + case _ => fail("expected UserDefinedScalarFunc in translated predicate") + } + } + + test("try_variant_get serializes to UserDefinedScalarFunc with try_variant_get name") { + val ref = AttributeReference("v", VariantType)() + val path = Literal.create("$.city", StringType) + val expr = VariantGet(ref, path, StringType, failOnError = false) + val gt = GreaterThan(expr, Literal.create("NYC", StringType)) + val result = new V2ExpressionBuilder(gt, isPredicate = true).build() + result match { + case Some(v2pred: Predicate) if v2pred.name() == ">" => + v2pred.children()(0) match { + case udf: UserDefinedScalarFunc => + assert(udf.name() == "try_variant_get") + assert(udf.children().length == 3) + case _ => fail("expected UserDefinedScalarFunc") + } + case _ => fail("expected predicate with name '>'") + } + } + test("Current Like functions are not supported") { val currentFunctions = Seq( CurrentDate(),