Skip to content

Commit

Permalink
Merge pull request #106 from gaurav-arya/ag-devdocs
Browse files Browse the repository at this point in the history
Fixes to devdocs
  • Loading branch information
gaurav-arya authored Oct 30, 2023
2 parents e13ee20 + cd2ee12 commit 3638b3b
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions docs/src/devdocs.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,18 @@ If a function does not meet the conditions of `StochasticAD.propagate` and is no
dispatch may be necessary. For example, consider the following function which manually implements a geometric random variable:

```@example rule
import Random # hide
import Random
Random.seed!(1234) # hide
using Distributions
function mygeometric(p)
# make rng input explicit
function mygeometric(rng, p)
x = 0
while !(rand(Bernoulli(p)))
while !(rand(rng, Bernoulli(p)))
x += 1
end
return x
end
mygeometric(p) = mygeometric(Random.default_rng(), p)
```

This is equivalent to `rand(Geometric(p))` which is already supported, but for pedagogical purposes we will
Expand All @@ -43,27 +45,28 @@ Using these expressions, we can now write the dispatch rule for stochastic tripl
```@example rule
using StochasticAD
import StochasticAD: StochasticTriple, similar_new, similar_empty, combine
function mygeometric(p_st::StochasticTriple{T}) where {T}
function mygeometric(rng, p_st::StochasticTriple{T}) where {T}
p = p_st.value
x = mygeometric(p)
rng_copy = copy(rng) # save a copy for coupling later
x = mygeometric(rng, p)
# Form the new discrete perturbations (combinations of weight w and perturbation Y - X)
Δs1 = if p_st.δ > 0
# right stochastic derivative
w = x / (p * (1 - p))
w = p_st.δ * x / (p * (1 - p))
x > 0 ? similar_new(p_st.Δs, -1, w) : similar_empty(p_st.Δs, Int)
elseif p_st.δ < 0
# left stochastic derivative
w = (x + 1) / p # positive since the negativity of p_st.δ cancels out the negativity of w_L
w = -p_st.δ * (x + 1) / p # positive since the negativity of p_st.δ cancels out the negativity of w_L
similar_new(p_st.Δs, 1, w)
else
similar_empty(p_st.Δs, Int)
end
# Propagate any existing perturbations to p through the function
function map_func(Δ)
# Sample mygeometric(p + Δ) independently. (A better strategy would be to couple to the original sample.)
mygeometric(p + Δ) - x
# Couple the samples by using the same RNG. (A simpler strategy would have been independent sampling, i.e. mygeometric(p + Δ) - x)
mygeometric(copy(rng_copy), p + Δ) - x
end
Δs2 = map(map_func, p_st.Δs)
Expand All @@ -79,7 +82,16 @@ We can test out our rule:
@show stochastic_triple(mygeometric, 0.1)
# try feeding an input that already has a pertrubation
f(x) = mygeometric(x + 0.4 * rand(Bernoulli(x)))
f(x) = mygeometric(2 * x + 0.1 * rand(Bernoulli(x)))^2
@show stochastic_triple(f, 0.1)
# verify against black-box finite differences
N = 1000000
samples_stochad = [derivative_estimate(f, 0.1) for i in 1:N]
samples_fd = [(f(0.105) - f(0.095)) / 0.01 for i in 1:N]
println("Stochastic AD: $(mean(samples_stochad)) ± $(std(samples_stochad) / sqrt(N))")
println("Finite differences: $(mean(samples_fd)) ± $(std(samples_fd) / sqrt(N))")
nothing # hide
```

0 comments on commit 3638b3b

Please sign in to comment.