From 664e682d2daf2eaf2143357ec937ca5845592ef4 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Wed, 21 Feb 2024 18:00:34 +0530 Subject: [PATCH 01/12] SAGEConv Hetero Layer --- src/layers/conv.jl | 5 +++-- test/layers/heteroconv.jl | 8 ++++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 006b03091..8517ebca9 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -801,9 +801,10 @@ function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean, SAGEConv(W, b, σ, aggr) end -function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix) +function (l::SAGEConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) - m = propagate(copy_xj, g, l.aggr, xj = x) + xj, _ = expand_srcdst(g, x) + m = propagate(copy_xj, g, l.aggr, xj = xj) x = l.σ.(l.weight * vcat(x, m) .+ l.bias) return x end diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index f1a07b7a7..e86e8ec6f 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -109,4 +109,12 @@ @test size(y.A) == (2,2) && size(y.B) == (2,3) end + @testset "SAGEConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu), + (:B, :to, :A) => SAGEConv(4 => 2, relu)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end + end From 7daefc637c2370a6df233eeba65bf3d8e38af3c2 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 22 Feb 2024 00:19:22 +0530 Subject: [PATCH 02/12] without tests --- test/layers/heteroconv.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index e86e8ec6f..d94a4c363 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -108,13 +108,4 @@ y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end - - @testset "SAGEConv" begin - x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu), - (:B, :to, :A) => SAGEConv(4 => 2, relu)); - y = layers(hg, x); - @test size(y.A) == (2, 2) && size(y.B) == (2, 3) - end - end From 7fa51cf5356a4219a36e3fde1e9adf5a7c686131 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 22 Feb 2024 18:16:29 +0530 Subject: [PATCH 03/12] test doesnt work yet --- test/layers/heteroconv.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index d94a4c363..f4ae23968 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -108,4 +108,12 @@ y = layers(hg, x); @test size(y.A) == (2,2) && size(y.B) == (2,3) end + + @testset "SAGEConv" begin + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + layers = HeteroGraphConv((:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +), + (:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +)); + y = layers(hg, x); + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + end end From e6f826d4e5bbec3a56839b800cfbae94edc97ca6 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Feb 2024 18:19:59 +0530 Subject: [PATCH 04/12] tests --- test/layers/heteroconv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index f4ae23968..28c7f930d 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -111,8 +111,8 @@ @testset "SAGEConv" begin x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => EdgeConv(Dense(2 * 4, 2), aggr = +), - (:B, :to, :A) => EdgeConv(Dense(2 * 4, 2), aggr = +)); + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(Dense(2 * 4, 2), relu, aggr = +), + (:B, :to, :A) => SAGEConv(Dense(2 * 4, 2), relu, aggr = +)); y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end From 4f4565c98703c6de100039fbb1ed8eafd9bde46a Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Feb 2024 18:27:22 +0530 Subject: [PATCH 05/12] tests should work --- src/layers/conv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 8517ebca9..1afc1d1fb 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -803,9 +803,9 @@ end function (l::SAGEConv)(g::AbstractGNNGraph, x) check_num_nodes(g, x) - xj, _ = expand_srcdst(g, x) + xj, xi = expand_srcdst(g, x) m = propagate(copy_xj, g, l.aggr, xj = xj) - x = l.σ.(l.weight * vcat(x, m) .+ l.bias) + x = l.σ.(l.weight * vcat(xi, m) .+ l.bias) return x end From 61cc6646a0f309234c84c4a90c5e0643a9340105 Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Feb 2024 18:30:20 +0530 Subject: [PATCH 06/12] test update --- test/layers/heteroconv.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index 28c7f930d..e4d0fd40a 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -111,8 +111,8 @@ @testset "SAGEConv" begin x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(Dense(2 * 4, 2), relu, aggr = +), - (:B, :to, :A) => SAGEConv(Dense(2 * 4, 2), relu, aggr = +)); + layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +), + (:B, :to, :A) => SAGEConv(4 => 2, relu, bias = false, aggr = +)); y = layers(hg, x); @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end From d8dde07b415c57d34ddc716de6f3d6743bee670f Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Feb 2024 18:32:08 +0530 Subject: [PATCH 07/12] temporary testing fast --- test/runtests.jl | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 271373ecc..e2d5d9cd9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,27 +25,7 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets include("test_utils.jl") tests = [ - "GNNGraphs/chainrules", - "GNNGraphs/datastore", - "GNNGraphs/gnngraph", - "GNNGraphs/convert", - "GNNGraphs/transform", - "GNNGraphs/operators", - "GNNGraphs/generate", - "GNNGraphs/query", - "GNNGraphs/sampling", - "GNNGraphs/gnnheterograph", - "GNNGraphs/temporalsnapshotsgnngraph", - "utils", - "msgpass", - "layers/basic", - "layers/conv", "layers/heteroconv", - "layers/temporalconv", - "layers/pool", - "mldatasets", - "examples/node_classification_cora", - "deprecations", ] !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") From fe683edf13b5227301fab739aeacabe9ffe086fb Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Thu, 22 Feb 2024 18:38:46 +0530 Subject: [PATCH 08/12] final tests --- test/runtests.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index e2d5d9cd9..271373ecc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,7 +25,27 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = true # for MLDatasets include("test_utils.jl") tests = [ + "GNNGraphs/chainrules", + "GNNGraphs/datastore", + "GNNGraphs/gnngraph", + "GNNGraphs/convert", + "GNNGraphs/transform", + "GNNGraphs/operators", + "GNNGraphs/generate", + "GNNGraphs/query", + "GNNGraphs/sampling", + "GNNGraphs/gnnheterograph", + "GNNGraphs/temporalsnapshotsgnngraph", + "utils", + "msgpass", + "layers/basic", + "layers/conv", "layers/heteroconv", + "layers/temporalconv", + "layers/pool", + "mldatasets", + "examples/node_classification_cora", + "deprecations", ] !CUDA.functional() && @warn("CUDA unavailable, not testing GPU support") From 2ca6246e94e8f007855b33a0f5e3a1f916f12507 Mon Sep 17 00:00:00 2001 From: rbSparky Date: Fri, 23 Feb 2024 01:14:31 +0530 Subject: [PATCH 09/12] EGNN Hetero --- .gitignore | 1 + src/layers/conv.jl | 10 +++++++--- src/layers/heteroconv.jl | 21 +++++++++++++++++++++ test/layers/heteroconv.jl | 29 +++++++++++++++++++++++------ 4 files changed, 52 insertions(+), 9 deletions(-) diff --git a/.gitignore b/.gitignore index 3d1804049..13cacaa12 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,4 @@ Manifest.toml LocalPreferences.toml .DS_Store /test.jl +try.jl diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 1afc1d1fb..0a7ba0b91 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1451,13 +1451,17 @@ function EGNNConv(ch::Pair{NTuple{2, Int}, Int}; hidden_size::Int = 2 * ch[1][1] return EGNNConv(ϕe, ϕx, ϕh, num_features, residual) end -function (l::EGNNConv)(g::GNNGraph, h::AbstractMatrix, x::AbstractMatrix, e = nothing) +function (l::EGNNConv)(g::AbstractGNNGraph, h, x, e = nothing) if l.num_features.edge > 0 @assert e!==nothing "Edge features must be provided." end + @assert size(h, 1)==l.num_features.in "Input features must match layer input size." - - x_diff = apply_edges(xi_sub_xj, g, x, x) + print("\n\n\n\nPANG\n\n\n\n") + xj, xi = expand_srcdst(g, x) + #hj, hi = expand_srcdst(g, h) not needed since its invariant node features + + x_diff = apply_edges(xi_sub_xj, g, xi, xj) sqnorm_xdiff = sum(x_diff .^ 2, dims = 1) x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1.0f-6) diff --git a/src/layers/heteroconv.jl b/src/layers/heteroconv.jl index ec75c8922..fe8205199 100644 --- a/src/layers/heteroconv.jl +++ b/src/layers/heteroconv.jl @@ -65,6 +65,27 @@ function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::Union{NamedTuple,Dict}) return _reduceby_node_t(hgc.aggr, outs, dst_ntypes) end + +function (hgc::HeteroGraphConv)(g::GNNHeteroGraph, x::NamedTuple, h::AbstractMatrix) + function forw(l, et) + sg = edge_type_subgraph(g, et) + node1_t, _, node2_t = et + + print(x,"\n\n", h,"before\n\n\n") + + x_features = (x[node1_t], x[node2_t]) + h_features = h # temporary + + return l(sg, h_features, x_features) + + end + outs = [forw(l, et) for (l, et) in zip(hgc.layers, hgc.etypes)] + dst_ntypes = [et[3] for et in hgc.etypes] + return _reduceby_node_t(hgc.aggr, outs, dst_ntypes) +end + + + function _reduceby_node_t(aggr, outs, ntypes) function _reduce(node_t) idxs = findall(x -> x == node_t, ntypes) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index e4d0fd40a..9733bea99 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -109,11 +109,28 @@ @test size(y.A) == (2,2) && size(y.B) == (2,3) end - @testset "SAGEConv" begin - x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) - layers = HeteroGraphConv((:A, :to, :B) => SAGEConv(4 => 2, relu, bias = false, aggr = +), - (:B, :to, :A) => SAGEConv(4 => 2, relu, bias = false, aggr = +)); - y = layers(hg, x); - @test size(y.A) == (2, 2) && size(y.B) == (2, 3) + @testset "EGNNConv with Heterogeneous Graphs" begin + # Tests are work in progress + hin_A, hout_A, hidden_A = 5, 5, 10 + hin_B, hout_B, hidden_B = 3, 3, 6 + num_nodes_A, num_nodes_B = 5, 3 + + hg = rand_bipartite_heterograph((num_nodes_A, num_nodes_B), 15) + + layers = HeteroGraphConv([ + (:A, :to, :B) => EGNNConv((hin_A, 0) => hout_B; hidden_size = hidden_A, residual = false), + (:B, :to, :A) => EGNNConv((hin_B, 0) => hout_A; hidden_size = hidden_B, residual = false) + ]) + + T = Float32 + h = (A = randn(T, hin_A, num_nodes_A), B = randn(T, hin_B, num_nodes_B)) + x = (A = rand(T, 3, num_nodes_A), B = rand(T, 3, num_nodes_B)) + + y = layers(hg, x, h) + + @test size(y[:A].h) == (hout_A, num_nodes_A) + @test size(y[:B].h) == (hout_B, num_nodes_B) + @test size(y[:A].x) == (3, num_nodes_A) + @test size(y[:B].x) == (3, num_nodes_B) end end From 4cbcfb7c9cf05e261f43667bb00f0d57515984aa Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 23 Feb 2024 01:17:59 +0530 Subject: [PATCH 10/12] Update conv.jl --- src/layers/conv.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 0a7ba0b91..eb1558bb9 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1457,7 +1457,6 @@ function (l::EGNNConv)(g::AbstractGNNGraph, h, x, e = nothing) end @assert size(h, 1)==l.num_features.in "Input features must match layer input size." - print("\n\n\n\nPANG\n\n\n\n") xj, xi = expand_srcdst(g, x) #hj, hi = expand_srcdst(g, h) not needed since its invariant node features From bc75200608db1970dddd14fdaea5f0d504bd814f Mon Sep 17 00:00:00 2001 From: Rishabh <59335537+rbSparky@users.noreply.github.com> Date: Fri, 23 Feb 2024 01:22:48 +0530 Subject: [PATCH 11/12] remove old changes --- src/layers/conv.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index eb1558bb9..9e32ef060 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -801,11 +801,10 @@ function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean, SAGEConv(W, b, σ, aggr) end -function (l::SAGEConv)(g::AbstractGNNGraph, x) +function (l::SAGEConv)(g::GNNGraph, x::AbstractMatrix) check_num_nodes(g, x) - xj, xi = expand_srcdst(g, x) - m = propagate(copy_xj, g, l.aggr, xj = xj) - x = l.σ.(l.weight * vcat(xi, m) .+ l.bias) + m = propagate(copy_xj, g, l.aggr, xj = x) + x = l.σ.(l.weight * vcat(x, m) .+ l.bias) return x end From 7ad941036df600e543b4389ce018fb3b5d6c25ee Mon Sep 17 00:00:00 2001 From: rbSparky Date: Thu, 7 Mar 2024 01:44:39 +0530 Subject: [PATCH 12/12] tests still dont work, wip --- src/layers/conv.jl | 4 ++-- test/layers/heteroconv.jl | 31 ++++++++++--------------------- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 3ed439b4e..76fb3ec0f 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -1663,14 +1663,14 @@ function (l::EGNNConv)(g::AbstractGNNGraph, h, x, e = nothing) @assert size(h, 1)==l.num_features.in "Input features must match layer input size." xj, xi = expand_srcdst(g, x) - #hj, hi = expand_srcdst(g, h) not needed since its invariant node features + hj, hi = expand_srcdst(g, h) #not needed since its invariant node features x_diff = apply_edges(xi_sub_xj, g, xi, xj) sqnorm_xdiff = sum(x_diff .^ 2, dims = 1) x_diff = x_diff ./ (sqrt.(sqnorm_xdiff) .+ 1.0f-6) msg = apply_edges(message, g, l, - xi = (; h), xj = (; h), e = (; e, x_diff, sqnorm_xdiff)) + xi = (; hi), xj = (; hj), e = (; e, x_diff, sqnorm_xdiff)) h_aggr = aggregate_neighbors(g, +, msg.h) x_aggr = aggregate_neighbors(g, mean, msg.x) diff --git a/test/layers/heteroconv.jl b/test/layers/heteroconv.jl index 1b0bdd16e..c987eb58d 100644 --- a/test/layers/heteroconv.jl +++ b/test/layers/heteroconv.jl @@ -126,28 +126,17 @@ end @testset "EGNNConv with Heterogeneous Graphs" begin - # Tests are work in progress - hin_A, hout_A, hidden_A = 5, 5, 10 - hin_B, hout_B, hidden_B = 3, 3, 6 - num_nodes_A, num_nodes_B = 5, 3 - - hg = rand_bipartite_heterograph((num_nodes_A, num_nodes_B), 15) - - layers = HeteroGraphConv([ - (:A, :to, :B) => EGNNConv((hin_A, 0) => hout_B; hidden_size = hidden_A, residual = false), - (:B, :to, :A) => EGNNConv((hin_B, 0) => hout_A; hidden_size = hidden_B, residual = false) - ]) - - T = Float32 - h = (A = randn(T, hin_A, num_nodes_A), B = randn(T, hin_B, num_nodes_B)) - x = (A = rand(T, 3, num_nodes_A), B = rand(T, 3, num_nodes_B)) - + hin = 5 + hout = 5 + hidden = 5 + hg = rand_bipartite_heterograph((2,3), 6) + hg.num_nodes + x = (A = rand(Float32, 4, 2), B = rand(Float32, 4, 3)) + h = (A = rand(Float32, 5, 2), B = rand(Float32, 5, 3)) + layers = HeteroGraphConv((:A, :to, :B) => EGNNConv(4 => 2), + (:B, :to, :A) => EGNNConv(4 => 2)); y = layers(hg, x, h) - - @test size(y[:A].h) == (hout_A, num_nodes_A) - @test size(y[:B].h) == (hout_B, num_nodes_B) - @test size(y[:A].x) == (3, num_nodes_A) - @test size(y[:B].x) == (3, num_nodes_B) + @test size(y.A) == (2, 2) && size(y.B) == (2, 3) end @testset "GINConv" begin