diff --git a/tests/test_tensor/test_optimizers.py b/tests/test_tensor/test_optimizers.py index b2bb3f97..327b8cbd 100644 --- a/tests/test_tensor/test_optimizers.py +++ b/tests/test_tensor/test_optimizers.py @@ -16,6 +16,11 @@ found_jax = importlib.util.find_spec("jax") is not None found_tensorflow = importlib.util.find_spec("tensorflow") is not None +if found_tensorflow: + import tensorflow.experimental.numpy as tnp + + tnp.experimental_enable_numpy_behavior() + jax_case = pytest.param( "jax", marks=pytest.mark.skipif(not found_jax, reason="jax not installed") )