From 8c71b23c348e103ae80adb5b70ae6bbbecfd20b8 Mon Sep 17 00:00:00 2001 From: Jad-yehya Date: Thu, 21 Nov 2024 16:28:46 +0100 Subject: [PATCH] Reverted to manually checking data in swat --- datasets/swat.py | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/datasets/swat.py b/datasets/swat.py index ab17fa6..d06477a 100644 --- a/datasets/swat.py +++ b/datasets/swat.py @@ -1,14 +1,24 @@ from benchopt import BaseDataset, safe_import_context from benchopt.config import get_data_path -from benchmark_utils import check_data with safe_import_context() as import_ctx: import pandas as pd - # Checking if the data is available + # Temporary : Checks if the data is available for the tests path = get_data_path(key="SWaT") - check_data(path, "SWaT", "train") - check_data(path, "SWaT", "test") + + if ( + not (path / "swat_train2.csv").exists() + ) or ( + not (path / "swat2.csv").exists() + ): + raise ImportError( + "Test data not found. Please download the data " + "from the Google Drive " + "https://drive.google.com/drive/folders/" + "1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w" + f" and place it in {path}" + ) class Dataset(BaseDataset): @@ -23,12 +33,29 @@ class Dataset(BaseDataset): } def get_data(self): - # To get the data, you need to ask for access to the dataset # at the following link: # https://drive.google.com/drive/folders/1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w - # path = get_data_path(key="SWaT") + path = get_data_path(key="SWaT") + + if not (path / "swat_train2.csv").exists(): + raise FileNotFoundError( + "Train data not found. Please download the data " + "from the Google Drive " + "https://drive.google.com/drive/folders/" + "1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w" + f" and place it in {path}" + ) + + if not (path / "swat2.csv").exists(): + raise FileNotFoundError( + "Test data not found. Please download the data " + "from the Google Drive " + "https://drive.google.com/drive/folders/" + "1xhcYqh6okRs98QJomFWBKNLw4d1T4Q0w" + f" and place it in {path}" + ) # Load the data X_train = pd.read_csv(path / "swat_train2.csv")