Skip to content

Commit

Permalink
New VAE Solver from pyod
Browse files Browse the repository at this point in the history
  • Loading branch information
Jad-yehya committed Jul 24, 2024
1 parent 32d2658 commit 0a1b3af
Showing 1 changed file with 77 additions and 0 deletions.
77 changes: 77 additions & 0 deletions solvers/vae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from benchopt import BaseSolver, safe_import_context

with safe_import_context() as import_ctx:
from pyod.models.vae import VAE
import numpy as np
import torch


class Solver(BaseSolver):
name = "VAE"

install_cmd = "conda"
requirements = ["pyod", "tqdm", "pip:torch"]

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sampling_strategy = "run_once"

parameters = {
"contamination": [0.005, 0.05, 0.1, 0.2],
"n_epochs": [50],
"window": [False],
"window_size": [256],
"horizon": [0],
"stride": [1],
"batch_size": [128],
"preprocessing": [True, False],
"latent_dim": [2, 5, 10],
"batch_norm": [True, False],
"dropout_rate": [0.1, 0.2, 0.5],
}

def set_objective(self, X_train, y_test, X_test):
self.X_train = X_train
self.X_test, self.y_test = X_test, y_test
self.clf = VAE(contamination=self.contamination,
preprocessing=self.preprocessing,
batch_size=self.batch_size,
epoch_num=self.n_epochs,
device=self.device,
latent_dim=self.latent_dim,
batch_norm=self.batch_norm,
dropout_rate=self.dropout_rate,
)

if self.window:
self.Xw_train = np.lib.stride_tricks.sliding_window_view(
X_train,
window_shape=self.window_size+self.horizon,
axis=0
).transpose(0, 2, 1)

if self.X_test is not None:
self.Xw_test = np.lib.stride_tricks.sliding_window_view(
X_test,
window_shape=self.window_size+self.horizon,
axis=0
).transpose(0, 2, 1)

if self.y_test is not None:
self.yw_test = np.lib.stride_tricks.sliding_window_view(
self.y_test, window_shape=self.window_size, axis=0
)[::self.stride]

self.yw_test = torch.tensor(
self.yw_test, dtype=torch.float32
)
else:
self.Xw_train = X_train
self.Xw_test = X_test

def run(self, _):
self.clf.fit(self.Xw_train)
self.y_pred = self.clf.predict(self.Xw_test)

def get_result(self):
return dict(y_hat=self.y_pred)

0 comments on commit 0a1b3af

Please sign in to comment.