Skip to content

Commit

Permalink
Add logp_score implementation to Normal, and add comments to base Dis…
Browse files Browse the repository at this point in the history
…tribution.
  • Loading branch information
ThomasColthurst committed May 16, 2024
1 parent 9c8802a commit e15f244
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
11 changes: 11 additions & 0 deletions cxx/distributions/base.hh
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,31 @@

class Distribution {
// Abstract base class for probability distributions in HIRM.
// These probability distributions are generative, and so must come with
// a prior (usually a conjugate prior) over the parameters implied by
// their observed data.
public:
// N is the number of incorporated observations.
int N = 0;

// Accumulate x.
virtual void incorporate(double x) = 0;

// Undo the accumulation of x. Should only be called with x's that
// have been previously passed to incorporate().
virtual void unincorporate(double x) = 0;

// The log probability of x according to the distribution we have
// accumulated so far.
virtual double logp(double x) const = 0;

// The log probability of the data we have accumulated so far according
// to the prior.
virtual double logp_score() const = 0;

// A sample from the distribution we have accumulated so far.
// TODO(thomaswc): Consider refactoring so that this takes a
// PRNG parameter.
virtual double sample() = 0;

~Distribution(){};
Expand Down
15 changes: 13 additions & 2 deletions cxx/distributions/normal.hh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@

class Normal : public Distribution {
public:
// Hyperparameters:
// The conjugate prior to a normal distribution is a
// normal-inverse-gamma distribution, which we parameterize following
// https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution .
double mu = 0;
double lambda = 1.0;
double alpha = 1.0;
double beta = 1.0;

// We use Welford's algorithm for computing the mean and variance
// of streaming data in a numerically stable way. See Knuth's
// Art of Computer Programming vol. 2, 3rd edition, page 232.
Expand Down Expand Up @@ -43,8 +52,10 @@ public:
}

double logp_score() const {
// TODO(thomaswc): This.
return 0.0;
double y = mean - mu;
return 0.5 * log(lambda) - 0.5 * log(var) - 0.5 * log(M_2PI)
- alpha * log(beta) - lgamma(alpha) - (alpha + 1) * log(var)
- (2 * beta + lambda * y * y) / (2.0 * var);
}

double sample() {
Expand Down

0 comments on commit e15f244

Please sign in to comment.