Fix a tfe.Network reference cycle, make it easier to check for reference cycles
in unit tests. PiperOrigin-RevId: 174917072
This commit is contained in:
parent
f589bf0d4d
commit
09964d1e55
tensorflow
contrib/eager/python
python
@ -182,6 +182,40 @@ def _make_custom_getter_for_deferred_restorations():
|
||||
return _custom_getter, deferred_restorations
|
||||
|
||||
|
||||
def _make_prefix_stripping_map_fn(scope_name):
|
||||
"""Closure for stripping the scope name of a Network.
|
||||
|
||||
Implemented as a closure rather than a member function to avoid reference
|
||||
cycles in deferred restorations (this function should not have a reference to
|
||||
the Network which created it).
|
||||
|
||||
Args:
|
||||
scope_name: The Network.scope_name to strip from variables.
|
||||
Returns:
|
||||
A scope_name-stripping default `map_fn` for the Network.
|
||||
"""
|
||||
|
||||
def _strip_variable_prefix(original_variable_name):
|
||||
"""The default map_func for saving or restoring variables.
|
||||
|
||||
Strips the variable prefix for the Network on which save/restore was called,
|
||||
and leaves other variable names fully qualified in the checkpoint.
|
||||
|
||||
Args:
|
||||
original_variable_name: The _shared_name of the variable (no :0
|
||||
suffix) to map.
|
||||
Returns:
|
||||
The checkpoint name of the variable.
|
||||
"""
|
||||
scope_name_with_slash = scope_name + "/"
|
||||
if original_variable_name.startswith(scope_name_with_slash):
|
||||
return original_variable_name[len(scope_name_with_slash):]
|
||||
else:
|
||||
return original_variable_name
|
||||
|
||||
return _strip_variable_prefix
|
||||
|
||||
|
||||
class Network(base.Layer):
|
||||
"""Represents the composition of a set of Layers.
|
||||
|
||||
@ -488,24 +522,6 @@ class Network(base.Layer):
|
||||
"at https://github.com/tensorflow/tensorflow/issues/new if this is "
|
||||
"important to you")
|
||||
|
||||
def _strip_variable_prefix(self, original_variable_name):
|
||||
"""The default map_func for saving or restoring variables.
|
||||
|
||||
Strips the variable prefix for the Network on which save/restore was called,
|
||||
and leaves other variable names fully qualified in the checkpoint.
|
||||
|
||||
Args:
|
||||
original_variable_name: The _shared_name of the variable (no :0
|
||||
suffix) to map.
|
||||
Returns:
|
||||
The checkpoint name of the variable.
|
||||
"""
|
||||
scope_name_with_slash = self.scope_name + "/"
|
||||
if original_variable_name.startswith(scope_name_with_slash):
|
||||
return original_variable_name[len(scope_name_with_slash):]
|
||||
else:
|
||||
return original_variable_name
|
||||
|
||||
def save(self, save_path, global_step=None, map_func=None):
|
||||
"""Save variables from the Network to a checkpoint.
|
||||
|
||||
@ -543,7 +559,7 @@ class Network(base.Layer):
|
||||
save_path = os.path.join(save_path, self.name)
|
||||
user_map_func = map_func
|
||||
if map_func is None:
|
||||
map_func = self._strip_variable_prefix
|
||||
map_func = _make_prefix_stripping_map_fn(self.scope_name)
|
||||
variable_map = {}
|
||||
for variable in self.variables:
|
||||
mapped_name = map_func(variable._shared_name)
|
||||
@ -737,7 +753,7 @@ class Network(base.Layer):
|
||||
save_path = os.path.join(save_path, self.name)
|
||||
user_map_func = map_func
|
||||
if map_func is None:
|
||||
map_func = self._strip_variable_prefix
|
||||
map_func = _make_prefix_stripping_map_fn(self.scope_name)
|
||||
# Step one is to restore any existing variables from the checkpoint.
|
||||
existing_variables_by_checkpoint_name = self._restore_existing_variables(
|
||||
save_path=save_path,
|
||||
|
@ -67,7 +67,7 @@ class NetworkTest(test.TestCase):
|
||||
original_output,
|
||||
self.evaluate(net(input_value)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testTrainableAttribute(self):
|
||||
net = network.Network()
|
||||
self.assertTrue(net.trainable)
|
||||
@ -75,7 +75,7 @@ class NetworkTest(test.TestCase):
|
||||
net.trainable = False
|
||||
self.assertTrue(net.trainable)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testNetworkCall(self):
|
||||
net = MyNetwork(name="abcd")
|
||||
net(constant_op.constant([[2.0]])) # Force variables to be created.
|
||||
@ -85,6 +85,7 @@ class NetworkTest(test.TestCase):
|
||||
result = net(constant_op.constant([[2.0]]))
|
||||
self.assertEqual(34.0, self.evaluate(result))
|
||||
|
||||
# TODO(allenl): This test creates garbage in some Python versions
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNetworkSaveRestoreAlreadyBuilt(self):
|
||||
net = MyNetwork(name="abcd")
|
||||
@ -96,6 +97,7 @@ class NetworkTest(test.TestCase):
|
||||
self._save_modify_load_network_built(net, global_step=None)
|
||||
self._save_modify_load_network_built(net, global_step=10)
|
||||
|
||||
# TODO(allenl): This test creates garbage in some Python versions
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testSaveRestoreDefaultGlobalStep(self):
|
||||
net = MyNetwork(name="abcd")
|
||||
@ -106,6 +108,7 @@ class NetworkTest(test.TestCase):
|
||||
save_path = net.save(self.get_temp_dir())
|
||||
self.assertIn("abcd-4242", save_path)
|
||||
|
||||
# TODO(allenl): This test creates garbage in some Python versions
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testNetworkSaveAndRestoreIntoUnbuilt(self):
|
||||
save_dir = self.get_temp_dir()
|
||||
@ -377,25 +380,25 @@ class NetworkTest(test.TestCase):
|
||||
gc.set_debug(previous_gc_debug_flags)
|
||||
gc.enable()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testAnonymousNoNameInitially(self):
|
||||
net = MyNetwork()
|
||||
with self.assertRaisesRegexp(ValueError, "does not yet have a final name"):
|
||||
net.name # pylint: disable=pointless-statement
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testExplicitHasNameInitially(self):
|
||||
net = MyNetwork(name="abcd")
|
||||
self.assertEqual("abcd", net.name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testUsingResourceVariables(self):
|
||||
net = MyNetwork()
|
||||
net(constant_op.constant([[0.]]))
|
||||
self.assertIsInstance(net.trainable_weights[0],
|
||||
resource_variable_ops.ResourceVariable)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testDuplicateNameError(self):
|
||||
one = constant_op.constant([[1.]])
|
||||
net = MyNetwork(name="foo")
|
||||
@ -405,7 +408,7 @@ class NetworkTest(test.TestCase):
|
||||
net1 = MyNetwork(name="foo")
|
||||
net1(one)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testWrappingInVariableScope(self):
|
||||
with variable_scope.variable_scope("outside_scope"):
|
||||
net = MyNetwork()
|
||||
@ -440,7 +443,7 @@ class NetworkTest(test.TestCase):
|
||||
actual=net.trainable_weights[0].name)
|
||||
self.assertEqual("explicit_name", net.first.name)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testWrappingInAnonymousVariableScope(self):
|
||||
# Named outside variable_scopes are not supported at the moment. However,
|
||||
# blank-named top level variable scopes do not change variable names, and so
|
||||
@ -455,20 +458,20 @@ class NetworkTest(test.TestCase):
|
||||
net(one)
|
||||
self.assertTrue(was_called[0])
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testReasonableSlashError(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "not allowed in Network names"):
|
||||
MyNetwork(name="slash/slash")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testNoVariableScopeNames(self):
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "VariableScopes are not valid Network names"):
|
||||
with variable_scope.variable_scope("some_scope") as vs:
|
||||
MyNetwork(name=vs)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testVariableScopeNameCollision(self):
|
||||
with variable_scope.variable_scope("abcd"):
|
||||
pass
|
||||
@ -478,7 +481,7 @@ class NetworkTest(test.TestCase):
|
||||
one = constant_op.constant([[1.]])
|
||||
net(one)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testNetworkVariablesDoNotInterfere(self):
|
||||
core.Dense(1, use_bias=True) # Should not interfere with naming.
|
||||
net1 = MyNetwork()
|
||||
@ -1007,6 +1010,7 @@ class NetworkTest(test.TestCase):
|
||||
|
||||
class SequentialTest(test.TestCase):
|
||||
|
||||
@test_util.assert_no_garbage_created
|
||||
def testTwoLayers(self):
|
||||
# Create a sequential network with one layer.
|
||||
net = network.Sequential([core.Dense(1, use_bias=False)])
|
||||
@ -1028,6 +1032,7 @@ class SequentialTest(test.TestCase):
|
||||
l2.trainable_variables[0].assign([[11.0]])
|
||||
self.assertEqual(231.0, net(constant_op.constant([[7.0]])).numpy())
|
||||
|
||||
@test_util.assert_no_garbage_created
|
||||
def testFunctions(self):
|
||||
# Create a sequential network with one function.
|
||||
net = network.Sequential([nn_ops.relu])
|
||||
@ -1038,6 +1043,7 @@ class SequentialTest(test.TestCase):
|
||||
net.add(math_ops.negative)
|
||||
self.assertEqual(-2.0, net(two).numpy())
|
||||
|
||||
@test_util.assert_no_garbage_created
|
||||
def testTrainingLayer(self):
|
||||
net = network.Sequential([core.Dropout(0.99999)])
|
||||
two = constant_op.constant(2.0)
|
||||
@ -1051,6 +1057,7 @@ class SequentialTest(test.TestCase):
|
||||
# Should only fail spuriously 1 in 10^100 runs.
|
||||
self.fail("Didn't see dropout happen after 20 tries.")
|
||||
|
||||
@test_util.assert_no_garbage_created
|
||||
def testTrainingFunction(self):
|
||||
# Output depends on value of "training".
|
||||
def add_training(input_value, training=None):
|
||||
|
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import math
|
||||
import random
|
||||
import re
|
||||
@ -452,9 +453,43 @@ class IsolateTest(object):
|
||||
type_arg, value_arg, traceback_arg)
|
||||
|
||||
|
||||
def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
|
||||
use_gpu=False, force_gpu=False,
|
||||
reset_test=True):
|
||||
def assert_no_garbage_created(f):
|
||||
"""Test method decorator to assert that no garbage has been created.
|
||||
|
||||
Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
|
||||
cannot be un-set (i.e. will disable garbage collection for any other unit
|
||||
tests in the same file/shard).
|
||||
|
||||
Args:
|
||||
f: The function to decorate.
|
||||
Returns:
|
||||
The decorated function.
|
||||
"""
|
||||
|
||||
def decorator(self, **kwargs):
|
||||
"""Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
|
||||
gc.disable()
|
||||
previous_debug_flags = gc.get_debug()
|
||||
gc.set_debug(gc.DEBUG_SAVEALL)
|
||||
gc.collect()
|
||||
previous_garbage = len(gc.garbage)
|
||||
f(self, **kwargs)
|
||||
gc.collect()
|
||||
# This will fail if any garbage has been created, typically because of a
|
||||
# reference cycle.
|
||||
self.assertEqual(previous_garbage, len(gc.garbage))
|
||||
# TODO(allenl): Figure out why this debug flag reset doesn't work. It would
|
||||
# be nice to be able to decorate arbitrary tests in a large test suite and
|
||||
# not hold on to every object in other tests.
|
||||
gc.set_debug(previous_debug_flags)
|
||||
gc.enable()
|
||||
return decorator
|
||||
|
||||
|
||||
def run_in_graph_and_eager_modes(
|
||||
__unused__=None, graph=None, config=None,
|
||||
use_gpu=False, force_gpu=False,
|
||||
reset_test=True, assert_no_eager_garbage=False):
|
||||
"""Runs the test in both graph and eager modes.
|
||||
|
||||
Args:
|
||||
@ -465,7 +500,14 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
|
||||
use_gpu: If True, attempt to run as many ops as possible on GPU.
|
||||
force_gpu: If True, pin all ops to `/device:GPU:0`.
|
||||
reset_test: If True, tearDown and SetUp the test case again.
|
||||
|
||||
assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
|
||||
collector and asserts that no extra garbage has been created when running
|
||||
the test in eager mode. This will fail if there are reference cycles
|
||||
(e.g. a = []; a.append(a)). Off by default because some tests may create
|
||||
garbage for legitimate reasons (e.g. they define a class which inherits
|
||||
from `object`), and because DEBUG_SAVEALL is sticky in some Python
|
||||
interpreters (meaning that tests which rely on objects being collected
|
||||
elsewhere in the unit test file will not work).
|
||||
Returns:
|
||||
Returns a decorator that will run the decorated test function
|
||||
using both a graph and using eager execution.
|
||||
@ -487,7 +529,7 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
|
||||
self.tearDown()
|
||||
self.setUp()
|
||||
|
||||
def run_eager_mode():
|
||||
def run_eager_mode(self, **kwargs):
|
||||
if force_gpu:
|
||||
gpu_name = gpu_device_name()
|
||||
if not gpu_name:
|
||||
@ -501,9 +543,12 @@ def run_in_graph_and_eager_modes(__unused__=None, graph=None, config=None,
|
||||
with context.device("/device:CPU:0"):
|
||||
f(self, **kwargs)
|
||||
|
||||
if assert_no_eager_garbage:
|
||||
run_eager_mode = assert_no_garbage_created(run_eager_mode)
|
||||
|
||||
with context.eager_mode():
|
||||
with IsolateTest():
|
||||
run_eager_mode()
|
||||
run_eager_mode(self, **kwargs)
|
||||
|
||||
return decorated
|
||||
return decorator
|
||||
|
@ -329,6 +329,30 @@ class TestUtilTest(test_util.TensorFlowTestCase):
|
||||
self.assertEqual(a_rand, b_rand)
|
||||
|
||||
|
||||
class GarbageCollectionTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_no_reference_cycle_decorator(self):
|
||||
|
||||
class ReferenceCycleTest(object):
|
||||
|
||||
def __init__(inner_self): # pylint: disable=no-self-argument
|
||||
inner_self.assertEqual = self.assertEqual # pylint: disable=invalid-name
|
||||
|
||||
@test_util.assert_no_garbage_created
|
||||
def test_has_cycle(self):
|
||||
a = []
|
||||
a.append(a)
|
||||
|
||||
@test_util.assert_no_garbage_created
|
||||
def test_has_no_cycle(self):
|
||||
pass
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
ReferenceCycleTest().test_has_cycle()
|
||||
|
||||
ReferenceCycleTest().test_has_no_cycle()
|
||||
|
||||
|
||||
@test_util.with_c_api
|
||||
class IsolationTest(test_util.TensorFlowTestCase):
|
||||
|
||||
|
@ -21,8 +21,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gc
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -31,37 +29,14 @@ from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
def assert_no_garbage_created(f):
|
||||
"""Test decorator to assert that no garbage has been created."""
|
||||
|
||||
def decorator(self):
|
||||
"""Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
|
||||
gc.disable()
|
||||
previous_debug_flags = gc.get_debug()
|
||||
gc.set_debug(gc.DEBUG_SAVEALL)
|
||||
gc.collect()
|
||||
previous_garbage = len(gc.garbage)
|
||||
f(self)
|
||||
gc.collect()
|
||||
# This will fail if any garbage has been created, typically because of a
|
||||
# reference cycle.
|
||||
self.assertEqual(previous_garbage, len(gc.garbage))
|
||||
# TODO(allenl): Figure out why this debug flag reset doesn't work. It would
|
||||
# be nice to be able to decorate arbitrary tests in a large test suite and
|
||||
# not hold on to every object in other tests.
|
||||
gc.set_debug(previous_debug_flags)
|
||||
gc.enable()
|
||||
return decorator
|
||||
|
||||
|
||||
class NoReferenceCycleTests(test_util.TensorFlowTestCase):
|
||||
|
||||
@assert_no_garbage_created
|
||||
@test_util.assert_no_garbage_created
|
||||
def testEagerResourceVariables(self):
|
||||
with context.eager_mode():
|
||||
resource_variable_ops.ResourceVariable(1.0, name="a")
|
||||
|
||||
@assert_no_garbage_created
|
||||
@test_util.assert_no_garbage_created
|
||||
def testTensorArrays(self):
|
||||
with context.eager_mode():
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
|
Loading…
Reference in New Issue
Block a user