Skip to content

Commit

Permalink
Add captured leaked tensor case.
Browse files Browse the repository at this point in the history
Hybrid case. See #281.
  • Loading branch information
khatchad committed Nov 2, 2023
1 parent 230f276 commit eae5ac4
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# From https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values

import tensorflow as tf
from nose.tools import assert_raises


@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor


correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)


@tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b


with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow==2.9.3
nose==1.3.7
Original file line number Diff line number Diff line change
Expand Up @@ -5231,6 +5231,56 @@ public void testPythonSideEffects46() throws Exception {
assertEquals(Collections.emptySet(), function.getTransformations());
}

@Test
public void testPythonSideEffects47() throws Exception {
Function leakyFunction = getFunction("leaky_function");

assertTrue(leakyFunction.isHybrid());
assertTrue(leakyFunction.getLikelyHasTensorParameter());
assertTrue(leakyFunction.getHasPythonSideEffects());

assertFalse("P2 \"failure.\"", leakyFunction.getStatus().isOK());
assertEquals(
"Should have one warning and one error. The warning is for running a hybrid function that has side-effects. The error is that it is already \"optimal\".",
2, leakyFunction.getStatus().getEntries().length);
assertEquals(RefactoringStatus.ERROR, leakyFunction.getStatus().getEntryWithHighestSeverity().getSeverity());
assertEquals(PreconditionFailure.ALREADY_OPTIMAL.getCode(), leakyFunction.getStatus().getEntryWithHighestSeverity().getCode());
assertNotNull(leakyFunction.getStatus().getEntryMatchingSeverity(RefactoringStatus.WARNING));

assertEquals(Refactoring.OPTIMIZE_HYBRID_FUNCTION, leakyFunction.getRefactoring());
assertNull(leakyFunction.getPassingPrecondition());
assertTrue(leakyFunction.getTransformations().isEmpty());

Function capturesLeakedTensor = getFunction("captures_leaked_tensor");

assertTrue(capturesLeakedTensor.isHybrid());
assertTrue(capturesLeakedTensor.getLikelyHasTensorParameter());

// NOTE: This function doesn't have Python side-effects, but it does capture a "leaky" tensor. See
// https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281.
assertFalse(capturesLeakedTensor.getHasPythonSideEffects());

assertFalse(capturesLeakedTensor.getStatus().isOK());
assertTrue(capturesLeakedTensor.getStatus().hasError());
assertFalse(capturesLeakedTensor.getStatus().hasFatalError());
RefactoringStatusEntry error = capturesLeakedTensor.getStatus().getEntryMatchingSeverity(RefactoringStatus.ERROR);
assertEquals(PreconditionFailure.ALREADY_OPTIMAL.getCode(), error.getCode());

assertNotNull(capturesLeakedTensor.getRefactoring());
assertEquals("P2 \"failure.\"", Refactoring.OPTIMIZE_HYBRID_FUNCTION, capturesLeakedTensor.getRefactoring());
assertNull(capturesLeakedTensor.getPassingPrecondition());
assertTrue(capturesLeakedTensor.getTransformations().isEmpty());

// NOTE: Change to assertTrue when https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281 is fixed.
assertFalse("We should warn about this.", capturesLeakedTensor.getStatus().hasWarning());

RefactoringStatusEntry warning = capturesLeakedTensor.getStatus().getEntryMatchingSeverity(RefactoringStatus.WARNING);
// NOTE: Change to assertNotNull when https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281 is fixed.
// NOTE: Add assertEquals(RefactoringStatus.WARNING, entry.getSeverity()) when
// https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281 is fixed.
assertNull("Warn about a hybrid function that leaks as a potential tensor.", warning);
}

@Test
public void testPythonSideEffects49() throws Exception {
Function function = getFunction("leaky_function");
Expand Down

0 comments on commit eae5ac4

Please sign in to comment.