Skip to content

Commit 397a8a3

Browse files
authored
encode root for path algo (#71)
1 parent 2b364a2 commit 397a8a3

File tree

6 files changed

+74
-29
lines changed

6 files changed

+74
-29
lines changed

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/AlgoConfig.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -200,16 +200,16 @@ object CoefficientConfig {
200200
/**
201201
* bfs
202202
*/
203-
case class BfsConfig(maxIter: Int, root: Long, encodeId: Boolean = false)
203+
case class BfsConfig(maxIter: Int, root: String, encodeId: Boolean = false)
204204
object BfsConfig {
205205
var maxIter: Int = _
206-
var root: Long = _
206+
var root: String = _
207207
var encodeId: Boolean = false
208208

209209
def getBfsConfig(configs: Configs): BfsConfig = {
210210
val bfsConfig = configs.algorithmConfig.map
211211
maxIter = bfsConfig("algorithm.bfs.maxIter").toInt
212-
root = bfsConfig("algorithm.bfs.root").toLong
212+
root = bfsConfig("algorithm.bfs.root").toString
213213
encodeId = ConfigUtil.getOrElseBoolean(bfsConfig, "algorithm.bfs.encodeId", false)
214214
BfsConfig(maxIter, root, encodeId)
215215
}
@@ -218,16 +218,16 @@ object BfsConfig {
218218
/**
219219
* dfs
220220
*/
221-
case class DfsConfig(maxIter: Int, root: Long, encodeId: Boolean = false)
221+
case class DfsConfig(maxIter: Int, root: String, encodeId: Boolean = false)
222222
object DfsConfig {
223223
var maxIter: Int = _
224-
var root: Long = _
224+
var root: String = _
225225
var encodeId: Boolean = false
226226

227227
def getDfsConfig(configs: Configs): DfsConfig = {
228228
val dfsConfig = configs.algorithmConfig.map
229229
maxIter = dfsConfig("algorithm.dfs.maxIter").toInt
230-
root = dfsConfig("algorithm.dfs.root").toLong
230+
root = dfsConfig("algorithm.dfs.root").toString
231231
encodeId = ConfigUtil.getOrElseBoolean(dfsConfig, "algorithm.dfs.encodeId", false)
232232
DfsConfig(maxIter, root, encodeId)
233233
}

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/config/Configs.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ object AlgoConstants {
386386
val HANP_RESULT_COL: String = "hanp"
387387
val NODE2VEC_RESULT_COL: String = "node2vec"
388388
val BFS_RESULT_COL: String = "bfs"
389+
val DFS_RESULT_COL: String = "dfs"
389390
val ENCODE_ID_COL: String = "encodedId"
390391
val ORIGIN_ID_COL: String = "id"
391392
}

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/BfsAlgo.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,18 @@ object BfsAlgo {
2626
*/
2727
def apply(spark: SparkSession, dataset: Dataset[Row], bfsConfig: BfsConfig): DataFrame = {
2828
var encodeIdDf: DataFrame = null
29+
var finalRoot: Long = 0
2930

3031
val graph: Graph[None.type, Double] = if (bfsConfig.encodeId) {
3132
val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false)
3233
encodeIdDf = encodeId
34+
finalRoot = encodeIdDf.filter(row => row.get(0).toString == bfsConfig.root).first().getLong(1)
3335
NebulaUtil.loadInitGraph(data, false)
3436
} else {
37+
finalRoot = bfsConfig.root.toLong
3538
NebulaUtil.loadInitGraph(dataset, false)
3639
}
37-
val bfsGraph = execute(graph, bfsConfig.maxIter, bfsConfig.root)
40+
val bfsGraph = execute(graph, bfsConfig.maxIter, finalRoot)
3841

3942
// filter out the not traversal vertices
4043
val visitedVertices = bfsGraph.vertices.filter(v => v._2 != Double.PositiveInfinity)

nebula-algorithm/src/main/scala/com/vesoft/nebula/algorithm/lib/DfsAlgo.scala

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,16 @@
55

66
package com.vesoft.nebula.algorithm.lib
77

8+
import com.vesoft.nebula.algorithm.config.AlgoConstants.{
9+
ALGO_ID_COL,
10+
DFS_RESULT_COL,
11+
ENCODE_ID_COL,
12+
ORIGIN_ID_COL
13+
}
814
import com.vesoft.nebula.algorithm.config.{AlgoConstants, BfsConfig, DfsConfig}
915
import com.vesoft.nebula.algorithm.utils.{DecodeUtil, NebulaUtil}
10-
import org.apache.spark.graphx.{EdgeDirection, Graph, VertexId}
16+
import org.apache.spark.graphx.{EdgeDirection, EdgeTriplet, Graph, Pregel, VertexId}
17+
import org.apache.spark.sql.functions.col
1118
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
1219
import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructField, StructType}
1320

@@ -18,21 +25,28 @@ object DfsAlgo {
1825

1926
def apply(spark: SparkSession, dataset: Dataset[Row], dfsConfig: DfsConfig): DataFrame = {
2027
var encodeIdDf: DataFrame = null
28+
var finalRoot: Long = 0
2129

2230
val graph: Graph[None.type, Double] = if (dfsConfig.encodeId) {
2331
val (data, encodeId) = DecodeUtil.convertStringId2LongId(dataset, false)
2432
encodeIdDf = encodeId
33+
finalRoot = encodeIdDf.filter(row => row.get(0).toString == dfsConfig.root).first().getLong(1)
2534
NebulaUtil.loadInitGraph(data, false)
2635
} else {
36+
finalRoot = dfsConfig.root.toLong
2737
NebulaUtil.loadInitGraph(dataset, false)
2838
}
29-
val bfsVertices = dfs(graph, dfsConfig.root, mutable.Seq.empty[VertexId])(dfsConfig.maxIter)
39+
val bfsVertices =
40+
dfs(graph, finalRoot, mutable.Seq.empty[VertexId])(dfsConfig.maxIter).vertices.filter(v =>
41+
v._2 != Double.PositiveInfinity)
3042

31-
val schema = StructType(List(StructField("dfs", LongType, nullable = false)))
43+
val schema = StructType(
44+
List(StructField(ALGO_ID_COL, LongType, nullable = false),
45+
StructField(DFS_RESULT_COL, DoubleType, nullable = true)))
3246

33-
val rdd = spark.sparkContext.parallelize(bfsVertices.toSeq, 1).map(row => Row(row))
34-
val algoResult = spark.sqlContext
35-
.createDataFrame(rdd, schema)
47+
val resultRDD = bfsVertices.map(v => Row(v._1, v._2))
48+
val algoResult =
49+
spark.sqlContext.createDataFrame(resultRDD, schema).orderBy(col(DFS_RESULT_COL))
3650

3751
if (dfsConfig.encodeId) {
3852
DecodeUtil.convertAlgoId2StringId(algoResult, encodeIdDf).coalesce(1)
@@ -42,18 +56,35 @@ object DfsAlgo {
4256
}
4357

4458
def dfs(g: Graph[None.type, Double], vertexId: VertexId, visited: mutable.Seq[VertexId])(
45-
maxIter: Int): mutable.Seq[VertexId] = {
46-
if (visited.contains(vertexId)) {
47-
visited
48-
} else {
49-
if (iterNums > maxIter) {
50-
return visited
59+
maxIter: Int): Graph[Double, Double] = {
60+
61+
val initialGraph =
62+
g.mapVertices((id, _) => if (id == vertexId) 0.0 else Double.PositiveInfinity)
63+
64+
def vertexProgram(id: VertexId, attr: Double, msg: Double): Double = {
65+
math.min(attr, msg)
66+
}
67+
68+
def sendMessage(edge: EdgeTriplet[Double, Double]): Iterator[(VertexId, Double)] = {
69+
val sourceVertex = edge.srcAttr
70+
val targetVertex = edge.dstAttr
71+
if (sourceVertex + 1 < targetVertex && sourceVertex < maxIter) {
72+
Iterator((edge.dstId, sourceVertex + 1))
73+
} else {
74+
Iterator.empty
5175
}
52-
val newVisited = visited :+ vertexId
53-
val neighbors = g.collectNeighbors(EdgeDirection.Out).lookup(vertexId).flatten
54-
iterNums = iterNums + 1
55-
neighbors.foldLeft(newVisited)((visited, neighbor) => dfs(g, neighbor._1, visited)(maxIter))
5676
}
77+
78+
def mergeMessage(a: Double, b: Double): Double = {
79+
math.min(a, b)
80+
}
81+
82+
//开始迭代
83+
val resultGraph =
84+
Pregel(initialGraph, Double.PositiveInfinity)(vertexProgram, sendMessage, mergeMessage)
85+
86+
resultGraph
87+
5788
}
5889

5990
}

nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/BfsAlgoSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class BfsAlgoSuite {
1414
def bfsAlgoSuite(): Unit = {
1515
val spark = SparkSession.builder().master("local").getOrCreate()
1616
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
17-
val bfsAlgoConfig = new BfsConfig(5, 1)
17+
val bfsAlgoConfig = new BfsConfig(5, "1")
1818
val result = BfsAlgo.apply(spark, data, bfsAlgoConfig)
1919
result.show()
2020
assert(result.count() == 4)

nebula-algorithm/src/test/scala/com/vesoft/nebula/algorithm/lib/DfsAlgoSuite.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,21 @@ import org.junit.Test
1313
class DfsAlgoSuite {
1414
@Test
1515
def bfsAlgoSuite(): Unit = {
16-
val spark = SparkSession.builder().master("local").getOrCreate()
16+
val spark = SparkSession
17+
.builder()
18+
.master("local")
19+
.config("spark.sql.shuffle.partitions", 5)
20+
.getOrCreate()
1721
val data = spark.read.option("header", true).csv("src/test/resources/edge.csv")
18-
val dfsAlgoConfig = new DfsConfig(5, 3)
19-
val result = DfsAlgo.apply(spark, data, dfsAlgoConfig)
20-
result.show()
21-
assert(result.count() == 4)
22+
val dfsAlgoConfig = new DfsConfig(5, "3")
23+
// val result = DfsAlgo.apply(spark, data, dfsAlgoConfig)
24+
// result.show()
25+
// assert(result.count() == 4)
26+
27+
val encodeDfsConfig = new DfsConfig(5, "3", true)
28+
val encodeResult = DfsAlgo.apply(spark, data, encodeDfsConfig)
29+
30+
encodeResult.show()
31+
assert(encodeResult.count() == 4)
2232
}
2333
}

0 commit comments

Comments
 (0)