Skip to content

Commit

Permalink
Convert SSGP updates to use assign method
Browse files Browse the repository at this point in the history
  • Loading branch information
itskalvik committed Dec 30, 2024
1 parent 6e243e3 commit 7370651
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions sgptools/models/core/osgpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,18 +67,16 @@ def update(self, data):
self.X, self.Y = self.data = gpflow.models.util.data_input_to_tensor(data)
self.num_data = self.X.shape[0]

self.Z_old = tf.Variable(self.inducing_variable.Z.numpy(),
shape=tf.TensorShape(None),
trainable=False)
self.Z_old.assign(self.inducing_variable.Z.numpy())

# Get posterior mean and covariance for the old inducing points
mu_old, Su_old = self.predict_f(self.Z_old, full_cov=True)
self.mu_old = tf.Variable(mu_old, shape=tf.TensorShape(None), trainable=False)
self.Su_old = tf.Variable(Su_old, shape=tf.TensorShape(None), trainable=False)
self.mu_old.assign(mu_old.numpy())
self.Su_old.assign(Su_old.numpy())

# Get the prior covariance matrix for the old inducing points
Kaa_old = self.kernel(self.Z_old)
self.Kaa_old = tf.Variable(Kaa_old, shape=tf.TensorShape(None), trainable=False)
self.Kaa_old.assign(Kaa_old.numpy())

def _common_terms(self):
Mb = self.inducing_variable.num_inducing
Expand Down

0 comments on commit 7370651

Please sign in to comment.