Skip to content

Commit 51d7a20

Browse files
harshmotw-dbcloud-fan
authored andcommitted
[SPARK-54306] Annotate Variant columns with Variant logical type annotation
### What changes were proposed in this pull request? This PR makes changes to the parquet writer to make it annotate variant columns with the parquet variant logical type annotation. ### Why are the changes needed? The Parquet spec has formally adopted the Variant logical type, and therefore, Variant columns must be properly annotated in Spark 4.1.0 which depends on Parquet-java 1.16.0 which contains the variant logical type annotation. This change is hidden behind a flag that is disabled by default until read support can be properly implemented. ### Does this PR introduce _any_ user-facing change? Yes, Parquet files written by Spark 4.1.0 with the flag enabled (which it eventually will be by default) could contain the variant logical type annotation which readers without support for the type will not be able to read ### How was this patch tested? Unit test to check if nested as well as top-level variants are properly annotated, and the data is being written correctly. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #53005 from harshmotw-db/harshmotw-db/variant_annotation_write. Authored-by: Harsh Motwani <harsh.motwani@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit 5270c99) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 2ee1dcb commit 51d7a20

File tree

4 files changed

+104
-5
lines changed

4 files changed

+104
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1585,6 +1585,14 @@ object SQLConf {
15851585
.booleanConf
15861586
.createWithDefault(true)
15871587

1588+
val PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE =
1589+
buildConf("spark.sql.parquet.variant.annotateLogicalType.enabled")
1590+
.doc("When enabled, Spark annotates the variant groups written to Parquet as the parquet " +
1591+
"variant logical type.")
1592+
.version("4.1.0")
1593+
.booleanConf
1594+
.createWithDefault(false)
1595+
15881596
val PARQUET_FIELD_ID_READ_ENABLED =
15891597
buildConf("spark.sql.parquet.fieldId.read.enabled")
15901598
.doc("Field ID is a native field of the Parquet schema spec. When enabled, Parquet readers " +
@@ -7638,6 +7646,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {
76387646

76397647
def parquetFieldIdWriteEnabled: Boolean = getConf(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED)
76407648

7649+
def parquetAnnotateVariantLogicalType: Boolean = getConf(PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE)
7650+
76417651
def ignoreMissingParquetFieldId: Boolean = getConf(SQLConf.IGNORE_MISSING_PARQUET_FIELD_ID)
76427652

76437653
def legacyParquetNanosAsLong: Boolean = getConf(SQLConf.LEGACY_PARQUET_NANOS_AS_LONG)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,10 @@ class ParquetToSparkSchemaConverter(
373373

374374
Option(field.getLogicalTypeAnnotation).fold(
375375
convertInternal(groupColumn, sparkReadType.map(_.asInstanceOf[StructType]))) {
376+
// Temporary workaround to read Shredded variant data
377+
case v: VariantLogicalTypeAnnotation if v.getSpecVersion == 1 && sparkReadType.isEmpty =>
378+
convertInternal(groupColumn, None)
379+
376380
// A Parquet list is represented as a 3-level structure:
377381
//
378382
// <list-repetition> group <name> (LIST) {
@@ -552,7 +556,9 @@ class SparkToParquetSchemaConverter(
552556
writeLegacyParquetFormat: Boolean = SQLConf.PARQUET_WRITE_LEGACY_FORMAT.defaultValue.get,
553557
outputTimestampType: SQLConf.ParquetOutputTimestampType.Value =
554558
SQLConf.ParquetOutputTimestampType.INT96,
555-
useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.defaultValue.get) {
559+
useFieldId: Boolean = SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.defaultValue.get,
560+
annotateVariantLogicalType: Boolean =
561+
SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.defaultValue.get) {
556562

557563
def this(conf: SQLConf) = this(
558564
writeLegacyParquetFormat = conf.writeLegacyParquetFormat,
@@ -563,7 +569,9 @@ class SparkToParquetSchemaConverter(
563569
writeLegacyParquetFormat = conf.get(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key).toBoolean,
564570
outputTimestampType = SQLConf.ParquetOutputTimestampType.withName(
565571
conf.get(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key)),
566-
useFieldId = conf.get(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key).toBoolean)
572+
useFieldId = conf.get(SQLConf.PARQUET_FIELD_ID_WRITE_ENABLED.key).toBoolean,
573+
annotateVariantLogicalType =
574+
conf.get(SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key).toBoolean)
567575

568576
/**
569577
* Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]].
@@ -817,14 +825,22 @@ class SparkToParquetSchemaConverter(
817825
// ===========
818826

819827
case VariantType =>
820-
Types.buildGroup(repetition)
828+
(if (annotateVariantLogicalType) {
829+
Types.buildGroup(repetition).as(LogicalTypeAnnotation.variantType(1))
830+
} else {
831+
Types.buildGroup(repetition)
832+
})
821833
.addField(convertField(StructField("value", BinaryType, nullable = false), inShredded))
822834
.addField(convertField(StructField("metadata", BinaryType, nullable = false), inShredded))
823835
.named(field.name)
824836

825837
case s: StructType if SparkShreddingUtils.isVariantShreddingStruct(s) =>
826838
// Variant struct takes a Variant and writes to Parquet as a shredded schema.
827-
val group = Types.buildGroup(repetition)
839+
val group = if (annotateVariantLogicalType) {
840+
Types.buildGroup(repetition).as(LogicalTypeAnnotation.variantType(1))
841+
} else {
842+
Types.buildGroup(repetition)
843+
}
828844
s.fields.foreach { f =>
829845
group.addField(convertField(f, inShredded = true))
830846
}

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,10 @@ object ParquetUtils extends Logging {
523523
SQLConf.LEGACY_PARQUET_NANOS_AS_LONG.key,
524524
sqlConf.legacyParquetNanosAsLong.toString)
525525

526+
conf.set(
527+
SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key,
528+
sqlConf.parquetAnnotateVariantLogicalType.toString)
529+
526530
// Sets compression scheme
527531
conf.set(ParquetOutputFormat.COMPRESSION, parquetOptions.compressionCodecClassName)
528532

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVariantShreddingSuite.scala

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.datasources.parquet
1919

2020
import java.io.File
2121

22+
import scala.jdk.CollectionConverters._
23+
2224
import org.apache.hadoop.conf.Configuration
2325
import org.apache.hadoop.fs.Path
2426
import org.apache.parquet.hadoop.ParquetFileReader
2527
import org.apache.parquet.hadoop.util.HadoopInputFile
26-
import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType}
28+
import org.apache.parquet.schema.{LogicalTypeAnnotation, PrimitiveType, Type}
2729
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
2830

2931
import org.apache.spark.sql.{QueryTest, Row}
@@ -154,6 +156,73 @@ class ParquetVariantShreddingSuite extends QueryTest with ParquetTest with Share
154156
}
155157
}
156158

159+
test("variant logical type annotation") {
160+
Seq(false, true).foreach { annotateVariantLogicalType =>
161+
Seq(false, true).foreach { shredVariant =>
162+
Seq(false, true).foreach { allowReadingShredded =>
163+
withSQLConf(SQLConf.VARIANT_WRITE_SHREDDING_ENABLED.key -> shredVariant.toString,
164+
SQLConf.VARIANT_INFER_SHREDDING_SCHEMA.key -> shredVariant.toString,
165+
SQLConf.VARIANT_ALLOW_READING_SHREDDED.key ->
166+
(allowReadingShredded || shredVariant).toString,
167+
SQLConf.PARQUET_ANNOTATE_VARIANT_LOGICAL_TYPE.key ->
168+
annotateVariantLogicalType.toString) {
169+
def validateAnnotation(g: Type): Unit = {
170+
if (annotateVariantLogicalType) {
171+
assert(g.getLogicalTypeAnnotation == LogicalTypeAnnotation.variantType(1))
172+
} else {
173+
assert(g.getLogicalTypeAnnotation == null)
174+
}
175+
}
176+
withTempDir { dir =>
177+
// write parquet file
178+
val df = spark.sql(
179+
"""
180+
| select
181+
| id * 2 i,
182+
| to_variant_object(named_struct('id', id)) v,
183+
| named_struct('i', (id * 2)::string,
184+
| 'nv', to_variant_object(named_struct('id', 30 + id))) ns,
185+
| array(to_variant_object(named_struct('id', 10 + id))) av,
186+
| map('v2', to_variant_object(named_struct('id', 20 + id))) mv
187+
| from range(0,3,1,1)""".stripMargin)
188+
df.write.mode("overwrite").parquet(dir.getAbsolutePath)
189+
val file = dir.listFiles().find(_.getName.endsWith(".parquet")).get
190+
val parquetFilePath = file.getAbsolutePath
191+
val inputFile = HadoopInputFile.fromPath(new Path(parquetFilePath),
192+
new Configuration())
193+
val reader = ParquetFileReader.open(inputFile)
194+
val footer = reader.getFooter
195+
val schema = footer.getFileMetaData.getSchema
196+
val vGroup = schema.getType(schema.getFieldIndex("v"))
197+
validateAnnotation(vGroup)
198+
assert(vGroup.asGroupType().getFields.asScala.toSeq
199+
.exists(_.getName == "typed_value") == shredVariant)
200+
val nsGroup = schema.getType(schema.getFieldIndex("ns")).asGroupType()
201+
val nvGroup = nsGroup.getType(nsGroup.getFieldIndex("nv"))
202+
validateAnnotation(nvGroup)
203+
val avGroup = schema.getType(schema.getFieldIndex("av")).asGroupType()
204+
val avList = avGroup.getType(avGroup.getFieldIndex("list")).asGroupType()
205+
val avElement = avList.getType(avList.getFieldIndex("element"))
206+
validateAnnotation(avElement)
207+
val mvGroup = schema.getType(schema.getFieldIndex("mv")).asGroupType()
208+
val mvList = mvGroup.getType(mvGroup.getFieldIndex("key_value")).asGroupType()
209+
val mvValue = mvList.getType(mvList.getFieldIndex("value"))
210+
validateAnnotation(mvValue)
211+
// verify result
212+
val result = spark.read.format("parquet")
213+
.schema("v variant, ns struct<nv variant>, av array<variant>, " +
214+
"mv map<string, variant>")
215+
.load(dir.getAbsolutePath)
216+
.selectExpr("v:id::int i1", "ns.nv:id::int i2", "av[0]:id::int i3",
217+
"mv['v2']:id::int i4")
218+
checkAnswer(result, Array(Row(0, 30, 10, 20), Row(1, 31, 11, 21), Row(2, 32, 12, 22)))
219+
reader.close()
220+
}
221+
}
222+
}
223+
}
224+
}
225+
}
157226

158227
testWithTempDir("write shredded variant basic") { dir =>
159228
val schema = "a int, b string, c decimal(15, 1)"

0 commit comments

Comments
 (0)