Skip to content

Commit

Permalink
nested class test
Browse files Browse the repository at this point in the history
  • Loading branch information
Jad-yehya committed Aug 9, 2024
1 parent 8bf39b5 commit b32088e
Showing 1 changed file with 25 additions and 26 deletions.
51 changes: 25 additions & 26 deletions solvers/AR.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,6 @@
from tqdm import tqdm


class AR_model(torch.nn.Module):
"""
Single linear layer for autoregressive model
Taking in input a window of size window_size and
outputting a window of size horizon
input : (batch_size, window_size, n_features)
output : (batch_size, horizon, n_features)
"""

def __init__(self, window_size: int, n_features: int, horizon: int):
super(AR_model, self).__init__()
self.window_size = window_size
self.n_features = n_features
self.horizon = horizon
self.linear = torch.nn.Linear(
window_size * n_features, horizon * n_features
)

def forward(self, x):
x = x.reshape(x.size(0), -1)
x = self.linear(x)
x = x.reshape(x.size(0), -1, self.n_features)
return x


class Solver(BaseSolver):
name = "AR"

Expand All @@ -59,7 +34,7 @@ def set_objective(self, X_train, y_test, X_test):
self.X_test, self.y_test = X_test, y_test
self.n_features = X_train.shape[1]

self.model = AR_model(
self.model = Solver.AR_model(
self.window_size,
self.n_features,
self.horizon
Expand Down Expand Up @@ -172,3 +147,27 @@ def skip(self, X_train, X_test, y_test):

def get_result(self):
return dict(y_hat=self.predictions)

class AR_model(torch.nn.Module):
"""
Single linear layer for autoregressive model
Taking in input a window of size window_size and
outputting a window of size horizon
input : (batch_size, window_size, n_features)
output : (batch_size, horizon, n_features)
"""

def __init__(self, window_size: int, n_features: int, horizon: int):
super(Solver.AR_model, self).__init__()
self.window_size = window_size
self.n_features = n_features
self.horizon = horizon
self.linear = torch.nn.Linear(
window_size * n_features, horizon * n_features
)

def forward(self, x):
x = x.reshape(x.size(0), -1)
x = self.linear(x)
x = x.reshape(x.size(0), -1, self.n_features)
return x

0 comments on commit b32088e

Please sign in to comment.