Skip to content

Commit

Permalink
Update SSGP's update method to update Z init
Browse files Browse the repository at this point in the history
  • Loading branch information
itskalvik committed Dec 30, 2024
1 parent 7370651 commit 3bd20b2
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion sgptools/models/core/osgpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,16 @@ def __init__(self, data, kernel, mu_old, Su_old, Kaa_old, Z_old, Z, mean_functio
self.Kaa_old = tf.Variable(Kaa_old, shape=tf.TensorShape(None), trainable=False)
self.Z_old = tf.Variable(Z_old, shape=tf.TensorShape(None), trainable=False)

def update(self, data):
def init_Z(self):
M = self.inducing_variable.Z.shape[0]
M_old = int(0.7 * M)
M_new = M - M_old
old_Z = self.Z_old.numpy()[np.random.permutation(M)[0:M_old], :]
new_Z = self.X.numpy()[np.random.permutation(self.X.shape[0])[0:M_new], :]
Z = np.vstack((old_Z, new_Z))
return Z

def update(self, data, inducing_variable=None):
"""Configure the OSGPR to adapt to a new batch of data.
Note: The OSGPR needs to be trained using gradient-based approaches after update.
Expand All @@ -67,7 +76,11 @@ def update(self, data):
self.X, self.Y = self.data = gpflow.models.util.data_input_to_tensor(data)
self.num_data = self.X.shape[0]

# Update the inducing points
self.Z_old.assign(self.inducing_variable.Z.numpy())
if inducing_variable is None:
inducing_variable = self.init_Z()
self.inducing_variable.Z.assign(inducing_variable)

# Get posterior mean and covariance for the old inducing points
mu_old, Su_old = self.predict_f(self.Z_old, full_cov=True)
Expand Down

0 comments on commit 3bd20b2

Please sign in to comment.