Skip to content

Commit

Permalink
Avoid uploading default params several times
Browse files Browse the repository at this point in the history
  • Loading branch information
Hartorn committed Oct 26, 2023
1 parent 6a7f344 commit 4744997
Showing 1 changed file with 17 additions and 20 deletions.
37 changes: 17 additions & 20 deletions giskard/core/suite.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,24 @@
from typing import Any, Dict, List, Optional, Tuple, Union

import inspect
import logging
import traceback
from dataclasses import dataclass
from functools import singledispatchmethod
from typing import List, Any, Union, Dict, Optional, Tuple

from mlflow import MlflowClient

from giskard.client.dtos import TestSuiteDTO, TestInputDTO, SuiteTestDTO
from giskard.client.dtos import SuiteTestDTO, TestInputDTO, TestSuiteDTO
from giskard.client.giskard_client import GiskardClient
from giskard.core.core import TestFunctionMeta
from giskard.datasets.base import Dataset
from giskard.ml_worker.core.savable import Artifact
from giskard.ml_worker.exceptions.IllegalArgumentError import IllegalArgumentError
from giskard.ml_worker.testing.registry.giskard_test import (
GiskardTest,
Test,
GiskardTestMethod,
)
from giskard.ml_worker.testing.registry.giskard_test import GiskardTest, GiskardTestMethod, Test
from giskard.ml_worker.testing.registry.registry import tests_registry
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction
from giskard.ml_worker.testing.registry.transformation_function import (
TransformationFunction,
)
from giskard.ml_worker.testing.test_result import (
TestResult,
TestMessage,
TestMessageLevel,
)
from giskard.ml_worker.testing.registry.transformation_function import TransformationFunction
from giskard.ml_worker.testing.test_result import TestMessage, TestMessageLevel, TestResult
from giskard.models.base import BaseModel

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -97,6 +88,7 @@ def _repr_html_(self):

def to_mlflow(self, mlflow_client: MlflowClient = None, mlflow_run_id: str = None):
import mlflow

from giskard.integrations.mlflow.giskard_evaluator_utils import process_text

metrics = dict()
Expand All @@ -122,8 +114,10 @@ def to_wandb(self, **kwargs) -> None:
Additional keyword arguments
(see https://docs.wandb.ai/ref/python/init) to be added to the active WandB run.
"""
from giskard.integrations.wandb.wandb_utils import wandb_run, _parse_test_name
import wandb

from giskard.integrations.wandb.wandb_utils import _parse_test_name, wandb_run

from ..utils.analytics_collector import analytics

with wandb_run(**kwargs) as run:
Expand Down Expand Up @@ -426,22 +420,25 @@ def upload(self, client: GiskardClient, project_key: str):
"""
if self.name is None:
self.name = "Unnamed test suite"


uploaded_uuids: List[str] = []
# Upload the default parameters if they are model or dataset
for arg in self.default_params.values():
if isinstance(arg, BaseModel) or isinstance(arg, Dataset):
arg.upload(client, project_key)
uploaded_uuids.append(str(arg.id))

self.id = client.save_test_suite(self.to_dto(client, project_key))
self.id = client.save_test_suite(self.to_dto(client, project_key, uploaded_uuids))
project_id = client.get_project(project_key).project_id
print(f"Test suite has been saved: {client.host_url}/main/projects/{project_id}/test-suite/{self.id}/overview")
return self

def to_dto(self, client: GiskardClient, project_key: str):
def to_dto(self, client: GiskardClient, project_key: str, uploaded_uuids: Optional[List[str]] = None):
suite_tests: List[SuiteTestDTO] = list()

# Avoid to upload the same artifacts several times
uploaded_uuids: List[str] = []
if uploaded_uuids is None:
uploaded_uuids = []

for t in self.tests:
suite_tests.append(
Expand Down

0 comments on commit 4744997

Please sign in to comment.