diff --git a/giskard/ml_worker/testing/utils.py b/giskard/ml_worker/testing/utils.py index 793528b6a8..71c715e97c 100644 --- a/giskard/ml_worker/testing/utils.py +++ b/giskard/ml_worker/testing/utils.py @@ -1,13 +1,14 @@ +from typing import Optional + +import numbers from enum import Enum from functools import wraps -import numbers -from typing import Optional -from giskard.datasets.base import Dataset from giskard.core.core import SupportedModelTypes +from giskard.datasets.base import Dataset -class Direction(Enum): +class Direction(int, Enum): Invariant = 0 Increasing = 1 Decreasing = -1 diff --git a/tests/communications/test_listener_utils.py b/tests/communications/test_listener_utils.py index 3b9ba9626f..812aca8db6 100644 --- a/tests/communications/test_listener_utils.py +++ b/tests/communications/test_listener_utils.py @@ -2,7 +2,11 @@ from giskard.ml_worker import websocket from giskard.ml_worker.exceptions.IllegalArgumentError import IllegalArgumentError -from giskard.ml_worker.websocket.listener import extract_debug_info, function_argument_to_ws, parse_function_arguments +from giskard.ml_worker.websocket.listener import ( + extract_debug_info, + function_argument_to_ws, + parse_function_arguments, +) TEST_PROJECT_KEY = "123" TEST_MODEL_ID = "231" @@ -30,7 +34,6 @@ def test_extract_debug_info(): def test_function_argument_to_ws(): - # Domain classes creation should be tested somewhere else, do not test them. # "dataset": Dataset, # "model": BaseModel, @@ -76,4 +79,4 @@ def test_parse_function_arguments(): assert "int" in kwargs.keys() and kwargs["int"] == TEST_FUNC_ARGUMENT_INT assert "str" in kwargs.keys() and kwargs["str"] == TEST_FUNC_ARGUMENT_STR assert "bool" in kwargs.keys() and kwargs["bool"] == TEST_FUNC_ARGUMENT_BOOL - assert "bool1" in kwargs.keys() and kwargs["bool1"] == TEST_FUNC_ARGUMENT_BOOL + assert "kwargs" in kwargs.keys() and kwargs["kwargs"]["bool1"] == TEST_FUNC_ARGUMENT_BOOL