From e15f244593a03d4f975c37baa42c0c3f39eaa214 Mon Sep 17 00:00:00 2001 From: Thomas Colthurst Date: Thu, 16 May 2024 13:10:01 +0000 Subject: [PATCH] Add logp_score implementation to Normal, and add comments to base Distribution. --- cxx/distributions/base.hh | 11 +++++++++++ cxx/distributions/normal.hh | 15 +++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/cxx/distributions/base.hh b/cxx/distributions/base.hh index 8efd21d..d7de92a 100644 --- a/cxx/distributions/base.hh +++ b/cxx/distributions/base.hh @@ -2,20 +2,31 @@ class Distribution { // Abstract base class for probability distributions in HIRM. +// These probability distributions are generative, and so must come with +// a prior (usually a conjugate prior) over the parameters implied by +// their observed data. public: // N is the number of incorporated observations. int N = 0; + // Accumulate x. virtual void incorporate(double x) = 0; + + // Undo the accumulation of x. Should only be called with x's that + // have been previously passed to incorporate(). virtual void unincorporate(double x) = 0; // The log probability of x according to the distribution we have // accumulated so far. virtual double logp(double x) const = 0; + // The log probability of the data we have accumulated so far according + // to the prior. virtual double logp_score() const = 0; // A sample from the distribution we have accumulated so far. + // TODO(thomaswc): Consider refactoring so that this takes a + // PRNG parameter. virtual double sample() = 0; ~Distribution(){}; diff --git a/cxx/distributions/normal.hh b/cxx/distributions/normal.hh index cad80d5..62ebbcc 100644 --- a/cxx/distributions/normal.hh +++ b/cxx/distributions/normal.hh @@ -10,6 +10,15 @@ class Normal : public Distribution { public: + // Hyperparameters: + // The conjugate prior to a normal distribution is a + // normal-inverse-gamma distribution, which we parameterize following + // https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution . + double mu = 0; + double lambda = 1.0; + double alpha = 1.0; + double beta = 1.0; + // We use Welford's algorithm for computing the mean and variance // of streaming data in a numerically stable way. See Knuth's // Art of Computer Programming vol. 2, 3rd edition, page 232. @@ -43,8 +52,10 @@ public: } double logp_score() const { - // TODO(thomaswc): This. - return 0.0; + double y = mean - mu; + return 0.5 * log(lambda) - 0.5 * log(var) - 0.5 * log(M_2PI) + - alpha * log(beta) - lgamma(alpha) - (alpha + 1) * log(var) + - (2 * beta + lambda * y * y) / (2.0 * var); } double sample() {