Fix a tfe.Network reference cycle, make it easier to check for reference cycles

in unit tests.

PiperOrigin-RevId: 174917072
This commit is contained in:
Allen Lavoie 2017-11-07 14:51:44 -08:00 committed by TensorFlower Gardener
parent f589bf0d4d
commit 09964d1e55
5 changed files with 132 additions and 65 deletions
tensorflow

View File

@ -182,6 +182,40 @@ def _make_custom_getter_for_deferred_restorations():
return _custom_getter, deferred_restorations
def _make_prefix_stripping_map_fn(scope_name):
"""Closure for stripping the scope name of a Network.
Implemented as a closure rather than a member function to avoid reference
cycles in deferred restorations (this function should not have a reference to
the Network which created it).
Args:
scope_name: The Network.scope_name to strip from variables.
Returns:
A scope_name-stripping default `map_fn` for the Network.
"""
def _strip_variable_prefix(original_variable_name):
"""The default map_func for saving or restoring variables.
Strips the variable prefix for the Network on which save/restore was called,
and leaves other variable names fully qualified in the checkpoint.
Args:
original_variable_name: The _shared_name of the variable (no :0
suffix) to map.
Returns:
The checkpoint name of the variable.
"""
scope_name_with_slash = scope_name + "/"
if original_variable_name.startswith(scope_name_with_slash):
return original_variable_name[len(scope_name_with_slash):]
else:
return original_variable_name
return _strip_variable_prefix
class Network(base.Layer):
"""Represents the composition of a set of Layers.
@ -488,24 +522,6 @@ class Network(base.Layer):
"at https://github.com/tensorflow/tensorflow/issues/new if this is "
"important to you")
def _strip_variable_prefix(self, original_variable_name):
"""The default map_func for saving or restoring variables.
Strips the variable prefix for the Network on which save/restore was called,
and leaves other variable names fully qualified in the checkpoint.
Args:
original_variable_name: The _shared_name of the variable (no :0
suffix) to map.
Returns:
The checkpoint name of the variable.
"""
scope_name_with_slash = self.scope_name + "/"
if original_variable_name.startswith(scope_name_with_slash):
return original_variable_name[len(scope_name_with_slash):]
else:
return original_variable_name
def save(self, save_path, global_step=None, map_func=None):
"""Save variables from the Network to a checkpoint.
@ -543,7 +559,7 @@ class Network(base.Layer):
save_path = os.path.join(save_path, self.name)
user_map_func = map_func
if map_func is None:
map_func = self._strip_variable_prefix
map_func = _make_prefix_stripping_map_fn(self.scope_name)
variable_map = {}
for variable in self.variables:
mapped_name = map_func(variable._shared_name)
@ -737,7 +753,7 @@ class Network(base.Layer):
save_path = os.path.join(save_path, self.name)
user_map_func = map_func
if map_func is None:
map_func = self._strip_variable_prefix
map_func = _make_prefix_stripping_map_fn(self.scope_name)
# Step one is to restore any existing variables from the checkpoint.
existing_variables_by_checkpoint_name = self._restore_existing_variables(
save_path=save_path,

View File

@ -67,7 +67,7 @@ class NetworkTest(test.TestCase):
original_output,
self.evaluate(net(input_value)))
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testTrainableAttribute(self):
net = network.Network()
self.assertTrue(net.trainable)
@ -75,7 +75,7 @@ class NetworkTest(test.TestCase):
net.trainable = False
self.assertTrue(net.trainable)
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testNetworkCall(self):
net = MyNetwork(name="abcd")
net(constant_op.constant([[2.0]])) # Force variables to be created.
@ -85,6 +85,7 @@ class NetworkTest(test.TestCase):
result = net(constant_op.constant([[2.0]]))
self.assertEqual(34.0, self.evaluate(result))
# TODO(allenl): This test creates garbage in some Python versions
@test_util.run_in_graph_and_eager_modes()
def testNetworkSaveRestoreAlreadyBuilt(self):
net = MyNetwork(name="abcd")
@ -96,6 +97,7 @@ class NetworkTest(test.TestCase):
self._save_modify_load_network_built(net, global_step=None)
self._save_modify_load_network_built(net, global_step=10)
# TODO(allenl): This test creates garbage in some Python versions
@test_util.run_in_graph_and_eager_modes()
def testSaveRestoreDefaultGlobalStep(self):
net = MyNetwork(name="abcd")
@ -106,6 +108,7 @@ class NetworkTest(test.TestCase):
save_path = net.save(self.get_temp_dir())
self.assertIn("abcd-4242", save_path)
# TODO(allenl): This test creates garbage in some Python versions
@test_util.run_in_graph_and_eager_modes()
def testNetworkSaveAndRestoreIntoUnbuilt(self):
save_dir = self.get_temp_dir()
@ -377,25 +380,25 @@ class NetworkTest(test.TestCase):
gc.set_debug(previous_gc_debug_flags)
gc.enable()
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testAnonymousNoNameInitially(self):
net = MyNetwork()
with self.assertRaisesRegexp(ValueError, "does not yet have a final name"):
net.name # pylint: disable=pointless-statement
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testExplicitHasNameInitially(self):
net = MyNetwork(name="abcd")
self.assertEqual("abcd", net.name)
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testUsingResourceVariables(self):
net = MyNetwork()
net(constant_op.constant([[0.]]))
self.assertIsInstance(net.trainable_weights[0],
resource_variable_ops.ResourceVariable)
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testDuplicateNameError(self):
one = constant_op.constant([[1.]])
net = MyNetwork(name="foo")
@ -405,7 +408,7 @@ class NetworkTest(test.TestCase):
net1 = MyNetwork(name="foo")
net1(one)
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testWrappingInVariableScope(self):
with variable_scope.variable_scope("outside_scope"):
net = MyNetwork()
@ -440,7 +443,7 @@ class NetworkTest(test.TestCase):
actual=net.trainable_weights[0].name)
self.assertEqual("explicit_name", net.first.name)
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testWrappingInAnonymousVariableScope(self):
# Named outside variable_scopes are not supported at the moment. However,
# blank-named top level variable scopes do not change variable names, and so
@ -455,20 +458,20 @@ class NetworkTest(test.TestCase):
net(one)
self.assertTrue(was_called[0])
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testReasonableSlashError(self):
with self.assertRaisesRegexp(
ValueError, "not allowed in Network names"):
MyNetwork(name="slash/slash")
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testNoVariableScopeNames(self):
with self.assertRaisesRegexp(
ValueError, "VariableScopes are not valid Network names"):
with variable_scope.variable_scope("some_scope") as vs:
MyNetwork(name=vs)
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testVariableScopeNameCollision(self):
with variable_scope.variable_scope("abcd"):
pass
@ -478,7 +481,7 @@ class NetworkTest(test.TestCase):
one = constant_op.constant([[1.]])
net(one)
@test_util.run_in_graph_and_eager_modes()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testNetworkVariablesDoNotInterfere(self):
core.Dense(1, use_bias=True) # Should not interfere with naming.
net1 = MyNetwork()
@ -1007,6 +1010,7 @@ class NetworkTest(test.TestCase):
class SequentialTest(test.TestCase):
@test_util.assert_no_garbage_created
def testTwoLayers(self):
# Create a sequential network with one layer.
net = network.Sequential([core.Dense(1, use_bias=False)])
@ -1028,6 +1032,7 @@ class SequentialTest(test.TestCase):
l2.trainable_variables[0].assign([[11.0]])
self.assertEqual(231.0, net(constant_op.constant([[7.0]])).numpy())
@test_util.assert_no_garbage_created
def testFunctions(self):
# Create a sequential network with one function.
net = network.Sequential([nn_ops.relu])
@ -1038,6 +1043,7 @@ class SequentialTest(test.TestCase):
net.add(math_ops.negative)
self.assertEqual(-2.0, net(two).numpy())
@test_util.assert_no_garbage_created
def testTrainingLayer(self):
net = network.Sequential([core.Dropout(0.99999)])
two = constant_op.constant(2.0)
@ -1051,6 +1057,7 @@ class SequentialTest(test.TestCase):
# Should only fail spuriously 1 in 10^100 runs.
self.fail("Didn't see dropout happen after 20 tries.")
@test_util.assert_no_garbage_created
def testTrainingFunction(self):
# Output depends on value of "training".
def add_training(input_value, training=None):

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import contextlib
import gc
import math
import random
import re
@ -452,9 +453,43 @@ class IsolateTest(object):
type_arg, value_arg, traceback_arg)
def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
use_gpu=False, force_gpu=False,
reset_test=True):
def assert_no_garbage_created(f):
"""Test method decorator to assert that no garbage has been created.
Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
cannot be un-set (i.e. will disable garbage collection for any other unit
tests in the same file/shard).
Args:
f: The function to decorate.
Returns:
The decorated function.
"""
def decorator(self, **kwargs):
"""Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
gc.disable()
previous_debug_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()
previous_garbage = len(gc.garbage)
f(self, **kwargs)
gc.collect()
# This will fail if any garbage has been created, typically because of a
# reference cycle.
self.assertEqual(previous_garbage, len(gc.garbage))
# TODO(allenl): Figure out why this debug flag reset doesn't work. It would
# be nice to be able to decorate arbitrary tests in a large test suite and
# not hold on to every object in other tests.
gc.set_debug(previous_debug_flags)
gc.enable()
return decorator
def run_in_graph_and_eager_modes(
__unused__=None, graph=None, config=None,
use_gpu=False, force_gpu=False,
reset_test=True, assert_no_eager_garbage=False):
"""Runs the test in both graph and eager modes.
Args:
@ -465,7 +500,14 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
use_gpu: If True, attempt to run as many ops as possible on GPU.
force_gpu: If True, pin all ops to `/device:GPU:0`.
reset_test: If True, tearDown and SetUp the test case again.
assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
collector and asserts that no extra garbage has been created when running
the test in eager mode. This will fail if there are reference cycles
(e.g. a = []; a.append(a)). Off by default because some tests may create
garbage for legitimate reasons (e.g. they define a class which inherits
from `object`), and because DEBUG_SAVEALL is sticky in some Python
interpreters (meaning that tests which rely on objects being collected
elsewhere in the unit test file will not work).
Returns:
Returns a decorator that will run the decorated test function
using both a graph and using eager execution.
@ -487,7 +529,7 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
self.tearDown()
self.setUp()
def run_eager_mode():
def run_eager_mode(self, **kwargs):
if force_gpu:
gpu_name = gpu_device_name()
if not gpu_name:
@ -501,9 +543,12 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
with context.device("/device:CPU:0"):
f(self, **kwargs)
if assert_no_eager_garbage:
run_eager_mode = assert_no_garbage_created(run_eager_mode)
with context.eager_mode():
with IsolateTest():
run_eager_mode()
run_eager_mode(self, **kwargs)
return decorated
return decorator

View File

@ -329,6 +329,30 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertEqual(a_rand, b_rand)
class GarbageCollectionTest(test_util.TensorFlowTestCase):
def test_no_reference_cycle_decorator(self):
class ReferenceCycleTest(object):
def __init__(inner_self): # pylint: disable=no-self-argument
inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
@test_util.assert_no_garbage_created
def test_has_cycle(self):
a = []
a.append(a)
@test_util.assert_no_garbage_created
def test_has_no_cycle(self):
pass
with self.assertRaises(AssertionError):
ReferenceCycleTest().test_has_cycle()
ReferenceCycleTest().test_has_no_cycle()
@test_util.with_c_api
class IsolationTest(test_util.TensorFlowTestCase):

View File

@ -21,8 +21,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@ -31,37 +29,14 @@ from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
def assert_no_garbage_created(f):
"""Test decorator to assert that no garbage has been created."""
def decorator(self):
"""Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
gc.disable()
previous_debug_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()
previous_garbage = len(gc.garbage)
f(self)
gc.collect()
# This will fail if any garbage has been created, typically because of a
# reference cycle.
self.assertEqual(previous_garbage, len(gc.garbage))
# TODO(allenl): Figure out why this debug flag reset doesn't work. It would
# be nice to be able to decorate arbitrary tests in a large test suite and
# not hold on to every object in other tests.
gc.set_debug(previous_debug_flags)
gc.enable()
return decorator
class NoReferenceCycleTests(test_util.TensorFlowTestCase):
@assert_no_garbage_created
@test_util.assert_no_garbage_created
def testEagerResourceVariables(self):
with context.eager_mode():
resource_variable_ops.ResourceVariable(1.0, name="a")
@assert_no_garbage_created
@test_util.assert_no_garbage_created
def testTensorArrays(self):
with context.eager_mode():
ta = tensor_array_ops.TensorArray(