Skip to content

Commit

Permalink
Convert SSGP to handle ndim data
Browse files Browse the repository at this point in the history
  • Loading branch information
itskalvik committed Dec 30, 2024
1 parent f5ebd8e commit 6e243e3
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions sgptools/models/core/osgpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def update(self, data):
Note: The OSGPR needs to be trained using gradient-based approaches after update.
Args:
data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n, 1)
data (tuple): (X, y) ndarrays with new batch of inputs (n, d) and labels (n, ndim)
"""
self.X, self.Y = self.data = gpflow.models.util.data_input_to_tensor(data)
self.num_data = self.X.shape[0]
Expand Down Expand Up @@ -228,7 +228,8 @@ def init_osgpr(X_train,
lengthscales=1.0,
variance=1.0,
noise_variance=0.001,
kernel=None):
kernel=None,
ndim=1):
"""Initialize a VFE OSGPR model with an RBF kernel with
unit variance and lengthcales, and 0.001 noise variance.
Used in the Online Continuous SGP approach.
Expand All @@ -243,6 +244,7 @@ def init_osgpr(X_train,
variance (float): Kernel variance
noise_variance (float): Data noise variance
kernel (gpflow.kernels.Kernel): gpflow kernel function
ndim (int): Number of output dimensions
Returns:
online_param (OSGPR_VFE): Initialized online sparse Gaussian process model
Expand All @@ -252,7 +254,7 @@ def init_osgpr(X_train,
kernel = gpflow.kernels.SquaredExponential(lengthscales=lengthscales,
variance=variance)

y_train = np.zeros((len(X_train), 1), dtype=X_train.dtype)
y_train = np.zeros((len(X_train), ndim), dtype=X_train.dtype)
Z_init = get_inducing_pts(X_train, num_inducing)
init_param = gpflow.models.SGPR((X_train, y_train),
kernel,
Expand All @@ -261,8 +263,8 @@ def init_osgpr(X_train,

# Initialize the OSGPR model using the parameters from the SGPR model
# The X_train and y_train here will be overwritten in the online phase
X_train = np.array([[0., 0.], [0., 0.]], dtype=X_train.dtype)
y_train = np.array([0., 0.], dtype=y_train.dtype).reshape(-1, 1)
X_train = np.zeros([2, X_train.shape[-1]], dtype=X_train.dtype)
y_train = np.zeros([2, ndim], dtype=X_train.dtype)
Zopt = init_param.inducing_variable.Z.numpy()
mu, Su = init_param.predict_f(Zopt, full_cov=True)
Kaa = init_param.kernel(Zopt)
Expand Down

0 comments on commit 6e243e3

Please sign in to comment.