diff --git a/sgptools/models/core/osgpr.py b/sgptools/models/core/osgpr.py index 34f7f25..5078fc0 100644 --- a/sgptools/models/core/osgpr.py +++ b/sgptools/models/core/osgpr.py @@ -57,7 +57,16 @@ def __init__(self, data, kernel, mu_old, Su_old, Kaa_old, Z_old, Z, mean_functio self.Kaa_old = tf.Variable(Kaa_old, shape=tf.TensorShape(None), trainable=False) self.Z_old = tf.Variable(Z_old, shape=tf.TensorShape(None), trainable=False) - def update(self, data): + def init_Z(self): + M = self.inducing_variable.Z.shape[0] + M_old = int(0.7 * M) + M_new = M - M_old + old_Z = self.Z_old.numpy()[np.random.permutation(M)[0:M_old], :] + new_Z = self.X.numpy()[np.random.permutation(self.X.shape[0])[0:M_new], :] + Z = np.vstack((old_Z, new_Z)) + return Z + + def update(self, data, inducing_variable=None): """Configure the OSGPR to adapt to a new batch of data. Note: The OSGPR needs to be trained using gradient-based approaches after update. @@ -67,7 +76,11 @@ def update(self, data): self.X, self.Y = self.data = gpflow.models.util.data_input_to_tensor(data) self.num_data = self.X.shape[0] + # Update the inducing points self.Z_old.assign(self.inducing_variable.Z.numpy()) + if inducing_variable is None: + inducing_variable = self.init_Z() + self.inducing_variable.Z.assign(inducing_variable) # Get posterior mean and covariance for the old inducing points mu_old, Su_old = self.predict_f(self.Z_old, full_cov=True)