Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raising importErr when data is not available for testing. #27

Merged
merged 9 commits into from
Nov 26, 2024
50 changes: 50 additions & 0 deletions benchmark_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# the usual import syntax

from benchopt import safe_import_context
from pathlib import Path

with safe_import_context() as import_ctx:
import numpy as np
Expand Down Expand Up @@ -43,3 +44,52 @@ def mean_overlaping_pred(predictions, stride):
averaged_predictions = accumulated / counts

return averaged_predictions


def check_data(data_path, dataset, data_type):
"""
Checks if the data is present in the specified path.

Args:
data_path: str
The path to the data directory.
dataset: str
The name of the dataset, either 'WADI' or 'SWaT'.
data_type: str
The type of data, either 'train' or 'test'.

Raises:
ImportError: If the required data files are not found.
"""
if dataset == "WADI":
if data_type == "train":
required_files = ["WADI_14days_new.csv"]
elif data_type == "test":
required_files = ["WADI_attackdataLABLE.csv"]
else:
raise ValueError("data_type must be either 'train' or 'test'")
elif dataset == "SWaT":
if data_type == "train":
required_files = ["swat_train2.csv"]
elif data_type == "test":
required_files = ["swat2.csv"]
else:
raise ValueError("data_type must be either 'train' or 'test'")
else:
raise ValueError("dataset must be either 'WADI' or 'SWaT'")

for file in required_files:
if not Path(data_path, file).exists():
official_repo = {
"WADI": "https://itrust.sutd.edu.sg/itrust-labs_datasets/\
dataset_info/",
"SWaT": "https://drive.google.com/drive/folders/\
1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w"
}
raise ImportError(
f"{data_type.capitalize()} data not found for {dataset}. "
"Please download the data "
"from the official repository "
f"{official_repo[dataset]}"
f"and place it in {data_path}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking more of putting this function in each dataset, but it can be ok like this if you prefer.

6 changes: 3 additions & 3 deletions datasets/msl.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ def get_data(self):
with open(path / "MSL_test_label.npy", "wb") as f:
f.write(response.content)

X_train = np.load(path / "MSL_train.npy")
X_test = np.load(path / "MSL_test.npy")
y_test = np.load(path / "MSL_test_label.npy")
X_train = np.load(path / "MSL_train.npy", allow_pickle=True)
X_test = np.load(path / "MSL_test.npy", allow_pickle=True)
y_test = np.load(path / "MSL_test_label.npy", allow_pickle=True)

# Limiting the size of the dataset for testing purposes
if self.debug:
Expand Down
30 changes: 8 additions & 22 deletions datasets/swat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from benchopt import BaseDataset, safe_import_context
from benchopt.config import get_data_path
from benchmark_utils import check_data

with safe_import_context() as import_ctx:
import pandas as pd

# Checking if the data is available
PATH = get_data_path(key="SWaT")
check_data(PATH, "SWaT", "train")
check_data(PATH, "SWaT", "test")


class Dataset(BaseDataset):
name = "SWaT"
Expand All @@ -21,29 +27,9 @@ def get_data(self):
# at the following link:
# https://drive.google.com/drive/folders/1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w

path = get_data_path(key="SWaT")

if not (path / "swat_train2.csv").exists():
raise FileNotFoundError(
"Train data not found. Please download the data "
"from the Google Drive "
"https://drive.google.com/drive/folders/"
"1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w"
f" and place it in {path}"
)

if not (path / "swat2.csv").exists():
raise FileNotFoundError(
"Test data not found. Please download the data "
"from the Google Drive "
"https://drive.google.com/drive/folders/"
"1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w"
f" and place it in {path}"
)

# Load the data
X_train = pd.read_csv(path / "swat_train2.csv")
X_test = pd.read_csv(path / "swat2.csv")
X_train = pd.read_csv(PATH / "swat_train2.csv")
X_test = pd.read_csv(PATH / "swat2.csv")

# Extract the target
y_test = X_test["Normal/Attack"].values
Expand Down
30 changes: 9 additions & 21 deletions datasets/wadi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from benchopt import BaseDataset, safe_import_context
from benchopt.config import get_data_path
from benchmark_utils import check_data

with safe_import_context() as import_ctx:
import pandas as pd

# Checking if the data is available
PATH = get_data_path(key="WADI")
check_data(PATH, "WADI", "train")
check_data(PATH, "WADI", "test")


class Dataset(BaseDataset):
name = "WADI"
Expand All @@ -21,27 +27,9 @@ def get_data(self):
# at the following link:
# https://itrust.sutd.edu.sg/itrust-labs_datasets/dataset_info/

path = get_data_path(key="WADI")

if not (path / "WADI_14days_new.csv").exists():
raise FileNotFoundError(
"Train data not found. Please download the data "
"from the official repository"
"https://itrust.sutd.edu.sg/itrust-labs_datasets/dataset_info/"
f"and place it in {path}"
)

if not (path / "WADI_attackdataLABLE.csv").exists():
raise FileNotFoundError(
"Test data not found. Please download the data "
"from the official repository"
"https://itrust.sutd.edu.sg/itrust-labs_datasets/dataset_info/"
f"and place it in {path}"
)

# Load the data
X_train = pd.read_csv(path / "WADI_14days_new.csv")
X_test = pd.read_csv(path / "WADI_attackdataLABLE.csv", header=1)
X_train = pd.read_csv(PATH / "WADI_14days_new.csv")
X_test = pd.read_csv(PATH / "WADI_attackdataLABLE.csv", header=1)

# Data processing
# Dropping the following colummns because more than 50% of the values
Expand All @@ -64,7 +52,7 @@ def get_data(self):
y_test = X_test["Attack LABLE (1:No Attack, -1:Attack)"].values
X_test.drop(
columns=todrop + [
"Attack LABLE (1:No Attack, -1:Attack)"],
"Attack LABLE (1:No Attack, -1:Attack)"],
inplace=True
)
# Using ffill to fill the missing values because
Expand Down
Loading