From fefebc240225c7a3961d3d08610a23e3cd2a4767 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 12:38:56 +0100 Subject: [PATCH 01/19] replaced a closure with `Fix1` --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 4cf1f1b02..98f020101 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -626,7 +626,7 @@ setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. """ function getranges(vi::VarInfo, vns::Vector{<:VarName}) - return mapreduce(vn -> getrange(vi, vn), vcat, vns; init=Int[]) + return mapreduce(Base.Fix1(getrange, vi), vcat, vns; init=Int[]) end """ From 76a91833f47dade5cf89f5fc10e15b964418c69e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 12:54:26 +0100 Subject: [PATCH 02/19] added correct implementation of `getrange` for `TypedVarInfo` --- src/varinfo.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 98f020101..99ea9c54a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -610,6 +610,21 @@ getidx(md::Metadata, vn::VarName) = md.idcs[vn] Return the index range of `vn` in the metadata of `vi`. """ getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn) +# For `TypedVarInfo` it's more difficult since we need to keep track of the offset. +# TOOD: Should we unroll this using `@generated`? +function getrange(vi::TypedVarInfo, vn::VarName) + offset = 0 + for md in values(vi.metadata) + # First, we need to check if `vn` is in `md`. + # In this case, we can just return the corresponding range + offset. + haskey(md, vn) && return getrange(md, vn) .+ offset + # Otherwise, we need to get the cumulative length of the ranges in `md` + # and add it to the offset. + offset += sum(length, md.ranges) + end + # If we reach this point, `vn` is not in `vi.metadata`. + throw(KeyError(vn)) +end getrange(md::Metadata, vn::VarName) = md.ranges[getidx(md, vn)] """ From 92b96d78799e38bba80b8b62552a63a21084b1e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 12:54:54 +0100 Subject: [PATCH 03/19] fixed calls to varinfo methods which should be metadata methods --- src/varinfo.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 99ea9c54a..04b8c4601 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1329,12 +1329,13 @@ end function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) # TODO: Use inplace versions to avoid allocations - yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, vn)) + yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, md)) # Determine the new range. - start = first(getrange(vi, vn)) + start = first(getrange(md, vn)) # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. - setrange!(vi, vn, start:(start + length(yvec) - 1)) + setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. + # TODO: should replace this with a `setval!` for `Metadata`. setval!(vi, yvec, vn) acclogp!!(vi, -logjac) return vi From 73d120f1ad1343ce332f6a4ddc25b7658c68911e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 13:06:56 +0100 Subject: [PATCH 04/19] fixed typo --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 04b8c4601..052bb0968 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1329,7 +1329,7 @@ end function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) # TODO: Use inplace versions to avoid allocations - yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(vi, md)) + yvec, logjac = with_logabsdet_jacobian(f, getindex_internal(md, vn)) # Determine the new range. start = first(getrange(md, vn)) # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. From b1b8a00f3a2f3d5d5d95e5bf045280630e160a50 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 13:08:31 +0100 Subject: [PATCH 05/19] use `setval!` on the metadata directly instead of on the varinfo --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 052bb0968..0b5164b0e 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1336,7 +1336,7 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. # TODO: should replace this with a `setval!` for `Metadata`. - setval!(vi, yvec, vn) + setval!(md, yvec, vn) acclogp!!(vi, -logjac) return vi end From 23d561f0099966ce6dc3a8d4dfec644b68516b29 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 13:50:11 +0100 Subject: [PATCH 06/19] added `length` implementation for `VarInfo` and `Metadata` --- src/varinfo.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/varinfo.jl b/src/varinfo.jl index 0b5164b0e..e08a945bd 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -202,6 +202,11 @@ function VarInfo( end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) + +Base.length(varinfo::VarInfo) = length(varinfo.metadata) +Base.length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) +Base.length(md::Metadata) = sum(length, md.ranges) + unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) # TODO: deprecate. @@ -643,6 +648,29 @@ Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. function getranges(vi::VarInfo, vns::Vector{<:VarName}) return mapreduce(Base.Fix1(getrange, vi), vcat, vns; init=Int[]) end +# A more efficient version for `TypedVarInfo`. +function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.VarName}) + # TODO: Does it help if we _don't_ convert to a vector here? + metadatas = collect(values(varinfo.metadata)) + # Extract the offsets. + offsets = cumsum(map(length, metadatas)) + # Extract the ranges from each metadata. + ranges = Vector{UnitRange{Int}}(undef, length(vns)) + for (i, metadata) in enumerate(metadatas) + vns_metadata = filter(Base.Fix1(haskey, metadata), vns) + # If none of the variables exist in the metadata, we return an empty array. + isempty(vns_metadata) && continue + # Otherwise, we extract the ranges. + offset = i == 1 ? 0 : offsets[i - 1] + for vn in vns_metadata + r_vn = getrange(metadata, vn) + # Get the index, so we return in the same order as `vns`. + idx = findfirst(==(vn), vns) + ranges[idx] = r_vn .+ offset + end + end + return ranges +end """ getdist(vi::VarInfo, vn::VarName) From 9cf72a0d72cb52f257be0e08d9ca5ccdb728aaca Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 13:50:29 +0100 Subject: [PATCH 07/19] added testing for `getranges --- test/varinfo.jl | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/varinfo.jl b/test/varinfo.jl index c45fb47e0..8f033bd44 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -813,4 +813,21 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test DynamicPPL.istrans(varinfo2, vn) end end + + @testset "getranges" begin + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + vns = DynamicPPL.TestUtils.varnames(model) + varinfo = DynamicPPL.typed_varinfo(model) + x = values_as(varinfo, Vector) + + # Let's just check all the subsets of `vns`. + @testset "$(convert(Vector{Any},vns_subset))" for vns_subset in combinations(vns) + ranges = DynamicPPL.getranges(varinfo, vns_subset) + @test length(ranges) == length(vns_subset) + for (r, vn) in zip(ranges, vns_subset) + @test x[r] == DynamicPPL.tovec(varinfo[vn]) + end + end + end + end end From 23bfe5407b8cdcb186e2776231fc1d1eb10f5d37 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 14:06:26 +0100 Subject: [PATCH 08/19] introduce `vector_length` instead of `length`, since `length` already refers to the dictionary-like length impl, not vector-like --- src/varinfo.jl | 8 ++++---- src/varnamedvector.jl | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index e08a945bd..1d32dee01 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -203,9 +203,9 @@ end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) -Base.length(varinfo::VarInfo) = length(varinfo.metadata) -Base.length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) -Base.length(md::Metadata) = sum(length, md.ranges) +vector_length(varinfo::VarInfo) = length(varinfo.metadata) +vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) +vector_length(md::Metadata) = sum(length, md.ranges) unflatten(vi::VarInfo, x::AbstractVector) = unflatten(vi, SampleFromPrior(), x) @@ -653,7 +653,7 @@ function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.Va # TODO: Does it help if we _don't_ convert to a vector here? metadatas = collect(values(varinfo.metadata)) # Extract the offsets. - offsets = cumsum(map(length, metadatas)) + offsets = cumsum(map(vector_length, metadatas)) # Extract the ranges from each metadata. ranges = Vector{UnitRange{Int}}(undef, length(vns)) for (i, metadata) in enumerate(metadatas) diff --git a/src/varnamedvector.jl b/src/varnamedvector.jl index a5097602d..039b549d6 100644 --- a/src/varnamedvector.jl +++ b/src/varnamedvector.jl @@ -1036,6 +1036,8 @@ function replace_raw_storage(vnv::VarNamedVector, ::Val{space}, vals) where {spa return replace_raw_storage(vnv, vals) end +vector_length(vnv::VarNamedVector) = length(vnv.vals) - num_inactive(vnv) + """ unflatten(vnv::VarNamedVector, vals::AbstractVector) From bdcc69f7ef9a186e3f284116d511248482256a54 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 14:07:00 +0100 Subject: [PATCH 09/19] fixed bug in `getranges` for untyped varinfo --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 1d32dee01..3a1d7f570 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -646,7 +646,7 @@ setrange!(md::Metadata, vn::VarName, range) = md.ranges[getidx(md, vn)] = range Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. """ function getranges(vi::VarInfo, vns::Vector{<:VarName}) - return mapreduce(Base.Fix1(getrange, vi), vcat, vns; init=Int[]) + return map(Base.Fix1(getrange, vi), vns) end # A more efficient version for `TypedVarInfo`. function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.VarName}) From 90aef0ba2c2da9f1d1ba0dae37d4c4c9ad6995c9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 14:07:11 +0100 Subject: [PATCH 10/19] added proper testing for other `VarInfo` types --- test/varinfo.jl | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/test/varinfo.jl b/test/varinfo.jl index 8f033bd44..36c07b82d 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -814,18 +814,26 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) end end - @testset "getranges" begin + # NOTE: It is not yet clear if this is something we want from all varinfo types. + # Hence, we only test the `VarInfo` types here. + @testset "getranges for `VarInfo`" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) - varinfo = DynamicPPL.typed_varinfo(model) - x = values_as(varinfo, Vector) - - # Let's just check all the subsets of `vns`. - @testset "$(convert(Vector{Any},vns_subset))" for vns_subset in combinations(vns) - ranges = DynamicPPL.getranges(varinfo, vns_subset) - @test length(ranges) == length(vns_subset) - for (r, vn) in zip(ranges, vns_subset) - @test x[r] == DynamicPPL.tovec(varinfo[vn]) + nt = DynamicPPL.TestUtils.rand_prior_true(model) + varinfos = DynamicPPL.TestUtils.setup_varinfos(model, nt, vns) + # Only keep `VarInfo` types. + varinfos = filter(Base.Fix2(isa, VarInfo), varinfos) + @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos + x = values_as(varinfo, Vector) + + # Let's just check all the subsets of `vns`. + @testset "$(convert(Vector{Any},vns_subset))" for vns_subset in + combinations(vns) + ranges = DynamicPPL.getranges(varinfo, vns_subset) + @test length(ranges) == length(vns_subset) + for (r, vn) in zip(ranges, vns_subset) + @test x[r] == DynamicPPL.tovec(varinfo[vn]) + end end end end From f500c237b580b6759aa12ec778f03128f0d618fd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 14:07:35 +0100 Subject: [PATCH 11/19] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fafeec2ba..df7357d3c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.31.1" +version = "0.31.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 8afe6813a64b75dd420f57a71d5e74bf9daa33ca Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 14:33:04 +0100 Subject: [PATCH 12/19] separated the `getrange` version which returns the range of the vecto representaiton rather than the internal representaiton into `vector_getrange` to make its function explicit --- src/threadsafe.jl | 4 ++++ src/varinfo.jl | 55 ++++++++++++++++++++++++++++++++--------------- test/varinfo.jl | 4 ++-- 3 files changed, 44 insertions(+), 19 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index ec890a674..1bf2a6058 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -178,6 +178,10 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<: return Accessors.@set vi.varinfo = BangBang.setindex!!(vi.varinfo, vals, vns) end +vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo) +vector_getrange(vi::ThreadSafeVarInfo) = vector_getrange(vi.varinfo) +vector_getranges(vi::ThreadSafeVarInfo) = vector_getranges(vi.varinfo) + function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) return set_retained_vns_del_by_spl!(vi.varinfo, spl) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 3a1d7f570..78acc9f93 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -203,6 +203,11 @@ end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) +""" + vector_length(varinfo::VarInfo) + +Return the length of the vector representation of `varinfo`. +""" vector_length(varinfo::VarInfo) = length(varinfo.metadata) vector_length(varinfo::TypedVarInfo) = sum(length, varinfo.metadata) vector_length(md::Metadata) = sum(length, md.ranges) @@ -615,21 +620,6 @@ getidx(md::Metadata, vn::VarName) = md.idcs[vn] Return the index range of `vn` in the metadata of `vi`. """ getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn) -# For `TypedVarInfo` it's more difficult since we need to keep track of the offset. -# TOOD: Should we unroll this using `@generated`? -function getrange(vi::TypedVarInfo, vn::VarName) - offset = 0 - for md in values(vi.metadata) - # First, we need to check if `vn` is in `md`. - # In this case, we can just return the corresponding range + offset. - haskey(md, vn) && return getrange(md, vn) .+ offset - # Otherwise, we need to get the cumulative length of the ranges in `md` - # and add it to the offset. - offset += sum(length, md.ranges) - end - # If we reach this point, `vn` is not in `vi.metadata`. - throw(KeyError(vn)) -end getrange(md::Metadata, vn::VarName) = md.ranges[getidx(md, vn)] """ @@ -648,8 +638,38 @@ Return the indices of `vns` in the metadata of `vi` corresponding to `vn`. function getranges(vi::VarInfo, vns::Vector{<:VarName}) return map(Base.Fix1(getrange, vi), vns) end -# A more efficient version for `TypedVarInfo`. -function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.VarName}) + +""" + vector_getrange(varinfo::VarInfo, varname::VarName) + +Return the range corresponding to `varname` in the vector representation of `varinfo`. +""" +vector_getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn) +function vector_getrange(vi::TypedVarInfo, vn::VarName) + offset = 0 + for md in values(vi.metadata) + # First, we need to check if `vn` is in `md`. + # In this case, we can just return the corresponding range + offset. + haskey(md, vn) && return vector_getrange(md, vn) .+ offset + # Otherwise, we need to get the cumulative length of the ranges in `md` + # and add it to the offset. + offset += sum(length, md.ranges) + end + # If we reach this point, `vn` is not in `vi.metadata`. + throw(KeyError(vn)) +end +vector_getrange(md::Metadata, vn::VarName) = getrange(md, vn) + +""" + vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName}) + +Return the range corresponding to `varname` in the vector representation of `varinfo`. +""" +function vector_getranges(varinfo::VarInfo, varname::Vector{<:VarName}) + return map(Base.Fix1(vector_getrange, varinfo), varname) +end +# Specialized version for `TypedVarInfo`. +function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) # TODO: Does it help if we _don't_ convert to a vector here? metadatas = collect(values(varinfo.metadata)) # Extract the offsets. @@ -672,6 +692,7 @@ function getranges(varinfo::DynamicPPL.TypedVarInfo, vns::Vector{<:DynamicPPL.Va return ranges end + """ getdist(vi::VarInfo, vn::VarName) diff --git a/test/varinfo.jl b/test/varinfo.jl index 36c07b82d..1f9f5fd21 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -816,7 +816,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # NOTE: It is not yet clear if this is something we want from all varinfo types. # Hence, we only test the `VarInfo` types here. - @testset "getranges for `VarInfo`" begin + @testset "vector_getranges for `VarInfo`" begin @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) nt = DynamicPPL.TestUtils.rand_prior_true(model) @@ -829,7 +829,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # Let's just check all the subsets of `vns`. @testset "$(convert(Vector{Any},vns_subset))" for vns_subset in combinations(vns) - ranges = DynamicPPL.getranges(varinfo, vns_subset) + ranges = DynamicPPL.vector_getranges(varinfo, vns_subset) @test length(ranges) == length(vns_subset) for (r, vn) in zip(ranges, vns_subset) @test x[r] == DynamicPPL.tovec(varinfo[vn]) From 25b19a45b6178069d541b1e529c0aaafbf9b803c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 14:34:25 +0100 Subject: [PATCH 13/19] formatting --- src/varinfo.jl | 2 -- test/varinfo.jl | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 78acc9f93..98ba09894 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -202,7 +202,6 @@ function VarInfo( end VarInfo(model::Model, args...) = VarInfo(Random.default_rng(), model, args...) - """ vector_length(varinfo::VarInfo) @@ -692,7 +691,6 @@ function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) return ranges end - """ getdist(vi::VarInfo, vn::VarName) diff --git a/test/varinfo.jl b/test/varinfo.jl index 1f9f5fd21..7e67fe787 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -817,7 +817,7 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) # NOTE: It is not yet clear if this is something we want from all varinfo types. # Hence, we only test the `VarInfo` types here. @testset "vector_getranges for `VarInfo`" begin - @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS + @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) nt = DynamicPPL.TestUtils.rand_prior_true(model) varinfos = DynamicPPL.TestUtils.setup_varinfos(model, nt, vns) From 2734070aed39f0bb37c7c37b643a8ef02559ce3d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 14:35:48 +0100 Subject: [PATCH 14/19] removed `vector_getrange` for metadata --- src/varinfo.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 98ba09894..cebe15fda 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -643,13 +643,13 @@ end Return the range corresponding to `varname` in the vector representation of `varinfo`. """ -vector_getrange(vi::VarInfo, vn::VarName) = getrange(getmetadata(vi, vn), vn) +vector_getrange(vi::VarInfo, vn::VarName) = getrange(vi.metadata, vn) function vector_getrange(vi::TypedVarInfo, vn::VarName) offset = 0 for md in values(vi.metadata) # First, we need to check if `vn` is in `md`. # In this case, we can just return the corresponding range + offset. - haskey(md, vn) && return vector_getrange(md, vn) .+ offset + haskey(md, vn) && return getrange(md, vn) .+ offset # Otherwise, we need to get the cumulative length of the ranges in `md` # and add it to the offset. offset += sum(length, md.ranges) @@ -657,7 +657,6 @@ function vector_getrange(vi::TypedVarInfo, vn::VarName) # If we reach this point, `vn` is not in `vi.metadata`. throw(KeyError(vn)) end -vector_getrange(md::Metadata, vn::VarName) = getrange(md, vn) """ vector_getranges(varinfo::VarInfo, varnames::Vector{<:VarName}) From cd78d24a84541a06ae61892366a8880263d7cfa9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 18:50:30 +0000 Subject: [PATCH 15/19] added handling of missing indices + tests for these cases --- src/varinfo.jl | 7 +++++++ test/varinfo.jl | 18 ++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index cebe15fda..98b1db669 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -687,6 +687,13 @@ function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) ranges[idx] = r_vn .+ offset end end + # Raise key error if any of the variables were not found. + if any(!isassigned, ranges) + inds = findall(!isassigned, ranges) + # Just use a `convert` to get the same type as the input; don't want to confuse by overly + # specilizing the types in the error message. + throw(KeyError(convert(typeof(vns), vns[inds]))) + end return ranges end diff --git a/test/varinfo.jl b/test/varinfo.jl index 7e67fe787..fae3dea48 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -820,9 +820,13 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @testset "$(model.f)" for model in DynamicPPL.TestUtils.DEMO_MODELS vns = DynamicPPL.TestUtils.varnames(model) nt = DynamicPPL.TestUtils.rand_prior_true(model) - varinfos = DynamicPPL.TestUtils.setup_varinfos(model, nt, vns) + varinfos = DynamicPPL.TestUtils.setup_varinfos( + model, nt, vns; include_threadsafe=true + ) # Only keep `VarInfo` types. - varinfos = filter(Base.Fix2(isa, VarInfo), varinfos) + varinfos = filter( + Base.Fix2(isa, DynamicPPL.VarInfoOrThreadSafeVarInfo), varinfos + ) @testset "$(short_varinfo_name(varinfo))" for varinfo in varinfos x = values_as(varinfo, Vector) @@ -835,6 +839,16 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test x[r] == DynamicPPL.tovec(varinfo[vn]) end end + + # Let's try some failure cases. + @test DynamicPPL.vector_getranges(varinfo, VarName[]) == UnitRange{Int}[] + # Non-existent variables. + @test_throws KeyError DynamicPPL.vector_getranges( + varinfo, [VarName{gensym("vn")}()] + ) + @test_throws KeyError DynamicPPL.vector_getranges( + varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()] + ) end end end From 65b2de40cff0a34ebc5cda7d1a3a37d4c1a5a950 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 18:55:08 +0000 Subject: [PATCH 16/19] added handling of duplicated values --- src/varinfo.jl | 7 +++++-- test/varinfo.jl | 3 +++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 98b1db669..986f0c97a 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -683,8 +683,11 @@ function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) for vn in vns_metadata r_vn = getrange(metadata, vn) # Get the index, so we return in the same order as `vns`. - idx = findfirst(==(vn), vns) - ranges[idx] = r_vn .+ offset + # NOTE: There might be duplicates in `vns`, so we need to handle that. + indices = findall(==(vn), vns) + for idx in indices + ranges[idx] = r_vn .+ offset + end end end # Raise key error if any of the variables were not found. diff --git a/test/varinfo.jl b/test/varinfo.jl index fae3dea48..ae4319904 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -849,6 +849,9 @@ DynamicPPL.getspace(::DynamicPPL.Sampler{MySAlg}) = (:s,) @test_throws KeyError DynamicPPL.vector_getranges( varinfo, [VarName{gensym("vn")}(), VarName{gensym("vn")}()] ) + # Duplicate variables. + ranges_duplicated = DynamicPPL.vector_getranges(varinfo, repeat(vns, 2)) + @test x[reduce(vcat, ranges_duplicated)] == repeat(x, 2) end end end From 5fc1b302b97607a4a1888e96d1a1a3973d615e8c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 5 Dec 2024 18:55:39 +0000 Subject: [PATCH 17/19] removed no-longer relevant comment --- src/varinfo.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 986f0c97a..7eb3c8601 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1391,7 +1391,6 @@ function _inner_transform!(md::Metadata, vi::VarInfo, vn::VarName, f) # NOTE: `length(yvec)` should never be longer than `getrange(vi, vn)`. setrange!(md, vn, start:(start + length(yvec) - 1)) # Set the new value. - # TODO: should replace this with a `setval!` for `Metadata`. setval!(md, yvec, vn) acclogp!!(vi, -logjac) return vi From 5e89f95a3ae6bf201f32e5684042ed4f6afd3e79 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 6 Dec 2024 07:03:42 +0000 Subject: [PATCH 18/19] fixed impl of `vector_getrange` and `vector_getranges` for threadsafe varinfo --- src/threadsafe.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/threadsafe.jl b/src/threadsafe.jl index 1bf2a6058..f7ce569b3 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -179,8 +179,10 @@ function BangBang.setindex!!(vi::ThreadSafeVarInfo, vals, vns::AbstractVector{<: end vector_length(vi::ThreadSafeVarInfo) = vector_length(vi.varinfo) -vector_getrange(vi::ThreadSafeVarInfo) = vector_getrange(vi.varinfo) -vector_getranges(vi::ThreadSafeVarInfo) = vector_getranges(vi.varinfo) +vector_getrange(vi::ThreadSafeVarInfo, vn::VarName) = vector_getrange(vi.varinfo, vn) +function vector_getranges(vi::ThreadSafeVarInfo, vns::Vector{<:VarName}) + return vector_getranges(vi.varinfo, vns) +end function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) return set_retained_vns_del_by_spl!(vi.varinfo, spl) From b5e20d579aa31dfdd8c474ae65f6bdd6aeeb279e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 6 Dec 2024 07:04:42 +0000 Subject: [PATCH 19/19] fixed `vector_getranges` when `vns` are not found --- src/varinfo.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index 7eb3c8601..2c07d4298 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -674,6 +674,8 @@ function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) offsets = cumsum(map(vector_length, metadatas)) # Extract the ranges from each metadata. ranges = Vector{UnitRange{Int}}(undef, length(vns)) + # Need to keep track of which ones we've seen. + not_seen = fill(true, length(vns)) for (i, metadata) in enumerate(metadatas) vns_metadata = filter(Base.Fix1(haskey, metadata), vns) # If none of the variables exist in the metadata, we return an empty array. @@ -686,13 +688,14 @@ function vector_getranges(varinfo::TypedVarInfo, vns::Vector{<:VarName}) # NOTE: There might be duplicates in `vns`, so we need to handle that. indices = findall(==(vn), vns) for idx in indices + not_seen[idx] = false ranges[idx] = r_vn .+ offset end end end # Raise key error if any of the variables were not found. - if any(!isassigned, ranges) - inds = findall(!isassigned, ranges) + if any(not_seen) + inds = findall(not_seen) # Just use a `convert` to get the same type as the input; don't want to confuse by overly # specilizing the types in the error message. throw(KeyError(convert(typeof(vns), vns[inds])))