@@ -343,7 +343,7 @@ function (l::GATConv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing,AbstractM
343343 Wx = l. dense_x (x)
344344 Wx = reshape (Wx, chout, heads, :) # chout × nheads × nnodes
345345
346- # a hand-writtent message passing
346+ # a hand-written message passing
347347 m = apply_edges ((xi, xj, e) -> message (l, xi, xj, e), g, Wx, Wx, e)
348348 α = softmax_edge_neighbors (g, m. logα)
349349 β = α .* m. Wxj
@@ -371,7 +371,7 @@ function message(l::GATConv, Wxi, Wxj, e)
371371 end
372372 aWW = sum (l. a .* Wxx, dims= 1 ) # 1 × nheads × nedges
373373 logα = leakyrelu .(aWW, l. negative_slope)
374- return (logα = logα, Wxj = Wxj)
374+ return (; logα, Wxj)
375375end
376376
377377function Base. show (io:: IO , l:: GATConv )
@@ -480,11 +480,13 @@ function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, Abstra
480480 _, out = l. channel
481481 heads = l. heads
482482
483- Wix = reshape (l. dense_i (x), out, heads, :) # out × heads × nnodes
484- Wjx = reshape (l. dense_j (x), out, heads, :) # out × heads × nnodes
483+ Wxi = reshape (l. dense_i (x), out, heads, :) # out × heads × nnodes
484+ Wxj = reshape (l. dense_j (x), out, heads, :) # out × heads × nnodes
485485
486- m = propagate (message, g, + , l; xi= Wix, xj= Wjx, e) # out × heads × nnodes
487- x = m. β ./ m. α
486+ m = apply_edges ((xi, xj, e) -> message (l, xi, xj, e), g, Wxi, Wxj, e)
487+ α = softmax_edge_neighbors (g, m. logα)
488+ β = α .* m. Wxj
489+ x = aggregate_neighbors (g, + , β)
488490
489491 if ! l. concat
490492 x = mean (x, dims= 2 )
@@ -494,17 +496,16 @@ function (l::GATv2Conv)(g::GNNGraph, x::AbstractMatrix, e::Union{Nothing, Abstra
494496 return x
495497end
496498
497- function message (l:: GATv2Conv , Wix, Wjx , e)
499+ function message (l:: GATv2Conv , Wxi, Wxj , e)
498500 _, out = l. channel
499501 heads = l. heads
500502
501- Wx = Wix + Wjx # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?"
503+ Wx = Wxi + Wxj # Note: this is equivalent to W * vcat(x_i, x_j) as in "How Attentive are Graph Attention Networks?"
502504 if e != = nothing
503505 Wx += reshape (l. dense_e (e), out, heads, :)
504506 end
505- eij = sum (l. a .* leakyrelu .(Wx, l. negative_slope), dims= 1 ) # 1 × heads × nedges
506- α = exp .(eij)
507- return (α = α, β = α .* Wjx)
507+ logα = sum (l. a .* leakyrelu .(Wx, l. negative_slope), dims= 1 ) # 1 × heads × nedges
508+ return (; logα, Wxj)
508509end
509510
510511function Base. show (io:: IO , l:: GATv2Conv )
0 commit comments