Skip to content

Commit

Permalink
Made kwargs default value persistent in the Hub
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmessiaen authored and Hartorn committed Oct 31, 2023
1 parent 949c1a8 commit 9130b3a
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
71 changes: 40 additions & 31 deletions giskard/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,26 +169,27 @@ def __init__(
self.tags = self.populate_tags(tags)

parameters = self.extract_parameters(callable_obj)
for param in parameters:
param.default = serialize_parameter(param.default)

self.args = {
parameter.name: FunctionArgument(
name=parameter.name,
type=extract_optional(parameter.annotation).__qualname__,
optional=parameter.default != inspect.Parameter.empty,
default=serialize_parameter(parameter.default),
argOrder=idx,
)
for idx, parameter in enumerate(parameters.values())
if name != "self"
}
self.args = {param.name: param for param in parameters}

def extract_parameters(self, callable_obj):
def extract_parameters(self, callable_obj) -> List[FunctionArgument]:
if inspect.isclass(callable_obj):
parameters = list(inspect.signature(callable_obj.__init__).parameters.values())[1:]
else:
parameters = list(inspect.signature(callable_obj).parameters.values())

return parameters
return [
FunctionArgument(
name=parameter.name,
type=extract_optional(parameter.annotation).__qualname__,
optional=parameter.default != inspect.Parameter.empty,
default=parameter.default,
argOrder=idx,
)
for idx, parameter in enumerate(parameters)
]

@staticmethod
def extract_module_doc(func_doc):
Expand Down Expand Up @@ -293,10 +294,8 @@ def __init__(
super().__init__(callable_obj, name, tags, version, type)
self.debug_description = debug_description

def extract_parameters(self, callable_obj):
parameters = unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj))

return {p.name: p for p in parameters}
def extract_parameters(self, callable_obj) -> List[FunctionArgument]:
return unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj))

def to_json(self):
json = super().to_json()
Expand Down Expand Up @@ -346,10 +345,8 @@ def __init__(
else:
self.column_type = None

def extract_parameters(self, callable_obj):
parameters = unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj)[1:])

return {p.name: p for p in parameters}
def extract_parameters(self, callable_obj) -> List[FunctionArgument]:
return unknown_annotations_to_kwargs(CallableMeta.extract_parameters(self, callable_obj)[1:])

def to_json(self):
json = super().to_json()
Expand All @@ -373,25 +370,37 @@ def init_from_json(self, json: Dict[str, Any]):
SMT = TypeVar("SMT", bound=SavableMeta)


def unknown_annotations_to_kwargs(parameters: List[inspect.Parameter]) -> List[inspect.Parameter]:
def unknown_annotations_to_kwargs(parameters: List[FunctionArgument]) -> List[FunctionArgument]:
from giskard.models.base import BaseModel
from giskard.datasets.base import Dataset
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction
from giskard.ml_worker.testing.registry.transformation_function import TransformationFunction

allowed_types = [str, bool, int, float, BaseModel, Dataset, SlicingFunction, TransformationFunction]
allowed_types = allowed_types + list(map(lambda x: Optional[x], allowed_types))
allowed_types = list(map(lambda x: x.__qualname__, allowed_types))

has_kwargs = any(
[param for param in parameters if not any([param.annotation == allowed_type for allowed_type in allowed_types])]
)
kwargs = [param for param in parameters if not any([param.type == allowed_type for allowed_type in allowed_types])]

parameters = [
param for param in parameters if any([param.annotation == allowed_type for allowed_type in allowed_types])
]
parameters = [param for param in parameters if any([param.type == allowed_type for allowed_type in allowed_types])]

if has_kwargs:
parameters.append(inspect.Parameter(name="kwargs", kind=4, annotation=Kwargs))
for idx, parameter in enumerate(parameters):
parameter.argOrder = idx

if any(kwargs) > 0:
kwargs_with_default = [param for param in kwargs if param.default != inspect.Parameter.empty]
default_value = (
dict({param.name: param.default for param in kwargs_with_default}) if any(kwargs_with_default) else None
)

parameters.append(
FunctionArgument(
name="kwargs",
type="Kwargs",
default=default_value,
optional=len(kwargs_with_default) == len(kwargs),
argOrder=len(parameters),
)
)

return parameters

Expand Down
14 changes: 9 additions & 5 deletions giskard/testing/tests/performance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Performance tests"""
import inspect
from typing import Optional

import inspect

import numpy as np
import pandas as pd
from sklearn.metrics import (
Expand All @@ -22,7 +23,10 @@
from giskard.ml_worker.testing.utils import Direction, check_slice_not_empty
from giskard.models.base import BaseModel
from giskard.models.utils import np_type_to_native
from giskard.testing.tests.debug_slicing_functions import incorrect_rows_slicing_fn, nlargest_abs_err_rows_slicing_fn
from giskard.testing.tests.debug_slicing_functions import (
incorrect_rows_slicing_fn,
nlargest_abs_err_rows_slicing_fn,
)

from . import debug_description_prefix, debug_prefix

Expand Down Expand Up @@ -149,11 +153,11 @@ def _test_diff_prediction(
" reference_dataset is equal to zero"
)

if direction == Direction.Invariant:
if direction == Direction.Invariant or direction == Direction.Invariant.value:
passed = abs(rel_change) < threshold
elif direction == Direction.Decreasing:
elif direction == Direction.Decreasing or direction == Direction.Decreasing.value:
passed = rel_change < threshold
elif direction == Direction.Increasing:
elif direction == Direction.Increasing or direction == Direction.Increasing.value:
passed = rel_change > threshold
else:
raise ValueError(f"Invalid direction: {direction}")
Expand Down
13 changes: 12 additions & 1 deletion giskard/utils/artifacts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import uuid
from typing import Any, Optional, Union
from enum import Enum
from typing import Any, Optional, Union, Dict

try:
from types import NoneType
Expand All @@ -18,13 +19,23 @@ def _serialize_artifact(artifact, artifact_uuid: Optional[Union[str, uuid.UUID]]
return str(artifact_uuid)


def repr_parameter(value: Any) -> str:
if isinstance(value, Enum):
return repr(value.value)

return repr(value)


def serialize_parameter(default_value: Any) -> PRIMITIVES:
if default_value == inspect.Parameter.empty:
return None

if isinstance(default_value, PRIMITIVES.__args__):
return default_value

if isinstance(default_value, Dict):
return "\n".join(f"kwargs[{repr(key)}] = {repr_parameter(value)}" for key, value in default_value.items())

from ..ml_worker.core.savable import Artifact

if isinstance(default_value, Artifact):
Expand Down

0 comments on commit 9130b3a

Please sign in to comment.