66package com .vesoft .nebula .algorithm
77
88import com .vesoft .nebula .algorithm .config .Configs .Argument
9- import com .vesoft .nebula .algorithm .config .{
10- AlgoConfig ,
11- BetweennessConfig ,
12- BfsConfig ,
13- CcConfig ,
14- CoefficientConfig ,
15- Configs ,
16- DfsConfig ,
17- HanpConfig ,
18- JaccardConfig ,
19- KCoreConfig ,
20- LPAConfig ,
21- LouvainConfig ,
22- Node2vecConfig ,
23- PRConfig ,
24- ShortestPathConfig ,
25- SparkConfig ,
26- DegreeStaticConfig
27- }
28- import com .vesoft .nebula .algorithm .lib .{
29- BetweennessCentralityAlgo ,
30- BfsAlgo ,
31- ClosenessAlgo ,
32- ClusteringCoefficientAlgo ,
33- ConnectedComponentsAlgo ,
34- DegreeStaticAlgo ,
35- DfsAlgo ,
36- GraphTriangleCountAlgo ,
37- HanpAlgo ,
38- JaccardAlgo ,
39- KCoreAlgo ,
40- LabelPropagationAlgo ,
41- LouvainAlgo ,
42- Node2vecAlgo ,
43- PageRankAlgo ,
44- ShortestPathAlgo ,
45- StronglyConnectedComponentsAlgo ,
46- TriangleCountAlgo
47- }
48- import com .vesoft .nebula .algorithm .reader .{CsvReader , JsonReader , NebulaReader }
49- import com .vesoft .nebula .algorithm .writer .{CsvWriter , NebulaWriter , TextWriter }
9+ import com .vesoft .nebula .algorithm .config ._
10+ import com .vesoft .nebula .algorithm .lib ._
11+ import com .vesoft .nebula .algorithm .reader .{CsvReader , DataReader , JsonReader , NebulaReader }
12+ import com .vesoft .nebula .algorithm .writer .{AlgoWriter , CsvWriter , NebulaWriter , TextWriter }
5013import org .apache .commons .math3 .ode .UnknownParameterException
5114import org .apache .log4j .Logger
5215import org .apache .spark .sql .{DataFrame , Dataset , Row , SparkSession }
@@ -114,26 +77,8 @@ object Main {
11477 private [this ] def createDataSource (spark : SparkSession ,
11578 configs : Configs ,
11679 partitionNum : String ): DataFrame = {
117- val dataSource = configs.dataSourceSinkEntry.source
118- val dataSet : Dataset [Row ] = dataSource.toLowerCase match {
119- case " nebula" => {
120- val reader = new NebulaReader (spark, configs, partitionNum)
121- reader.read()
122- }
123- case " nebula-ngql" => {
124- val reader = new NebulaReader (spark, configs, partitionNum)
125- reader.readNgql()
126- }
127- case " csv" => {
128- val reader = new CsvReader (spark, configs, partitionNum)
129- reader.read()
130- }
131- case " json" => {
132- val reader = new JsonReader (spark, configs, partitionNum)
133- reader.read()
134- }
135- }
136- dataSet
80+ val dataSource = DataReader .make(configs)
81+ dataSource.read(spark, configs, partitionNum)
13782 }
13883
13984 /**
@@ -149,99 +94,63 @@ object Main {
14994 configs : Configs ,
15095 dataSet : DataFrame ): DataFrame = {
15196 val hasWeight = configs.dataSourceSinkEntry.hasWeight
152- val algoResult = {
153- algoName.toLowerCase match {
154- case " pagerank" => {
155- val pageRankConfig = PRConfig .getPRConfig(configs)
156- PageRankAlgo (spark, dataSet, pageRankConfig, hasWeight)
157- }
158- case " louvain" => {
159- val louvainConfig = LouvainConfig .getLouvainConfig(configs)
160- LouvainAlgo (spark, dataSet, louvainConfig, hasWeight)
161- }
162- case " connectedcomponent" => {
163- val ccConfig = CcConfig .getCcConfig(configs)
164- ConnectedComponentsAlgo (spark, dataSet, ccConfig, hasWeight)
165- }
166- case " labelpropagation" => {
167- val lpaConfig = LPAConfig .getLPAConfig(configs)
168- LabelPropagationAlgo (spark, dataSet, lpaConfig, hasWeight)
169- }
170- case " shortestpaths" => {
171- val spConfig = ShortestPathConfig .getShortestPathConfig(configs)
172- ShortestPathAlgo (spark, dataSet, spConfig, hasWeight)
173- }
174- case " degreestatic" => {
175- val dsConfig = DegreeStaticConfig .getDegreeStaticConfig(configs)
176- DegreeStaticAlgo (spark, dataSet, dsConfig)
177- }
178- case " kcore" => {
179- val kCoreConfig = KCoreConfig .getKCoreConfig(configs)
180- KCoreAlgo (spark, dataSet, kCoreConfig)
181- }
182- case " stronglyconnectedcomponent" => {
183- val ccConfig = CcConfig .getCcConfig(configs)
184- StronglyConnectedComponentsAlgo (spark, dataSet, ccConfig, hasWeight)
185- }
186- case " betweenness" => {
187- val betweennessConfig = BetweennessConfig .getBetweennessConfig(configs)
188- BetweennessCentralityAlgo (spark, dataSet, betweennessConfig, hasWeight)
189- }
190- case " trianglecount" => {
191- TriangleCountAlgo (spark, dataSet)
192- }
193- case " graphtrianglecount" => {
194- GraphTriangleCountAlgo (spark, dataSet)
195- }
196- case " clusteringcoefficient" => {
197- val coefficientConfig = CoefficientConfig .getCoefficientConfig(configs)
198- ClusteringCoefficientAlgo (spark, dataSet, coefficientConfig)
199- }
200- case " closeness" => {
201- ClosenessAlgo (spark, dataSet, hasWeight)
202- }
203- case " hanp" => {
204- val hanpConfig = HanpConfig .getHanpConfig(configs)
205- HanpAlgo (spark, dataSet, hanpConfig, hasWeight)
206- }
207- case " node2vec" => {
208- val node2vecConfig = Node2vecConfig .getNode2vecConfig(configs)
209- Node2vecAlgo (spark, dataSet, node2vecConfig, hasWeight)
210- }
211- case " bfs" => {
212- val bfsConfig = BfsConfig .getBfsConfig(configs)
213- BfsAlgo (spark, dataSet, bfsConfig)
214- }
215- case " dfs" => {
216- val dfsConfig = DfsConfig .getDfsConfig(configs)
217- DfsAlgo (spark, dataSet, dfsConfig)
218- }
219- case " jaccard" => {
220- val jaccardConfig = JaccardConfig .getJaccardConfig(configs)
221- JaccardAlgo (spark, dataSet, jaccardConfig)
222- }
223- case _ => throw new UnknownParameterException (" unknown executeAlgo name." )
224- }
97+ AlgorithmType .mapping.getOrElse(algoName.toLowerCase, throw new UnknownParameterException (" unknown executeAlgo name." )) match {
98+ case AlgorithmType .Bfs =>
99+ val bfsConfig = BfsConfig .getBfsConfig(configs)
100+ BfsAlgo (spark, dataSet, bfsConfig)
101+ case AlgorithmType .Closeness =>
102+ ClosenessAlgo (spark, dataSet, hasWeight)
103+ case AlgorithmType .ClusteringCoefficient =>
104+ val coefficientConfig = CoefficientConfig .getCoefficientConfig(configs)
105+ ClusteringCoefficientAlgo (spark, dataSet, coefficientConfig)
106+ case AlgorithmType .ConnectedComponents =>
107+ val ccConfig = CcConfig .getCcConfig(configs)
108+ ConnectedComponentsAlgo (spark, dataSet, ccConfig, hasWeight)
109+ case AlgorithmType .DegreeStatic =>
110+ val dsConfig = DegreeStaticConfig .getDegreeStaticConfig(configs)
111+ DegreeStaticAlgo (spark, dataSet, dsConfig)
112+ case AlgorithmType .Dfs =>
113+ val dfsConfig = DfsConfig .getDfsConfig(configs)
114+ DfsAlgo (spark, dataSet, dfsConfig)
115+ case AlgorithmType .GraphTriangleCount =>
116+ GraphTriangleCountAlgo (spark, dataSet)
117+ case AlgorithmType .Hanp =>
118+ val hanpConfig = HanpConfig .getHanpConfig(configs)
119+ HanpAlgo (spark, dataSet, hanpConfig, hasWeight)
120+ case AlgorithmType .Jaccard =>
121+ val jaccardConfig = JaccardConfig .getJaccardConfig(configs)
122+ JaccardAlgo (spark, dataSet, jaccardConfig)
123+ case AlgorithmType .KCore =>
124+ val kCoreConfig = KCoreConfig .getKCoreConfig(configs)
125+ KCoreAlgo (spark, dataSet, kCoreConfig)
126+ case AlgorithmType .LabelPropagation =>
127+ val lpaConfig = LPAConfig .getLPAConfig(configs)
128+ LabelPropagationAlgo (spark, dataSet, lpaConfig, hasWeight)
129+ case AlgorithmType .Louvain =>
130+ val louvainConfig = LouvainConfig .getLouvainConfig(configs)
131+ LouvainAlgo (spark, dataSet, louvainConfig, hasWeight)
132+ case AlgorithmType .Node2vec =>
133+ val node2vecConfig = Node2vecConfig .getNode2vecConfig(configs)
134+ Node2vecAlgo (spark, dataSet, node2vecConfig, hasWeight)
135+ case AlgorithmType .PageRank =>
136+ val pageRankConfig = PRConfig .getPRConfig(configs)
137+ PageRankAlgo (spark, dataSet, pageRankConfig, hasWeight)
138+ case AlgorithmType .ShortestPath =>
139+ val spConfig = ShortestPathConfig .getShortestPathConfig(configs)
140+ ShortestPathAlgo (spark, dataSet, spConfig, hasWeight)
141+ case AlgorithmType .StronglyConnectedComponents =>
142+ val ccConfig = CcConfig .getCcConfig(configs)
143+ StronglyConnectedComponentsAlgo (spark, dataSet, ccConfig, hasWeight)
144+ case AlgorithmType .TriangleCount =>
145+ TriangleCountAlgo (spark, dataSet)
146+ case AlgorithmType .BetweennessCentrality =>
147+ val betweennessConfig = BetweennessConfig .getBetweennessConfig(configs)
148+ BetweennessCentralityAlgo (spark, dataSet, betweennessConfig, hasWeight)
225149 }
226- algoResult
227150 }
228151
229152 private [this ] def saveAlgoResult (algoResult : DataFrame , configs : Configs ): Unit = {
230- val dataSink = configs.dataSourceSinkEntry.sink
231- dataSink.toLowerCase match {
232- case " nebula" => {
233- val writer = new NebulaWriter (algoResult, configs)
234- writer.write()
235- }
236- case " csv" => {
237- val writer = new CsvWriter (algoResult, configs)
238- writer.write()
239- }
240- case " text" => {
241- val writer = new TextWriter (algoResult, configs)
242- writer.write()
243- }
244- case _ => throw new UnsupportedOperationException (" unsupported data sink" )
245- }
153+ val writer = AlgoWriter .make(configs)
154+ writer.write(algoResult, configs)
246155 }
247156}
0 commit comments