Use WeakKeyDictionaries for global Keras {graph->...} maps

These globals were holding onto graphs including FuncGraphs, which
held onto captured tensors leaving garbage around.

This change also adds a test to catch garbage like this in the future.
To make the test work, I needed to manually breakup some reference
cycles caused by OrderedDicts. We should probably have a custom impl
of OrderedDict similar to the one in Python3 and avoid these issues.

PiperOrigin-RevId: 212694290
This commit is contained in:
Igor Ganichev 2018-09-12 13:32:04 -07:00 committed by TensorFlower Gardener
parent 5d1de24583
commit 52d9dbfa8e
5 changed files with 151 additions and 23 deletions

View File

@ -25,6 +25,7 @@ import sys
import numpy import numpy
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -38,6 +39,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras.engine import training as keras_training
from tensorflow.python.layers import convolutional from tensorflow.python.layers import convolutional
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops from tensorflow.python.ops import clip_ops
@ -57,6 +59,21 @@ from tensorflow.python.util import compat
from tensorflow.python.util import nest from tensorflow.python.util import nest
class MiniModel(keras_training.Model):
"""Minimal model for mnist.
Useful for testing and debugging on slow TPU simulators.
"""
def __init__(self):
super(MiniModel, self).__init__(name='')
self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
bias_initializer='ones')
def call(self, inputs, training=True):
return self.fc(inputs)
@test_util.with_c_shapes @test_util.with_c_shapes
class FunctionTest(test.TestCase): class FunctionTest(test.TestCase):
@ -1005,6 +1022,7 @@ class FunctionTest(test.TestCase):
with ops.get_default_graph().as_default(): with ops.get_default_graph().as_default():
create_variable() create_variable()
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testLayerInDefun(self): def testLayerInDefun(self):
conv = convolutional.Conv2D( conv = convolutional.Conv2D(
filters=1, filters=1,
@ -1018,7 +1036,34 @@ class FunctionTest(test.TestCase):
x = array_ops.ones([1, 2, 2, 1]) x = array_ops.ones([1, 2, 2, 1])
y = model(x) y = model(x)
self.assertAllEqual([[[[4.0]]]], y.numpy())
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[[[4.0]]]], self.evaluate(y))
# Remove reference cycles in model
test_util.dismantle_polymorphic_function(model)
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testDefunKerasModelCall(self):
model = MiniModel()
model.call = function.defun(model.call)
x = array_ops.ones([1, 2])
y = model(x)
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[3.0]], self.evaluate(y))
# Remove reference cycles in defun.
test_util.dismantle_polymorphic_function(model.call)
# Break the reference cycle between the MiniModel and the defun:
# MiniModel --(through its `call` method)--> PolymorphicFunction
# PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel
del model.call
# Note: The ConfigProto below unfortunately only configures graph # Note: The ConfigProto below unfortunately only configures graph
# construction. Eager's configuration is controlled in `__main__`. # construction. Eager's configuration is controlled in `__main__`.

View File

@ -58,6 +58,7 @@ from tensorflow.python.util import decorator_utils
from tensorflow.python.util import deprecation from tensorflow.python.util import deprecation
from tensorflow.python.util import function_utils from tensorflow.python.util import function_utils
from tensorflow.python.util import lock_util from tensorflow.python.util import lock_util
from tensorflow.python.util import memory
from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_stack from tensorflow.python.util import tf_stack
from tensorflow.python.util.deprecation import deprecated_args from tensorflow.python.util.deprecation import deprecated_args
@ -5824,23 +5825,11 @@ def dismantle_graph(graph):
graph: A `Graph` object to destroy. Neither it nor any of its ops are usable graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
after this function runs. after this function runs.
""" """
# pylint: disable=protected-access memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access
# OrderedDict, constructed on Graph creation, makes a simple reference loop
# and hides it in an __attribute in some Python versions. We don't need to
# throw an error if we can't find it, but if we do find it we can break the
# loop to avoid creating work for the garbage collector.
graph_operations = graph.get_operations()
problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
# pylint: enable=protected-access
if problematic_cycle:
try:
del problematic_cycle[0][:]
except TypeError:
# This is probably not one of the problematic Python versions. Continue
# with the rest of our cleanup.
pass
# Now clean up Operation<->Graph reference cycles by clearing all of the # Now clean up Operation<->Graph reference cycles by clearing all of the
# attributes for the Graph and its ops. # attributes for the Graph and its ops.
graph_operations = graph.get_operations()
for op in graph_operations: for op in graph_operations:
op.__dict__ = {} op.__dict__ = {}
graph.__dict__ = {} graph.__dict__ = {}

View File

@ -69,6 +69,7 @@ from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib from tensorflow.python.training import server_lib
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import memory
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
from tensorflow.python.util.protobuf import compare from tensorflow.python.util.protobuf import compare
@ -2008,3 +2009,42 @@ def set_producer_version(graph, producer_version):
with graph.as_default(): with graph.as_default():
importer.import_graph_def(graph_def) importer.import_graph_def(graph_def)
assert graph.graph_def_versions.producer, producer_version assert graph.graph_def_versions.producer, producer_version
def dismantle_func_graph(func_graph):
"""Removes reference cycles in `func_graph` FuncGraph.
Helpful for making sure the garbage collector doesn't need to run when
the FuncGraph goes out of scope, e.g. in tests using defun with
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
Args:
func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
after this function.
"""
# TODO(b/115366440): Delete this method when a custom OrderedDict is added.
# Clearing captures using clear() leaves some cycles around.
while func_graph.captures:
func_graph.captures.popitem()
memory.dismantle_ordered_dict(func_graph.captures)
ops.dismantle_graph(func_graph)
def dismantle_polymorphic_function(func):
"""Removes reference cycles in PolymorphicFunction `func`.
Helpful for making sure the garbage collector doesn't need to run when
PolymorphicFunction goes out of scope, e.g. in tests using defun with
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
Args:
func: A `PolymorphicFunction` object to destroy. `func` is unusable
after this function.
"""
# TODO(b/115366440): Delete this method when a custom OrderedDict is added
cache = func._function_cache # pylint: disable=protected-access
for concrete_func in cache.values():
dismantle_func_graph(concrete_func.graph)
while cache:
cache.popitem()
memory.dismantle_ordered_dict(cache)

View File

@ -73,7 +73,16 @@ _SESSION = None
# This dictionary holds a mapping {graph: learning_phase}. # This dictionary holds a mapping {graph: learning_phase}.
# A learning phase is a bool tensor used to run Keras models in # A learning phase is a bool tensor used to run Keras models in
# either train mode (learning_phase == 1) or test mode (learning_phase == 0). # either train mode (learning_phase == 1) or test mode (learning_phase == 0).
_GRAPH_LEARNING_PHASES = {} _GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
# We keep a separate reference to it to make sure it does not get removed from
# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a
# string because strings are not weakly-referencable.
class _DummyEagerGraph(object):
pass
_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
# This boolean flag can be set to True to leave variable initialization # This boolean flag can be set to True to leave variable initialization
# up to the user. # up to the user.
@ -96,11 +105,11 @@ _LOCAL_DEVICES = None
# This dictionary holds a mapping between a graph and variables to initialize # This dictionary holds a mapping between a graph and variables to initialize
# in the graph. # in the graph.
_GRAPH_VARIABLES = {} _GRAPH_VARIABLES = weakref.WeakKeyDictionary()
# This dictionary holds a mapping between a graph and TF optimizers created in # This dictionary holds a mapping between a graph and TF optimizers created in
# the graph. # the graph.
_GRAPH_TF_OPTIMIZERS = {} _GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
@tf_export('keras.backend.backend') @tf_export('keras.backend.backend')
@ -359,10 +368,10 @@ def learning_phase():
Learning phase (scalar integer tensor or Python integer). Learning phase (scalar integer tensor or Python integer).
""" """
if context.executing_eagerly(): if context.executing_eagerly():
if 'eager' not in _GRAPH_LEARNING_PHASES: if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
# Fallback to inference mode as default. # Fallback to inference mode as default.
return 0 return 0
return _GRAPH_LEARNING_PHASES['eager'] return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
graph = ops.get_default_graph() graph = ops.get_default_graph()
if graph not in _GRAPH_LEARNING_PHASES: if graph not in _GRAPH_LEARNING_PHASES:
@ -386,7 +395,7 @@ def set_learning_phase(value):
if value not in {0, 1}: if value not in {0, 1}:
raise ValueError('Expected learning phase to be 0 or 1.') raise ValueError('Expected learning phase to be 0 or 1.')
if context.executing_eagerly(): if context.executing_eagerly():
_GRAPH_LEARNING_PHASES['eager'] = value _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
else: else:
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
@ -415,7 +424,7 @@ def learning_phase_scope(value):
finally: finally:
# Restore learning phase to initial value. # Restore learning phase to initial value.
if context.executing_eagerly(): if context.executing_eagerly():
_GRAPH_LEARNING_PHASES['eager'] = previous_value _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
else: else:
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value

View File

@ -0,0 +1,45 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions related to Python memory management."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# TODO(b/115366440): Delete this function when a custom OrderedDict is added
def dismantle_ordered_dict(ordered_dict):
"""Remove reference cycle in OrderedDict `ordered_dict`.
Helpful for making sure the garbage collector doesn't need to run after
using an OrderedDict.
Args:
ordered_dict: A `OrderedDict` object to destroy. This object is unusable
after this function runs.
"""
# OrderedDict, makes a simple reference loop
# and hides it in an __attribute in some Python versions. We don't need to
# throw an error if we can't find it, but if we do find it we can break the
# loop to avoid creating work for the garbage collector.
problematic_cycle = ordered_dict.__dict__.get("_OrderedDict__root", None) # pylint: disable=protected-access
if problematic_cycle:
try:
del problematic_cycle[0][:]
except TypeError:
# This is probably not one of the problematic Python versions. Continue
# with the rest of our cleanup.
pass