From 52d9dbfa8ed7bc8b91f1a1be706cf77314b1c687 Mon Sep 17 00:00:00 2001 From: Igor Ganichev <iga@google.com> Date: Wed, 12 Sep 2018 13:32:04 -0700 Subject: [PATCH] Use WeakKeyDictionaries for global Keras {graph->...} maps These globals were holding onto graphs including FuncGraphs, which held onto captured tensors leaving garbage around. This change also adds a test to catch garbage like this in the future. To make the test work, I needed to manually breakup some reference cycles caused by OrderedDicts. We should probably have a custom impl of OrderedDict similar to the one in Python3 and avoid these issues. PiperOrigin-RevId: 212694290 --- tensorflow/python/eager/function_test.py | 47 +++++++++++++++++++++++- tensorflow/python/framework/ops.py | 19 ++-------- tensorflow/python/framework/test_util.py | 40 ++++++++++++++++++++ tensorflow/python/keras/backend.py | 23 ++++++++---- tensorflow/python/util/memory.py | 45 +++++++++++++++++++++++ 5 files changed, 151 insertions(+), 23 deletions(-) create mode 100644 tensorflow/python/util/memory.py diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index e6a49b66cf7..d2b1d9c8a7b 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -25,6 +25,7 @@ import sys import numpy from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import keras from tensorflow.python.data.ops import iterator_ops from tensorflow.python.eager import backprop from tensorflow.python.eager import context @@ -38,6 +39,7 @@ from tensorflow.python.framework import random_seed from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util +from tensorflow.python.keras.engine import training as keras_training from tensorflow.python.layers import convolutional from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops @@ -57,6 +59,21 @@ from tensorflow.python.util import compat from tensorflow.python.util import nest +class MiniModel(keras_training.Model): + """Minimal model for mnist. + + Useful for testing and debugging on slow TPU simulators. + """ + + def __init__(self): + super(MiniModel, self).__init__(name='') + self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones', + bias_initializer='ones') + + def call(self, inputs, training=True): + return self.fc(inputs) + + @test_util.with_c_shapes class FunctionTest(test.TestCase): @@ -1005,6 +1022,7 @@ class FunctionTest(test.TestCase): with ops.get_default_graph().as_default(): create_variable() + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testLayerInDefun(self): conv = convolutional.Conv2D( filters=1, @@ -1018,7 +1036,34 @@ class FunctionTest(test.TestCase): x = array_ops.ones([1, 2, 2, 1]) y = model(x) - self.assertAllEqual([[[[4.0]]]], y.numpy()) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + + self.assertAllEqual([[[[4.0]]]], self.evaluate(y)) + + # Remove reference cycles in model + test_util.dismantle_polymorphic_function(model) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testDefunKerasModelCall(self): + model = MiniModel() + model.call = function.defun(model.call) + + x = array_ops.ones([1, 2]) + y = model(x) + + if not context.executing_eagerly(): + self.evaluate(variables.global_variables_initializer()) + + self.assertAllEqual([[3.0]], self.evaluate(y)) + + # Remove reference cycles in defun. + test_util.dismantle_polymorphic_function(model.call) + # Break the reference cycle between the MiniModel and the defun: + # MiniModel --(through its `call` method)--> PolymorphicFunction + # PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel + del model.call # Note: The ConfigProto below unfortunately only configures graph # construction. Eager's configuration is controlled in `__main__`. diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 75678cbc016..343f52fe8f3 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -58,6 +58,7 @@ from tensorflow.python.util import decorator_utils from tensorflow.python.util import deprecation from tensorflow.python.util import function_utils from tensorflow.python.util import lock_util +from tensorflow.python.util import memory from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_stack from tensorflow.python.util.deprecation import deprecated_args @@ -5824,23 +5825,11 @@ def dismantle_graph(graph): graph: A `Graph` object to destroy. Neither it nor any of its ops are usable after this function runs. """ - # pylint: disable=protected-access - # OrderedDict, constructed on Graph creation, makes a simple reference loop - # and hides it in an __attribute in some Python versions. We don't need to - # throw an error if we can't find it, but if we do find it we can break the - # loop to avoid creating work for the garbage collector. - graph_operations = graph.get_operations() - problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None) - # pylint: enable=protected-access - if problematic_cycle: - try: - del problematic_cycle[0][:] - except TypeError: - # This is probably not one of the problematic Python versions. Continue - # with the rest of our cleanup. - pass + memory.dismantle_ordered_dict(graph._functions) # pylint: disable=protected-access + # Now clean up Operation<->Graph reference cycles by clearing all of the # attributes for the Graph and its ops. + graph_operations = graph.get_operations() for op in graph_operations: op.__dict__ = {} graph.__dict__ = {} diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 6a2c897f3f4..1cc3bb4628b 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -69,6 +69,7 @@ from tensorflow.python.platform import googletest from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib from tensorflow.python.util import compat +from tensorflow.python.util import memory from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect from tensorflow.python.util.protobuf import compare @@ -2008,3 +2009,42 @@ def set_producer_version(graph, producer_version): with graph.as_default(): importer.import_graph_def(graph_def) assert graph.graph_def_versions.producer, producer_version + + +def dismantle_func_graph(func_graph): + """Removes reference cycles in `func_graph` FuncGraph. + + Helpful for making sure the garbage collector doesn't need to run when + the FuncGraph goes out of scope, e.g. in tests using defun with + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). + + Args: + func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable + after this function. + """ + # TODO(b/115366440): Delete this method when a custom OrderedDict is added. + # Clearing captures using clear() leaves some cycles around. + while func_graph.captures: + func_graph.captures.popitem() + memory.dismantle_ordered_dict(func_graph.captures) + ops.dismantle_graph(func_graph) + + +def dismantle_polymorphic_function(func): + """Removes reference cycles in PolymorphicFunction `func`. + + Helpful for making sure the garbage collector doesn't need to run when + PolymorphicFunction goes out of scope, e.g. in tests using defun with + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). + + Args: + func: A `PolymorphicFunction` object to destroy. `func` is unusable + after this function. + """ + # TODO(b/115366440): Delete this method when a custom OrderedDict is added + cache = func._function_cache # pylint: disable=protected-access + for concrete_func in cache.values(): + dismantle_func_graph(concrete_func.graph) + while cache: + cache.popitem() + memory.dismantle_ordered_dict(cache) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 7768caeaf05..529b07dc12a 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -73,7 +73,16 @@ _SESSION = None # This dictionary holds a mapping {graph: learning_phase}. # A learning phase is a bool tensor used to run Keras models in # either train mode (learning_phase == 1) or test mode (learning_phase == 0). -_GRAPH_LEARNING_PHASES = {} +_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary() + + +# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES. +# We keep a separate reference to it to make sure it does not get removed from +# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a +# string because strings are not weakly-referencable. +class _DummyEagerGraph(object): + pass +_DUMMY_EAGER_GRAPH = _DummyEagerGraph() # This boolean flag can be set to True to leave variable initialization # up to the user. @@ -96,11 +105,11 @@ _LOCAL_DEVICES = None # This dictionary holds a mapping between a graph and variables to initialize # in the graph. -_GRAPH_VARIABLES = {} +_GRAPH_VARIABLES = weakref.WeakKeyDictionary() # This dictionary holds a mapping between a graph and TF optimizers created in # the graph. -_GRAPH_TF_OPTIMIZERS = {} +_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary() @tf_export('keras.backend.backend') @@ -359,10 +368,10 @@ def learning_phase(): Learning phase (scalar integer tensor or Python integer). """ if context.executing_eagerly(): - if 'eager' not in _GRAPH_LEARNING_PHASES: + if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES: # Fallback to inference mode as default. return 0 - return _GRAPH_LEARNING_PHASES['eager'] + return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] graph = ops.get_default_graph() if graph not in _GRAPH_LEARNING_PHASES: @@ -386,7 +395,7 @@ def set_learning_phase(value): if value not in {0, 1}: raise ValueError('Expected learning phase to be 0 or 1.') if context.executing_eagerly(): - _GRAPH_LEARNING_PHASES['eager'] = value + _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value else: _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value @@ -415,7 +424,7 @@ def learning_phase_scope(value): finally: # Restore learning phase to initial value. if context.executing_eagerly(): - _GRAPH_LEARNING_PHASES['eager'] = previous_value + _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value else: _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value diff --git a/tensorflow/python/util/memory.py b/tensorflow/python/util/memory.py new file mode 100644 index 00000000000..e78f6d509a4 --- /dev/null +++ b/tensorflow/python/util/memory.py @@ -0,0 +1,45 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Functions related to Python memory management.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +# TODO(b/115366440): Delete this function when a custom OrderedDict is added +def dismantle_ordered_dict(ordered_dict): + """Remove reference cycle in OrderedDict `ordered_dict`. + + Helpful for making sure the garbage collector doesn't need to run after + using an OrderedDict. + + Args: + ordered_dict: A `OrderedDict` object to destroy. This object is unusable + after this function runs. + """ + # OrderedDict, makes a simple reference loop + # and hides it in an __attribute in some Python versions. We don't need to + # throw an error if we can't find it, but if we do find it we can break the + # loop to avoid creating work for the garbage collector. + problematic_cycle = ordered_dict.__dict__.get("_OrderedDict__root", None) # pylint: disable=protected-access + if problematic_cycle: + try: + del problematic_cycle[0][:] + except TypeError: + # This is probably not one of the problematic Python versions. Continue + # with the rest of our cleanup. + pass