-
-
Notifications
You must be signed in to change notification settings - Fork 282
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GSK-1911 Improve performance of the push feature & make CTA almost instant #1512
Changes from 3 commits
0d81c3b
88b2100
5291254
838911d
b8eff96
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import hashlib | ||
|
||
|
||
class SimpleCache: | ||
def __init__(self, max_results=20): | ||
self.results = {} | ||
self.order = [] # To maintain the order in which results were added | ||
self.max_results = max_results | ||
|
||
def add_result(self, obj, result): | ||
# Calculate the hash of the object | ||
obj_hash = hashlib.md5(repr(obj).encode()).hexdigest() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add encoding explicitely to encode pls ? |
||
|
||
# Store the result with the object's hash as the key | ||
self.results[obj_hash] = result | ||
|
||
# Add the object's hash to the order list and remove the oldest result if necessary | ||
self.order.append(obj_hash) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you should use LRU (least recently used) algorithm here, instead of oldest. |
||
if len(self.order) > self.max_results: | ||
oldest_obj_hash = self.order.pop(0) | ||
del self.results[oldest_obj_hash] | ||
|
||
def get_result(self, obj): | ||
# Calculate the hash of the object | ||
obj_hash = hashlib.md5(repr(obj).encode()).hexdigest() | ||
|
||
# Return the result for the given object if it exists | ||
return self.results.get(obj_hash) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,7 @@ | |
from giskard.ml_worker.testing.registry.giskard_test import GiskardTest | ||
from giskard.ml_worker.testing.registry.slicing_function import SlicingFunction | ||
from giskard.ml_worker.testing.registry.transformation_function import TransformationFunction | ||
from giskard.ml_worker.utils.cache import SimpleCache | ||
from giskard.ml_worker.utils.file_utils import get_file_name | ||
from giskard.ml_worker.websocket import CallToActionKind, GetInfoParam, PushKind | ||
from giskard.ml_worker.websocket.action import MLWorkerAction | ||
|
@@ -54,11 +55,9 @@ | |
from giskard.utils.analytics_collector import analytics | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
push_cache = SimpleCache(max_results=20) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you write it as PUSH_CACHE instead ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, 20 is not enough I think. Maybe 128 ? |
||
MAX_STOMP_ML_WORKER_REPLY_SIZE = 1500 | ||
|
||
|
||
@dataclass | ||
class MLWorkerInfo: | ||
id: str | ||
|
@@ -681,46 +680,13 @@ def get_push( | |
object_uuid = "" | ||
object_params = {} | ||
project_key = params.model.project_key | ||
try: | ||
model = BaseModel.download(client, params.model.project_key, params.model.id) | ||
dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id) | ||
|
||
df = pd.DataFrame.from_records([r.columns for r in params.dataframe.rows]) | ||
if params.column_dtypes: | ||
for missing_column in [ | ||
column_name for column_name in params.column_dtypes.keys() if column_name not in df.columns | ||
]: | ||
df[missing_column] = np.nan | ||
df = Dataset.cast_column_to_dtypes(df, params.column_dtypes) | ||
|
||
except ValueError as e: | ||
if "unsupported pickle protocol" in str(e): | ||
raise ValueError( | ||
"Unable to unpickle object, " | ||
"Make sure that Python version of client code is the same as the Python version in ML Worker." | ||
"To change Python version, please refer to https://docs.giskard.ai/start/guides/configuration" | ||
f"\nOriginal Error: {e}" | ||
) from e | ||
raise e | ||
except ModuleNotFoundError as e: | ||
raise GiskardException( | ||
f"Failed to import '{e.name}'. " | ||
f"Make sure it's installed in the ML Worker environment." | ||
"To have more information on ML Worker, please see: https://docs.giskard.ai/start/guides/installation/ml-worker" | ||
) from e | ||
|
||
# if df is empty, return early | ||
if df.empty: | ||
return | ||
|
||
from giskard.push.contribution import create_contribution_push | ||
from giskard.push.perturbation import create_perturbation_push | ||
from giskard.push.prediction import create_borderline_push, create_overconfidence_push | ||
|
||
contribs = create_contribution_push(model, dataset, df) | ||
perturbs = create_perturbation_push(model, dataset, df) | ||
overconf = create_overconfidence_push(model, dataset, df) | ||
borderl = create_borderline_push(model, dataset, df) | ||
cached_result_tuple = push_cache.get_result(params) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not using it on the overall get_push instead ? |
||
if cached_result_tuple is None: | ||
contribs, perturbs, overconf, borderl = get_push_objects(client, params) | ||
push_cache.add_result(params, (contribs, perturbs, overconf, borderl)) | ||
else: | ||
contribs, perturbs, overconf, borderl = cached_result_tuple | ||
|
||
contrib_ws = push_to_ws(contribs) | ||
perturb_ws = push_to_ws(perturbs) | ||
|
@@ -790,3 +756,51 @@ def get_push( | |
|
||
def push_to_ws(push: Push): | ||
return push.to_ws() if push is not None else None | ||
|
||
|
||
def get_push_objects(client: Optional[GiskardClient], params: websocket.GetPushParam): | ||
try: | ||
model = BaseModel.download(client, params.model.project_key, params.model.id) | ||
dataset = Dataset.download(client, params.dataset.project_key, params.dataset.id) | ||
|
||
df = pd.DataFrame.from_records([r.columns for r in params.dataframe.rows]) | ||
if params.column_dtypes: | ||
for missing_column in [ | ||
column_name for column_name in params.column_dtypes.keys() if column_name not in df.columns | ||
]: | ||
df[missing_column] = np.nan | ||
df = Dataset.cast_column_to_dtypes(df, params.column_dtypes) | ||
|
||
except ValueError as e: | ||
if "unsupported pickle protocol" in str(e): | ||
raise ValueError( | ||
"Unable to unpickle object, " | ||
"Make sure that Python version of client code is the same as the Python version in ML Worker." | ||
"To change Python version, please refer to https://docs.giskard.ai/start/guides/configuration" | ||
f"\nOriginal Error: {e}" | ||
) from e | ||
raise e | ||
except ModuleNotFoundError as e: | ||
raise GiskardException( | ||
f"Failed to import '{e.name}'. " | ||
f"Make sure it's installed in the ML Worker environment." | ||
"To have more information on ML Worker, please see: https://docs.giskard.ai/start/guides/installation/ml-worker" | ||
) from e | ||
|
||
# if df is empty, return early | ||
if df.empty: | ||
return None, None, None, None | ||
|
||
from giskard.push.contribution import create_contribution_push | ||
from giskard.push.perturbation import create_perturbation_push | ||
from giskard.push.prediction import create_borderline_push, create_overconfidence_push | ||
|
||
push_functions = { | ||
None: (create_contribution_push, create_perturbation_push, create_overconfidence_push, create_borderline_push), | ||
PushKind.CONTRIBUTION: (create_contribution_push, None, None, None), | ||
PushKind.PERTURBATION: (None, create_perturbation_push, None, None), | ||
PushKind.OVERCONFIDENCE: (None, None, create_overconfidence_push, None), | ||
PushKind.BORDERLINE: (None, None, None, create_borderline_push) | ||
} | ||
|
||
return (func(model, dataset, df) if func else None for func in push_functions[params.push_kind]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add typing hint here, for better readability ?