@@ -14,7 +14,7 @@ import scala.collection.mutable.ListBuffer
1414
1515abstract class DataReader {
1616 val tpe : ReaderType
17- def read (spark : SparkSession , configs : Configs , partitionNum : String ): DataFrame
17+ def read (spark : SparkSession , configs : Configs , partitionNum : Int ): DataFrame
1818}
1919object DataReader {
2020 def make (configs : Configs ): DataReader = {
@@ -32,12 +32,11 @@ object DataReader {
3232
3333class NebulaReader extends DataReader {
3434 override val tpe : ReaderType = ReaderType .nebula
35- override def read (spark : SparkSession , configs : Configs , partitionNum : String ): DataFrame = {
35+ override def read (spark : SparkSession , configs : Configs , partitionNum : Int ): DataFrame = {
3636 val metaAddress = configs.nebulaConfig.readConfigEntry.address
3737 val space = configs.nebulaConfig.readConfigEntry.space
3838 val labels = configs.nebulaConfig.readConfigEntry.labels
3939 val weights = configs.nebulaConfig.readConfigEntry.weightCols
40- val partition = partitionNum.toInt
4140
4241 val config =
4342 NebulaConnectionConfig
@@ -60,7 +59,7 @@ class NebulaReader extends DataReader {
6059 .withLabel(labels(i))
6160 .withNoColumn(noColumn)
6261 .withReturnCols(returnCols.toList)
63- .withPartitionNum(partition )
62+ .withPartitionNum(partitionNum )
6463 .build()
6564 if (dataset == null ) {
6665 dataset = spark.read.nebula(config, nebulaReadEdgeConfig).loadEdgesToDF()
@@ -85,13 +84,12 @@ final class NebulaNgqlReader extends NebulaReader {
8584
8685 override val tpe : ReaderType = ReaderType .nebulaNgql
8786
88- override def read (spark : SparkSession , configs : Configs , partitionNum : String ): DataFrame = {
87+ override def read (spark : SparkSession , configs : Configs , partitionNum : Int ): DataFrame = {
8988 val metaAddress = configs.nebulaConfig.readConfigEntry.address
9089 val graphAddress = configs.nebulaConfig.readConfigEntry.graphAddress
9190 val space = configs.nebulaConfig.readConfigEntry.space
9291 val labels = configs.nebulaConfig.readConfigEntry.labels
9392 val weights = configs.nebulaConfig.readConfigEntry.weightCols
94- val partition = partitionNum.toInt
9593 val ngql = configs.nebulaConfig.readConfigEntry.ngql
9694
9795 val config =
@@ -112,7 +110,7 @@ final class NebulaNgqlReader extends NebulaReader {
112110 .builder()
113111 .withSpace(space)
114112 .withLabel(labels(i))
115- .withPartitionNum(partition )
113+ .withPartitionNum(partitionNum )
116114 .withNgql(ngql)
117115 .build()
118116 if (dataset == null ) {
@@ -137,13 +135,11 @@ final class NebulaNgqlReader extends NebulaReader {
137135
138136final class CsvReader extends DataReader {
139137 override val tpe : ReaderType = ReaderType .csv
140- override def read (spark : SparkSession , configs : Configs , partitionNum : String ): DataFrame = {
138+ override def read (spark : SparkSession , configs : Configs , partitionNum : Int ): DataFrame = {
141139 val delimiter = configs.localConfigEntry.delimiter
142140 val header = configs.localConfigEntry.header
143141 val localPath = configs.localConfigEntry.filePath
144142
145- val partition = partitionNum.toInt
146-
147143 val data =
148144 spark.read
149145 .option(" header" , header)
@@ -157,18 +153,17 @@ final class CsvReader extends DataReader {
157153 } else {
158154 data.select(src, dst)
159155 }
160- if (partition != 0 ) {
161- data.repartition(partition )
156+ if (partitionNum != 0 ) {
157+ data.repartition(partitionNum )
162158 }
163159 data
164160 }
165161}
166162final class JsonReader extends DataReader {
167163 override val tpe : ReaderType = ReaderType .json
168- override def read (spark : SparkSession , configs : Configs , partitionNum : String ): DataFrame = {
164+ override def read (spark : SparkSession , configs : Configs , partitionNum : Int ): DataFrame = {
169165 val localPath = configs.localConfigEntry.filePath
170166 val data = spark.read.json(localPath)
171- val partition = partitionNum.toInt
172167
173168 val weight = configs.localConfigEntry.weight
174169 val src = configs.localConfigEntry.srcId
@@ -178,8 +173,8 @@ final class JsonReader extends DataReader {
178173 } else {
179174 data.select(src, dst)
180175 }
181- if (partition != 0 ) {
182- data.repartition(partition )
176+ if (partitionNum != 0 ) {
177+ data.repartition(partitionNum )
183178 }
184179 data
185180 }
0 commit comments