Skip to content

Commit e3ca660

Browse files
authored
add bfs algorithm (#24)
1 parent 553d76e commit e3ca660

File tree

6 files changed

+129
-1
lines changed

6 files changed

+129
-1
lines changed

nebula-algorithm/src/main/resources/application.conf

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@
141141
# ClosenessAlgo parameter
142142
closeness:{}
143143

144+
# BFS parameter
145+
bfs:{
146+
maxIter:5
147+
root:"10"
148+
}
149+
144150
# HanpAlgo parameter
145151
hanp:{
146152
hopAttenuation:0.1

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/Main.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import com.vesoft.nebula.algorithm.config.Configs.Argument
99
import com.vesoft.nebula.algorithm.config.{
1010
AlgoConfig,
1111
BetweennessConfig,
12+
BfsConfig,
1213
CcConfig,
1314
CoefficientConfig,
1415
Configs,
@@ -23,8 +24,9 @@ import com.vesoft.nebula.algorithm.config.{
2324
}
2425
import com.vesoft.nebula.algorithm.lib.{
2526
BetweennessCentralityAlgo,
26-
ClusteringCoefficientAlgo,
27+
BfsAlgo,
2728
ClosenessAlgo,
29+
ClusteringCoefficientAlgo,
2830
ConnectedComponentsAlgo,
2931
DegreeStaticAlgo,
3032
GraphTriangleCountAlgo,
@@ -185,6 +187,10 @@ object Main {
185187
val node2vecConfig = Node2vecConfig.getNode2vecConfig(configs)
186188
Node2vecAlgo(spark, dataSet, node2vecConfig, hasWeight)
187189
}
190+
case "bfs" => {
191+
val bfsConfig = BfsConfig.getBfsConfig(configs)
192+
BfsAlgo(spark, dataSet, bfsConfig)
193+
}
188194
case _ => throw new UnknownParameterException("unknown executeAlgo name.")
189195
}
190196
}

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,22 @@ object CoefficientConfig {
167167
}
168168
}
169169

170+
/**
171+
* bfs
172+
*/
173+
case class BfsConfig(maxIter: Int, root: Long)
174+
object BfsConfig {
175+
var maxIter: Int = _
176+
var root: Long = _
177+
178+
def getBfsConfig(configs: Configs): BfsConfig = {
179+
val bfsConfig = configs.algorithmConfig.map
180+
maxIter = bfsConfig("algorithm.bfs.maxIter").toInt
181+
root = bfsConfig("algorithm.bfs.root").toLong
182+
BfsConfig(maxIter, root)
183+
}
184+
}
185+
170186
/**
171187
* Hanp
172188
*/

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,4 +365,5 @@ object AlgoConstants {
365365
val CLOSENESS_RESULT_COL: String = "closeness"
366366
val HANP_RESULT_COL: String = "hanp"
367367
val NODE2VEC_RESULT_COL: String = "node2vec"
368+
val BFS_RESULT_COL: String = "bfs"
368369
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright (c) 2021 vesoft inc. All rights reserved.
2+
*
3+
* This source code is licensed under Apache 2.0 License.
4+
*/
5+
6+
package com.vesoft.nebula.algorithm.lib
7+
8+
import com.vesoft.nebula.algorithm.config.{AlgoConstants, BfsConfig}
9+
import com.vesoft.nebula.algorithm.utils.NebulaUtil
10+
import org.apache.log4j.Logger
11+
import org.apache.spark.graphx.{EdgeTriplet, Graph, VertexId}
12+
import org.apache.spark.sql.functions.col
13+
import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StructField, StructType}
14+
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
15+
16+
/**
17+
* Breadth-First Search for un-weight graph
18+
*/
19+
object BfsAlgo {
20+
private val LOGGER = Logger.getLogger(this.getClass)
21+
22+
val ALGORITHM: String = "BFS"
23+
24+
/**
25+
* run the louvain algorithm for nebula graph
26+
*/
27+
def apply(spark: SparkSession, dataset: Dataset[Row], bfsConfig: BfsConfig): DataFrame = {
28+
val graph: Graph[None.type, Double] = NebulaUtil.loadInitGraph(dataset, false)
29+
val bfsGraph = execute(graph, bfsConfig.maxIter, bfsConfig.root)
30+
31+
// filter out the not traversal vertices
32+
val visitedVertices = bfsGraph.vertices.filter(v => v._2 != Double.PositiveInfinity)
33+
34+
val schema = StructType(
35+
List(
36+
StructField(AlgoConstants.ALGO_ID_COL, LongType, nullable = false),
37+
StructField(AlgoConstants.BFS_RESULT_COL, DoubleType, nullable = true)
38+
))
39+
val resultRDD = visitedVertices.map(vertex => Row(vertex._1, vertex._2))
40+
val algoResult = spark.sqlContext
41+
.createDataFrame(resultRDD, schema)
42+
.orderBy(col(AlgoConstants.BFS_RESULT_COL))
43+
algoResult
44+
}
45+
46+
def execute(graph: Graph[None.type, Double], maxIter: Int, root: Long): Graph[Double, Double] = {
47+
val initialGraph = graph.mapVertices(
48+
(id, _) =>
49+
if (id == root) 0.0
50+
else Double.PositiveInfinity)
51+
52+
// vertex program
53+
val vprog = { (id: VertexId, attr: Double, msg: Double) =>
54+
math.min(attr, msg)
55+
}
56+
57+
val sendMsg = { (triplet: EdgeTriplet[Double, Double]) =>
58+
var iter: Iterator[(VertexId, Double)] = Iterator.empty
59+
val isSrcMarked = triplet.srcAttr != Double.PositiveInfinity
60+
val isDstMarked = triplet.dstAttr != Double.PositiveInfinity
61+
if (!(isSrcMarked && isDstMarked)) {
62+
if (isSrcMarked) {
63+
iter = Iterator((triplet.dstId, triplet.srcAttr + 1))
64+
} else {
65+
iter = Iterator((triplet.srcId, triplet.dstAttr + 1))
66+
}
67+
}
68+
iter
69+
}
70+
71+
val mergeMsg = { (a: Double, b: Double) =>
72+
math.min(a, b)
73+
}
74+
75+
initialGraph.pregel(Double.PositiveInfinity, maxIter)(vprog, sendMsg, mergeMsg);
76+
}
77+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/* Copyright (c) 2021 vesoft inc. All rights reserved.
2+
*
3+
* This source code is licensed under Apache 2.0 License.
4+
*/
5+
6+
package com.vesoft.nebula.algorithm.lib
7+
8+
import com.vesoft.nebula.algorithm.config.{BfsConfig, CcConfig}
9+
import org.apache.spark.sql.SparkSession
10+
import org.junit.Test
11+
12+
class BfsAlgoSuite {
13+
@Test
14+
def bfsAlgoSuite(): Unit = {
15+
val spark = SparkSession.builder().master("local").getOrCreate()
16+
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
17+
val bfsAlgoConfig = new BfsConfig(5, 1)
18+
val result = BfsAlgo.apply(spark, data, bfsAlgoConfig)
19+
result.show()
20+
assert(result.count() == 4)
21+
}
22+
}

0 commit comments

Comments
 (0)