diff --git a/giskard/core/core.py b/giskard/core/core.py index a0fb40cf79..b589ac0614 100644 --- a/giskard/core/core.py +++ b/giskard/core/core.py @@ -169,26 +169,35 @@ def __init__( self.tags = self.populate_tags(tags) parameters = self.extract_parameters(callable_obj) + for param in parameters: + param.default = serialize_parameter(param.default) - self.args = { - parameter.name: FunctionArgument( - name=parameter.name, - type=extract_optional(parameter.annotation).__qualname__, - optional=parameter.default != inspect.Parameter.empty, - default=serialize_parameter(parameter.default), - argOrder=idx, - ) - for idx, parameter in enumerate(parameters.values()) - if name != "self" - } + self.args = {param.name: param for param in parameters} - def extract_parameters(self, callable_obj): + def extract_parameters(self, callable_obj) -> List[FunctionArgument]: if inspect.isclass(callable_obj): parameters = list(inspect.signature(callable_obj.__init__).parameters.values())[1:] else: parameters = list(inspect.signature(callable_obj).parameters.values()) - return parameters + return [ + FunctionArgument( + name=parameter.name, + type=self.extract_parameter_type_name(parameter), + optional=parameter.default != inspect.Parameter.empty, + default=parameter.default, + argOrder=idx, + ) + for idx, parameter in enumerate(parameters) + ] + + @staticmethod + def extract_parameter_type_name(parameter): + return ( + extract_optional(parameter.annotation).__qualname__ + if hasattr(extract_optional(parameter.annotation), "__qualname__") + else None + ) @staticmethod def extract_module_doc(func_doc): @@ -293,10 +302,8 @@ def __init__( super().__init__(callable_obj, name, tags, version, type) self.debug_description = debug_description - def extract_parameters(self, callable_obj): - parameters = unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj)) - - return {p.name: p for p in parameters} + def extract_parameters(self, callable_obj) -> List[FunctionArgument]: + return unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj)) def to_json(self): json = super().to_json() @@ -346,10 +353,8 @@ def __init__( else: self.column_type = None - def extract_parameters(self, callable_obj): - parameters = unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj)[1:]) - - return {p.name: p for p in parameters} + def extract_parameters(self, callable_obj) -> List[FunctionArgument]: + return unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj)[1:]) def to_json(self): json = super().to_json() @@ -373,25 +378,37 @@ def init_from_json(self, json: Dict[str, Any]): SMT = TypeVar("SMT", bound=SavableMeta) -def unknown_annotations_to_kwargs(parameters: List[inspect.Parameter]) -> List[inspect.Parameter]: +def unknown_annotations_to_kwargs(parameters: List[FunctionArgument]) -> List[FunctionArgument]: from giskard.models.base import BaseModel from giskard.datasets.base import Dataset from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction from giskard.ml_worker.testing.registry.transformation_function import TransformationFunction allowed_types = [str, bool, int, float, BaseModel, Dataset, SlicingFunction, TransformationFunction] - allowed_types = allowed_types + list(map(lambda x: Optional[x], allowed_types)) + allowed_types = list(map(lambda x: x.__qualname__, allowed_types)) + + kwargs = [param for param in parameters if not any([param.type == allowed_type for allowed_type in allowed_types])] - has_kwargs = any( - [param for param in parameters if not any([param.annotation == allowed_type for allowed_type in allowed_types])] - ) + parameters = [param for param in parameters if any([param.type == allowed_type for allowed_type in allowed_types])] - parameters = [ - param for param in parameters if any([param.annotation == allowed_type for allowed_type in allowed_types]) - ] + for idx, parameter in enumerate(parameters): + parameter.argOrder = idx - if has_kwargs: - parameters.append(inspect.Parameter(name="kwargs", kind=4, annotation=Kwargs)) + if any(kwargs) > 0: + kwargs_with_default = [param for param in kwargs if param.default != inspect.Parameter.empty] + default_value = ( + dict({param.name: param.default for param in kwargs_with_default}) if any(kwargs_with_default) else None + ) + + parameters.append( + FunctionArgument( + name="kwargs", + type="Kwargs", + default=default_value, + optional=len(kwargs_with_default) == len(kwargs), + argOrder=len(parameters), + ) + ) return parameters diff --git a/giskard/testing/tests/performance.py b/giskard/testing/tests/performance.py index c94a561381..9edaf3bdc8 100644 --- a/giskard/testing/tests/performance.py +++ b/giskard/testing/tests/performance.py @@ -1,7 +1,8 @@ """Performance tests""" -import inspect from typing import Optional +import inspect + import numpy as np import pandas as pd from sklearn.metrics import ( @@ -22,7 +23,10 @@ from giskard.ml_worker.testing.utils import Direction, check_slice_not_empty from giskard.models.base import BaseModel from giskard.models.utils import np_type_to_native -from giskard.testing.tests.debug_slicing_functions import incorrect_rows_slicing_fn, nlargest_abs_err_rows_slicing_fn +from giskard.testing.tests.debug_slicing_functions import ( + incorrect_rows_slicing_fn, + nlargest_abs_err_rows_slicing_fn, +) from . import debug_description_prefix, debug_prefix diff --git a/giskard/utils/artifacts.py b/giskard/utils/artifacts.py index a535d15b1e..a87d718eb5 100644 --- a/giskard/utils/artifacts.py +++ b/giskard/utils/artifacts.py @@ -1,6 +1,7 @@ import inspect import uuid -from typing import Any, Optional, Union +from enum import Enum +from typing import Any, Optional, Union, Dict try: from types import NoneType @@ -18,6 +19,13 @@ def _serialize_artifact(artifact, artifact_uuid: Optional[Union[str, uuid.UUID]] return str(artifact_uuid) +def repr_parameter(value: Any) -> str: + if isinstance(value, Enum): + return repr(value.value) + + return repr(value) + + def serialize_parameter(default_value: Any) -> PRIMITIVES: if default_value == inspect.Parameter.empty: return None @@ -25,6 +33,9 @@ def serialize_parameter(default_value: Any) -> PRIMITIVES: if isinstance(default_value, PRIMITIVES.__args__): return default_value + if isinstance(default_value, Dict): + return "\n".join(f"kwargs[{repr(key)}] = {repr_parameter(value)}" for key, value in default_value.items()) + from ..ml_worker.core.savable import Artifact if isinstance(default_value, Artifact):