Skip to content

Commit

Permalink
Add eager case for #281.
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Nov 2, 2023
1 parent eae5ac4 commit 699177b
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# From https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values

import tensorflow as tf
from nose.tools import assert_raises


@tf.function
def leaky_function(a):
global x
x = a + 1 # Bad - leaks local tensor
return x # Good - uses local tensor


correct_a = leaky_function(tf.constant(1))

print(correct_a.numpy()) # Good - value obtained from function's returns
try:
x.numpy() # Bad - tensor leaked from inside the function, cannot be used here
except AttributeError as expected:
print(expected)


# @tf.function
def captures_leaked_tensor(b):
b += x # Bad - `x` is leaked from `leaky_function`
return b


with assert_raises(TypeError):
captures_leaked_tensor(tf.constant(2))
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow==2.9.3
nose==1.3.7
Original file line number Diff line number Diff line change
Expand Up @@ -5266,6 +5266,9 @@ public void testPythonSideEffects47() throws Exception {
RefactoringStatusEntry error = capturesLeakedTensor.getStatus().getEntryMatchingSeverity(RefactoringStatus.ERROR);
assertEquals(PreconditionFailure.ALREADY_OPTIMAL.getCode(), error.getCode());

// NOTE: Change to assertTrue once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281 is fixed.
assertFalse("We should warn that the hybrid function is capturing leaked tensors.", capturesLeakedTensor.getStatus().hasWarning());

assertNotNull(capturesLeakedTensor.getRefactoring());
assertEquals("P2 \"failure.\"", Refactoring.OPTIMIZE_HYBRID_FUNCTION, capturesLeakedTensor.getRefactoring());
assertNull(capturesLeakedTensor.getPassingPrecondition());
Expand All @@ -5281,6 +5284,53 @@ public void testPythonSideEffects47() throws Exception {
assertNull("Warn about a hybrid function that leaks as a potential tensor.", warning);
}

@Test
public void testPythonSideEffects48() throws Exception {
Function leakyFunction = getFunction("leaky_function");

assertTrue(leakyFunction.isHybrid());
assertTrue(leakyFunction.getLikelyHasTensorParameter());
assertTrue(leakyFunction.getHasPythonSideEffects());

// P2 "failure."
assertFalse(leakyFunction.getStatus().isOK());
assertEquals(
"Should have one warning and one error. The warning is for running a hybrid function that has side-effects. The error is that it is already \"optimal\".",
2, leakyFunction.getStatus().getEntries().length);
assertEquals(RefactoringStatus.ERROR, leakyFunction.getStatus().getEntryWithHighestSeverity().getSeverity());
assertEquals(PreconditionFailure.ALREADY_OPTIMAL.getCode(), leakyFunction.getStatus().getEntryWithHighestSeverity().getCode());
assertNotNull(leakyFunction.getStatus().getEntryMatchingSeverity(RefactoringStatus.WARNING));

Function capturesLeakedTensor = getFunction("captures_leaked_tensor");

assertFalse(capturesLeakedTensor.isHybrid());
assertTrue(capturesLeakedTensor.getLikelyHasTensorParameter());

// NOTE: This function doesn't have Python side-effects, but it does capture a "leaky" tensor. See
// https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281.
assertFalse(capturesLeakedTensor.getHasPythonSideEffects());

// NOTE: Change to assertFalse once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281 is fixed.
assertTrue("Passes P1.", capturesLeakedTensor.getStatus().isOK());

assertFalse(capturesLeakedTensor.getStatus().hasWarning());
// NOTE: Change to assertTrue once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281 is fixed.
assertFalse(capturesLeakedTensor.getStatus().hasError());

assertNotNull(capturesLeakedTensor.getRefactoring());
assertEquals("We shouldn't refactor this but we do currently. Nevertheless, the refactoring kind should remain intact.",
Refactoring.CONVERT_EAGER_FUNCTION_TO_HYBRID, capturesLeakedTensor.getRefactoring());

// NOTE: Change to assertNull once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281 is fixed.
assertNotNull(capturesLeakedTensor.getPassingPrecondition());
assertEquals("We really shouldn't refactor this.", capturesLeakedTensor.getPassingPrecondition(), PreconditionSuccess.P1);

// NOTE: Change to assertTrue once https://github.com/ponder-lab/Hybridize-Functions-Refactoring/issues/281 is fixed.
assertFalse(capturesLeakedTensor.getTransformations().isEmpty());
assertEquals("We really shouldn't transform this.", Collections.singleton(Transformation.CONVERT_TO_HYBRID),
capturesLeakedTensor.getTransformations());
}

@Test
public void testPythonSideEffects49() throws Exception {
Function function = getFunction("leaky_function");
Expand Down

0 comments on commit 699177b

Please sign in to comment.