2828import org .neo4j .graphalgo .core .utils .TerminationFlag ;
2929import org .neo4j .graphalgo .core .utils .paged .AllocationTracker ;
3030import org .neo4j .graphalgo .core .write .Exporter ;
31- import org .neo4j .graphalgo .impl .PageRankResult ;
31+ import org .neo4j .graphalgo .impl .pagerank . PageRankResult ;
3232import org .neo4j .graphalgo .impl .Algorithm ;
33- import org .neo4j .graphalgo .impl .PageRankAlgorithm ;
33+ import org .neo4j .graphalgo .impl .pagerank . PageRankAlgorithm ;
3434import org .neo4j .graphalgo .results .PageRankScore ;
3535import org .neo4j .graphdb .Direction ;
3636import org .neo4j .graphdb .Node ;
@@ -58,6 +58,8 @@ public final class PageRankProc {
5858 public static final Integer DEFAULT_ITERATIONS = 20 ;
5959 public static final String DEFAULT_SCORE_PROPERTY = "pagerank" ;
6060
61+ public static final String CONFIG_WEIGHT_KEY = "weightProperty" ;
62+
6163 @ Context
6264 public GraphDatabaseAPI api ;
6365
@@ -69,7 +71,7 @@ public final class PageRankProc {
6971
7072 @ Procedure (value = "algo.pageRank" , mode = Mode .WRITE )
7173 @ Description ("CALL algo.pageRank(label:String, relationship:String, " +
72- "{iterations:5, dampingFactor:0.85, write: true, writeProperty:'pagerank', concurrency:4}) " +
74+ "{iterations:5, dampingFactor:0.85, weightProperty: null, write: true, writeProperty:'pagerank', concurrency:4}) " +
7375 "YIELD nodes, iterations, loadMillis, computeMillis, writeMillis, dampingFactor, write, writeProperty" +
7476 " - calculates page rank and potentially writes back" )
7577 public Stream <PageRankScore .Stats > pageRank (
@@ -79,17 +81,19 @@ public Stream<PageRankScore.Stats> pageRank(
7981
8082 ProcedureConfiguration configuration = ProcedureConfiguration .create (config );
8183
84+ final String weightPropertyKey = configuration .getString (CONFIG_WEIGHT_KEY , null );
85+
8286 PageRankScore .Stats .Builder statsBuilder = new PageRankScore .Stats .Builder ();
8387 AllocationTracker tracker = AllocationTracker .create ();
84- final Graph graph = load (label , relationship , tracker , configuration .getGraphImpl (), statsBuilder , configuration );
88+ final Graph graph = load (label , relationship , tracker , configuration .getGraphImpl (), statsBuilder , configuration , weightPropertyKey );
8589
8690 if (graph .nodeCount () == 0 ) {
8791 graph .release ();
8892 return Stream .of (statsBuilder .build ());
8993 }
9094
9195 TerminationFlag terminationFlag = TerminationFlag .wrap (transaction );
92- PageRankResult scores = evaluate (graph , tracker , terminationFlag , configuration , statsBuilder );
96+ PageRankResult scores = evaluate (graph , tracker , terminationFlag , configuration , statsBuilder , weightPropertyKey );
9397
9498 log .info ("PageRank: overall memory usage: %s" , tracker .getUsageString ());
9599
@@ -100,7 +104,7 @@ public Stream<PageRankScore.Stats> pageRank(
100104
101105 @ Procedure (value = "algo.pageRank.stream" , mode = Mode .READ )
102106 @ Description ("CALL algo.pageRank.stream(label:String, relationship:String, " +
103- "{iterations:20, dampingFactor:0.85, concurrency:4}) " +
107+ "{iterations:20, dampingFactor:0.85, weightProperty: null, concurrency:4}) " +
104108 "YIELD node, score - calculates page rank and streams results" )
105109 public Stream <PageRankScore > pageRankStream (
106110 @ Name (value = "label" , defaultValue = "" ) String label ,
@@ -109,17 +113,19 @@ public Stream<PageRankScore> pageRankStream(
109113
110114 ProcedureConfiguration configuration = ProcedureConfiguration .create (config );
111115
116+ final String weightPropertyKey = configuration .getString (CONFIG_WEIGHT_KEY , null );
117+
112118 PageRankScore .Stats .Builder statsBuilder = new PageRankScore .Stats .Builder ();
113119 AllocationTracker tracker = AllocationTracker .create ();
114- final Graph graph = load (label , relationship , tracker , configuration .getGraphImpl (), statsBuilder , configuration );
120+ final Graph graph = load (label , relationship , tracker , configuration .getGraphImpl (), statsBuilder , configuration , weightPropertyKey );
115121
116122 if (graph .nodeCount () == 0 ) {
117123 graph .release ();
118124 return Stream .empty ();
119125 }
120126
121127 TerminationFlag terminationFlag = TerminationFlag .wrap (transaction );
122- PageRankResult scores = evaluate (graph , tracker , terminationFlag , configuration , statsBuilder );
128+ PageRankResult scores = evaluate (graph , tracker , terminationFlag , configuration , statsBuilder , weightPropertyKey );
123129
124130 log .info ("PageRank: overall memory usage: %s" , tracker .getUsageString ());
125131
@@ -152,11 +158,13 @@ private Graph load(
152158 String relationship ,
153159 AllocationTracker tracker ,
154160 Class <? extends GraphFactory > graphFactory ,
155- PageRankScore .Stats .Builder statsBuilder , ProcedureConfiguration configuration ) {
161+ PageRankScore .Stats .Builder statsBuilder ,
162+ ProcedureConfiguration configuration ,
163+ String weightPropertyKey ) {
156164 GraphLoader graphLoader = new GraphLoader (api , Pools .DEFAULT )
157165 .init (log , label , relationship , configuration )
158166 .withAllocationTracker (tracker )
159- .withoutRelationshipWeights ( );
167+ .withOptionalRelationshipWeightsFromProperty ( weightPropertyKey , configuration . getWeightPropertyDefaultValue ( 0.0 ) );
160168
161169 Direction direction = configuration .getDirection (Direction .OUTGOING );
162170 if (direction == Direction .BOTH ) {
@@ -178,7 +186,8 @@ private PageRankResult evaluate(
178186 AllocationTracker tracker ,
179187 TerminationFlag terminationFlag ,
180188 ProcedureConfiguration configuration ,
181- PageRankScore .Stats .Builder statsBuilder ) {
189+ PageRankScore .Stats .Builder statsBuilder ,
190+ String weightPropertyKey ) {
182191
183192 double dampingFactor = configuration .get (CONFIG_DAMPING , DEFAULT_DAMPING );
184193 int iterations = configuration .getIterations (DEFAULT_ITERATIONS );
@@ -189,14 +198,29 @@ private PageRankResult evaluate(
189198
190199 List <Node > sourceNodes = configuration .get ("sourceNodes" , new ArrayList <>());
191200 LongStream sourceNodeIds = sourceNodes .stream ().mapToLong (Node ::getId );
192- PageRankAlgorithm prAlgo = PageRankAlgorithm .of (
193- tracker ,
194- graph ,
195- dampingFactor ,
196- sourceNodeIds ,
197- Pools .DEFAULT ,
198- concurrency ,
199- batchSize );
201+
202+ PageRankAlgorithm prAlgo ;
203+ if (weightPropertyKey != null ) {
204+ prAlgo = PageRankAlgorithm .weightedOf (
205+ tracker ,
206+ graph ,
207+ dampingFactor ,
208+ sourceNodeIds ,
209+ Pools .DEFAULT ,
210+ concurrency ,
211+ batchSize );
212+ } else {
213+ prAlgo = PageRankAlgorithm .of (
214+ tracker ,
215+ graph ,
216+ dampingFactor ,
217+ sourceNodeIds ,
218+ Pools .DEFAULT ,
219+ concurrency ,
220+ batchSize );
221+ }
222+
223+
200224 Algorithm <?> algo = prAlgo
201225 .algorithm ()
202226 .withLog (log )
0 commit comments