Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switching to broadcasting random walk #747

Merged
merged 14 commits into from
Aug 29, 2024
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- The interface for defining delay distributions has been generalised to also cater for continuous distributions
- When defining probability distributions these can now be truncated using the `tolerance` argument
- Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @.
- Switch to broadcasting from random walks and added unit tests. By @seabbs in #747 and reviewed by @jamesmbaazam.
- Switch to broadcasting the day of the week effect. By @seabbs in #746 and reviewed by @jamesmbaazam.
- A warning is now thrown if nonparametric PMFs passed to delay options have consecutive tail values that are below a certain low threshold as these lead to loss in speed with little gain in accuracy. By @jamesmbaazam in #752 and reviewed by @seabbs.

Expand All @@ -27,6 +28,7 @@

- Updated the documentation of the dots argument of the `stan_sampling_opts()` to add that the dots are passed to `cmdstanr::sample()`. By @jamesmbaazam in #699 and reviewed by @sbfnk.
- `generation_time_opts()` has been shortened to `gt_opts()` to make it easier to specify. Calls to both functions are equivalent. By @jamesmbaazam in #698 and reviewed by @seabbs and @sbfnk .
- Added stan documentation for `update_rt`. By @seabbs in #747 and reviewed by @jamesmbaazam.
seabbs marked this conversation as resolved.
Show resolved Hide resolved

# EpiNow2 1.5.2

Expand Down
27 changes: 21 additions & 6 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,20 @@ create_future_rt <- function(future = c("latest", "project", "estimate"),
#'
#' # using breakpoints
#' create_rt_data(rt_opts(use_breakpoints = TRUE), breakpoints = rep(1, 10))
#'
#' # using random walk
#' create_rt_data(rt_opts(rw = 7), breakpoints = rep(1, 10))
#' }
create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,
delay = 0, horizon = 0) {

# Define if GP is on or off
if (is.null(rt)) {
jamesmbaazam marked this conversation as resolved.
Show resolved Hide resolved
rt <- rt_opts(
use_rt = FALSE,
future = "project",
gp_on = "R0"
gp_on = "R0",
rw = 0
)
}
# define future Rt arguments
Expand All @@ -279,24 +284,34 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,
)
# apply random walk
if (rt$rw != 0) {
breakpoints <- as.integer(seq_along(breakpoints) %% rt$rw == 0)
if (is.null(breakpoints)) {
stop("breakpoints must be supplied when using random walk")
}

breakpoints <- seq_along(breakpoints)
breakpoints <- floor(breakpoints / rt$rw)
if (!(rt$future == "project")) {
max_bps <- length(breakpoints) - horizon + future_rt$from
if (max_bps < length(breakpoints)) {
breakpoints[(max_bps + 1):length(breakpoints)] <- 0
breakpoints[(max_bps + 1):length(breakpoints)] <- breakpoints[max_bps]
}
}
}else {
breakpoints <- cumsum(breakpoints)
}
# check breakpoints
if (is.null(breakpoints) || sum(breakpoints) == 0) {

if (sum(breakpoints) == 0) {
rt$use_breakpoints <- FALSE
}
# add a shift for 0 effect in breakpoints
breakpoints <- breakpoints + 1

# map settings to underlying gp stan requirements
rt_data <- list(
r_mean = rt$prior$mean,
r_sd = rt$prior$sd,
estimate_r = as.numeric(rt$use_rt),
bp_n = ifelse(rt$use_breakpoints, sum(breakpoints, na.rm = TRUE), 0),
bp_n = ifelse(rt$use_breakpoints, max(breakpoints) - 1, 0),
breakpoints = breakpoints,
future_fixed = as.numeric(future_rt$fixed),
fixed_from = future_rt$from,
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ parameters{
array[estimate_r] real initial_infections ; // seed infections
array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate
array[bp_n > 0 ? 1 : 0] real<lower = 0> bp_sd; // standard deviation of breakpoint effect
array[bp_n] real bp_effects; // Rt breakpoint effects
vector[bp_n] bp_effects; // Rt breakpoint effects
// observation model

vector<lower = delay_params_lower>[delay_params_length] delay_params; // delay parameters
Expand Down
66 changes: 44 additions & 22 deletions inst/stan/functions/rt.stan
Original file line number Diff line number Diff line change
@@ -1,26 +1,34 @@
// update a vector of Rts
/**
* Update a vector of effective reproduction numbers (Rt) based on
* an intercept, breakpoints (i.e. a random walk), and a Gaussian
* process.
*
* @param t Length of the time series
* @param log_R Logarithm of the base reproduction number
* @param noise Vector of Gaussian process noise values
* @param bps Array of breakpoint indices
* @param bp_effects Vector of breakpoint effects
* @param stationary Flag indicating whether the Gaussian process is stationary
* (1) or non-stationary (0)
* @return A vector of length t containing the updated Rt values
*/
vector update_Rt(int t, real log_R, vector noise, array[] int bps,
array[] real bp_effects, int stationary) {
vector bp_effects, int stationary) {
// define control parameters
int bp_n = num_elements(bp_effects);
int bp_c = 0;
int gp_n = num_elements(noise);
// define result vectors
vector[t] bp = rep_vector(0, t);
vector[t] gp = rep_vector(0, t);
vector[t] R;
// initialise breakpoints
// initialise intercept
vector[t] R = rep_vector(log_R, t);
//initialise breakpoints + rw
if (bp_n) {
for (s in 1:t) {
if (bps[s]) {
bp_c += bps[s];
bp[s] = bp_effects[bp_c];
}
}
bp = cumulative_sum(bp);
vector[bp_n + 1] bp0;
bp0[1] = 0;
bp0[2:(bp_n + 1)] = cumulative_sum(bp_effects);
R = R + bp0[bps];
}
//initialise gaussian process
if (gp_n) {
vector[t] gp = rep_vector(0, t);
if (stationary) {
gp[1:gp_n] = noise;
// fix future gp based on last estimated
Expand All @@ -31,18 +39,31 @@ vector update_Rt(int t, real log_R, vector noise, array[] int bps,
gp[2:(gp_n + 1)] = noise;
gp = cumulative_sum(gp);
}
R = R + gp;
}
// Calculate Rt
R = rep_vector(log_R, t) + bp + gp;
R = exp(R);
return(R);

return exp(R);
}
// Rt priors

/**
* Calculate the log-probability of the reproduction number (Rt) priors
*
* @param log_R Logarithm of the base reproduction number
* @param initial_infections Array of initial infection values
* @param initial_growth Array of initial growth rates
* @param bp_effects Vector of breakpoint effects
* @param bp_sd Array of breakpoint standard deviations
* @param bp_n Number of breakpoints
* @param seeding_time Time point at which seeding occurs
* @param r_logmean Log-mean of the prior distribution for the base reproduction number
* @param r_logsd Log-standard deviation of the prior distribution for the base reproduction number
* @param prior_infections Prior mean for initial infections
* @param prior_growth Prior mean for initial growth rates
*/
void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_growth,
array[] real bp_effects, array[] real bp_sd, int bp_n, int seeding_time,
vector bp_effects, array[] real bp_sd, int bp_n, int seeding_time,
real r_logmean, real r_logsd, real prior_infections,
real prior_growth) {
// prior on R
log_R ~ normal(r_logmean, r_logsd);
//breakpoint effects on Rt
if (bp_n > 0) {
Expand All @@ -51,6 +72,7 @@ void rt_lp(vector log_R, array[] real initial_infections, array[] real initial_g
}
// initial infections
initial_infections ~ normal(prior_infections, 0.2);

if (seeding_time > 1) {
initial_growth ~ normal(prior_growth, 0.2);
}
Expand Down
2 changes: 1 addition & 1 deletion man/EpiNow2-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions man/create_rt_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

88 changes: 88 additions & 0 deletions tests/testthat/test-create_rt_date.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
test_that("create_rt_data returns expected default values", {
result <- create_rt_data()

expect_type(result, "list")
expect_equal(result$r_mean, 1)
expect_equal(result$r_sd, 1)
expect_equal(result$estimate_r, 1)
expect_equal(result$bp_n, 0)
expect_equal(result$breakpoints, numeric(0))
expect_equal(result$future_fixed, 1)
expect_equal(result$fixed_from, 0)
expect_equal(result$pop, 0)
expect_equal(result$stationary, 0)
expect_equal(result$future_time, 0)
})

test_that("create_rt_data handles NULL rt input correctly", {
result <- create_rt_data(rt = NULL)

expect_equal(result$estimate_r, 0)
expect_equal(result$future_fixed, 0)
expect_equal(result$stationary, 1)
})

test_that("create_rt_data handles custom rt_opts correctly", {
custom_rt <- rt_opts(
prior = list(mean = 2, sd = 0.5),
use_rt = FALSE,
rw = 0,
use_breakpoints = FALSE,
future = "project",
gp_on = "R0",
pop = 1000000
)

result <- create_rt_data(rt = custom_rt, horizon = 7)

expect_equal(result$r_mean, 2)
expect_equal(result$r_sd, 0.5)
expect_equal(result$estimate_r, 0)
expect_equal(result$pop, 1000000)
expect_equal(result$stationary, 1)
expect_equal(result$future_time, 7)
})

test_that("create_rt_data handles breakpoints correctly", {
result <- create_rt_data(rt_opts(use_breakpoints = TRUE),
breakpoints = c(1, 0, 1, 0, 1))

expect_equal(result$bp_n, 3)
expect_equal(result$breakpoints, c(2, 2, 3, 3, 4))
})

test_that("create_rt_data handles random walk correctly", {
result <- create_rt_data(rt_opts(rw = 2),
breakpoints = rep(1, 10))

expect_equal(result$bp_n, 5)
expect_equal(result$breakpoints, c(1, 2, 2, 3, 3, 4, 4, 5, 5, 6))
})

test_that("create_rt_data throws error for invalid inputs", {
expect_error(create_rt_data(rt_opts(rw = 2)),
"breakpoints must be supplied when using random walk")
})

test_that("create_rt_data handles future projections correctly", {
result <- create_rt_data(rt_opts(future = "project"), horizon = 7)

expect_equal(result$future_fixed, 0)
expect_equal(result$fixed_from, 0)
expect_equal(result$future_time, 7)
})

test_that("create_rt_data handles zero sum breakpoints", {
result <- create_rt_data(rt_opts(use_breakpoints = TRUE),
breakpoints = rep(0, 5))

expect_equal(result$bp_n, 0)
})

test_that("create_rt_data adjusts breakpoints for horizon", {
result <- create_rt_data(rt_opts(rw = 2, future = "latest"),
breakpoints = rep(1, 10),
horizon = 3)

expect_equal(result$breakpoints, c(1, 2, 2, 3, 3, 4, 4, 4, 4, 4))
})
61 changes: 61 additions & 0 deletions tests/testthat/test-rt_opts.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
test_that("rt_opts returns expected default values", {
result <- rt_opts()

expect_s3_class(result, "rt_opts")
expect_equal(result$prior, list(mean = 1, sd = 1))
expect_true(result$use_rt)
expect_equal(result$rw, 0)
expect_true(result$use_breakpoints)
expect_equal(result$future, "latest")
expect_equal(result$pop, 0)
expect_equal(result$gp_on, "R_t-1")
})

test_that("rt_opts handles custom inputs correctly", {
result <- rt_opts(
prior = list(mean = 2, sd = 0.5),
use_rt = FALSE,
rw = 7,
use_breakpoints = FALSE,
future = "project",
gp_on = "R0",
pop = 1000000
)

expect_equal(result$prior, list(mean = 2, sd = 0.5))
expect_false(result$use_rt)
expect_equal(result$rw, 7)
expect_true(result$use_breakpoints) # Should be TRUE when rw > 0
expect_equal(result$future, "project")
expect_equal(result$pop, 1000000)
expect_equal(result$gp_on, "R0")
})

test_that("rt_opts sets use_breakpoints to TRUE when rw > 0", {
result <- rt_opts(rw = 3, use_breakpoints = FALSE)
expect_true(result$use_breakpoints)
})

test_that("rt_opts throws error for invalid prior", {
expect_error(rt_opts(prior = list(mean = 1)),
"prior must have both a mean and sd specified")
expect_error(rt_opts(prior = list(sd = 1)),
"prior must have both a mean and sd specified")
})

test_that("rt_opts validates gp_on argument", {
expect_error(rt_opts(gp_on = "invalid"), "must be one")
})

test_that("rt_opts returns object of correct class", {
result <- rt_opts()
expect_s3_class(result, "rt_opts")
expect_true("list" %in% class(result))
})

test_that("rt_opts handles edge cases correctly", {
result <- rt_opts(rw = 0.1, pop = -1)
expect_equal(result$rw, 0.1)
expect_equal(result$pop, -1)
expect_true(result$use_breakpoints)
})
12 changes: 6 additions & 6 deletions tests/testthat/test-stan-rt.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,29 @@ test_that("update_Rt works when Rt is fixed", {
})
test_that("update_Rt works when Rt is fixed but a breakpoint is present", {
expect_equal(
round(update_Rt(5, log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 0), 2),
round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 0), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(5, log(1.2), numeric(0), c(0, 0, 1, 0, 0), 0.1, 1), 2),
round(update_Rt(5, log(1.2), numeric(0), c(1, 1, 2, 2, 2), 0.1, 1), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(5, log(1.2), numeric(0), c(0, 1, 1, 0, 0), rep(0.1, 2), 0), 2),
round(update_Rt(5, log(1.2), numeric(0), c(1, 2, 3, 3, 3), rep(0.1, 2), 0), 2),
c(1.2, 1.33, rep(1.47, 3))
)
})
test_that("update_Rt works when Rt is variable and a breakpoint is present", {
expect_equal(
round(update_Rt(5, log(1.2), rep(0, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2),
round(update_Rt(5, log(1.2), rep(0, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(5, log(1.2), rep(0, 5), c(0, 0, 1, 0, 0), 0.1, 1), 2),
round(update_Rt(5, log(1.2), rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2),
c(1.2, 1.2, rep(1.33, 3))
)
expect_equal(
round(update_Rt(5, log(1.2), rep(0.1, 4), c(0, 0, 1, 0, 0), 0.1, 0), 2),
round(update_Rt(5, log(1.2), rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2),
c(1.20, 1.33, 1.62, 1.79, 1.98)
)
})
Loading