diff --git a/demo-project/src/demo_project/settings.py b/demo-project/src/demo_project/settings.py index 304f7d8c1..63533b8eb 100644 --- a/demo-project/src/demo_project/settings.py +++ b/demo-project/src/demo_project/settings.py @@ -2,19 +2,6 @@ # List the installed plugins for which to disable auto-registry # DISABLE_HOOKS_FOR_PLUGINS = ("kedro-viz",) -from pathlib import Path - -# Define where to store data from a KedroSession. Defaults to BaseSessionStore. -# from kedro.framework.session.store import ShelveStore -from kedro_viz.integrations.kedro.sqlite_store import SQLiteStore - -SESSION_STORE_CLASS = SQLiteStore -SESSION_STORE_ARGS = {"path": str(Path(__file__).parents[2] / "data")} - -# Setup for collaborative experiment tracking. -# SESSION_STORE_ARGS = {"path": str(Path(__file__).parents[2] / "data"), -# "remote_path": "s3://{path-to-session_store}" } - # Define custom context class. Defaults to `KedroContext` # CONTEXT_CLASS = KedroContext diff --git a/docs/source/experiment_tracking.md b/docs/source/experiment_tracking.md index c06ca5a20..3a6f42add 100644 --- a/docs/source/experiment_tracking.md +++ b/docs/source/experiment_tracking.md @@ -357,4 +357,4 @@ Additionally, you can monitor the changes to metrics over time from the pipeline Clicking on any `MetricsDataset` node opens a side panel displaying how the metric value has changed over time: -![](./images/pipeline_show_metrics.gif) +![](./images/pipeline_show_metrics.gif) \ No newline at end of file diff --git a/docs/source/kedro-viz_visualisation.md b/docs/source/kedro-viz_visualisation.md index 66c9cf651..18379f45d 100644 --- a/docs/source/kedro-viz_visualisation.md +++ b/docs/source/kedro-viz_visualisation.md @@ -326,4 +326,4 @@ Press `Cmd` + `Shift` + `P` (on macOS) or `Ctrl` + `Shift` + `P` (on Windows/Lin Type `kedro: Run Kedro Viz` and select the command. This will launch Kedro-Viz and display your pipeline visually within the extension. -![Kedro Viz in VSCode](./images/viz-in-vscode.gif) +![Kedro Viz in VSCode](./images/viz-in-vscode.gif) \ No newline at end of file diff --git a/package/README.md b/package/README.md index dc543fae1..781fff39c 100644 --- a/package/README.md +++ b/package/README.md @@ -307,4 +307,4 @@ Kedro-Viz is licensed under the [Apache 2.0](https://github.com/kedro-org/kedro- ## Citation -If you're an academic, Kedro-Viz can also help you, for example, as a tool to visualise how your publication's pipeline is structured. Find our citation reference on [Zenodo](https://doi.org/10.5281/zenodo.4277218). +If you're an academic, Kedro-Viz can also help you, for example, as a tool to visualise how your publication's pipeline is structured. Find our citation reference on [Zenodo](https://doi.org/10.5281/zenodo.4277218). \ No newline at end of file diff --git a/package/kedro_viz/database.py b/package/kedro_viz/database.py deleted file mode 100644 index 5a62e32bc..000000000 --- a/package/kedro_viz/database.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Database management layer based on SQLAlchemy""" - -import os - -from sqlalchemy import create_engine, text -from sqlalchemy.orm import sessionmaker - -from kedro_viz.models.experiment_tracking import Base - - -def configure_wal_for_azure(engine): - """Applies WAL mode to SQLite if running in an Azure ML environment.""" - is_azure_ml = any( - var in os.environ - for var in [ - "AZUREML_ARM_SUBSCRIPTION", - "AZUREML_ARM_RESOURCEGROUP", - "AZUREML_RUN_ID", - ] - ) - if is_azure_ml: - with engine.connect() as conn: - conn.execute(text("PRAGMA journal_mode=WAL;")) - - -def make_db_session_factory(session_store_location: str) -> sessionmaker: - """SQLAlchemy connection to a SQLite DB""" - database_url = f"sqlite:///{session_store_location}" - engine = create_engine(database_url, connect_args={"check_same_thread": False}) - # TODO: making db session factory shouldn't depend on models. - # So want to move the table creation elsewhere ideally. - # But this means returning engine as well as session class. - - # Check if we are running in an Azure ML environment if so enable WAL mode. - configure_wal_for_azure(engine) - - # Create the database tables if they do not exist. - Base.metadata.create_all(bind=engine) - - # Return a session factory bound to the engine. - return sessionmaker(bind=engine) diff --git a/package/kedro_viz/integrations/kedro/data_loader.py b/package/kedro_viz/integrations/kedro/data_loader.py index 623227036..ec2d6671b 100644 --- a/package/kedro_viz/integrations/kedro/data_loader.py +++ b/package/kedro_viz/integrations/kedro/data_loader.py @@ -13,7 +13,6 @@ from kedro import __version__ from kedro.framework.project import configure_project, pipelines from kedro.framework.session import KedroSession -from kedro.framework.session.store import BaseSessionStore from kedro.framework.startup import bootstrap_project from kedro.io import DataCatalog from kedro.pipeline import Pipeline @@ -73,8 +72,7 @@ def _load_data_helper( configuration. is_lite: A flag to run Kedro-Viz in lite mode. Returns: - A tuple containing the data catalog, pipeline dictionary, session store - and dataset stats dictionary. + A tuple containing the data catalog, pipeline dictionary and dataset stats dictionary. """ with KedroSession.create( @@ -88,7 +86,6 @@ def _load_data_helper( session._hook_manager = _VizNullPluginManager() # type: ignore context = session.load_context() - session_store = session._store # patch the AbstractDataset class for a custom # implementation to handle kedro.io.core.DatasetError @@ -110,7 +107,7 @@ def _load_data_helper( # Useful for users who have `get_current_session` in their `register_pipelines()`. pipelines_dict = dict(pipelines) stats_dict = _get_dataset_stats(project_path) - return catalog, pipelines_dict, session_store, stats_dict + return catalog, pipelines_dict, stats_dict def load_data( @@ -120,7 +117,7 @@ def load_data( package_name: Optional[str] = None, extra_params: Optional[Dict[str, Any]] = None, is_lite: bool = False, -) -> Tuple[DataCatalog, Dict[str, Pipeline], BaseSessionStore, Dict]: +) -> Tuple[DataCatalog, Dict[str, Pipeline], Dict]: """Load data from a Kedro project. Args: project_path: the path where the Kedro project is located. @@ -134,8 +131,7 @@ def load_data( configuration. is_lite: A flag to run Kedro-Viz in lite mode. Returns: - A tuple containing the data catalog, pipeline dictionary, session store - and dataset stats dictionary. + A tuple containing the data catalog, pipeline dictionary,and dataset stats dictionary. """ if package_name: configure_project(package_name) diff --git a/package/kedro_viz/integrations/kedro/sqlite_store.py b/package/kedro_viz/integrations/kedro/sqlite_store.py deleted file mode 100644 index 8ba1a5ac9..000000000 --- a/package/kedro_viz/integrations/kedro/sqlite_store.py +++ /dev/null @@ -1,201 +0,0 @@ -"""kedro_viz.intergrations.kedro.sqlite_store is a child of BaseSessionStore -which stores sessions data in the SQLite database""" - -import getpass -import json -import logging -import os -from pathlib import Path -from typing import Any, Optional - -import fsspec -from kedro.framework.project import settings -from kedro.framework.session.store import BaseSessionStore -from kedro.io.core import get_protocol_and_path -from sqlalchemy import create_engine, select -from sqlalchemy.orm import Session - -from kedro_viz.constants import VIZ_SESSION_STORE_ARGS -from kedro_viz.database import make_db_session_factory -from kedro_viz.launchers.utils import _find_kedro_project -from kedro_viz.models.experiment_tracking import RunModel - -logger = logging.getLogger(__name__) - - -def _get_dbname(): - return os.getenv("KEDRO_SQLITE_STORE_USERNAME", getpass.getuser()) + ".db" - - -def _is_json_serializable(obj: Any): - try: - json.dumps(obj) - return True - except (TypeError, OverflowError): - return False - - -def _get_session_path(session_path: str) -> str: - """Returns the session path by creating its parent directory - if unavailable. - """ - session_file_path = Path(session_path) - session_file_path.parent.mkdir(parents=True, exist_ok=True) - return str(session_file_path) - - -class SQLiteStore(BaseSessionStore): - """Stores the session data on the sqlite db.""" - - def __init__(self, *args, remote_path: Optional[str] = None, **kwargs): - """Initializes the SQLiteStore object.""" - super().__init__(*args, **kwargs) - self._db_session_class = make_db_session_factory(self.location) - self._remote_path = remote_path - - if self.remote_location: - protocol, _ = get_protocol_and_path(self.remote_location) - self._remote_fs = fsspec.filesystem(protocol) - - @property - def location(self) -> str: - """Returns location of the sqlite_store database""" - if "path" not in settings.SESSION_STORE_ARGS: - kedro_project_path = _find_kedro_project(Path.cwd()) or self._path - return _get_session_path( - f"{kedro_project_path}/{VIZ_SESSION_STORE_ARGS['path']}/session_store.db" - ) - - return _get_session_path(f"{self._path}/session_store.db") - - @property - def remote_location(self) -> Optional[str]: - """Returns the remote location of the sqlite_store database on the cloud""" - return self._remote_path - - def _to_json(self) -> str: - """Returns session_store information in json format after converting PosixPath to string""" - session_dict = {} - for key, value in self.data.items(): - if key == "git": - try: - import git - - branch = git.Repo(search_parent_directories=True).active_branch - value["branch"] = branch.name - except ImportError as exc: # pragma: no cover - logger.warning("%s:%s", exc.__class__.__name__, exc.msg) - except Exception as exc: # pragma: no cover - logger.warning("Something went wrong when fetching git metadata.") - logger.warning(exc) - - if _is_json_serializable(value): - session_dict[key] = value - else: - session_dict[key] = str(value) - return json.dumps(session_dict) - - def save(self): - """Save the session store info on db and uploads it - to the cloud if a remote cloud path is provided .""" - with self._db_session_class.begin() as session: - session.add(RunModel(id=self._session_id, blob=self._to_json())) - if self.remote_location: - self._upload() - - def _upload(self): - """Uploads the session store database file to the specified - remote path on the cloud storage.""" - db_name = _get_dbname() - logger.debug( - """Uploading local session store to %s with name - %s...""", - self.remote_location, - db_name, - ) - try: - self._remote_fs.put(self.location, f"{self.remote_location}/{db_name}") - except Exception as exc: - logger.exception("Upload failed: %s ", exc) - - def _download(self): - """Downloads all the session store database files - from the specified remote path on the cloud storage - to your local project. - """ - try: - # In theory we should be able to do this as a single operation: - # self._remote_fs.get(f"{self.remote_location}/*.db", str(Path(self.location).parent)) - # but this does not seem to work correctly - maybe a bug in fsspec. So instead - # we do it in two steps. Also need to add os.sep so it works with older s3fs version. - # This is a known bug in s3fs - https://github.com/fsspec/s3fs/issues/717 - remote_dbs = self._remote_fs.glob(f"{self.remote_location}/*.db") - logger.debug( - "Downloading %s remote session stores to local...", len(remote_dbs) - ) - for remote_db in remote_dbs: - self._remote_fs.get(remote_db, str(Path(self.location).parent) + os.sep) - except Exception as exc: - logger.exception("Download failed: %s ", exc) - - def _merge(self): - """Merges all the session store databases stored at the - specified locations into the user's local session_store.db - - Notes: - - This method uses multiple SQLAlchemy engines to connect to the - user's session_store.db and to all the other downloaded dbs. - - It is assumed that all the databases share the same schema. - - In the Kedro-viz version 6.2.0 - we only merge the runs table which - contains all the experiments. - """ - - all_new_runs = [] - - with self._db_session_class() as session: - existing_run_ids = session.execute(select(RunModel.id)).scalars().all() - - # Look at all databases in the local session store directory - # that aren't the actual session_store.db itself. - downloaded_db_locations = set(Path(self.location).parent.glob("*.db")) - { - Path(self.location) - } - - logger.debug( - "Checking %s downloaded session stores for new runs...", - len(downloaded_db_locations), - ) - for downloaded_db_location in downloaded_db_locations: - engine = create_engine(f"sqlite:///{downloaded_db_location}") - with Session(engine) as session: - query = select(RunModel).where(RunModel.id.not_in(existing_run_ids)) - new_runs = session.execute(query).scalars().all() - - existing_run_ids.extend([run.id for run in new_runs]) - all_new_runs.extend(new_runs) - logger.debug( - "Found %s new runs in downloaded session store %s", - len(new_runs), - downloaded_db_location.name, - ) - - if all_new_runs: - logger.debug("Adding %s new runs to session store...", len(all_new_runs)) - with self._db_session_class.begin() as session: - for run in all_new_runs: - session.merge(run) - - def sync(self): - """ - Synchronizes the user's local session_store.db with - remote session_store.db stored on a cloud storage service. - """ - - if self.remote_location: - self._download() - # We don't want a failed merge to stop the whole kedro-viz process. - try: - self._merge() - except Exception as exc: - logger.exception("Merge failed on sync: %s", exc) - self._upload() diff --git a/package/kedro_viz/server.py b/package/kedro_viz/server.py index 8643bec73..5c6c48b7c 100644 --- a/package/kedro_viz/server.py +++ b/package/kedro_viz/server.py @@ -4,16 +4,13 @@ from pathlib import Path from typing import Any, Dict, Optional -from kedro.framework.session.store import BaseSessionStore from kedro.io import DataCatalog from kedro.pipeline import Pipeline from kedro_viz.autoreload_file_filter import AutoreloadFileFilter from kedro_viz.constants import DEFAULT_HOST, DEFAULT_PORT from kedro_viz.data_access import DataAccessManager, data_access_manager -from kedro_viz.database import make_db_session_factory from kedro_viz.integrations.kedro import data_loader as kedro_data_loader -from kedro_viz.integrations.kedro.sqlite_store import SQLiteStore from kedro_viz.launchers.utils import _check_viz_up, _wait_for, display_cli_message DEV_PORT = 4142 @@ -23,18 +20,12 @@ def populate_data( data_access_manager: DataAccessManager, catalog: DataCatalog, pipelines: Dict[str, Pipeline], - session_store: BaseSessionStore, stats_dict: Dict, ): """Populate data repositories. Should be called once on application start if creating an api app from project. """ - if isinstance(session_store, SQLiteStore): - session_store.sync() - session_class = make_db_session_factory(session_store.location) - data_access_manager.set_db_session(session_class) - data_access_manager.add_catalog(catalog, pipelines) # add dataset stats before adding pipelines as the data nodes @@ -56,7 +47,7 @@ def load_and_populate_data( """Loads underlying Kedro project data and populates Kedro Viz Repositories""" # Loads data from underlying Kedro Project - catalog, pipelines, session_store, stats_dict = kedro_data_loader.load_data( + catalog, pipelines, stats_dict = kedro_data_loader.load_data( path, env, include_hooks, package_name, extra_params, is_lite ) @@ -67,7 +58,7 @@ def load_and_populate_data( ) # Creates data repositories which are used by Kedro Viz Backend APIs - populate_data(data_access_manager, catalog, pipelines, session_store, stats_dict) + populate_data(data_access_manager, catalog, pipelines, stats_dict) def run_server( diff --git a/package/tests/conftest.py b/package/tests/conftest.py index ea25e94f7..a54000748 100644 --- a/package/tests/conftest.py +++ b/package/tests/conftest.py @@ -6,7 +6,6 @@ import pandas as pd import pytest from fastapi.testclient import TestClient -from kedro.framework.session.store import BaseSessionStore from kedro.io import DataCatalog, MemoryDataset, Version from kedro.pipeline import Pipeline, node from kedro.pipeline.modular_pipeline import pipeline @@ -20,7 +19,6 @@ ModularPipelinesRepository, ) from kedro_viz.integrations.kedro.hooks import DatasetStatsHook -from kedro_viz.integrations.kedro.sqlite_store import SQLiteStore from kedro_viz.models.flowchart.node_metadata import DataNodeMetadata from kedro_viz.models.flowchart.nodes import GraphNode from kedro_viz.server import populate_data @@ -38,16 +36,6 @@ def data_access_manager(): yield DataAccessManager() -@pytest.fixture -def session_store(): - yield BaseSessionStore("dummy_path", "dummy_session_id") - - -@pytest.fixture -def sqlite_session_store(tmp_path): - yield SQLiteStore(tmp_path, "dummy_session_id") - - @pytest.fixture def example_stats_dict(): yield { @@ -490,7 +478,6 @@ def example_api( data_access_manager: DataAccessManager, example_pipelines: Dict[str, Pipeline], example_catalog: DataCatalog, - session_store: BaseSessionStore, example_stats_dict: Dict, mocker, ): @@ -499,7 +486,6 @@ def example_api( data_access_manager, example_catalog, example_pipelines, - session_store, example_stats_dict, ) mocker.patch( @@ -518,14 +504,11 @@ def example_api_no_default_pipeline( data_access_manager: DataAccessManager, example_pipelines: Dict[str, Pipeline], example_catalog: DataCatalog, - session_store: BaseSessionStore, mocker, ): del example_pipelines["__default__"] api = apps.create_api_app_from_project(mock.MagicMock()) - populate_data( - data_access_manager, example_catalog, example_pipelines, session_store, {} - ) + populate_data(data_access_manager, example_catalog, example_pipelines, {}) mocker.patch( "kedro_viz.api.rest.responses.pipelines.data_access_manager", new=data_access_manager, @@ -542,7 +525,6 @@ def example_api_for_edge_case_pipelines( data_access_manager: DataAccessManager, edge_case_example_pipelines: Dict[str, Pipeline], example_catalog: DataCatalog, - session_store: BaseSessionStore, mocker, ): api = apps.create_api_app_from_project(mock.MagicMock()) @@ -558,7 +540,6 @@ def example_api_for_edge_case_pipelines( data_access_manager, example_catalog, edge_case_example_pipelines, - session_store, {}, ) mocker.patch( @@ -577,7 +558,6 @@ def example_api_for_pipelines_with_additional_tags( data_access_manager: DataAccessManager, example_pipelines_with_additional_tags: Dict[str, Pipeline], example_catalog: DataCatalog, - session_store: BaseSessionStore, mocker, ): api = apps.create_api_app_from_project(mock.MagicMock()) @@ -593,7 +573,6 @@ def example_api_for_pipelines_with_additional_tags( data_access_manager, example_catalog, example_pipelines_with_additional_tags, - session_store, {}, ) mocker.patch( @@ -612,7 +591,6 @@ def example_transcoded_api( data_access_manager: DataAccessManager, example_transcoded_pipelines: Dict[str, Pipeline], example_transcoded_catalog: DataCatalog, - session_store: BaseSessionStore, mocker, ): api = apps.create_api_app_from_project(mock.MagicMock()) @@ -620,7 +598,6 @@ def example_transcoded_api( data_access_manager, example_transcoded_catalog, example_transcoded_pipelines, - session_store, {}, ) mocker.patch( diff --git a/package/tests/test_api/test_graphql/__init__.py b/package/tests/test_api/test_graphql/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/package/tests/test_api/test_graphql/conftest.py b/package/tests/test_api/test_graphql/conftest.py deleted file mode 100644 index fb57f5aa5..000000000 --- a/package/tests/test_api/test_graphql/conftest.py +++ /dev/null @@ -1,246 +0,0 @@ -import base64 -import json -from pathlib import Path - -import pytest -from kedro.io import DataCatalog, Version -from kedro_datasets import matplotlib, pandas, plotly, tracking - -from kedro_viz.api.graphql.types import Run -from kedro_viz.database import make_db_session_factory -from kedro_viz.models.experiment_tracking import RunModel, UserRunDetailsModel - - -@pytest.fixture -def example_run_ids(): - yield ["2021-11-03T18.24.24.379Z", "2021-11-02T18.24.24.379Z"] - - -@pytest.fixture -def example_db_session(tmp_path): - session_store_location = Path(tmp_path / "session_store.db") - session_class = make_db_session_factory(session_store_location) - yield session_class - - -@pytest.fixture -def example_db_session_with_runs(example_db_session, example_run_ids): - with example_db_session.begin() as session: - for run_id in example_run_ids: - session_data = { - "package_name": "testsql", - "project_path": "/Users/Projects/testsql", - "session_id": run_id, - "cli": { - "args": [], - "params": { - "from_inputs": [], - "to_outputs": [], - "from_nodes": [], - "to_nodes": [], - "node_names": (), - "runner": None, - "parallel": False, - "is_async": False, - "env": None, - "tag": (), - "load_version": {}, - "pipeline": None, - "config": None, - "params": {}, - }, - "command_name": "run", - "command_path": "kedro run", - }, - } - run = RunModel(id=run_id, blob=json.dumps(session_data)) - user_run_details = UserRunDetailsModel(run_id=run.id, bookmark=True) - session.add(run) - session.add(user_run_details) - yield example_db_session - - -@pytest.fixture -def data_access_manager_with_no_run(data_access_manager, example_db_session, mocker): - data_access_manager.set_db_session(example_db_session) - mocker.patch( - "kedro_viz.api.graphql.schema.data_access_manager", data_access_manager - ) - yield data_access_manager - - -@pytest.fixture -def data_access_manager_with_runs( - data_access_manager, example_db_session_with_runs, mocker -): - data_access_manager.set_db_session(example_db_session_with_runs) - mocker.patch( - "kedro_viz.api.graphql.schema.data_access_manager", data_access_manager - ) - yield data_access_manager - - -@pytest.fixture -def save_version(example_run_ids): - yield example_run_ids[0] - - -@pytest.fixture -def example_tracking_catalog(example_run_ids, tmp_path): - example_run_id = example_run_ids[0] - metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_id), - ) - metrics_dataset.save({"col1": 1, "col2": 2, "col3": 3}) - - csv_dataset = pandas.CSVDataset( - filepath=Path(tmp_path / "metrics.csv").as_posix(), - version=Version(None, example_run_id), - ) - - more_metrics = tracking.MetricsDataset( - filepath=Path(tmp_path / "metrics.json").as_posix(), - version=Version(None, example_run_id), - ) - more_metrics.save({"col4": 4, "col5": 5, "col6": 6}) - - json_dataset = tracking.JSONDataset( - filepath=Path(tmp_path / "tracking.json").as_posix(), - version=Version(None, example_run_id), - ) - json_dataset.save({"col7": "column_seven", "col2": True, "col3": 3}) - - plotly_dataset = plotly.JSONDataset( - filepath=Path(tmp_path / "plotly.json").as_posix(), - version=Version(None, example_run_id), - ) - - class MockPlotlyData: - data = { - "data": [ - { - "x": ["giraffes", "orangutans", "monkeys"], - "y": [20, 14, 23], - "type": "bar", - } - ] - } - - @classmethod - def write_json(cls, fs_file, **kwargs): - json.dump(cls.data, fs_file, **kwargs) - - plotly_dataset.save(MockPlotlyData) - - matplotlib_dataset = matplotlib.MatplotlibWriter( - filepath=Path(tmp_path / "matplotlib.png").as_posix(), - version=Version(None, example_run_id), - ) - - class MockMatplotData: - data = base64.b64decode( - "iVBORw0KGgoAAAANSUhEUg" - "AAAAEAAAABCAQAAAC1HAwCAA" - "AAC0lEQVQYV2NgYAAAAAM" - "AAWgmWQ0AAAAASUVORK5CYII=" - ) - - @classmethod - def savefig(cls, bytes_buffer, **kwargs): - bytes_buffer.write(cls.data) - - matplotlib_dataset.save(MockMatplotData) - - catalog = DataCatalog( - datasets={ - "metrics": metrics_dataset, - "csv_dataset": csv_dataset, - "more_metrics": more_metrics, - "json_tracking": json_dataset, - "plotly_dataset": plotly_dataset, - "matplotlib_dataset": matplotlib_dataset, - } - ) - - yield catalog - - -@pytest.fixture -def example_multiple_run_tracking_catalog(example_run_ids, tmp_path): - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[1]), - ) - new_metrics_dataset.save({"col1": 1, "col3": 3}) - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[0]), - ) - new_data = {"col1": 3, "col2": 3.23} - new_metrics_dataset.save(new_data) - catalog = DataCatalog( - datasets={ - "new_metrics": new_metrics_dataset, - } - ) - - yield catalog - - -@pytest.fixture -def example_multiple_run_tracking_catalog_at_least_one_empty_run( - example_run_ids, tmp_path -): - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[1]), - ) - new_metrics_dataset.save({"col1": 1, "col3": 3}) - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[0]), - ) - catalog = DataCatalog( - datasets={ - "new_metrics": new_metrics_dataset, - } - ) - - yield catalog - - -@pytest.fixture -def example_multiple_run_tracking_catalog_all_empty_runs(example_run_ids, tmp_path): - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[1]), - ) - new_metrics_dataset = tracking.MetricsDataset( - filepath=Path(tmp_path / "test.json").as_posix(), - version=Version(None, example_run_ids[0]), - ) - catalog = DataCatalog( - datasets={ - "new_metrics": new_metrics_dataset, - } - ) - - yield catalog - - -@pytest.fixture -def example_runs(example_run_ids): - yield [ - Run( - id=run_id, - bookmark=False, - notes="Hello World", - title="Hello Kedro", - author="", - git_branch="", - git_sha="", - run_command="", - ) - for run_id in example_run_ids - ] diff --git a/package/tests/test_api/test_graphql/test_mutations.py b/package/tests/test_api/test_graphql/test_mutations.py deleted file mode 100644 index 5ff328538..000000000 --- a/package/tests/test_api/test_graphql/test_mutations.py +++ /dev/null @@ -1,232 +0,0 @@ -import json - -import pytest - -from kedro_viz.models.experiment_tracking import RunModel - - -@pytest.mark.usefixtures("data_access_manager_with_runs") -class TestGraphQLMutation: - @pytest.mark.parametrize( - "bookmark,notes,title", - [ - ( - False, - "new notes", - "new title", - ), - (True, "new notes", "new title"), - (True, "", ""), - ], - ) - def test_update_user_details_success( - self, - bookmark, - notes, - title, - client, - example_run_ids, - ): - example_run_id = example_run_ids[0] - query = f""" - mutation updateRun {{ - updateRunDetails( - runId: "{example_run_id}", - runInput: {{bookmark: {str(bookmark).lower()}, notes: "{notes}", title: "{title}"}} - ) {{ - __typename - ... on UpdateRunDetailsSuccess {{ - run {{ - id - title - bookmark - notes - }} - }} - ... on UpdateRunDetailsFailure {{ - id - errorMessage - }} - }} - }} - """ - response = client.post("/graphql", json={"query": query}) - assert response.json() == { - "data": { - "updateRunDetails": { - "__typename": "UpdateRunDetailsSuccess", - "run": { - "id": example_run_id, - "bookmark": bookmark, - "title": title if title != "" else example_run_id, - "notes": notes, - }, - } - } - } - - def test_update_user_details_only_bookmark( - self, - client, - example_run_ids, - ): - example_run_id = example_run_ids[0] - query = f""" - mutation updateRun {{ - updateRunDetails(runId: "{example_run_id}", runInput: {{bookmark: true}}) {{ - __typename - ... on UpdateRunDetailsSuccess {{ - run {{ - id - title - bookmark - notes - }} - }} - ... on UpdateRunDetailsFailure {{ - id - errorMessage - }} - }} - }} - """ - - response = client.post("/graphql", json={"query": query}) - assert response.json() == { - "data": { - "updateRunDetails": { - "__typename": "UpdateRunDetailsSuccess", - "run": { - "id": example_run_id, - "bookmark": True, - "title": example_run_id, - "notes": "", - }, - } - } - } - - def test_update_user_details_should_add_when_no_details_exist( - self, client, data_access_manager_with_no_run - ): - # add a new run - example_run_id = "test_id" - run = RunModel( - id=example_run_id, - blob=json.dumps( - {"session_id": example_run_id, "cli": {"command_path": "kedro run"}} - ), - ) - data_access_manager_with_no_run.runs.add_run(run) - - query = f""" - mutation updateRun {{ - updateRunDetails(runId: "{example_run_id}", runInput: {{bookmark: true}}) {{ - __typename - ... on UpdateRunDetailsSuccess {{ - run {{ - id - title - bookmark - notes - }} - }} - ... on UpdateRunDetailsFailure {{ - id - errorMessage - }} - }} - }} - """ - - response = client.post("/graphql", json={"query": query}) - assert response.json() == { - "data": { - "updateRunDetails": { - "__typename": "UpdateRunDetailsSuccess", - "run": { - "id": example_run_id, - "bookmark": True, - "title": example_run_id, - "notes": "", - }, - } - } - } - - def test_update_user_details_should_update_when_details_exist( - self, client, example_run_ids - ): - example_run_id = example_run_ids[0] - query = f""" - mutation updateRun {{ - updateRunDetails(runId: "{example_run_id}", runInput: {{title:"new title", notes: "new notes"}}) {{ - __typename - ... on UpdateRunDetailsSuccess {{ - run {{ - id - title - bookmark - notes - }} - }} - ... on UpdateRunDetailsFailure {{ - id - errorMessage - }} - }} - }} - """ - - response = client.post("/graphql", json={"query": query}) - assert response.json() == { - "data": { - "updateRunDetails": { - "__typename": "UpdateRunDetailsSuccess", - "run": { - "id": example_run_id, - "bookmark": True, - "title": "new title", - "notes": "new notes", - }, - } - } - } - - def test_update_user_details_should_fail_when_run_doesnt_exist(self, client): - response = client.post( - "/graphql", - json={ - "query": """ - mutation { - updateRunDetails( - runId: "I don't exist", - runInput: { bookmark: false, title: "Hello Kedro", notes: "There are notes"} - ) { - __typename - ... on UpdateRunDetailsSuccess { - run { - id - title - notes - bookmark - } - } - ... on UpdateRunDetailsFailure { - id - errorMessage - } - } - } - """ - }, - ) - assert response.json() == { - "data": { - "updateRunDetails": { - "__typename": "UpdateRunDetailsFailure", - "id": "I don't exist", - "errorMessage": "Given run_id: I don't exist doesn't exist", - } - } - } diff --git a/package/tests/test_api/test_graphql/test_queries.py b/package/tests/test_api/test_graphql/test_queries.py deleted file mode 100644 index 05dcf6fcd..000000000 --- a/package/tests/test_api/test_graphql/test_queries.py +++ /dev/null @@ -1,429 +0,0 @@ -import json - -import pytest -from packaging.version import parse - -from kedro_viz import __version__ - - -class TestQueryNoSessionStore: - def test_graphql_run_list_endpoint(self, client): - response = client.post("/graphql", json={"query": "{runsList {id bookmark}}"}) - assert response.json() == {"data": {"runsList": []}} - - def test_graphql_runs_metadata_endpoint(self, client): - response = client.post( - "/graphql", - json={"query": '{runMetadata(runIds: ["id"]) {id bookmark}}'}, - ) - assert response.json() == {"data": {"runMetadata": []}} - - -@pytest.mark.usefixtures("data_access_manager_with_no_run") -class TestQueryNoRun: - def test_graphql_run_list_endpoint(self, client): - response = client.post("/graphql", json={"query": "{runsList {id bookmark}}"}) - assert response.json() == {"data": {"runsList": []}} - - def test_graphql_runs_metadata_endpoint(self, client): - response = client.post( - "/graphql", - json={"query": '{runMetadata(runIds: ["invalid run id"]) {id bookmark}}'}, - ) - assert response.json() == {"data": {"runMetadata": []}} - - -@pytest.mark.usefixtures("data_access_manager_with_runs") -class TestQueryWithRuns: - def test_run_list_query( - self, - client, - example_run_ids, - ): - response = client.post("/graphql", json={"query": "{runsList {id bookmark}}"}) - assert response.json() == { - "data": { - "runsList": [ - {"id": run_id, "bookmark": True} for run_id in example_run_ids - ] - } - } - - def test_graphql_runs_metadata_endpoint(self, example_run_ids, client): - response = client.post( - "/graphql", - json={ - "query": f"""{{runMetadata(runIds: ["{ example_run_ids[0] }"]) {{id bookmark}}}}""" - }, - ) - assert response.json() == { - "data": {"runMetadata": [{"id": example_run_ids[0], "bookmark": True}]} - } - - def test_run_tracking_data_query( - self, - example_run_ids, - client, - example_tracking_catalog, - data_access_manager_with_runs, - example_pipelines, - ): - data_access_manager_with_runs.add_catalog( - example_tracking_catalog, example_pipelines - ) - example_run_id = example_run_ids[0] - - response = client.post( - "/graphql", - json={ - "query": f""" - {{ - metrics: runTrackingData(runIds:["{example_run_id}"],group:METRIC) - {{datasetName, datasetType, data}} - json: runTrackingData(runIds:["{example_run_id}"],group:JSON) - {{datasetName, datasetType, data}} - plots: runTrackingData(runIds:["{example_run_id}"],group:PLOT) - {{datasetName, datasetType, data}} - }} - """ - }, - ) - - expected_response = { - "data": { - "metrics": [ - { - "datasetName": "metrics", - "datasetType": "tracking.metrics_dataset.MetricsDataset", - "data": { - "col1": [{"runId": example_run_id, "value": 1.0}], - "col2": [{"runId": example_run_id, "value": 2.0}], - "col3": [{"runId": example_run_id, "value": 3.0}], - }, - }, - { - "datasetName": "more_metrics", - "datasetType": "tracking.metrics_dataset.MetricsDataset", - "data": { - "col4": [{"runId": example_run_id, "value": 4.0}], - "col5": [{"runId": example_run_id, "value": 5.0}], - "col6": [{"runId": example_run_id, "value": 6.0}], - }, - }, - ], - "json": [ - { - "datasetName": "json_tracking", - "datasetType": "tracking.json_dataset.JSONDataset", - "data": { - "col2": [{"runId": example_run_id, "value": True}], - "col3": [{"runId": example_run_id, "value": 3}], - "col7": [ - { - "runId": example_run_id, - "value": "column_seven", - } - ], - }, - }, - ], - "plots": [ - { - "datasetName": "plotly_dataset", - "datasetType": "plotly.json_dataset.JSONDataset", - "data": { - "plotly.json": [ - { - "runId": "2021-11-03T18.24.24.379Z", - "value": { - "data": [ - { - "x": [ - "giraffes", - "orangutans", - "monkeys", - ], - "y": [20, 14, 23], - "type": "bar", - } - ] - }, - } - ] - }, - }, - { - "datasetName": "matplotlib_dataset", - "datasetType": "matplotlib.matplotlib_writer.MatplotlibWriter", - "data": { - "matplotlib.png": [ - { - "runId": "2021-11-03T18.24.24.379Z", - "value": "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVQYV2NgYAAAAAMAAWgmWQ0AAAAASUVORK5CYII=", - } - ] - }, - }, - ], - } - } - - assert response.json() == expected_response - - def test_metrics_data( - self, - client, - example_tracking_catalog, - data_access_manager_with_runs, - example_pipelines, - ): - data_access_manager_with_runs.add_catalog( - example_tracking_catalog, example_pipelines - ) - - response = client.post( - "/graphql", - json={ - "query": "query MyQuery {\n runMetricsData(limit: 3) {\n data\n }\n}\n" - }, - ) - - expected = { - "data": { - "runMetricsData": { - "data": { - "metrics": { - "metrics.col1": [1.0, None], - "metrics.col2": [2.0, None], - "metrics.col3": [3.0, None], - "more_metrics.col4": [4.0, None], - "more_metrics.col5": [5.0, None], - "more_metrics.col6": [6.0, None], - }, - "runs": { - "2021-11-02T18.24.24.379Z": [ - None, - None, - None, - None, - None, - None, - ], - "2021-11-03T18.24.24.379Z": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], - }, - } - } - } - } - - assert response.json() == expected - - @pytest.mark.parametrize( - "show_diff,expected_response", - [ - ( - True, - { - "data": { - "runTrackingData": [ - { - "datasetName": "new_metrics", - "datasetType": "tracking.metrics_dataset.MetricsDataset", - "data": { - "col1": [ - { - "runId": "2021-11-03T18.24.24.379Z", - "value": 3.0, - }, - { - "runId": "2021-11-02T18.24.24.379Z", - "value": 1.0, - }, - ], - "col2": [ - { - "runId": "2021-11-03T18.24.24.379Z", - "value": 3.23, - }, - ], - "col3": [ - { - "runId": "2021-11-02T18.24.24.379Z", - "value": 3.0, - }, - ], - }, - } - ] - } - }, - ), - ( - False, - { - "data": { - "runTrackingData": [ - { - "datasetName": "new_metrics", - "datasetType": "tracking.metrics_dataset.MetricsDataset", - "data": { - "col1": [ - { - "runId": "2021-11-03T18.24.24.379Z", - "value": 3.0, - }, - { - "runId": "2021-11-02T18.24.24.379Z", - "value": 1.0, - }, - ], - }, - }, - ] - } - }, - ), - ], - ) - def test_graphql_run_tracking_data( - self, - example_run_ids, - client, - example_multiple_run_tracking_catalog, - data_access_manager_with_runs, - show_diff, - expected_response, - example_pipelines, - ): - data_access_manager_with_runs.add_catalog( - example_multiple_run_tracking_catalog, example_pipelines - ) - - response = client.post( - "/graphql", - json={ - "query": f"""{{runTrackingData - (group: METRIC runIds:{json.dumps(example_run_ids)}, showDiff: {json.dumps(show_diff)}) - {{datasetName, datasetType, data}}}}""" - }, - ) - assert response.json() == expected_response - - @pytest.mark.parametrize( - "show_diff,expected_response", - [ - ( - True, - { - "data": { - "runTrackingData": [ - { - "datasetName": "new_metrics", - "datasetType": "tracking.metrics_dataset.MetricsDataset", - "data": { - "col1": [ - { - "runId": "2021-11-02T18.24.24.379Z", - "value": 1.0, - }, - ], - "col3": [ - { - "runId": "2021-11-02T18.24.24.379Z", - "value": 3.0, - }, - ], - }, - } - ] - } - }, - ), - ( - False, - {"data": {"runTrackingData": []}}, - ), - ], - ) - def test_graphql_run_tracking_data_at_least_one_empty_run( - self, - example_run_ids, - client, - example_multiple_run_tracking_catalog_at_least_one_empty_run, - data_access_manager_with_runs, - show_diff, - expected_response, - example_pipelines, - ): - data_access_manager_with_runs.add_catalog( - example_multiple_run_tracking_catalog_at_least_one_empty_run, - example_pipelines, - ) - - response = client.post( - "/graphql", - json={ - "query": f"""{{runTrackingData - (group: METRIC runIds:{json.dumps(example_run_ids)}, showDiff: {json.dumps(show_diff)}) - {{datasetName, datasetType, data}}}}""" - }, - ) - assert response.json() == expected_response - - @pytest.mark.parametrize( - "show_diff,expected_response", - [ - ( - True, - {"data": {"runTrackingData": []}}, - ), - ( - False, - {"data": {"runTrackingData": []}}, - ), - ], - ) - def test_graphql_run_tracking_data_all_empty_runs( - self, - example_run_ids, - client, - example_multiple_run_tracking_catalog_all_empty_runs, - data_access_manager_with_runs, - show_diff, - expected_response, - example_pipelines, - ): - data_access_manager_with_runs.add_catalog( - example_multiple_run_tracking_catalog_all_empty_runs, example_pipelines - ) - - response = client.post( - "/graphql", - json={ - "query": f"""{{runTrackingData - (group: METRIC runIds:{json.dumps(example_run_ids)}, showDiff: {json.dumps(show_diff)}) - {{datasetName, datasetType, data}}}}""" - }, - ) - assert response.json() == expected_response - - -class TestQueryVersion: - def test_graphql_version_endpoint(self, client, mocker): - mocker.patch( - "kedro_viz.api.graphql.schema.get_latest_version", - return_value=parse("1.0.0"), - ) - response = client.post( - "/graphql", - json={"query": "{version {installed isOutdated latest}}"}, - ) - assert response.json() == { - "data": { - "version": { - "installed": __version__, - "isOutdated": False, - "latest": "1.0.0", - } - } - } diff --git a/package/tests/test_api/test_graphql/test_serializers.py b/package/tests/test_api/test_graphql/test_serializers.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/package/tests/test_integrations/test_sqlite_store.py b/package/tests/test_integrations/test_sqlite_store.py deleted file mode 100644 index 4f0cb6a00..000000000 --- a/package/tests/test_integrations/test_sqlite_store.py +++ /dev/null @@ -1,390 +0,0 @@ -import json -import os -from pathlib import Path - -import boto3 -import pytest -from moto import mock_aws -from sqlalchemy import create_engine, func, select, text -from sqlalchemy.orm import sessionmaker - -from kedro_viz.database import make_db_session_factory -from kedro_viz.integrations.kedro.sqlite_store import SQLiteStore, _get_dbname -from kedro_viz.models.experiment_tracking import Base, RunModel - -BUCKET_NAME = "test-bucket" - - -@pytest.fixture -def parametrize_session_store_args(request): - """Fixture to parameterize has_session_store_args.""" - - # This fixture sets a class attribute has_session_store_args - # based on the parameter passed - request.cls.has_session_store_args = request.param - - -@pytest.fixture -def mock_session_store_args(request, mocker, setup_kedro_project): - """Fixture to mock SESSION_STORE_ARGS and _find_kedro_project.""" - - # This fixture uses the class attribute has_session_store_args - # to apply the appropriate mocks. - if request.cls.has_session_store_args: - mocker.patch.dict( - "kedro_viz.integrations.kedro.sqlite_store.settings.SESSION_STORE_ARGS", - {"path": "some_path"}, - clear=True, - ) - else: - mocker.patch( - "kedro_viz.integrations.kedro.sqlite_store._find_kedro_project", - return_value=setup_kedro_project, - ) - - -@pytest.fixture -def store_path(request, tmp_path, setup_kedro_project): - if request.cls.has_session_store_args: - return Path(tmp_path) - session_store_path = Path(tmp_path / setup_kedro_project / ".viz") - session_store_path.mkdir(parents=True, exist_ok=True) - return session_store_path - - -@pytest.fixture -def db_session_class(store_path): - engine = create_engine(f"sqlite:///{store_path}/session_store.db") - Base.metadata.create_all(engine) - Session = sessionmaker(bind=engine) - return Session - - -@pytest.fixture(scope="class") -def aws_credentials(): - """Mocked AWS credentials for moto""" - os.environ["AWS_ACCESS_KEY_ID"] = "testing" - os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" - os.environ["AWS_SESSION_TOKEN"] = "testing" - - -@pytest.fixture(scope="class") -def mocked_s3_bucket(aws_credentials): - """S3 Mock Client""" - with mock_aws(): - conn = boto3.client("s3", region_name="us-east-1") - conn.create_bucket(Bucket=BUCKET_NAME) - yield conn - - -@pytest.fixture -def remote_path(): - return f"s3://{BUCKET_NAME}" - - -@pytest.fixture -def mock_db1(store_path): - database_loc = str(store_path / "db1.db") - with make_db_session_factory(database_loc).begin() as session: - session.add(RunModel(id="1", blob="blob1")) - yield Path(database_loc) - - -@pytest.fixture -def mock_db2(store_path): - database_loc = str(store_path / "db2.db") - with make_db_session_factory(database_loc).begin() as session: - session.add(RunModel(id="2", blob="blob2")) - yield Path(database_loc) - - -@pytest.fixture -def mock_db3_with_db2_data(store_path): - database_loc = str(store_path / "db3.db") - with make_db_session_factory(database_loc).begin() as session: - session.add(RunModel(id="2", blob="blob2")) - yield Path(database_loc) - - -def get_files_in_bucket(bucket_name): - s3 = boto3.client("s3") - response = s3.list_objects(Bucket=bucket_name) - files = [obj["Key"] for obj in response.get("Contents", [])] - return files - - -@pytest.fixture -def mocked_db_in_s3(mocked_s3_bucket, mock_db1, mock_db2): - # define the name of the S3 bucket and the database file names - db1_filename = "db1.db" - db2_filename = "db2.db" - - # upload each mock database file to the mocked S3 bucket - mocked_s3_bucket.put_object( - Bucket=BUCKET_NAME, Key=db1_filename, Body=mock_db1.read_bytes() - ) - mocked_s3_bucket.put_object( - Bucket=BUCKET_NAME, Key=db2_filename, Body=mock_db2.read_bytes() - ) - - return get_files_in_bucket(BUCKET_NAME) - - -@pytest.fixture -def mocked_db_in_s3_repeated_runs( - mocked_s3_bucket, mock_db1, mock_db2, mock_db3_with_db2_data -): - # define the name of the S3 bucket and the database file names - db1_filename = "db1.db" - db2_filename = "db2.db" - db3_filename = "db3.db" - - # upload each mock database file to the mocked S3 bucket - mocked_s3_bucket.put_object( - Bucket=BUCKET_NAME, Key=db1_filename, Body=mock_db1.read_bytes() - ) - mocked_s3_bucket.put_object( - Bucket=BUCKET_NAME, Key=db2_filename, Body=mock_db2.read_bytes() - ) - mocked_s3_bucket.put_object( - Bucket=BUCKET_NAME, Key=db3_filename, Body=mock_db3_with_db2_data.read_bytes() - ) - - return get_files_in_bucket(BUCKET_NAME) - - -def session_id(): - i = 0 - while True: - yield f"session_{i}" - i += 1 - - -def test_get_dbname_with_env_var(mocker): - mocker.patch.dict( - os.environ, {"KEDRO_SQLITE_STORE_USERNAME": "env_user_name"}, clear=True - ) - mocker.patch("getpass.getuser", return_value="computer_user_name") - dbname = _get_dbname() - assert dbname == "env_user_name.db" - - -def test_get_dbname_without_env_var(mocker): - mocker.patch.dict("os.environ", clear=True) - mocker.patch("getpass.getuser", return_value="computer_user_name") - dbname = _get_dbname() - assert dbname == "computer_user_name.db" - - -@pytest.mark.usefixtures("parametrize_session_store_args", "mock_session_store_args") -@pytest.mark.parametrize("parametrize_session_store_args", [True, False], indirect=True) -class TestSQLiteStore: - def test_empty(self, store_path): - sqlite_store = SQLiteStore(store_path, next(session_id())) - assert not sqlite_store - assert sqlite_store.location == str(Path(store_path) / "session_store.db") - - def test_save_single_run(self, store_path): - sqlite_store = SQLiteStore(store_path, next(session_id())) - sqlite_store.data = {"project_path": store_path, "project_name": "test"} - sqlite_store.save() - with sqlite_store._db_session_class() as session: - query = select(RunModel) - loaded_runs = session.execute(query).scalars().all() - assert len(loaded_runs) == 1 - assert json.loads(loaded_runs[0].blob) == { - "project_path": str(store_path), - "project_name": "test", - } - - def test_save_multiple_runs(self, store_path): - session = session_id() - sqlite_store = SQLiteStore(store_path, next(session)) - sqlite_store.save() - with sqlite_store._db_session_class() as db_session: - query = select(func.count()).select_from(RunModel) - assert db_session.execute(query).scalar() == 1 - # save another session - sqlite_store2 = SQLiteStore(store_path, next(session)) - sqlite_store2.save() - with sqlite_store2._db_session_class() as db_session: - query = select(func.count()).select_from(RunModel) - assert db_session.execute(query).scalar() == 2 - - def test_save_run_with_remote_path(self, mocker, store_path, remote_path): - mocker.patch("fsspec.filesystem") - sqlite_store = SQLiteStore( - store_path, next(session_id()), remote_path=remote_path - ) - sqlite_store.data = {"project_path": store_path, "project_name": "test"} - mock_upload = mocker.patch.object(sqlite_store, "_upload") - sqlite_store.save() - mock_upload.assert_called_once() - - def test_save_run_without_remote_path(self, mocker, store_path): - sqlite_store = SQLiteStore(store_path, next(session_id())) - sqlite_store.data = {"project_path": store_path, "project_name": "test"} - mock_upload = mocker.patch.object(sqlite_store, "_upload") - sqlite_store.save() - mock_upload.assert_not_called() - - def test_update_git_branch(self, store_path, mocker): - sqlite_store = SQLiteStore(store_path, next(session_id())) - sqlite_store.data = { - "project_path": store_path, - "git": {"commit_sha": "123456"}, - } - mocker.patch("git.Repo.active_branch").name = "test_branch" - - assert sqlite_store._to_json() == json.dumps( - { - "project_path": str(store_path), - "git": {"commit_sha": "123456", "branch": "test_branch"}, - } - ) - - def test_upload_to_s3_success(self, mocker, store_path, remote_path): - mocker.patch("fsspec.filesystem") - sqlite_store = SQLiteStore( - store_path, next(session_id()), remote_path=remote_path - ) - sqlite_store._upload() - sqlite_store._remote_fs.put.assert_called_once() - - def test_upload_to_s3_fail(self, mocker, store_path, remote_path, caplog): - mocker.patch("fsspec.filesystem") - sqlite_store = SQLiteStore( - store_path, next(session_id()), remote_path=remote_path - ) - sqlite_store._remote_fs.put.side_effect = ConnectionError("Connection error") - sqlite_store._upload() - assert "Upload failed: Connection error" in caplog.text - - def test_download_from_s3_success( - self, - mocker, - store_path, - remote_path, - mocked_db_in_s3, - ): - mocker.patch("fsspec.filesystem") - sqlite_store = SQLiteStore( - store_path, next(session_id()), remote_path=remote_path - ) - sqlite_store._remote_fs.glob.return_value = mocked_db_in_s3 - sqlite_store._download() - - assert set(file.name for file in Path(store_path).glob("*.db")) == { - "db1.db", - "db2.db", - "session_store.db", - } - - def test_download_from_s3_failure(self, mocker, store_path, remote_path, caplog): - mocker.patch("fsspec.filesystem") - sqlite_store = SQLiteStore( - store_path, next(session_id()), remote_path=remote_path - ) - sqlite_store._remote_fs.glob.side_effect = ConnectionError("Connection error") - sqlite_store._download() - # assert that downloaded dbs are not downloaded - assert set(file.name for file in Path(store_path).glob("*.db")) == { - "session_store.db" - } - assert "Download failed: Connection error" in caplog.text - - def test_merge_databases( - self, - mocker, - store_path, - remote_path, - mocked_db_in_s3, - ): - mocker.patch("fsspec.filesystem") - sqlite_store = SQLiteStore( - store_path, next(session_id()), remote_path=remote_path - ) - sqlite_store._remote_fs.glob.return_value = mocked_db_in_s3 - sqlite_store._download() - sqlite_store._merge() - db_session = sqlite_store._db_session_class - with db_session() as session: - assert session.execute(select(RunModel.id)).scalars().all() == ["1", "2"] - - def test_merge_databases_with_repeated_runs( - self, - mocker, - store_path, - remote_path, - mocked_db_in_s3_repeated_runs, - ): - mocker.patch("fsspec.filesystem") - sqlite_store = SQLiteStore( - store_path, next(session_id()), remote_path=remote_path - ) - sqlite_store._remote_fs.glob.return_value = mocked_db_in_s3_repeated_runs - sqlite_store._download() - sqlite_store._merge() - db_session = sqlite_store._db_session_class - with db_session() as session: - assert session.execute(select(RunModel.id)).scalars().all() == ["1", "2"] - - def test_sync(self, mocker, store_path, remote_path, mocked_db_in_s3): - mocker.patch("fsspec.filesystem") - sqlite_store = SQLiteStore( - store_path, next(session_id()), remote_path=remote_path - ) - sqlite_store._remote_fs.glob.return_value = mocked_db_in_s3 - mock_download = mocker.patch.object(sqlite_store, "_download") - mock_merge = mocker.patch.object(sqlite_store, "_merge") - mock_upload = mocker.patch.object(sqlite_store, "_upload") - sqlite_store.sync() - mock_download.assert_called_once() - mock_merge.assert_called_once() - mock_upload.assert_called_once() - - def test_sync_without_remote_path(self, mocker, store_path): - mocker.patch("fsspec.filesystem") - sqlite_store = SQLiteStore(store_path, next(session_id())) - mock_download = mocker.patch.object(sqlite_store, "_download") - mock_merge = mocker.patch.object(sqlite_store, "_merge") - mock_upload = mocker.patch.object(sqlite_store, "_upload") - sqlite_store.sync() - mock_download.assert_not_called() - mock_merge.assert_not_called() - mock_upload.assert_not_called() - - def test_sync_with_merge_error(self, mocker, store_path, remote_path, caplog): - mocker.patch("fsspec.filesystem") - sqlite_store = SQLiteStore( - store_path, next(session_id()), remote_path=remote_path - ) - mock_download = mocker.patch.object(sqlite_store, "_download") - mock_merge = mocker.patch.object( - sqlite_store, "_merge", side_effect=Exception("Merge failed") - ) - mock_upload = mocker.patch.object(sqlite_store, "_upload") - sqlite_store.sync() - mock_download.assert_called_once() - mock_merge.assert_called_once() - mock_upload.assert_called_once() - assert "Merge failed on sync: Merge failed" in caplog.text - - def test_make_db_session_factory_with_azure_env_var(self, mocker, tmp_path): - """Test that WAL mode is enabled when running in an Azure environment.""" - mocker.patch.dict( - os.environ, - { - "AZUREML_ARM_SUBSCRIPTION": "dummy_value", - "AZUREML_ARM_RESOURCEGROUP": "dummy_value", - }, - ) - db_location = str(tmp_path / "test_session_store.db") - session_class = make_db_session_factory(db_location) - - # Ensure that the session can be created without issues. - with session_class() as session: - assert session is not None - # Check if the database is using WAL mode by querying the PRAGMA - result = session.execute(text("PRAGMA journal_mode;")).scalar() - assert result == "wal" diff --git a/package/tests/test_server.py b/package/tests/test_server.py index ca8d19a2c..1c0960c1c 100644 --- a/package/tests/test_server.py +++ b/package/tests/test_server.py @@ -31,30 +31,12 @@ def patched_create_api_app_from_file(mocker): @pytest.fixture(autouse=True) -def patched_load_data( - mocker, example_catalog, example_pipelines, session_store, example_stats_dict -): +def patched_load_data(mocker, example_catalog, example_pipelines, example_stats_dict): yield mocker.patch( "kedro_viz.server.kedro_data_loader.load_data", return_value=( example_catalog, example_pipelines, - session_store, - example_stats_dict, - ), - ) - - -@pytest.fixture -def patched_load_data_with_sqlite_session_store( - mocker, example_catalog, example_pipelines, sqlite_session_store, example_stats_dict -): - yield mocker.patch( - "kedro_viz.server.kedro_data_loader.load_data", - return_value=( - example_catalog, - example_pipelines, - sqlite_session_store, example_stats_dict, ), ) @@ -84,31 +66,6 @@ def test_run_server_from_project( # an uvicorn server is launched patched_uvicorn_run.assert_called_once() - def test_run_server_from_project_with_sqlite_store( - self, - patched_create_api_app_from_project, - patched_data_access_manager, - patched_uvicorn_run, - patched_load_data_with_sqlite_session_store, - example_catalog, - example_pipelines, - ): - run_server() - # assert that when running server, data are added correctly to the data access manager - patched_data_access_manager.add_catalog.assert_called_once_with( - example_catalog, example_pipelines - ) - patched_data_access_manager.add_pipelines.assert_called_once_with( - example_pipelines - ) - patched_data_access_manager.set_db_session.assert_called_once() - - # correct api app is created - patched_create_api_app_from_project.assert_called_once() - - # an uvicorn server is launched - patched_uvicorn_run.assert_called_once() - def test_specific_pipeline( self, patched_data_access_manager,