Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GSK-1911 Improve performance of the push feature & make CTA almost instant #1512

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions giskard/ml_worker/utils/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import hashlib


class SimpleCache:
def __init__(self, max_results=20):
self.results = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add typing hint here, for better readability ?

self.order = [] # To maintain the order in which results were added
self.max_results = max_results

def add_result(self, obj, result):
# Calculate the hash of the object
obj_hash = hashlib.md5(repr(obj).encode()).hexdigest()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add encoding explicitely to encode pls ?


# Store the result with the object's hash as the key
self.results[obj_hash] = result

# Add the object's hash to the order list and remove the oldest result if necessary
self.order.append(obj_hash)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you should use LRU (least recently used) algorithm here, instead of oldest.

if len(self.order) > self.max_results:
oldest_obj_hash = self.order.pop(0)
del self.results[oldest_obj_hash]

def get_result(self, obj):
# Calculate the hash of the object
obj_hash = hashlib.md5(repr(obj).encode()).hexdigest()

# Return the result for the given object if it exists
return self.results.get(obj_hash)
98 changes: 56 additions & 42 deletions giskard/ml_worker/websocket/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from giskard.ml_worker.testing.registry.giskard_test import GiskardTest
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction
from giskard.ml_worker.testing.registry.transformation_function import TransformationFunction
from giskard.ml_worker.utils.cache import SimpleCache
from giskard.ml_worker.utils.file_utils import get_file_name
from giskard.ml_worker.websocket import CallToActionKind, GetInfoParam, PushKind
from giskard.ml_worker.websocket.action import MLWorkerAction
Expand All @@ -54,11 +55,9 @@
from giskard.utils.analytics_collector import analytics

logger = logging.getLogger(__name__)


push_cache = SimpleCache(max_results=20)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you write it as PUSH_CACHE instead ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, 20 is not enough I think. Maybe 128 ?

MAX_STOMP_ML_WORKER_REPLY_SIZE = 1500


@dataclass
class MLWorkerInfo:
id: str
Expand Down Expand Up @@ -681,46 +680,13 @@ def get_push(
object_uuid = ""
object_params = {}
project_key = params.model.project_key
try:
model = BaseModel.download(client, params.model.project_key, params.model.id)
dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id)

df = pd.DataFrame.from_records([r.columns for r in params.dataframe.rows])
if params.column_dtypes:
for missing_column in [
column_name for column_name in params.column_dtypes.keys() if column_name not in df.columns
]:
df[missing_column] = np.nan
df = Dataset.cast_column_to_dtypes(df, params.column_dtypes)

except ValueError as e:
if "unsupported pickle protocol" in str(e):
raise ValueError(
"Unable to unpickle object, "
"Make sure that Python version of client code is the same as the Python version in ML Worker."
"To change Python version, please refer to https://docs.giskard.ai/start/guides/configuration"
f"\nOriginal Error: {e}"
) from e
raise e
except ModuleNotFoundError as e:
raise GiskardException(
f"Failed to import '{e.name}'. "
f"Make sure it's installed in the ML Worker environment."
"To have more information on ML Worker, please see: https://docs.giskard.ai/start/guides/installation/ml-worker"
) from e

# if df is empty, return early
if df.empty:
return

from giskard.push.contribution import create_contribution_push
from giskard.push.perturbation import create_perturbation_push
from giskard.push.prediction import create_borderline_push, create_overconfidence_push

contribs = create_contribution_push(model, dataset, df)
perturbs = create_perturbation_push(model, dataset, df)
overconf = create_overconfidence_push(model, dataset, df)
borderl = create_borderline_push(model, dataset, df)
cached_result_tuple = push_cache.get_result(params)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not using it on the overall get_push instead ?

if cached_result_tuple is None:
contribs, perturbs, overconf, borderl = get_push_objects(client, params)
push_cache.add_result(params, (contribs, perturbs, overconf, borderl))
else:
contribs, perturbs, overconf, borderl = cached_result_tuple

contrib_ws = push_to_ws(contribs)
perturb_ws = push_to_ws(perturbs)
Expand Down Expand Up @@ -790,3 +756,51 @@ def get_push(

def push_to_ws(push: Push):
return push.to_ws() if push is not None else None


def get_push_objects(client: Optional[GiskardClient], params: websocket.GetPushParam):
try:
model = BaseModel.download(client, params.model.project_key, params.model.id)
dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id)

df = pd.DataFrame.from_records([r.columns for r in params.dataframe.rows])
if params.column_dtypes:
for missing_column in [
column_name for column_name in params.column_dtypes.keys() if column_name not in df.columns
]:
df[missing_column] = np.nan
df = Dataset.cast_column_to_dtypes(df, params.column_dtypes)

except ValueError as e:
if "unsupported pickle protocol" in str(e):
raise ValueError(
"Unable to unpickle object, "
"Make sure that Python version of client code is the same as the Python version in ML Worker."
"To change Python version, please refer to https://docs.giskard.ai/start/guides/configuration"
f"\nOriginal Error: {e}"
) from e
raise e
except ModuleNotFoundError as e:
raise GiskardException(
f"Failed to import '{e.name}'. "
f"Make sure it's installed in the ML Worker environment."
"To have more information on ML Worker, please see: https://docs.giskard.ai/start/guides/installation/ml-worker"
) from e

# if df is empty, return early
if df.empty:
return None, None, None, None

from giskard.push.contribution import create_contribution_push
from giskard.push.perturbation import create_perturbation_push
from giskard.push.prediction import create_borderline_push, create_overconfidence_push

push_functions = {
None: (create_contribution_push, create_perturbation_push, create_overconfidence_push, create_borderline_push),
PushKind.CONTRIBUTION: (create_contribution_push, None, None, None),
PushKind.PERTURBATION: (None, create_perturbation_push, None, None),
PushKind.OVERCONFIDENCE: (None, None, create_overconfidence_push, None),
PushKind.BORDERLINE: (None, None, None, create_borderline_push)
}

return (func(model, dataset, df) if func else None for func in push_functions[params.push_kind])
5 changes: 3 additions & 2 deletions giskard/push/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ def _borderline(self):
"push_title": "This example was predicted with very low confidence",
"details": [
{
"action": "Generate a one-sample to automatically test the underconfidence",
"explanation": "This may help you ensure this example is not predicted with low confidence for a new model",
"action": "Generate a one-sample to check for underconfidence",
"explanation": "This may help you ensure this specific example is not predicted with low "
"confidence for a new model",
"button": "Create one-sample test",
"cta": CallToActionKind.CREATE_TEST,
},
Expand Down