Use WeakKeyDictionaries for global Keras {graph->...} maps
These globals were holding onto graphs including FuncGraphs, which held onto captured tensors leaving garbage around. This change also adds a test to catch garbage like this in the future. To make the test work, I needed to manually breakup some reference cycles caused by OrderedDicts. We should probably have a custom impl of OrderedDict similar to the one in Python3 and avoid these issues. PiperOrigin-RevId: 212694290
This commit is contained in:
parent
5d1de24583
commit
52d9dbfa8e
tensorflow/python
@ -25,6 +25,7 @@ import sys
|
||||
import numpy
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
@ -38,6 +39,7 @@ from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.engine import training as keras_training
|
||||
from tensorflow.python.layers import convolutional
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
@ -57,6 +59,21 @@ from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class MiniModel(keras_training.Model):
|
||||
"""Minimal model for mnist.
|
||||
|
||||
Useful for testing and debugging on slow TPU simulators.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(MiniModel, self).__init__(name='')
|
||||
self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
|
||||
bias_initializer='ones')
|
||||
|
||||
def call(self, inputs, training=True):
|
||||
return self.fc(inputs)
|
||||
|
||||
|
||||
@test_util.with_c_shapes
|
||||
class FunctionTest(test.TestCase):
|
||||
|
||||
@ -1005,6 +1022,7 @@ class FunctionTest(test.TestCase):
|
||||
with ops.get_default_graph().as_default():
|
||||
create_variable()
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testLayerInDefun(self):
|
||||
conv = convolutional.Conv2D(
|
||||
filters=1,
|
||||
@ -1018,7 +1036,34 @@ class FunctionTest(test.TestCase):
|
||||
|
||||
x = array_ops.ones([1, 2, 2, 1])
|
||||
y = model(x)
|
||||
self.assertAllEqual([[[[4.0]]]], y.numpy())
|
||||
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
self.assertAllEqual([[[[4.0]]]], self.evaluate(y))
|
||||
|
||||
# Remove reference cycles in model
|
||||
test_util.dismantle_polymorphic_function(model)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||
def testDefunKerasModelCall(self):
|
||||
model = MiniModel()
|
||||
model.call = function.defun(model.call)
|
||||
|
||||
x = array_ops.ones([1, 2])
|
||||
y = model(x)
|
||||
|
||||
if not context.executing_eagerly():
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
self.assertAllEqual([[3.0]], self.evaluate(y))
|
||||
|
||||
# Remove reference cycles in defun.
|
||||
test_util.dismantle_polymorphic_function(model.call)
|
||||
# Break the reference cycle between the MiniModel and the defun:
|
||||
# MiniModel --(through its `call` method)--> PolymorphicFunction
|
||||
# PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel
|
||||
del model.call
|
||||
|
||||
# Note: The ConfigProto below unfortunately only configures graph
|
||||
# construction. Eager's configuration is controlled in `__main__`.
|
||||
|
@ -58,6 +58,7 @@ from tensorflow.python.util import decorator_utils
|
||||
from tensorflow.python.util import deprecation
|
||||
from tensorflow.python.util import function_utils
|
||||
from tensorflow.python.util import lock_util
|
||||
from tensorflow.python.util import memory
|
||||
from tensorflow.python.util import tf_contextlib
|
||||
from tensorflow.python.util import tf_stack
|
||||
from tensorflow.python.util.deprecation import deprecated_args
|
||||
@ -5824,23 +5825,11 @@ def dismantle_graph(graph):
|
||||
graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
|
||||
after this function runs.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
# OrderedDict, constructed on Graph creation, makes a simple reference loop
|
||||
# and hides it in an __attribute in some Python versions. We don't need to
|
||||
# throw an error if we can't find it, but if we do find it we can break the
|
||||
# loop to avoid creating work for the garbage collector.
|
||||
graph_operations = graph.get_operations()
|
||||
problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
|
||||
# pylint: enable=protected-access
|
||||
if problematic_cycle:
|
||||
try:
|
||||
del problematic_cycle[0][:]
|
||||
except TypeError:
|
||||
# This is probably not one of the problematic Python versions. Continue
|
||||
# with the rest of our cleanup.
|
||||
pass
|
||||
memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access
|
||||
|
||||
# Now clean up Operation<->Graph reference cycles by clearing all of the
|
||||
# attributes for the Graph and its ops.
|
||||
graph_operations = graph.get_operations()
|
||||
for op in graph_operations:
|
||||
op.__dict__ = {}
|
||||
graph.__dict__ = {}
|
||||
|
@ -69,6 +69,7 @@ from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import server_lib
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import memory
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.protobuf import compare
|
||||
@ -2008,3 +2009,42 @@ def set_producer_version(graph, producer_version):
|
||||
with graph.as_default():
|
||||
importer.import_graph_def(graph_def)
|
||||
assert graph.graph_def_versions.producer, producer_version
|
||||
|
||||
|
||||
def dismantle_func_graph(func_graph):
|
||||
"""Removes reference cycles in `func_graph` FuncGraph.
|
||||
|
||||
Helpful for making sure the garbage collector doesn't need to run when
|
||||
the FuncGraph goes out of scope, e.g. in tests using defun with
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
|
||||
|
||||
Args:
|
||||
func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
|
||||
after this function.
|
||||
"""
|
||||
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
|
||||
# Clearing captures using clear() leaves some cycles around.
|
||||
while func_graph.captures:
|
||||
func_graph.captures.popitem()
|
||||
memory.dismantle_ordered_dict(func_graph.captures)
|
||||
ops.dismantle_graph(func_graph)
|
||||
|
||||
|
||||
def dismantle_polymorphic_function(func):
|
||||
"""Removes reference cycles in PolymorphicFunction `func`.
|
||||
|
||||
Helpful for making sure the garbage collector doesn't need to run when
|
||||
PolymorphicFunction goes out of scope, e.g. in tests using defun with
|
||||
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
|
||||
|
||||
Args:
|
||||
func: A `PolymorphicFunction` object to destroy. `func` is unusable
|
||||
after this function.
|
||||
"""
|
||||
# TODO(b/115366440): Delete this method when a custom OrderedDict is added
|
||||
cache = func._function_cache # pylint: disable=protected-access
|
||||
for concrete_func in cache.values():
|
||||
dismantle_func_graph(concrete_func.graph)
|
||||
while cache:
|
||||
cache.popitem()
|
||||
memory.dismantle_ordered_dict(cache)
|
||||
|
@ -73,7 +73,16 @@ _SESSION = None
|
||||
# This dictionary holds a mapping {graph: learning_phase}.
|
||||
# A learning phase is a bool tensor used to run Keras models in
|
||||
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
|
||||
_GRAPH_LEARNING_PHASES = {}
|
||||
_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
|
||||
|
||||
|
||||
# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
|
||||
# We keep a separate reference to it to make sure it does not get removed from
|
||||
# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a
|
||||
# string because strings are not weakly-referencable.
|
||||
class _DummyEagerGraph(object):
|
||||
pass
|
||||
_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
|
||||
|
||||
# This boolean flag can be set to True to leave variable initialization
|
||||
# up to the user.
|
||||
@ -96,11 +105,11 @@ _LOCAL_DEVICES = None
|
||||
|
||||
# This dictionary holds a mapping between a graph and variables to initialize
|
||||
# in the graph.
|
||||
_GRAPH_VARIABLES = {}
|
||||
_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
|
||||
|
||||
# This dictionary holds a mapping between a graph and TF optimizers created in
|
||||
# the graph.
|
||||
_GRAPH_TF_OPTIMIZERS = {}
|
||||
_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
|
||||
|
||||
|
||||
@tf_export('keras.backend.backend')
|
||||
@ -359,10 +368,10 @@ def learning_phase():
|
||||
Learning phase (scalar integer tensor or Python integer).
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
if 'eager' not in _GRAPH_LEARNING_PHASES:
|
||||
if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
|
||||
# Fallback to inference mode as default.
|
||||
return 0
|
||||
return _GRAPH_LEARNING_PHASES['eager']
|
||||
return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
|
||||
|
||||
graph = ops.get_default_graph()
|
||||
if graph not in _GRAPH_LEARNING_PHASES:
|
||||
@ -386,7 +395,7 @@ def set_learning_phase(value):
|
||||
if value not in {0, 1}:
|
||||
raise ValueError('Expected learning phase to be 0 or 1.')
|
||||
if context.executing_eagerly():
|
||||
_GRAPH_LEARNING_PHASES['eager'] = value
|
||||
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
|
||||
else:
|
||||
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
|
||||
|
||||
@ -415,7 +424,7 @@ def learning_phase_scope(value):
|
||||
finally:
|
||||
# Restore learning phase to initial value.
|
||||
if context.executing_eagerly():
|
||||
_GRAPH_LEARNING_PHASES['eager'] = previous_value
|
||||
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
|
||||
else:
|
||||
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
|
||||
|
||||
|
45
tensorflow/python/util/memory.py
Normal file
45
tensorflow/python/util/memory.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Functions related to Python memory management."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
# TODO(b/115366440): Delete this function when a custom OrderedDict is added
|
||||
def dismantle_ordered_dict(ordered_dict):
|
||||
"""Remove reference cycle in OrderedDict `ordered_dict`.
|
||||
|
||||
Helpful for making sure the garbage collector doesn't need to run after
|
||||
using an OrderedDict.
|
||||
|
||||
Args:
|
||||
ordered_dict: A `OrderedDict` object to destroy. This object is unusable
|
||||
after this function runs.
|
||||
"""
|
||||
# OrderedDict, makes a simple reference loop
|
||||
# and hides it in an __attribute in some Python versions. We don't need to
|
||||
# throw an error if we can't find it, but if we do find it we can break the
|
||||
# loop to avoid creating work for the garbage collector.
|
||||
problematic_cycle = ordered_dict.__dict__.get("_OrderedDict__root", None) # pylint: disable=protected-access
|
||||
if problematic_cycle:
|
||||
try:
|
||||
del problematic_cycle[0][:]
|
||||
except TypeError:
|
||||
# This is probably not one of the problematic Python versions. Continue
|
||||
# with the rest of our cleanup.
|
||||
pass
|
Loading…
Reference in New Issue
Block a user