Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Normal::logp_score() #3

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions cxx/distributions/base.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#pragma once

class Distribution {
// Abstract base class for probability distributions in HIRM.
// These probability distributions are generative, and so must come with
// a prior (usually a conjugate prior) over the parameters implied by
// their observed data.
public:
// N is the number of incorporated observations.
int N = 0;

// Accumulate x.
virtual void incorporate(double x) = 0;

// Undo the accumulation of x. Should only be called with x's that
// have been previously passed to incorporate().
virtual void unincorporate(double x) = 0;

// The log probability of x according to the distribution we have
// accumulated so far.
virtual double logp(double x) const = 0;

// The log probability of the data we have accumulated so far according
// to the prior.
virtual double logp_score() const = 0;

// A sample from the distribution we have accumulated so far.
// TODO(thomaswc): Consider refactoring so that this takes a
// PRNG parameter.
virtual double sample() = 0;

~Distribution(){};
};

51 changes: 51 additions & 0 deletions cxx/distributions/beta_bernoulli.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright 2024
// See LICENSE.txt

#pragma once
#include "base.hh"

class BetaBernoulli : public Distribution {
public:
double alpha = 1; // hyperparameter
double beta = 1; // hyperparameter
int s = 0; // sum of observed values
PRNG *prng;

BetaBernoulli(PRNG *prng) {
this->prng = prng;
}
void incorporate(double x){
assert(x == 0 || x == 1);
N += 1;
s += x;
}
void unincorporate(double x) {
assert(x == 0 || x ==1);
N -= 1;
s -= x;
assert(0 <= s);
assert(0 <= N);
}
double logp(double x) const {
double log_denom = log(N + alpha + beta);
if (x == 1) { return log(s + alpha) - log_denom; }
if (x == 0) { return log(N - s + beta) - log_denom; }
assert(false);
}
double logp_score() const {
double v1 = lbeta(s + alpha, N - s + beta);
double v2 = lbeta(alpha, beta);
return v1 - v2;
}
double sample() {
double p = exp(logp(1));
vector<int> items {0, 1};
vector<double> weights {1-p, p};
auto idx = choice(weights, prng);
return items[idx];
}

// Disable copying.
BetaBernoulli & operator=(const BetaBernoulli&) = delete;
BetaBernoulli(const BetaBernoulli&) = delete;
};
70 changes: 70 additions & 0 deletions cxx/distributions/normal.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// Copyright 2024
// See LICENSE.txt

#pragma once
#include "base.hh"

#ifndef M_2PI
#define M_2PI 6.28318530717958647692528676655
#endif

class Normal : public Distribution {
public:
// Hyperparameters:
// The conjugate prior to a normal distribution is a
// normal-inverse-gamma distribution, which we parameterize following
// https://en.wikipedia.org/wiki/Normal-inverse-gamma_distribution .
double mu = 0;
double lambda = 1.0;
double alpha = 1.0;
double beta = 1.0;

// We use Welford's algorithm for computing the mean and variance
// of streaming data in a numerically stable way. See Knuth's
// Art of Computer Programming vol. 2, 3rd edition, page 232.
int mean = 0; // Mean of observed values
int var = 0; // Variance of observed values

PRNG *prng;

Normal(PRNG *prng) {
this->prng = prng;
}

void incorporate(double x){
N += 1;
double old_mean = mean;
mean += (x - mean) / N;
var += (x - mean) * (x - old_mean);
}

void unincorporate(double x) {
int old_N = N;
N -= 1;
double old_mean = mean;
mean = (mean * old_N - x) / N;
var -= (x - mean) * (x - old_mean);
}

double logp(double x) const {
double y = (x - mean);
return -0.5 * (y * y / var + log(var) + log(M_2PI));
}

double logp_score() const {
double y = mean - mu;
return 0.5 * log(lambda) - 0.5 * log(var) - 0.5 * log(M_2PI)
- alpha * log(beta) - lgamma(alpha) - (alpha + 1) * log(var)
- (2 * beta + lambda * y * y) / (2.0 * var);
}

double sample() {
std::normal_distribution<double> d(mean, var);
return d(*prng);
}

// Disable copying.
Normal & operator=(const Normal&) = delete;
Normal(const Normal&) = delete;
};

59 changes: 2 additions & 57 deletions cxx/hirm.hh
Original file line number Diff line number Diff line change
Expand Up @@ -5,68 +5,13 @@
#include "globals.hh"
#include "util_hash.hh"
#include "util_math.hh"
#include "distributions/base.hh"
#include "distributions/beta_bernoulli.hh"

typedef int T_item;
typedef vector<T_item> T_items;
typedef VectorIntHash H_items;

class Distribution {
public:
int N = 0;
virtual void incorporate(double x) = 0;
virtual void unincorporate(double x) = 0;
virtual double logp(double x) const = 0;
virtual double logp_score() const = 0;
virtual double sample() = 0;
~Distribution(){};
};

class BetaBernoulli : public Distribution {
public:
double alpha = 1; // hyperparameter
double beta = 1; // hyperparameter
int s = 0; // sum of observed values
PRNG *prng;

BetaBernoulli(PRNG *prng) {
this->prng = prng;
}
void incorporate(double x){
assert(x == 0 || x == 1);
N += 1;
s += x;
}
void unincorporate(double x) {
assert(x == 0 || x ==1);
N -= 1;
s -= x;
assert(0 <= s);
assert(0 <= N);
}
double logp(double x) const {
double log_denom = log(N + alpha + beta);
if (x == 1) { return log(s + alpha) - log_denom; }
if (x == 0) { return log(N - s + beta) - log_denom; }
assert(false);
}
double logp_score() const {
double v1 = lbeta(s + alpha, N - s + beta);
double v2 = lbeta(alpha, beta);
return v1 - v2;
}
double sample() {
double p = exp(logp(1));
vector<int> items {0, 1};
vector<double> weights {1-p, p};
auto idx = choice(weights, prng);
return items[idx];
}

// Disable copying.
BetaBernoulli & operator=(const BetaBernoulli&) = delete;
BetaBernoulli(const BetaBernoulli&) = delete;
};

class CRP {
public:
double alpha = 1; // concentration parameter
Expand Down
Loading