Use WeakKeyDictionaries for global Keras {graph->...} maps
These globals were holding onto graphs including FuncGraphs, which held onto captured tensors leaving garbage around. This change also adds a test to catch garbage like this in the future. To make the test work, I needed to manually breakup some reference cycles caused by OrderedDicts. We should probably have a custom impl of OrderedDict similar to the one in Python3 and avoid these issues. PiperOrigin-RevId: 212694290
This commit is contained in:
parent
5d1de24583
commit
52d9dbfa8e
@ -25,6 +25,7 @@ import sys
|
|||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -38,6 +39,7 @@ from tensorflow.python.framework import random_seed
|
|||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_spec
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
|
from tensorflow.python.keras.engine import training as keras_training
|
||||||
from tensorflow.python.layers import convolutional
|
from tensorflow.python.layers import convolutional
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import clip_ops
|
from tensorflow.python.ops import clip_ops
|
||||||
@ -57,6 +59,21 @@ from tensorflow.python.util import compat
|
|||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
|
class MiniModel(keras_training.Model):
|
||||||
|
"""Minimal model for mnist.
|
||||||
|
|
||||||
|
Useful for testing and debugging on slow TPU simulators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MiniModel, self).__init__(name='')
|
||||||
|
self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
|
||||||
|
bias_initializer='ones')
|
||||||
|
|
||||||
|
def call(self, inputs, training=True):
|
||||||
|
return self.fc(inputs)
|
||||||
|
|
||||||
|
|
||||||
@test_util.with_c_shapes
|
@test_util.with_c_shapes
|
||||||
class FunctionTest(test.TestCase):
|
class FunctionTest(test.TestCase):
|
||||||
|
|
||||||
@ -1005,6 +1022,7 @@ class FunctionTest(test.TestCase):
|
|||||||
with ops.get_default_graph().as_default():
|
with ops.get_default_graph().as_default():
|
||||||
create_variable()
|
create_variable()
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||||
def testLayerInDefun(self):
|
def testLayerInDefun(self):
|
||||||
conv = convolutional.Conv2D(
|
conv = convolutional.Conv2D(
|
||||||
filters=1,
|
filters=1,
|
||||||
@ -1018,7 +1036,34 @@ class FunctionTest(test.TestCase):
|
|||||||
|
|
||||||
x = array_ops.ones([1, 2, 2, 1])
|
x = array_ops.ones([1, 2, 2, 1])
|
||||||
y = model(x)
|
y = model(x)
|
||||||
self.assertAllEqual([[[[4.0]]]], y.numpy())
|
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
self.assertAllEqual([[[[4.0]]]], self.evaluate(y))
|
||||||
|
|
||||||
|
# Remove reference cycles in model
|
||||||
|
test_util.dismantle_polymorphic_function(model)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
|
||||||
|
def testDefunKerasModelCall(self):
|
||||||
|
model = MiniModel()
|
||||||
|
model.call = function.defun(model.call)
|
||||||
|
|
||||||
|
x = array_ops.ones([1, 2])
|
||||||
|
y = model(x)
|
||||||
|
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
self.evaluate(variables.global_variables_initializer())
|
||||||
|
|
||||||
|
self.assertAllEqual([[3.0]], self.evaluate(y))
|
||||||
|
|
||||||
|
# Remove reference cycles in defun.
|
||||||
|
test_util.dismantle_polymorphic_function(model.call)
|
||||||
|
# Break the reference cycle between the MiniModel and the defun:
|
||||||
|
# MiniModel --(through its `call` method)--> PolymorphicFunction
|
||||||
|
# PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel
|
||||||
|
del model.call
|
||||||
|
|
||||||
# Note: The ConfigProto below unfortunately only configures graph
|
# Note: The ConfigProto below unfortunately only configures graph
|
||||||
# construction. Eager's configuration is controlled in `__main__`.
|
# construction. Eager's configuration is controlled in `__main__`.
|
||||||
|
@ -58,6 +58,7 @@ from tensorflow.python.util import decorator_utils
|
|||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
from tensorflow.python.util import function_utils
|
from tensorflow.python.util import function_utils
|
||||||
from tensorflow.python.util import lock_util
|
from tensorflow.python.util import lock_util
|
||||||
|
from tensorflow.python.util import memory
|
||||||
from tensorflow.python.util import tf_contextlib
|
from tensorflow.python.util import tf_contextlib
|
||||||
from tensorflow.python.util import tf_stack
|
from tensorflow.python.util import tf_stack
|
||||||
from tensorflow.python.util.deprecation import deprecated_args
|
from tensorflow.python.util.deprecation import deprecated_args
|
||||||
@ -5824,23 +5825,11 @@ def dismantle_graph(graph):
|
|||||||
graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
|
graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
|
||||||
after this function runs.
|
after this function runs.
|
||||||
"""
|
"""
|
||||||
# pylint: disable=protected-access
|
memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access
|
||||||
# OrderedDict, constructed on Graph creation, makes a simple reference loop
|
|
||||||
# and hides it in an __attribute in some Python versions. We don't need to
|
|
||||||
# throw an error if we can't find it, but if we do find it we can break the
|
|
||||||
# loop to avoid creating work for the garbage collector.
|
|
||||||
graph_operations = graph.get_operations()
|
|
||||||
problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
if problematic_cycle:
|
|
||||||
try:
|
|
||||||
del problematic_cycle[0][:]
|
|
||||||
except TypeError:
|
|
||||||
# This is probably not one of the problematic Python versions. Continue
|
|
||||||
# with the rest of our cleanup.
|
|
||||||
pass
|
|
||||||
# Now clean up Operation<->Graph reference cycles by clearing all of the
|
# Now clean up Operation<->Graph reference cycles by clearing all of the
|
||||||
# attributes for the Graph and its ops.
|
# attributes for the Graph and its ops.
|
||||||
|
graph_operations = graph.get_operations()
|
||||||
for op in graph_operations:
|
for op in graph_operations:
|
||||||
op.__dict__ = {}
|
op.__dict__ = {}
|
||||||
graph.__dict__ = {}
|
graph.__dict__ = {}
|
||||||
|
@ -69,6 +69,7 @@ from tensorflow.python.platform import googletest
|
|||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
from tensorflow.python.util import memory
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_inspect
|
from tensorflow.python.util import tf_inspect
|
||||||
from tensorflow.python.util.protobuf import compare
|
from tensorflow.python.util.protobuf import compare
|
||||||
@ -2008,3 +2009,42 @@ def set_producer_version(graph, producer_version):
|
|||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
importer.import_graph_def(graph_def)
|
importer.import_graph_def(graph_def)
|
||||||
assert graph.graph_def_versions.producer, producer_version
|
assert graph.graph_def_versions.producer, producer_version
|
||||||
|
|
||||||
|
|
||||||
|
def dismantle_func_graph(func_graph):
|
||||||
|
"""Removes reference cycles in `func_graph` FuncGraph.
|
||||||
|
|
||||||
|
Helpful for making sure the garbage collector doesn't need to run when
|
||||||
|
the FuncGraph goes out of scope, e.g. in tests using defun with
|
||||||
|
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
|
||||||
|
after this function.
|
||||||
|
"""
|
||||||
|
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
|
||||||
|
# Clearing captures using clear() leaves some cycles around.
|
||||||
|
while func_graph.captures:
|
||||||
|
func_graph.captures.popitem()
|
||||||
|
memory.dismantle_ordered_dict(func_graph.captures)
|
||||||
|
ops.dismantle_graph(func_graph)
|
||||||
|
|
||||||
|
|
||||||
|
def dismantle_polymorphic_function(func):
|
||||||
|
"""Removes reference cycles in PolymorphicFunction `func`.
|
||||||
|
|
||||||
|
Helpful for making sure the garbage collector doesn't need to run when
|
||||||
|
PolymorphicFunction goes out of scope, e.g. in tests using defun with
|
||||||
|
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: A `PolymorphicFunction` object to destroy. `func` is unusable
|
||||||
|
after this function.
|
||||||
|
"""
|
||||||
|
# TODO(b/115366440): Delete this method when a custom OrderedDict is added
|
||||||
|
cache = func._function_cache # pylint: disable=protected-access
|
||||||
|
for concrete_func in cache.values():
|
||||||
|
dismantle_func_graph(concrete_func.graph)
|
||||||
|
while cache:
|
||||||
|
cache.popitem()
|
||||||
|
memory.dismantle_ordered_dict(cache)
|
||||||
|
@ -73,7 +73,16 @@ _SESSION = None
|
|||||||
# This dictionary holds a mapping {graph: learning_phase}.
|
# This dictionary holds a mapping {graph: learning_phase}.
|
||||||
# A learning phase is a bool tensor used to run Keras models in
|
# A learning phase is a bool tensor used to run Keras models in
|
||||||
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
|
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
|
||||||
_GRAPH_LEARNING_PHASES = {}
|
_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
|
|
||||||
|
# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
|
||||||
|
# We keep a separate reference to it to make sure it does not get removed from
|
||||||
|
# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a
|
||||||
|
# string because strings are not weakly-referencable.
|
||||||
|
class _DummyEagerGraph(object):
|
||||||
|
pass
|
||||||
|
_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
|
||||||
|
|
||||||
# This boolean flag can be set to True to leave variable initialization
|
# This boolean flag can be set to True to leave variable initialization
|
||||||
# up to the user.
|
# up to the user.
|
||||||
@ -96,11 +105,11 @@ _LOCAL_DEVICES = None
|
|||||||
|
|
||||||
# This dictionary holds a mapping between a graph and variables to initialize
|
# This dictionary holds a mapping between a graph and variables to initialize
|
||||||
# in the graph.
|
# in the graph.
|
||||||
_GRAPH_VARIABLES = {}
|
_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
# This dictionary holds a mapping between a graph and TF optimizers created in
|
# This dictionary holds a mapping between a graph and TF optimizers created in
|
||||||
# the graph.
|
# the graph.
|
||||||
_GRAPH_TF_OPTIMIZERS = {}
|
_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
|
|
||||||
@tf_export('keras.backend.backend')
|
@tf_export('keras.backend.backend')
|
||||||
@ -359,10 +368,10 @@ def learning_phase():
|
|||||||
Learning phase (scalar integer tensor or Python integer).
|
Learning phase (scalar integer tensor or Python integer).
|
||||||
"""
|
"""
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
if 'eager' not in _GRAPH_LEARNING_PHASES:
|
if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
|
||||||
# Fallback to inference mode as default.
|
# Fallback to inference mode as default.
|
||||||
return 0
|
return 0
|
||||||
return _GRAPH_LEARNING_PHASES['eager']
|
return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
|
||||||
|
|
||||||
graph = ops.get_default_graph()
|
graph = ops.get_default_graph()
|
||||||
if graph not in _GRAPH_LEARNING_PHASES:
|
if graph not in _GRAPH_LEARNING_PHASES:
|
||||||
@ -386,7 +395,7 @@ def set_learning_phase(value):
|
|||||||
if value not in {0, 1}:
|
if value not in {0, 1}:
|
||||||
raise ValueError('Expected learning phase to be 0 or 1.')
|
raise ValueError('Expected learning phase to be 0 or 1.')
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
_GRAPH_LEARNING_PHASES['eager'] = value
|
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
|
||||||
else:
|
else:
|
||||||
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
|
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
|
||||||
|
|
||||||
@ -415,7 +424,7 @@ def learning_phase_scope(value):
|
|||||||
finally:
|
finally:
|
||||||
# Restore learning phase to initial value.
|
# Restore learning phase to initial value.
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
_GRAPH_LEARNING_PHASES['eager'] = previous_value
|
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
|
||||||
else:
|
else:
|
||||||
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
|
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
|
||||||
|
|
||||||
|
45
tensorflow/python/util/memory.py
Normal file
45
tensorflow/python/util/memory.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Functions related to Python memory management."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(b/115366440): Delete this function when a custom OrderedDict is added
|
||||||
|
def dismantle_ordered_dict(ordered_dict):
|
||||||
|
"""Remove reference cycle in OrderedDict `ordered_dict`.
|
||||||
|
|
||||||
|
Helpful for making sure the garbage collector doesn't need to run after
|
||||||
|
using an OrderedDict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ordered_dict: A `OrderedDict` object to destroy. This object is unusable
|
||||||
|
after this function runs.
|
||||||
|
"""
|
||||||
|
# OrderedDict, makes a simple reference loop
|
||||||
|
# and hides it in an __attribute in some Python versions. We don't need to
|
||||||
|
# throw an error if we can't find it, but if we do find it we can break the
|
||||||
|
# loop to avoid creating work for the garbage collector.
|
||||||
|
problematic_cycle = ordered_dict.__dict__.get("_OrderedDict__root", None) # pylint: disable=protected-access
|
||||||
|
if problematic_cycle:
|
||||||
|
try:
|
||||||
|
del problematic_cycle[0][:]
|
||||||
|
except TypeError:
|
||||||
|
# This is probably not one of the problematic Python versions. Continue
|
||||||
|
# with the rest of our cleanup.
|
||||||
|
pass
|
Loading…
Reference in New Issue
Block a user