From 09964d1e55eb8ff699db4cf1aa38808cf825c2df Mon Sep 17 00:00:00 2001 From: Allen Lavoie <allenl@google.com> Date: Tue, 7 Nov 2017 14:51:44 -0800 Subject: [PATCH] Fix a tfe.Network reference cycle, make it easier to check for reference cycles in unit tests. PiperOrigin-RevId: 174917072 --- tensorflow/contrib/eager/python/network.py | 56 +++++++++++------- .../contrib/eager/python/network_test.py | 31 ++++++---- tensorflow/python/framework/test_util.py | 57 +++++++++++++++++-- tensorflow/python/framework/test_util_test.py | 24 ++++++++ .../kernel_tests/garbage_collection_test.py | 29 +--------- 5 files changed, 132 insertions(+), 65 deletions(-) diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 97feaec30ed..c6e628b074e 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -182,6 +182,40 @@ def _make_custom_getter_for_deferred_restorations(): return _custom_getter, deferred_restorations +def _make_prefix_stripping_map_fn(scope_name): + """Closure for stripping the scope name of a Network. + + Implemented as a closure rather than a member function to avoid reference + cycles in deferred restorations (this function should not have a reference to + the Network which created it). + + Args: + scope_name: The Network.scope_name to strip from variables. + Returns: + A scope_name-stripping default `map_fn` for the Network. + """ + + def _strip_variable_prefix(original_variable_name): + """The default map_func for saving or restoring variables. + + Strips the variable prefix for the Network on which save/restore was called, + and leaves other variable names fully qualified in the checkpoint. + + Args: + original_variable_name: The _shared_name of the variable (no :0 + suffix) to map. + Returns: + The checkpoint name of the variable. + """ + scope_name_with_slash = scope_name + "/" + if original_variable_name.startswith(scope_name_with_slash): + return original_variable_name[len(scope_name_with_slash):] + else: + return original_variable_name + + return _strip_variable_prefix + + class Network(base.Layer): """Represents the composition of a set of Layers. @@ -488,24 +522,6 @@ class Network(base.Layer): "at https://github.com/tensorflow/tensorflow/issues/new if this is " "important to you") - def _strip_variable_prefix(self, original_variable_name): - """The default map_func for saving or restoring variables. - - Strips the variable prefix for the Network on which save/restore was called, - and leaves other variable names fully qualified in the checkpoint. - - Args: - original_variable_name: The _shared_name of the variable (no :0 - suffix) to map. - Returns: - The checkpoint name of the variable. - """ - scope_name_with_slash = self.scope_name + "/" - if original_variable_name.startswith(scope_name_with_slash): - return original_variable_name[len(scope_name_with_slash):] - else: - return original_variable_name - def save(self, save_path, global_step=None, map_func=None): """Save variables from the Network to a checkpoint. @@ -543,7 +559,7 @@ class Network(base.Layer): save_path = os.path.join(save_path, self.name) user_map_func = map_func if map_func is None: - map_func = self._strip_variable_prefix + map_func = _make_prefix_stripping_map_fn(self.scope_name) variable_map = {} for variable in self.variables: mapped_name = map_func(variable._shared_name) @@ -737,7 +753,7 @@ class Network(base.Layer): save_path = os.path.join(save_path, self.name) user_map_func = map_func if map_func is None: - map_func = self._strip_variable_prefix + map_func = _make_prefix_stripping_map_fn(self.scope_name) # Step one is to restore any existing variables from the checkpoint. existing_variables_by_checkpoint_name = self._restore_existing_variables( save_path=save_path, diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index c621f527c28..14adbafe573 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -67,7 +67,7 @@ class NetworkTest(test.TestCase): original_output, self.evaluate(net(input_value))) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testTrainableAttribute(self): net = network.Network() self.assertTrue(net.trainable) @@ -75,7 +75,7 @@ class NetworkTest(test.TestCase): net.trainable = False self.assertTrue(net.trainable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNetworkCall(self): net = MyNetwork(name="abcd") net(constant_op.constant([[2.0]])) # Force variables to be created. @@ -85,6 +85,7 @@ class NetworkTest(test.TestCase): result = net(constant_op.constant([[2.0]])) self.assertEqual(34.0, self.evaluate(result)) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testNetworkSaveRestoreAlreadyBuilt(self): net = MyNetwork(name="abcd") @@ -96,6 +97,7 @@ class NetworkTest(test.TestCase): self._save_modify_load_network_built(net, global_step=None) self._save_modify_load_network_built(net, global_step=10) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testSaveRestoreDefaultGlobalStep(self): net = MyNetwork(name="abcd") @@ -106,6 +108,7 @@ class NetworkTest(test.TestCase): save_path = net.save(self.get_temp_dir()) self.assertIn("abcd-4242", save_path) + # TODO(allenl): This test creates garbage in some Python versions @test_util.run_in_graph_and_eager_modes() def testNetworkSaveAndRestoreIntoUnbuilt(self): save_dir = self.get_temp_dir() @@ -377,25 +380,25 @@ class NetworkTest(test.TestCase): gc.set_debug(previous_gc_debug_flags) gc.enable() - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testAnonymousNoNameInitially(self): net = MyNetwork() with self.assertRaisesRegexp(ValueError, "does not yet have a final name"): net.name # pylint: disable=pointless-statement - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testExplicitHasNameInitially(self): net = MyNetwork(name="abcd") self.assertEqual("abcd", net.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testUsingResourceVariables(self): net = MyNetwork() net(constant_op.constant([[0.]])) self.assertIsInstance(net.trainable_weights[0], resource_variable_ops.ResourceVariable) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testDuplicateNameError(self): one = constant_op.constant([[1.]]) net = MyNetwork(name="foo") @@ -405,7 +408,7 @@ class NetworkTest(test.TestCase): net1 = MyNetwork(name="foo") net1(one) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testWrappingInVariableScope(self): with variable_scope.variable_scope("outside_scope"): net = MyNetwork() @@ -440,7 +443,7 @@ class NetworkTest(test.TestCase): actual=net.trainable_weights[0].name) self.assertEqual("explicit_name", net.first.name) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testWrappingInAnonymousVariableScope(self): # Named outside variable_scopes are not supported at the moment. However, # blank-named top level variable scopes do not change variable names, and so @@ -455,20 +458,20 @@ class NetworkTest(test.TestCase): net(one) self.assertTrue(was_called[0]) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testReasonableSlashError(self): with self.assertRaisesRegexp( ValueError, "not allowed in Network names"): MyNetwork(name="slash/slash") - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNoVariableScopeNames(self): with self.assertRaisesRegexp( ValueError, "VariableScopes are not valid Network names"): with variable_scope.variable_scope("some_scope") as vs: MyNetwork(name=vs) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testVariableScopeNameCollision(self): with variable_scope.variable_scope("abcd"): pass @@ -478,7 +481,7 @@ class NetworkTest(test.TestCase): one = constant_op.constant([[1.]]) net(one) - @test_util.run_in_graph_and_eager_modes() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testNetworkVariablesDoNotInterfere(self): core.Dense(1, use_bias=True) # Should not interfere with naming. net1 = MyNetwork() @@ -1007,6 +1010,7 @@ class NetworkTest(test.TestCase): class SequentialTest(test.TestCase): + @test_util.assert_no_garbage_created def testTwoLayers(self): # Create a sequential network with one layer. net = network.Sequential([core.Dense(1, use_bias=False)]) @@ -1028,6 +1032,7 @@ class SequentialTest(test.TestCase): l2.trainable_variables[0].assign([[11.0]]) self.assertEqual(231.0, net(constant_op.constant([[7.0]])).numpy()) + @test_util.assert_no_garbage_created def testFunctions(self): # Create a sequential network with one function. net = network.Sequential([nn_ops.relu]) @@ -1038,6 +1043,7 @@ class SequentialTest(test.TestCase): net.add(math_ops.negative) self.assertEqual(-2.0, net(two).numpy()) + @test_util.assert_no_garbage_created def testTrainingLayer(self): net = network.Sequential([core.Dropout(0.99999)]) two = constant_op.constant(2.0) @@ -1051,6 +1057,7 @@ class SequentialTest(test.TestCase): # Should only fail spuriously 1 in 10^100 runs. self.fail("Didn't see dropout happen after 20 tries.") + @test_util.assert_no_garbage_created def testTrainingFunction(self): # Output depends on value of "training". def add_training(input_value, training=None): diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index dbe9a2421c9..6e3a35af3cd 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import contextlib +import gc import math import random import re @@ -452,9 +453,43 @@ class IsolateTest(object): type_arg, value_arg, traceback_arg) -def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, - use_gpu=False, force_gpu=False, - reset_test=True): +def assert_no_garbage_created(f): + """Test method decorator to assert that no garbage has been created. + + Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters + cannot be un-set (i.e. will disable garbage collection for any other unit + tests in the same file/shard). + + Args: + f: The function to decorate. + Returns: + The decorated function. + """ + + def decorator(self, **kwargs): + """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" + gc.disable() + previous_debug_flags = gc.get_debug() + gc.set_debug(gc.DEBUG_SAVEALL) + gc.collect() + previous_garbage = len(gc.garbage) + f(self, **kwargs) + gc.collect() + # This will fail if any garbage has been created, typically because of a + # reference cycle. + self.assertEqual(previous_garbage, len(gc.garbage)) + # TODO(allenl): Figure out why this debug flag reset doesn't work. It would + # be nice to be able to decorate arbitrary tests in a large test suite and + # not hold on to every object in other tests. + gc.set_debug(previous_debug_flags) + gc.enable() + return decorator + + +def run_in_graph_and_eager_modes( + __unused__=None, graph=None, config=None, + use_gpu=False, force_gpu=False, + reset_test=True, assert_no_eager_garbage=False): """Runs the test in both graph and eager modes. Args: @@ -465,7 +500,14 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, use_gpu: If True, attempt to run as many ops as possible on GPU. force_gpu: If True, pin all ops to `/device:GPU:0`. reset_test: If True, tearDown and SetUp the test case again. - + assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage + collector and asserts that no extra garbage has been created when running + the test in eager mode. This will fail if there are reference cycles + (e.g. a = []; a.append(a)). Off by default because some tests may create + garbage for legitimate reasons (e.g. they define a class which inherits + from `object`), and because DEBUG_SAVEALL is sticky in some Python + interpreters (meaning that tests which rely on objects being collected + elsewhere in the unit test file will not work). Returns: Returns a decorator that will run the decorated test function using both a graph and using eager execution. @@ -487,7 +529,7 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, self.tearDown() self.setUp() - def run_eager_mode(): + def run_eager_mode(self, **kwargs): if force_gpu: gpu_name = gpu_device_name() if not gpu_name: @@ -501,9 +543,12 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None, with context.device("/device:CPU:0"): f(self, **kwargs) + if assert_no_eager_garbage: + run_eager_mode = assert_no_garbage_created(run_eager_mode) + with context.eager_mode(): with IsolateTest(): - run_eager_mode() + run_eager_mode(self, **kwargs) return decorated return decorator diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index b2f8d62095f..1c5db945005 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -329,6 +329,30 @@ class TestUtilTest(test_util.TensorFlowTestCase): self.assertEqual(a_rand, b_rand) +class GarbageCollectionTest(test_util.TensorFlowTestCase): + + def test_no_reference_cycle_decorator(self): + + class ReferenceCycleTest(object): + + def __init__(inner_self): # pylint: disable=no-self-argument + inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name + + @test_util.assert_no_garbage_created + def test_has_cycle(self): + a = [] + a.append(a) + + @test_util.assert_no_garbage_created + def test_has_no_cycle(self): + pass + + with self.assertRaises(AssertionError): + ReferenceCycleTest().test_has_cycle() + + ReferenceCycleTest().test_has_no_cycle() + + @test_util.with_c_api class IsolationTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/kernel_tests/garbage_collection_test.py b/tensorflow/python/kernel_tests/garbage_collection_test.py index 24a6ee74c56..39f936fbc92 100644 --- a/tensorflow/python/kernel_tests/garbage_collection_test.py +++ b/tensorflow/python/kernel_tests/garbage_collection_test.py @@ -21,8 +21,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import gc - from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import test_util @@ -31,37 +29,14 @@ from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test -def assert_no_garbage_created(f): - """Test decorator to assert that no garbage has been created.""" - - def decorator(self): - """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage.""" - gc.disable() - previous_debug_flags = gc.get_debug() - gc.set_debug(gc.DEBUG_SAVEALL) - gc.collect() - previous_garbage = len(gc.garbage) - f(self) - gc.collect() - # This will fail if any garbage has been created, typically because of a - # reference cycle. - self.assertEqual(previous_garbage, len(gc.garbage)) - # TODO(allenl): Figure out why this debug flag reset doesn't work. It would - # be nice to be able to decorate arbitrary tests in a large test suite and - # not hold on to every object in other tests. - gc.set_debug(previous_debug_flags) - gc.enable() - return decorator - - class NoReferenceCycleTests(test_util.TensorFlowTestCase): - @assert_no_garbage_created + @test_util.assert_no_garbage_created def testEagerResourceVariables(self): with context.eager_mode(): resource_variable_ops.ResourceVariable(1.0, name="a") - @assert_no_garbage_created + @test_util.assert_no_garbage_created def testTensorArrays(self): with context.eager_mode(): ta = tensor_array_ops.TensorArray(