From a0e6f325c78dbeede9f098e6ef561e538595eecb Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 8 Jan 2025 13:57:33 +0000 Subject: [PATCH] Remove `resume_from` argument --- Project.toml | 2 +- docs/src/api.md | 7 ++++++- src/DynamicPPL.jl | 1 + src/sampler.jl | 19 ++++++++++++++++--- 4 files changed, 24 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 60dbcdc81..fb9a1c55f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.33.0" +version = "0.34.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/api.md b/docs/src/api.md index d5c6bd690..8321777e0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -134,6 +134,12 @@ returned(::Model) ## Utilities +To retrieve the final sampler state from a chain of samples (useful for resuming sampling from a previous point): + +```@docs +loadstate +``` + It is possible to manually increase (or decrease) the accumulated log density from within a model function. ```@docs @@ -425,7 +431,6 @@ The default implementation of [`Sampler`](@ref) uses the following unexported fu ```@docs DynamicPPL.initialstep -DynamicPPL.loadstate DynamicPPL.initialsampler ``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index c1cdbd94e..166d27f69 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -92,6 +92,7 @@ export AbstractVarInfo, Sampler, SampleFromPrior, SampleFromUniform, + loadstate, # Contexts SamplingContext, DefaultContext, diff --git a/src/sampler.jl b/src/sampler.jl index 974828e8b..dfdc21295 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -100,10 +100,16 @@ function AbstractMCMC.sample( sampler::Sampler, N::Integer; chain_type=default_chain_type(sampler), - resume_from=nothing, - initial_state=loadstate(resume_from), + initial_state=nothing, kwargs..., ) + if haskey(kwargs, :resume_from) + throw( + ArgumentError( + "The `resume_from` keyword argument is no longer supported. Please use `initial_state=loadstate(chain)` instead of `resume_from=chain`.", + ), + ) + end return AbstractMCMC.mcmcsample( rng, model, sampler, N; chain_type, initial_state, kwargs... ) @@ -135,7 +141,14 @@ end Load sampler state from `data`. -By default, `data` is returned. +If `data` isa MCMCChains.Chains object, this attempts to fetch the last state +of the sampler from the metadata stored inside the Chains object. This requires +you to have passed the `save_state=true` keyword argument to the `sample()` +when generating the chain. + +This function can be overloaded for specific types of `data` if desired. If +there is no specific implementation for a given type, it falls back to just +returning `data`, i.e. acts as an identity function. """ loadstate(data) = data