diff --git a/tests/trace/test_op_versioning.py b/tests/trace/test_op_versioning.py index a8ed18ce780..0235ef4c896 100644 --- a/tests/trace/test_op_versioning.py +++ b/tests/trace/test_op_versioning.py @@ -1,3 +1,4 @@ +import re import typing import numpy as np @@ -508,3 +509,46 @@ def some_d(v): print(saved_code) assert saved_code == EXPECTED_NO_REPEATS_CODE + + +EXPECTED_INSTANCE_CODE = """import weave + +instance = ".MyClass object at 0x000000000>" + +@weave.op() +def t(text: str): + print(instance._version) + return text +""" + + +def test_op_instance(client): + class MyClass: + _version: str + api_key: str + + def __init__(self, secret: str) -> None: + self._version = "1.0.0" + self.api_key = secret + + # We want to make sure this secret value is not saved in the code + instance = MyClass("sk-1234567890qwertyuiop") + + @weave.op() + def t(text: str): + print(instance._version) + return text + + t("hello") + + ref = weave.obj_ref(t) + assert ref is not None + + saved_code = get_saved_code(client, ref) + print("SAVED CODE") + print(saved_code) + + # Instance address expected to change each run + clean_saved_code = re.sub(r"0x[0-9a-fA-F]+", "0x000000000", saved_code) + + assert clean_saved_code == EXPECTED_INSTANCE_CODE