From 09964d1e55eb8ff699db4cf1aa38808cf825c2df Mon Sep 17 00:00:00 2001
From: Allen Lavoie <allenl@google.com>
Date: Tue, 7 Nov 2017 14:51:44 -0800
Subject: [PATCH] Fix a tfe.Network reference cycle, make it easier to check
 for reference cycles in unit tests.

PiperOrigin-RevId: 174917072
---
 tensorflow/contrib/eager/python/network.py    | 56 +++++++++++-------
 .../contrib/eager/python/network_test.py      | 31 ++++++----
 tensorflow/python/framework/test_util.py      | 57 +++++++++++++++++--
 tensorflow/python/framework/test_util_test.py | 24 ++++++++
 .../kernel_tests/garbage_collection_test.py   | 29 +---------
 5 files changed, 132 insertions(+), 65 deletions(-)

diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py
index 97feaec30ed..c6e628b074e 100644
--- a/tensorflow/contrib/eager/python/network.py
+++ b/tensorflow/contrib/eager/python/network.py
@@ -182,6 +182,40 @@ def _make_custom_getter_for_deferred_restorations():
   return _custom_getter, deferred_restorations
 
 
+def _make_prefix_stripping_map_fn(scope_name):
+  """Closure for stripping the scope name of a Network.
+
+  Implemented as a closure rather than a member function to avoid reference
+  cycles in deferred restorations (this function should not have a reference to
+  the Network which created it).
+
+  Args:
+    scope_name: The Network.scope_name to strip from variables.
+  Returns:
+    A scope_name-stripping default `map_fn` for the Network.
+  """
+
+  def _strip_variable_prefix(original_variable_name):
+    """The default map_func for saving or restoring variables.
+
+    Strips the variable prefix for the Network on which save/restore was called,
+    and leaves other variable names fully qualified in the checkpoint.
+
+    Args:
+      original_variable_name: The _shared_name of the variable (no :0
+        suffix) to map.
+    Returns:
+      The checkpoint name of the variable.
+    """
+    scope_name_with_slash = scope_name + "/"
+    if original_variable_name.startswith(scope_name_with_slash):
+      return original_variable_name[len(scope_name_with_slash):]
+    else:
+      return original_variable_name
+
+  return _strip_variable_prefix
+
+
 class Network(base.Layer):
   """Represents the composition of a set of Layers.
 
@@ -488,24 +522,6 @@ class Network(base.Layer):
         "at https://github.com/tensorflow/tensorflow/issues/new if this is "
         "important to you")
 
-  def _strip_variable_prefix(self, original_variable_name):
-    """The default map_func for saving or restoring variables.
-
-    Strips the variable prefix for the Network on which save/restore was called,
-    and leaves other variable names fully qualified in the checkpoint.
-
-    Args:
-      original_variable_name: The _shared_name of the variable (no :0
-        suffix) to map.
-    Returns:
-      The checkpoint name of the variable.
-    """
-    scope_name_with_slash = self.scope_name + "/"
-    if original_variable_name.startswith(scope_name_with_slash):
-      return original_variable_name[len(scope_name_with_slash):]
-    else:
-      return original_variable_name
-
   def save(self, save_path, global_step=None, map_func=None):
     """Save variables from the Network to a checkpoint.
 
@@ -543,7 +559,7 @@ class Network(base.Layer):
       save_path = os.path.join(save_path, self.name)
     user_map_func = map_func
     if map_func is None:
-      map_func = self._strip_variable_prefix
+      map_func = _make_prefix_stripping_map_fn(self.scope_name)
     variable_map = {}
     for variable in self.variables:
       mapped_name = map_func(variable._shared_name)
@@ -737,7 +753,7 @@ class Network(base.Layer):
       save_path = os.path.join(save_path, self.name)
     user_map_func = map_func
     if map_func is None:
-      map_func = self._strip_variable_prefix
+      map_func = _make_prefix_stripping_map_fn(self.scope_name)
     # Step one is to restore any existing variables from the checkpoint.
     existing_variables_by_checkpoint_name = self._restore_existing_variables(
         save_path=save_path,
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index c621f527c28..14adbafe573 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -67,7 +67,7 @@ class NetworkTest(test.TestCase):
         original_output,
         self.evaluate(net(input_value)))
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testTrainableAttribute(self):
     net = network.Network()
     self.assertTrue(net.trainable)
@@ -75,7 +75,7 @@ class NetworkTest(test.TestCase):
       net.trainable = False
     self.assertTrue(net.trainable)
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testNetworkCall(self):
     net = MyNetwork(name="abcd")
     net(constant_op.constant([[2.0]]))  # Force variables to be created.
@@ -85,6 +85,7 @@ class NetworkTest(test.TestCase):
     result = net(constant_op.constant([[2.0]]))
     self.assertEqual(34.0, self.evaluate(result))
 
+  # TODO(allenl): This test creates garbage in some Python versions
   @test_util.run_in_graph_and_eager_modes()
   def testNetworkSaveRestoreAlreadyBuilt(self):
     net = MyNetwork(name="abcd")
@@ -96,6 +97,7 @@ class NetworkTest(test.TestCase):
     self._save_modify_load_network_built(net, global_step=None)
     self._save_modify_load_network_built(net, global_step=10)
 
+  # TODO(allenl): This test creates garbage in some Python versions
   @test_util.run_in_graph_and_eager_modes()
   def testSaveRestoreDefaultGlobalStep(self):
     net = MyNetwork(name="abcd")
@@ -106,6 +108,7 @@ class NetworkTest(test.TestCase):
     save_path = net.save(self.get_temp_dir())
     self.assertIn("abcd-4242", save_path)
 
+  # TODO(allenl): This test creates garbage in some Python versions
   @test_util.run_in_graph_and_eager_modes()
   def testNetworkSaveAndRestoreIntoUnbuilt(self):
     save_dir = self.get_temp_dir()
@@ -377,25 +380,25 @@ class NetworkTest(test.TestCase):
     gc.set_debug(previous_gc_debug_flags)
     gc.enable()
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testAnonymousNoNameInitially(self):
     net = MyNetwork()
     with self.assertRaisesRegexp(ValueError, "does not yet have a final name"):
       net.name  # pylint: disable=pointless-statement
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testExplicitHasNameInitially(self):
     net = MyNetwork(name="abcd")
     self.assertEqual("abcd", net.name)
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testUsingResourceVariables(self):
     net = MyNetwork()
     net(constant_op.constant([[0.]]))
     self.assertIsInstance(net.trainable_weights[0],
                           resource_variable_ops.ResourceVariable)
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testDuplicateNameError(self):
     one = constant_op.constant([[1.]])
     net = MyNetwork(name="foo")
@@ -405,7 +408,7 @@ class NetworkTest(test.TestCase):
       net1 = MyNetwork(name="foo")
       net1(one)
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testWrappingInVariableScope(self):
     with variable_scope.variable_scope("outside_scope"):
       net = MyNetwork()
@@ -440,7 +443,7 @@ class NetworkTest(test.TestCase):
                           actual=net.trainable_weights[0].name)
     self.assertEqual("explicit_name", net.first.name)
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testWrappingInAnonymousVariableScope(self):
     # Named outside variable_scopes are not supported at the moment. However,
     # blank-named top level variable scopes do not change variable names, and so
@@ -455,20 +458,20 @@ class NetworkTest(test.TestCase):
       net(one)
     self.assertTrue(was_called[0])
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testReasonableSlashError(self):
     with self.assertRaisesRegexp(
         ValueError, "not allowed in Network names"):
       MyNetwork(name="slash/slash")
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testNoVariableScopeNames(self):
     with self.assertRaisesRegexp(
         ValueError, "VariableScopes are not valid Network names"):
       with variable_scope.variable_scope("some_scope") as vs:
         MyNetwork(name=vs)
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testVariableScopeNameCollision(self):
     with variable_scope.variable_scope("abcd"):
       pass
@@ -478,7 +481,7 @@ class NetworkTest(test.TestCase):
       one = constant_op.constant([[1.]])
       net(one)
 
-  @test_util.run_in_graph_and_eager_modes()
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testNetworkVariablesDoNotInterfere(self):
     core.Dense(1, use_bias=True)  # Should not interfere with naming.
     net1 = MyNetwork()
@@ -1007,6 +1010,7 @@ class NetworkTest(test.TestCase):
 
 class SequentialTest(test.TestCase):
 
+  @test_util.assert_no_garbage_created
   def testTwoLayers(self):
     # Create a sequential network with one layer.
     net = network.Sequential([core.Dense(1, use_bias=False)])
@@ -1028,6 +1032,7 @@ class SequentialTest(test.TestCase):
     l2.trainable_variables[0].assign([[11.0]])
     self.assertEqual(231.0, net(constant_op.constant([[7.0]])).numpy())
 
+  @test_util.assert_no_garbage_created
   def testFunctions(self):
     # Create a sequential network with one function.
     net = network.Sequential([nn_ops.relu])
@@ -1038,6 +1043,7 @@ class SequentialTest(test.TestCase):
     net.add(math_ops.negative)
     self.assertEqual(-2.0, net(two).numpy())
 
+  @test_util.assert_no_garbage_created
   def testTrainingLayer(self):
     net = network.Sequential([core.Dropout(0.99999)])
     two = constant_op.constant(2.0)
@@ -1051,6 +1057,7 @@ class SequentialTest(test.TestCase):
     # Should only fail spuriously 1 in 10^100 runs.
     self.fail("Didn't see dropout happen after 20 tries.")
 
+  @test_util.assert_no_garbage_created
   def testTrainingFunction(self):
     # Output depends on value of "training".
     def add_training(input_value, training=None):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index dbe9a2421c9..6e3a35af3cd 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -20,6 +20,7 @@ from __future__ import division
 from __future__ import print_function
 
 import contextlib
+import gc
 import math
 import random
 import re
@@ -452,9 +453,43 @@ class IsolateTest(object):
         type_arg, value_arg, traceback_arg)
 
 
-def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
-                                 use_gpu=False, force_gpu=False,
-                                 reset_test=True):
+def assert_no_garbage_created(f):
+  """Test method decorator to assert that no garbage has been created.
+
+  Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
+  cannot be un-set (i.e. will disable garbage collection for any other unit
+  tests in the same file/shard).
+
+  Args:
+    f: The function to decorate.
+  Returns:
+    The decorated function.
+  """
+
+  def decorator(self, **kwargs):
+    """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
+    gc.disable()
+    previous_debug_flags = gc.get_debug()
+    gc.set_debug(gc.DEBUG_SAVEALL)
+    gc.collect()
+    previous_garbage = len(gc.garbage)
+    f(self, **kwargs)
+    gc.collect()
+    # This will fail if any garbage has been created, typically because of a
+    # reference cycle.
+    self.assertEqual(previous_garbage, len(gc.garbage))
+    # TODO(allenl): Figure out why this debug flag reset doesn't work. It would
+    # be nice to be able to decorate arbitrary tests in a large test suite and
+    # not hold on to every object in other tests.
+    gc.set_debug(previous_debug_flags)
+    gc.enable()
+  return decorator
+
+
+def run_in_graph_and_eager_modes(
+    __unused__=None, graph=None, config=None,
+    use_gpu=False, force_gpu=False,
+    reset_test=True, assert_no_eager_garbage=False):
   """Runs the test in both graph and eager modes.
 
   Args:
@@ -465,7 +500,14 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
     use_gpu: If True, attempt to run as many ops as possible on GPU.
     force_gpu: If True, pin all ops to `/device:GPU:0`.
     reset_test: If True, tearDown and SetUp the test case again.
-
+    assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
+      collector and asserts that no extra garbage has been created when running
+      the test in eager mode. This will fail if there are reference cycles
+      (e.g. a = []; a.append(a)). Off by default because some tests may create
+      garbage for legitimate reasons (e.g. they define a class which inherits
+      from `object`), and because DEBUG_SAVEALL is sticky in some Python
+      interpreters (meaning that tests which rely on objects being collected
+      elsewhere in the unit test file will not work).
   Returns:
     Returns a decorator that will run the decorated test function
         using both a graph and using eager execution.
@@ -487,7 +529,7 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
         self.tearDown()
         self.setUp()
 
-      def run_eager_mode():
+      def run_eager_mode(self, **kwargs):
         if force_gpu:
           gpu_name = gpu_device_name()
           if not gpu_name:
@@ -501,9 +543,12 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
           with context.device("/device:CPU:0"):
             f(self, **kwargs)
 
+      if assert_no_eager_garbage:
+        run_eager_mode = assert_no_garbage_created(run_eager_mode)
+
       with context.eager_mode():
         with IsolateTest():
-          run_eager_mode()
+          run_eager_mode(self, **kwargs)
 
     return decorated
   return decorator
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index b2f8d62095f..1c5db945005 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -329,6 +329,30 @@ class TestUtilTest(test_util.TensorFlowTestCase):
     self.assertEqual(a_rand, b_rand)
 
 
+class GarbageCollectionTest(test_util.TensorFlowTestCase):
+
+  def test_no_reference_cycle_decorator(self):
+
+    class ReferenceCycleTest(object):
+
+      def __init__(inner_self):  # pylint: disable=no-self-argument
+        inner_self.assertEqual = self.assertEqual  # pylint: disable=invalid-name
+
+      @test_util.assert_no_garbage_created
+      def test_has_cycle(self):
+        a = []
+        a.append(a)
+
+      @test_util.assert_no_garbage_created
+      def test_has_no_cycle(self):
+        pass
+
+    with self.assertRaises(AssertionError):
+      ReferenceCycleTest().test_has_cycle()
+
+    ReferenceCycleTest().test_has_no_cycle()
+
+
 @test_util.with_c_api
 class IsolationTest(test_util.TensorFlowTestCase):
 
diff --git a/tensorflow/python/kernel_tests/garbage_collection_test.py b/tensorflow/python/kernel_tests/garbage_collection_test.py
index 24a6ee74c56..39f936fbc92 100644
--- a/tensorflow/python/kernel_tests/garbage_collection_test.py
+++ b/tensorflow/python/kernel_tests/garbage_collection_test.py
@@ -21,8 +21,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-import gc
-
 from tensorflow.python.eager import context
 from tensorflow.python.framework import dtypes
 from tensorflow.python.framework import test_util
@@ -31,37 +29,14 @@ from tensorflow.python.ops import tensor_array_ops
 from tensorflow.python.platform import test
 
 
-def assert_no_garbage_created(f):
-  """Test decorator to assert that no garbage has been created."""
-
-  def decorator(self):
-    """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
-    gc.disable()
-    previous_debug_flags = gc.get_debug()
-    gc.set_debug(gc.DEBUG_SAVEALL)
-    gc.collect()
-    previous_garbage = len(gc.garbage)
-    f(self)
-    gc.collect()
-    # This will fail if any garbage has been created, typically because of a
-    # reference cycle.
-    self.assertEqual(previous_garbage, len(gc.garbage))
-    # TODO(allenl): Figure out why this debug flag reset doesn't work. It would
-    # be nice to be able to decorate arbitrary tests in a large test suite and
-    # not hold on to every object in other tests.
-    gc.set_debug(previous_debug_flags)
-    gc.enable()
-  return decorator
-
-
 class NoReferenceCycleTests(test_util.TensorFlowTestCase):
 
-  @assert_no_garbage_created
+  @test_util.assert_no_garbage_created
   def testEagerResourceVariables(self):
     with context.eager_mode():
       resource_variable_ops.ResourceVariable(1.0, name="a")
 
-  @assert_no_garbage_created
+  @test_util.assert_no_garbage_created
   def testTensorArrays(self):
     with context.eager_mode():
       ta = tensor_array_ops.TensorArray(