Skip to content

Commit 579f624

Browse files
optimize and more like Scala (#89)
* optimize * add license * fix typo --------- Co-authored-by: Anqi <anqi.wang@vesoft.com>
1 parent 3fc5955 commit 579f624

File tree

9 files changed

+259
-180
lines changed

9 files changed

+259
-180
lines changed

example/src/main/scala/com/vesoft/nebula/algorithm/AlgoPerformanceTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ object AlgoPerformanceTest {
4444
.builder()
4545
.withMetaAddress("127.0.0.0.1:9559")
4646
.withTimeout(6000)
47-
.withConenctionRetry(2)
47+
.withConnectionRetry(2)
4848
.build()
4949
val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig
5050
.builder()

example/src/main/scala/com/vesoft/nebula/algorithm/DeepQueryTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ object DeepQueryTest {
3939
.builder()
4040
.withMetaAddress("192.168.15.5:9559")
4141
.withTimeout(6000)
42-
.withConenctionRetry(2)
42+
.withConnectionRetry(2)
4343
.build()
4444
val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig
4545
.builder()

example/src/main/scala/com/vesoft/nebula/algorithm/ReadData.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ object ReadData {
6767
.builder()
6868
.withMetaAddress("127.0.0.1:9559")
6969
.withTimeout(6000)
70-
.withConenctionRetry(2)
70+
.withConnectionRetry(2)
7171
.build()
7272
val nebulaReadEdgeConfig: ReadNebulaConfig = ReadNebulaConfig
7373
.builder()

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala

Lines changed: 60 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -6,47 +6,10 @@
66
package com.vesoft.nebula.algorithm
77

88
import com.vesoft.nebula.algorithm.config.Configs.Argument
9-
import com.vesoft.nebula.algorithm.config.{
10-
AlgoConfig,
11-
BetweennessConfig,
12-
BfsConfig,
13-
CcConfig,
14-
CoefficientConfig,
15-
Configs,
16-
DfsConfig,
17-
HanpConfig,
18-
JaccardConfig,
19-
KCoreConfig,
20-
LPAConfig,
21-
LouvainConfig,
22-
Node2vecConfig,
23-
PRConfig,
24-
ShortestPathConfig,
25-
SparkConfig,
26-
DegreeStaticConfig
27-
}
28-
import com.vesoft.nebula.algorithm.lib.{
29-
BetweennessCentralityAlgo,
30-
BfsAlgo,
31-
ClosenessAlgo,
32-
ClusteringCoefficientAlgo,
33-
ConnectedComponentsAlgo,
34-
DegreeStaticAlgo,
35-
DfsAlgo,
36-
GraphTriangleCountAlgo,
37-
HanpAlgo,
38-
JaccardAlgo,
39-
KCoreAlgo,
40-
LabelPropagationAlgo,
41-
LouvainAlgo,
42-
Node2vecAlgo,
43-
PageRankAlgo,
44-
ShortestPathAlgo,
45-
StronglyConnectedComponentsAlgo,
46-
TriangleCountAlgo
47-
}
48-
import com.vesoft.nebula.algorithm.reader.{CsvReader, JsonReader, NebulaReader}
49-
import com.vesoft.nebula.algorithm.writer.{CsvWriter, NebulaWriter, TextWriter}
9+
import com.vesoft.nebula.algorithm.config._
10+
import com.vesoft.nebula.algorithm.lib._
11+
import com.vesoft.nebula.algorithm.reader.{CsvReader, DataReader, JsonReader, NebulaReader}
12+
import com.vesoft.nebula.algorithm.writer.{AlgoWriter, CsvWriter, NebulaWriter, TextWriter}
5013
import org.apache.commons.math3.ode.UnknownParameterException
5114
import org.apache.log4j.Logger
5215
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
@@ -114,26 +77,8 @@ object Main {
11477
private[this] def createDataSource(spark: SparkSession,
11578
configs: Configs,
11679
partitionNum: String): DataFrame = {
117-
val dataSource = configs.dataSourceSinkEntry.source
118-
val dataSet: Dataset[Row] = dataSource.toLowerCase match {
119-
case "nebula" => {
120-
val reader = new NebulaReader(spark, configs, partitionNum)
121-
reader.read()
122-
}
123-
case "nebula-ngql" => {
124-
val reader = new NebulaReader(spark, configs, partitionNum)
125-
reader.readNgql()
126-
}
127-
case "csv" => {
128-
val reader = new CsvReader(spark, configs, partitionNum)
129-
reader.read()
130-
}
131-
case "json" => {
132-
val reader = new JsonReader(spark, configs, partitionNum)
133-
reader.read()
134-
}
135-
}
136-
dataSet
80+
val dataSource = DataReader.make(configs)
81+
dataSource.read(spark, configs, partitionNum)
13782
}
13883

13984
/**
@@ -149,99 +94,63 @@ object Main {
14994
configs: Configs,
15095
dataSet: DataFrame): DataFrame = {
15196
val hasWeight = configs.dataSourceSinkEntry.hasWeight
152-
val algoResult = {
153-
algoName.toLowerCase match {
154-
case "pagerank" => {
155-
val pageRankConfig = PRConfig.getPRConfig(configs)
156-
PageRankAlgo(spark, dataSet, pageRankConfig, hasWeight)
157-
}
158-
case "louvain" => {
159-
val louvainConfig = LouvainConfig.getLouvainConfig(configs)
160-
LouvainAlgo(spark, dataSet, louvainConfig, hasWeight)
161-
}
162-
case "connectedcomponent" => {
163-
val ccConfig = CcConfig.getCcConfig(configs)
164-
ConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight)
165-
}
166-
case "labelpropagation" => {
167-
val lpaConfig = LPAConfig.getLPAConfig(configs)
168-
LabelPropagationAlgo(spark, dataSet, lpaConfig, hasWeight)
169-
}
170-
case "shortestpaths" => {
171-
val spConfig = ShortestPathConfig.getShortestPathConfig(configs)
172-
ShortestPathAlgo(spark, dataSet, spConfig, hasWeight)
173-
}
174-
case "degreestatic" => {
175-
val dsConfig = DegreeStaticConfig.getDegreeStaticConfig(configs)
176-
DegreeStaticAlgo(spark, dataSet, dsConfig)
177-
}
178-
case "kcore" => {
179-
val kCoreConfig = KCoreConfig.getKCoreConfig(configs)
180-
KCoreAlgo(spark, dataSet, kCoreConfig)
181-
}
182-
case "stronglyconnectedcomponent" => {
183-
val ccConfig = CcConfig.getCcConfig(configs)
184-
StronglyConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight)
185-
}
186-
case "betweenness" => {
187-
val betweennessConfig = BetweennessConfig.getBetweennessConfig(configs)
188-
BetweennessCentralityAlgo(spark, dataSet, betweennessConfig, hasWeight)
189-
}
190-
case "trianglecount" => {
191-
TriangleCountAlgo(spark, dataSet)
192-
}
193-
case "graphtrianglecount" => {
194-
GraphTriangleCountAlgo(spark, dataSet)
195-
}
196-
case "clusteringcoefficient" => {
197-
val coefficientConfig = CoefficientConfig.getCoefficientConfig(configs)
198-
ClusteringCoefficientAlgo(spark, dataSet, coefficientConfig)
199-
}
200-
case "closeness" => {
201-
ClosenessAlgo(spark, dataSet, hasWeight)
202-
}
203-
case "hanp" => {
204-
val hanpConfig = HanpConfig.getHanpConfig(configs)
205-
HanpAlgo(spark, dataSet, hanpConfig, hasWeight)
206-
}
207-
case "node2vec" => {
208-
val node2vecConfig = Node2vecConfig.getNode2vecConfig(configs)
209-
Node2vecAlgo(spark, dataSet, node2vecConfig, hasWeight)
210-
}
211-
case "bfs" => {
212-
val bfsConfig = BfsConfig.getBfsConfig(configs)
213-
BfsAlgo(spark, dataSet, bfsConfig)
214-
}
215-
case "dfs" => {
216-
val dfsConfig = DfsConfig.getDfsConfig(configs)
217-
DfsAlgo(spark, dataSet, dfsConfig)
218-
}
219-
case "jaccard" => {
220-
val jaccardConfig = JaccardConfig.getJaccardConfig(configs)
221-
JaccardAlgo(spark, dataSet, jaccardConfig)
222-
}
223-
case _ => throw new UnknownParameterException("unknown executeAlgo name.")
224-
}
97+
AlgorithmType.mapping.getOrElse(algoName.toLowerCase, throw new UnknownParameterException("unknown executeAlgo name.")) match {
98+
case AlgorithmType.Bfs =>
99+
val bfsConfig = BfsConfig.getBfsConfig(configs)
100+
BfsAlgo(spark, dataSet, bfsConfig)
101+
case AlgorithmType.Closeness =>
102+
ClosenessAlgo(spark, dataSet, hasWeight)
103+
case AlgorithmType.ClusteringCoefficient =>
104+
val coefficientConfig = CoefficientConfig.getCoefficientConfig(configs)
105+
ClusteringCoefficientAlgo(spark, dataSet, coefficientConfig)
106+
case AlgorithmType.ConnectedComponents =>
107+
val ccConfig = CcConfig.getCcConfig(configs)
108+
ConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight)
109+
case AlgorithmType.DegreeStatic =>
110+
val dsConfig = DegreeStaticConfig.getDegreeStaticConfig(configs)
111+
DegreeStaticAlgo(spark, dataSet, dsConfig)
112+
case AlgorithmType.Dfs =>
113+
val dfsConfig = DfsConfig.getDfsConfig(configs)
114+
DfsAlgo(spark, dataSet, dfsConfig)
115+
case AlgorithmType.GraphTriangleCount =>
116+
GraphTriangleCountAlgo(spark, dataSet)
117+
case AlgorithmType.Hanp =>
118+
val hanpConfig = HanpConfig.getHanpConfig(configs)
119+
HanpAlgo(spark, dataSet, hanpConfig, hasWeight)
120+
case AlgorithmType.Jaccard =>
121+
val jaccardConfig = JaccardConfig.getJaccardConfig(configs)
122+
JaccardAlgo(spark, dataSet, jaccardConfig)
123+
case AlgorithmType.KCore =>
124+
val kCoreConfig = KCoreConfig.getKCoreConfig(configs)
125+
KCoreAlgo(spark, dataSet, kCoreConfig)
126+
case AlgorithmType.LabelPropagation =>
127+
val lpaConfig = LPAConfig.getLPAConfig(configs)
128+
LabelPropagationAlgo(spark, dataSet, lpaConfig, hasWeight)
129+
case AlgorithmType.Louvain =>
130+
val louvainConfig = LouvainConfig.getLouvainConfig(configs)
131+
LouvainAlgo(spark, dataSet, louvainConfig, hasWeight)
132+
case AlgorithmType.Node2vec =>
133+
val node2vecConfig = Node2vecConfig.getNode2vecConfig(configs)
134+
Node2vecAlgo(spark, dataSet, node2vecConfig, hasWeight)
135+
case AlgorithmType.PageRank =>
136+
val pageRankConfig = PRConfig.getPRConfig(configs)
137+
PageRankAlgo(spark, dataSet, pageRankConfig, hasWeight)
138+
case AlgorithmType.ShortestPath =>
139+
val spConfig = ShortestPathConfig.getShortestPathConfig(configs)
140+
ShortestPathAlgo(spark, dataSet, spConfig, hasWeight)
141+
case AlgorithmType.StronglyConnectedComponents =>
142+
val ccConfig = CcConfig.getCcConfig(configs)
143+
StronglyConnectedComponentsAlgo(spark, dataSet, ccConfig, hasWeight)
144+
case AlgorithmType.TriangleCount =>
145+
TriangleCountAlgo(spark, dataSet)
146+
case AlgorithmType.BetweennessCentrality =>
147+
val betweennessConfig = BetweennessConfig.getBetweennessConfig(configs)
148+
BetweennessCentralityAlgo(spark, dataSet, betweennessConfig, hasWeight)
225149
}
226-
algoResult
227150
}
228151

229152
private[this] def saveAlgoResult(algoResult: DataFrame, configs: Configs): Unit = {
230-
val dataSink = configs.dataSourceSinkEntry.sink
231-
dataSink.toLowerCase match {
232-
case "nebula" => {
233-
val writer = new NebulaWriter(algoResult, configs)
234-
writer.write()
235-
}
236-
case "csv" => {
237-
val writer = new CsvWriter(algoResult, configs)
238-
writer.write()
239-
}
240-
case "text" => {
241-
val writer = new TextWriter(algoResult, configs)
242-
writer.write()
243-
}
244-
case _ => throw new UnsupportedOperationException("unsupported data sink")
245-
}
153+
val writer = AlgoWriter.make(configs)
154+
writer.write(algoResult, configs)
246155
}
247156
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
/* Copyright (c) 2021 vesoft inc. All rights reserved.
2+
*
3+
* This source code is licensed under Apache 2.0 License.
4+
*/
5+
6+
package com.vesoft.nebula.algorithm.lib
7+
8+
/**
9+
*
10+
* @author 梦境迷离
11+
* @version 1.0,2023/9/12
12+
*/
13+
sealed trait AlgorithmType {
14+
self =>
15+
def stringify: String = self match {
16+
case AlgorithmType.Bfs => "bfs"
17+
case AlgorithmType.Closeness => "closeness"
18+
case AlgorithmType.ClusteringCoefficient => "clusteringcoefficient"
19+
case AlgorithmType.ConnectedComponents => "connectedcomponent"
20+
case AlgorithmType.DegreeStatic => "degreestatic"
21+
case AlgorithmType.Dfs => "dfs"
22+
case AlgorithmType.GraphTriangleCount => "graphtrianglecount"
23+
case AlgorithmType.Hanp => "hanp"
24+
case AlgorithmType.Jaccard => "jaccard"
25+
case AlgorithmType.KCore => "kcore"
26+
case AlgorithmType.LabelPropagation => "labelpropagation"
27+
case AlgorithmType.Louvain => "louvain"
28+
case AlgorithmType.Node2vec => "node2vec"
29+
case AlgorithmType.PageRank => "pagerank"
30+
case AlgorithmType.ShortestPath => "shortestpaths"
31+
case AlgorithmType.StronglyConnectedComponents => "stronglyconnectedcomponent"
32+
case AlgorithmType.TriangleCount => "trianglecount"
33+
case AlgorithmType.BetweennessCentrality => "betweenness"
34+
}
35+
}
36+
object AlgorithmType {
37+
lazy val mapping: Map[String, AlgorithmType] = Map(
38+
Bfs.stringify -> Bfs,
39+
Closeness.stringify -> Closeness,
40+
ClusteringCoefficient.stringify -> ClusteringCoefficient,
41+
ConnectedComponents.stringify -> ConnectedComponents,
42+
DegreeStatic.stringify -> DegreeStatic,
43+
GraphTriangleCount.stringify -> GraphTriangleCount,
44+
Hanp.stringify -> Hanp,
45+
Jaccard.stringify -> Jaccard,
46+
KCore.stringify -> KCore,
47+
LabelPropagation.stringify -> LabelPropagation,
48+
Louvain.stringify -> Louvain,
49+
Node2vec.stringify -> Node2vec,
50+
PageRank.stringify -> PageRank,
51+
ShortestPath.stringify -> ShortestPath,
52+
StronglyConnectedComponents.stringify -> StronglyConnectedComponents,
53+
TriangleCount.stringify -> TriangleCount,
54+
BetweennessCentrality.stringify -> BetweennessCentrality
55+
)
56+
object BetweennessCentrality extends AlgorithmType
57+
object Bfs extends AlgorithmType
58+
object Closeness extends AlgorithmType
59+
object ClusteringCoefficient extends AlgorithmType
60+
object ConnectedComponents extends AlgorithmType
61+
object DegreeStatic extends AlgorithmType
62+
object Dfs extends AlgorithmType
63+
object GraphTriangleCount extends AlgorithmType
64+
object Hanp extends AlgorithmType
65+
object Jaccard extends AlgorithmType
66+
object KCore extends AlgorithmType
67+
object LabelPropagation extends AlgorithmType
68+
object Louvain extends AlgorithmType
69+
object Node2vec extends AlgorithmType
70+
object PageRank extends AlgorithmType
71+
object ShortestPath extends AlgorithmType
72+
object StronglyConnectedComponents extends AlgorithmType
73+
object TriangleCount extends AlgorithmType
74+
}

0 commit comments

Comments
 (0)