Skip to content

Commit 05a5c5e

Browse files
committed
WIP: Transducer legacy deserialization
1 parent 0139bd2 commit 05a5c5e

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

src/main/scala/com/johnsnowlabs/nlp/serialization/Feature.scala

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,17 @@ class ArrayFeature[TValue: ClassTag](model: HasFeatures, override val name: Stri
338338
val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
339339
val dataPath = getFieldPath(path, field)
340340
if (fs.exists(dataPath)) {
341-
Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect())
341+
try {
342+
Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect())
343+
} catch {
344+
case e: org.apache.spark.SparkException
345+
if e.getCause.isInstanceOf[java.io.InvalidClassException] =>
346+
println(
347+
"WARNING: Detected InvalidClassException during deserialization, attempting to load as legacy object.")
348+
Some(deserializeLegacyObject[TValue](spark, dataPath.toString).collect())
349+
case e: Exception =>
350+
throw e
351+
}
342352
} else {
343353
None
344354
}
@@ -391,7 +401,17 @@ class SetFeature[TValue: ClassTag](model: HasFeatures, override val name: String
391401
val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
392402
val dataPath = getFieldPath(path, field)
393403
if (fs.exists(dataPath)) {
394-
Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect().toSet)
404+
try {
405+
Some(spark.sparkContext.objectFile[TValue](dataPath.toString).collect.toSet)
406+
} catch {
407+
case e: org.apache.spark.SparkException
408+
if e.getCause.isInstanceOf[java.io.InvalidClassException] =>
409+
println(
410+
"WARNING: Detected InvalidClassException during deserialization, attempting to load as legacy object.")
411+
Some(deserializeLegacyObject[TValue](spark, dataPath.toString).collect.toSet)
412+
case e: Exception =>
413+
throw e
414+
}
395415
} else {
396416
None
397417
}
@@ -444,8 +464,17 @@ class TransducerFeature(model: HasFeatures, override val name: String)
444464
val fs: FileSystem = FileSystem.get(uri, spark.sparkContext.hadoopConfiguration)
445465
val dataPath = getFieldPath(path, field)
446466
if (fs.exists(dataPath)) {
447-
val sc = spark.sparkContext.objectFile[VocabParser](dataPath.toString).collect().head
448-
Some(sc)
467+
try {
468+
Some(spark.sparkContext.objectFile[VocabParser](dataPath.toString).collect().head)
469+
} catch {
470+
case e: org.apache.spark.SparkException
471+
if e.getCause.isInstanceOf[java.io.InvalidClassException] =>
472+
println(
473+
"WARNING: Detected InvalidClassException during deserialization, attempting to load as legacy object.")
474+
Some(deserializeLegacyObject[VocabParser](spark, dataPath.toString).collect().head)
475+
case e: Exception =>
476+
throw e
477+
}
449478
} else {
450479
None
451480
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package com.johnsnowlabs.nlp.serialization
2+
3+
import scala.collection.immutable.HashSet
4+
5+
/** Copied from Scala 2.12.
6+
*
7+
* @param orig
8+
*/
9+
@SerialVersionUID(212L)
10+
private class LegacyHashSetSerializationProxy(@transient private var orig: HashSet[Any])
11+
extends Serializable {
12+
private def writeObject(out: java.io.ObjectOutputStream): Unit = {
13+
val s = orig.size
14+
out.writeInt(s)
15+
for (e <- orig) {
16+
out.writeObject(e)
17+
}
18+
}
19+
20+
private def readObject(in: java.io.ObjectInputStream): Unit = {
21+
orig = HashSet.empty
22+
val s = in.readInt()
23+
for (i <- 0 until s) {
24+
val e = in.readObject()
25+
orig = orig + e
26+
}
27+
}
28+
29+
private def readResolve(): AnyRef = orig
30+
}

src/main/scala/com/johnsnowlabs/nlp/serialization/LegacyObjectInputStream.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,9 @@ class LegacyObjectInputStream(
4545
resultClassDescriptor.getName match {
4646
case "scala.collection.immutable.HashMap$SerializationProxy" =>
4747
ObjectStreamClass.lookup(classOf[LegacyHashMapSerializationProxy])
48+
case "scala.collection.immutable.HashSet$SerializationProxy" =>
49+
ObjectStreamClass.lookup(classOf[LegacyHashSetSerializationProxy])
4850
case "scala.collection.immutable.List$SerializationProxy" =>
49-
/* println("DHA: Using LegacyListSerializationProxy")*/
5051
ObjectStreamClass.lookup(classOf[LegacyListSerializationProxy])
5152
case "scala.collection.immutable.ListSerializeEnd$" =>
5253
println("DHA: Using LegacyListSerializationEnd")

0 commit comments

Comments
 (0)