From 5a8172cb7ecb52e8977ad1e6fa9b564743dae9e3 Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Wed, 25 Dec 2024 23:57:28 +0100 Subject: [PATCH] no constraints on numbers of features gdata for heterographs (#570) --- GNNGraphs/Project.toml | 2 +- GNNGraphs/src/gnnheterograph/gnnheterograph.jl | 3 ++- GNNGraphs/test/gnnheterograph.jl | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/GNNGraphs/Project.toml b/GNNGraphs/Project.toml index d56d27399..0db318a93 100644 --- a/GNNGraphs/Project.toml +++ b/GNNGraphs/Project.toml @@ -1,7 +1,7 @@ name = "GNNGraphs" uuid = "aed8fd31-079b-4b5a-b342-a13352159b8c" authors = ["Carlo Lucibello and contributors"] -version = "1.4.0" +version = "1.4.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/GNNGraphs/src/gnnheterograph/gnnheterograph.jl b/GNNGraphs/src/gnnheterograph/gnnheterograph.jl index 24d57d453..b449b42dd 100644 --- a/GNNGraphs/src/gnnheterograph/gnnheterograph.jl +++ b/GNNGraphs/src/gnnheterograph/gnnheterograph.jl @@ -144,7 +144,8 @@ function GNNHeteroGraph(data::EDict; ndata = normalize_heterographdata(ndata, default_name = :x, ns = num_nodes) edata = normalize_heterographdata(edata, default_name = :e, ns = num_edges, duplicate_if_needed = true) - gdata = normalize_graphdata(gdata, default_name = :u, n = num_graphs) + gdata = normalize_graphdata(gdata, default_name = :u, + n = num_graphs > 1 ? num_graphs : -1) end return GNNHeteroGraph(graph, diff --git a/GNNGraphs/test/gnnheterograph.jl b/GNNGraphs/test/gnnheterograph.jl index f3c29b80f..f92bbad0f 100644 --- a/GNNGraphs/test/gnnheterograph.jl +++ b/GNNGraphs/test/gnnheterograph.jl @@ -31,9 +31,9 @@ end @test hg.ndata isa Dict{Symbol, DataStore} @test hg.edata isa Dict{Tuple{Symbol, Symbol, Symbol}, DataStore} @test isempty(hg.gdata) + @test hg.gdata._n == -1 # no constraints on gdata @test sort(hg.ntypes) == [:A, :B] @test sort(hg.etypes) == [(:A, :rel1, :B), (:B, :rel2, :A)] - end @testset "features" begin