Skip to content

Commit b556e59

Browse files
Addison Highamfalaki
authored andcommitted
Add nullValue being respected when parsing CSVs
This change makes it so that we look for a user specified nullValue through the CSV parsing. This allows for handling CSVs that might use something else other than an empty string to represent nulls. It reuses the same flag as CSV saving, `nullValue`. This change should be non-breaking. This also pushes this behavior into inferSchema so that inferred schemas will properly reflect the user given null value. Author: Addison Higham <ahigham@instructure.com> Closes #224 from addisonj/master.
1 parent 44964a2 commit b556e59

File tree

11 files changed

+107
-20
lines changed

11 files changed

+107
-20
lines changed

README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ When reading files the API accepts several options:
5656
* `inferSchema`: automatically infers column types. It requires one extra pass over the data and is false by default
5757
* `comment`: skip lines beginning with this character. Default is `"#"`. Disable comments by setting this to `null`.
5858
* `codec`: compression codec to use when saving to file. Should be the fully qualified name of a class implementing `org.apache.hadoop.io.compress.CompressionCodec`. Defaults to no compression when a codec is not specified.
59+
* `nullValue`: specificy a string that indicates a null value, any fields matching this string will be set as nulls in the DataFrame
5960

6061
The package also support saving simple (non-nested) DataFrame. When saving you can specify the delimiter and whether we should generate a header row for the table. See following examples for more details.
6162

@@ -109,7 +110,7 @@ import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerT
109110

110111
val sqlContext = new SQLContext(sc)
111112
val customSchema = StructType(
112-
StructField("year", IntegerType, true),
113+
StructField("year", IntegerType, true),
113114
StructField("make", StringType, true),
114115
StructField("model", StringType, true),
115116
StructField("comment", StringType, true),
@@ -155,7 +156,7 @@ import org.apache.spark.sql.SQLContext
155156

156157
val sqlContext = new SQLContext(sc)
157158
val df = sqlContext.load(
158-
"com.databricks.spark.csv",
159+
"com.databricks.spark.csv",
159160
Map("path" -> "cars.csv", "header" -> "true", "inferSchema" -> "true"))
160161
val selectedData = df.select("year", "model")
161162
selectedData.save("newcars.csv", "com.databricks.spark.csv")
@@ -168,14 +169,14 @@ import org.apache.spark.sql.types.{StructType, StructField, StringType, IntegerT
168169

169170
val sqlContext = new SQLContext(sc)
170171
val customSchema = StructType(
171-
StructField("year", IntegerType, true),
172+
StructField("year", IntegerType, true),
172173
StructField("make", StringType, true),
173174
StructField("model", StringType, true),
174175
StructField("comment", StringType, true),
175176
StructField("blank", StringType, true))
176177

177178
val df = sqlContext.load(
178-
"com.databricks.spark.csv",
179+
"com.databricks.spark.csv",
179180
schema = customSchema,
180181
Map("path" -> "cars.csv", "header" -> "true"))
181182

@@ -210,7 +211,7 @@ import org.apache.spark.sql.types.*;
210211

211212
SQLContext sqlContext = new SQLContext(sc);
212213
StructType customSchema = new StructType(new StructField[] {
213-
new StructField("year", DataTypes.IntegerType, true, Metadata.empty()),
214+
new StructField("year", DataTypes.IntegerType, true, Metadata.empty()),
214215
new StructField("make", DataTypes.StringType, true, Metadata.empty()),
215216
new StructField("model", DataTypes.StringType, true, Metadata.empty()),
216217
new StructField("comment", DataTypes.StringType, true, Metadata.empty()),

src/main/scala/com/databricks/spark/csv/CsvParser.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class CsvParser extends Serializable {
4040
private var charset: String = TextFile.DEFAULT_CHARSET.name()
4141
private var inferSchema: Boolean = false
4242
private var codec: String = null
43+
private var nullValue: String = ""
4344

4445
def withUseHeader(flag: Boolean): CsvParser = {
4546
this.useHeader = flag
@@ -111,6 +112,11 @@ class CsvParser extends Serializable {
111112
this
112113
}
113114

115+
def withNullValue(nullValue: String): CsvParser = {
116+
this.nullValue = nullValue
117+
this
118+
}
119+
114120
/** Returns a Schema RDD for the given CSV path. */
115121
@throws[RuntimeException]
116122
def csvFile(sqlContext: SQLContext, path: String): DataFrame = {
@@ -129,7 +135,8 @@ class CsvParser extends Serializable {
129135
treatEmptyValuesAsNulls,
130136
schema,
131137
inferSchema,
132-
codec)(sqlContext)
138+
codec,
139+
nullValue)(sqlContext)
133140
sqlContext.baseRelationToDataFrame(relation)
134141
}
135142

@@ -149,7 +156,8 @@ class CsvParser extends Serializable {
149156
treatEmptyValuesAsNulls,
150157
schema,
151158
inferSchema,
152-
codec)(sqlContext)
159+
codec,
160+
nullValue)(sqlContext)
153161
sqlContext.baseRelationToDataFrame(relation)
154162
}
155163
}

src/main/scala/com/databricks/spark/csv/CsvRelation.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ case class CsvRelation protected[spark] (
4646
treatEmptyValuesAsNulls: Boolean,
4747
userSchema: StructType = null,
4848
inferCsvSchema: Boolean,
49-
codec: String = null)(@transient val sqlContext: SQLContext)
49+
codec: String = null,
50+
nullValue: String = "")(@transient val sqlContext: SQLContext)
5051
extends BaseRelation with TableScan with PrunedScan with InsertableRelation {
5152

5253
/**
@@ -116,7 +117,7 @@ case class CsvRelation protected[spark] (
116117
while (index < schemaFields.length) {
117118
val field = schemaFields(index)
118119
rowArray(index) = TypeCast.castTo(tokens(index), field.dataType, field.nullable,
119-
treatEmptyValuesAsNulls)
120+
treatEmptyValuesAsNulls, nullValue)
120121
index = index + 1
121122
}
122123
Some(Row.fromSeq(rowArray))
@@ -189,7 +190,9 @@ case class CsvRelation protected[spark] (
189190
indexSafeTokens(index),
190191
field.dataType,
191192
field.nullable,
192-
treatEmptyValuesAsNulls)
193+
treatEmptyValuesAsNulls,
194+
nullValue
195+
)
193196
subIndex = subIndex + 1
194197
}
195198
Some(Row.fromSeq(rowArray.take(requiredSize)))
@@ -235,7 +238,7 @@ case class CsvRelation protected[spark] (
235238
firstRow.zipWithIndex.map { case (value, index) => s"C$index"}
236239
}
237240
if (this.inferCsvSchema) {
238-
InferSchema(tokenRdd(header), header)
241+
InferSchema(tokenRdd(header), header, nullValue)
239242
} else {
240243
// By default fields are assumed to be StringType
241244
val schemaFields = header.map { fieldName =>

src/main/scala/com/databricks/spark/csv/DefaultSource.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class DefaultSource
136136
} else {
137137
throw new Exception("Infer schema flag can be true or false")
138138
}
139+
val nullValue = parameters.getOrElse("nullValue", "")
139140

140141
val codec = parameters.getOrElse("codec", null)
141142

@@ -154,7 +155,8 @@ class DefaultSource
154155
treatEmptyValuesAsNullsFlag,
155156
schema,
156157
inferSchemaFlag,
157-
codec)(sqlContext)
158+
codec,
159+
nullValue)(sqlContext)
158160
}
159161

160162
override def createRelation(

src/main/scala/com/databricks/spark/csv/util/InferSchema.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,15 @@ private[csv] object InferSchema {
3131
* 2. Merge row types to find common type
3232
* 3. Replace any null types with string type
3333
*/
34-
def apply(tokenRdd: RDD[Array[String]], header: Array[String]): StructType = {
34+
def apply(
35+
tokenRdd: RDD[Array[String]],
36+
header: Array[String],
37+
nullValue: String = ""): StructType = {
3538

3639
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
37-
val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)(inferRowType, mergeRowTypes)
40+
val rootTypes: Array[DataType] = tokenRdd.aggregate(startType)(
41+
inferRowType(nullValue),
42+
mergeRowTypes)
3843

3944
val structFields = header.zip(rootTypes).map { case (thisHeader, rootType) =>
4045
StructField(thisHeader, rootType, nullable = true)
@@ -43,10 +48,11 @@ private[csv] object InferSchema {
4348
StructType(structFields)
4449
}
4550

46-
private def inferRowType(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
51+
private def inferRowType(nullValue: String)
52+
(rowSoFar: Array[DataType], next: Array[String]): Array[DataType] = {
4753
var i = 0
4854
while (i < math.min(rowSoFar.length, next.length)) { // May have columns on right missing.
49-
rowSoFar(i) = inferField(rowSoFar(i), next(i))
55+
rowSoFar(i) = inferField(rowSoFar(i), next(i), nullValue)
5056
i+=1
5157
}
5258
rowSoFar
@@ -68,8 +74,10 @@ private[csv] object InferSchema {
6874
* Infer type of string field. Given known type Double, and a string "1", there is no
6975
* point checking if it is an Int, as the final type must be Double or higher.
7076
*/
71-
private[csv] def inferField(typeSoFar: DataType, field: String): DataType = {
72-
if (field == null || field.isEmpty) {
77+
private[csv] def inferField(typeSoFar: DataType,
78+
field: String,
79+
nullValue: String = ""): DataType = {
80+
if (field == null || field.isEmpty || field == nullValue) {
7381
typeSoFar
7482
} else {
7583
typeSoFar match {

src/main/scala/com/databricks/spark/csv/util/TypeCast.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,15 @@ object TypeCast {
4343
datum: String,
4444
castType: DataType,
4545
nullable: Boolean = true,
46-
treatEmptyValuesAsNulls: Boolean = false): Any = {
47-
if (datum == "" && nullable && (!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls)){
46+
treatEmptyValuesAsNulls: Boolean = false,
47+
nullValue: String = ""): Any = {
48+
// if nullValue is not an empty string, don't require treatEmptyValuesAsNulls
49+
// to be set to true
50+
val nullValueIsNotEmpty = nullValue != ""
51+
if (datum == nullValue &&
52+
nullable &&
53+
(!castType.isInstanceOf[StringType] || treatEmptyValuesAsNulls || nullValueIsNotEmpty)
54+
){
4855
null
4956
} else {
5057
castType match {
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
name,age
2+
alice,35
3+
bob,null
4+
null,24
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
name,age
2+
alice,35
3+
bob,\N
4+
\N,24

src/test/scala/com/databricks/spark/csv/CsvSuite.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
3333
val carsAltFile = "src/test/resources/cars-alternative.csv"
3434
val carsUnbalancedQuotesFile = "src/test/resources/cars-unbalanced-quotes.csv"
3535
val nullNumbersFile = "src/test/resources/null-numbers.csv"
36+
val nullNullNumbersFile = "src/test/resources/null_null_numbers.csv"
37+
val nullSlashNNumbersFile = "src/test/resources/null_slashn_numbers.csv"
3638
val emptyFile = "src/test/resources/empty.csv"
3739
val ageFile = "src/test/resources/ages.csv"
3840
val escapeFile = "src/test/resources/escape.csv"
@@ -572,6 +574,36 @@ abstract class AbstractCsvSuite extends FunSuite with BeforeAndAfterAll {
572574
assert(results(2).toSeq === Seq("", 24))
573575
}
574576

577+
test("DSL test nullable fields with user defined null value of \"null\"") {
578+
val results = new CsvParser()
579+
.withSchema(StructType(List(StructField("name", StringType, false),
580+
StructField("age", IntegerType, true))))
581+
.withUseHeader(true)
582+
.withParserLib(parserLib)
583+
.withNullValue("null")
584+
.csvFile(sqlContext, nullNullNumbersFile)
585+
.collect()
586+
587+
assert(results.head.toSeq === Seq("alice", 35))
588+
assert(results(1).toSeq === Seq("bob", null))
589+
assert(results(2).toSeq === Seq("null", 24))
590+
}
591+
592+
test("DSL test nullable fields with user defined null value of \"\\N\"") {
593+
val results = new CsvParser()
594+
.withSchema(StructType(List(StructField("name", StringType, false),
595+
StructField("age", IntegerType, true))))
596+
.withUseHeader(true)
597+
.withParserLib(parserLib)
598+
.withNullValue("\\N")
599+
.csvFile(sqlContext, nullSlashNNumbersFile)
600+
.collect()
601+
602+
assert(results.head.toSeq === Seq("alice", 35))
603+
assert(results(1).toSeq === Seq("bob", null))
604+
assert(results(2).toSeq === Seq("\\N", 24))
605+
}
606+
575607
test("Commented lines in CSV data") {
576608
val results: Array[Row] = new CsvParser()
577609
.withDelimiter(',')

src/test/scala/com/databricks/spark/csv/util/InferSchemaSuite.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ class InferSchemaSuite extends FunSuite {
1515
assert(InferSchema.inferField(NullType, "2015-08-20 15:57:00") == TimestampType)
1616
}
1717

18+
test("Null fields are handled properly when a nullValue is specified") {
19+
assert(InferSchema.inferField(NullType, "null", "null") == NullType)
20+
assert(InferSchema.inferField(StringType, "null", "null") == StringType)
21+
assert(InferSchema.inferField(LongType, "null", "null") == LongType)
22+
assert(InferSchema.inferField(IntegerType, "\\N", "\\N") == IntegerType)
23+
assert(InferSchema.inferField(DoubleType, "\\N", "\\N") == DoubleType)
24+
assert(InferSchema.inferField(TimestampType, "\\N", "\\N") == TimestampType)
25+
}
26+
1827
test("String fields types are inferred correctly from other types") {
1928
assert(InferSchema.inferField(LongType, "1.0") == DoubleType)
2029
assert(InferSchema.inferField(LongType, "test") == StringType)

0 commit comments

Comments
 (0)