From 52d9dbfa8ed7bc8b91f1a1be706cf77314b1c687 Mon Sep 17 00:00:00 2001
From: Igor Ganichev <iga@google.com>
Date: Wed, 12 Sep 2018 13:32:04 -0700
Subject: [PATCH] Use WeakKeyDictionaries for global Keras {graph->...} maps

These globals were holding onto graphs including FuncGraphs, which
held onto captured tensors leaving garbage around.

This change also adds a test to catch garbage like this in the future.
To make the test work, I needed to manually breakup some reference
cycles caused by OrderedDicts. We should probably have a custom impl
of OrderedDict similar to the one in Python3 and avoid these issues.

PiperOrigin-RevId: 212694290
---
 tensorflow/python/eager/function_test.py | 47 +++++++++++++++++++++++-
 tensorflow/python/framework/ops.py       | 19 ++--------
 tensorflow/python/framework/test_util.py | 40 ++++++++++++++++++++
 tensorflow/python/keras/backend.py       | 23 ++++++++----
 tensorflow/python/util/memory.py         | 45 +++++++++++++++++++++++
 5 files changed, 151 insertions(+), 23 deletions(-)
 create mode 100644 tensorflow/python/util/memory.py

diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index e6a49b66cf7..d2b1d9c8a7b 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -25,6 +25,7 @@ import sys
 import numpy
 
 from tensorflow.core.protobuf import config_pb2
+from tensorflow.python import keras
 from tensorflow.python.data.ops import iterator_ops
 from tensorflow.python.eager import backprop
 from tensorflow.python.eager import context
@@ -38,6 +39,7 @@ from tensorflow.python.framework import random_seed
 from tensorflow.python.framework import tensor_shape
 from tensorflow.python.framework import tensor_spec
 from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training as keras_training
 from tensorflow.python.layers import convolutional
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import clip_ops
@@ -57,6 +59,21 @@ from tensorflow.python.util import compat
 from tensorflow.python.util import nest
 
 
+class MiniModel(keras_training.Model):
+  """Minimal model for mnist.
+
+  Useful for testing and debugging on slow TPU simulators.
+  """
+
+  def __init__(self):
+    super(MiniModel, self).__init__(name='')
+    self.fc = keras.layers.Dense(1, name='fc', kernel_initializer='ones',
+                                 bias_initializer='ones')
+
+  def call(self, inputs, training=True):
+    return self.fc(inputs)
+
+
 @test_util.with_c_shapes
 class FunctionTest(test.TestCase):
 
@@ -1005,6 +1022,7 @@ class FunctionTest(test.TestCase):
       with ops.get_default_graph().as_default():
         create_variable()
 
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
   def testLayerInDefun(self):
     conv = convolutional.Conv2D(
         filters=1,
@@ -1018,7 +1036,34 @@ class FunctionTest(test.TestCase):
 
     x = array_ops.ones([1, 2, 2, 1])
     y = model(x)
-    self.assertAllEqual([[[[4.0]]]], y.numpy())
+
+    if not context.executing_eagerly():
+      self.evaluate(variables.global_variables_initializer())
+
+    self.assertAllEqual([[[[4.0]]]], self.evaluate(y))
+
+    # Remove reference cycles in model
+    test_util.dismantle_polymorphic_function(model)
+
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+  def testDefunKerasModelCall(self):
+    model = MiniModel()
+    model.call = function.defun(model.call)
+
+    x = array_ops.ones([1, 2])
+    y = model(x)
+
+    if not context.executing_eagerly():
+      self.evaluate(variables.global_variables_initializer())
+
+    self.assertAllEqual([[3.0]], self.evaluate(y))
+
+    # Remove reference cycles in defun.
+    test_util.dismantle_polymorphic_function(model.call)
+    # Break the reference cycle between the MiniModel and the defun:
+    # MiniModel --(through its `call` method)--> PolymorphicFunction
+    # PolymorphicFunction --(instancemethod on MiniModel)--> MiniModel
+    del model.call
 
   # Note: The ConfigProto below unfortunately only configures graph
   # construction. Eager's configuration is controlled in `__main__`.
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 75678cbc016..343f52fe8f3 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -58,6 +58,7 @@ from tensorflow.python.util import decorator_utils
 from tensorflow.python.util import deprecation
 from tensorflow.python.util import function_utils
 from tensorflow.python.util import lock_util
+from tensorflow.python.util import memory
 from tensorflow.python.util import tf_contextlib
 from tensorflow.python.util import tf_stack
 from tensorflow.python.util.deprecation import deprecated_args
@@ -5824,23 +5825,11 @@ def dismantle_graph(graph):
     graph: A `Graph` object to destroy. Neither it nor any of its ops are usable
       after this function runs.
   """
-  # pylint: disable=protected-access
-  # OrderedDict, constructed on Graph creation, makes a simple reference loop
-  # and hides it in an __attribute in some Python versions. We don't need to
-  # throw an error if we can't find it, but if we do find it we can break the
-  # loop to avoid creating work for the garbage collector.
-  graph_operations = graph.get_operations()
-  problematic_cycle = graph._functions.__dict__.get("_OrderedDict__root", None)
-  # pylint: enable=protected-access
-  if problematic_cycle:
-    try:
-      del problematic_cycle[0][:]
-    except TypeError:
-      # This is probably not one of the problematic Python versions. Continue
-      # with the rest of our cleanup.
-      pass
+  memory.dismantle_ordered_dict(graph._functions)  # pylint: disable=protected-access
+
   # Now clean up Operation<->Graph reference cycles by clearing all of the
   # attributes for the Graph and its ops.
+  graph_operations = graph.get_operations()
   for op in graph_operations:
     op.__dict__ = {}
   graph.__dict__ = {}
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 6a2c897f3f4..1cc3bb4628b 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -69,6 +69,7 @@ from tensorflow.python.platform import googletest
 from tensorflow.python.platform import tf_logging as logging
 from tensorflow.python.training import server_lib
 from tensorflow.python.util import compat
+from tensorflow.python.util import memory
 from tensorflow.python.util import nest
 from tensorflow.python.util import tf_inspect
 from tensorflow.python.util.protobuf import compare
@@ -2008,3 +2009,42 @@ def set_producer_version(graph, producer_version):
   with graph.as_default():
     importer.import_graph_def(graph_def)
   assert graph.graph_def_versions.producer, producer_version
+
+
+def dismantle_func_graph(func_graph):
+  """Removes reference cycles in `func_graph` FuncGraph.
+
+  Helpful for making sure the garbage collector doesn't need to run when
+  the FuncGraph goes out of scope, e.g. in tests using defun with
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
+
+  Args:
+    func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable
+      after this function.
+  """
+  # TODO(b/115366440): Delete this method when a custom OrderedDict is added.
+  # Clearing captures using clear() leaves some cycles around.
+  while func_graph.captures:
+    func_graph.captures.popitem()
+  memory.dismantle_ordered_dict(func_graph.captures)
+  ops.dismantle_graph(func_graph)
+
+
+def dismantle_polymorphic_function(func):
+  """Removes reference cycles in PolymorphicFunction `func`.
+
+  Helpful for making sure the garbage collector doesn't need to run when
+  PolymorphicFunction goes out of scope, e.g. in tests using defun with
+  @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True).
+
+  Args:
+    func: A `PolymorphicFunction` object to destroy. `func` is unusable
+      after this function.
+  """
+  # TODO(b/115366440): Delete this method when a custom OrderedDict is added
+  cache = func._function_cache  # pylint: disable=protected-access
+  for concrete_func in cache.values():
+    dismantle_func_graph(concrete_func.graph)
+  while cache:
+    cache.popitem()
+  memory.dismantle_ordered_dict(cache)
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 7768caeaf05..529b07dc12a 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -73,7 +73,16 @@ _SESSION = None
 # This dictionary holds a mapping {graph: learning_phase}.
 # A learning phase is a bool tensor used to run Keras models in
 # either train mode (learning_phase == 1) or test mode (learning_phase == 0).
-_GRAPH_LEARNING_PHASES = {}
+_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
+
+
+# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
+# We keep a separate reference to it to make sure it does not get removed from
+# _GRAPH_LEARNING_PHASES. We use a dummy class instead of something like a
+# string because strings are not weakly-referencable.
+class _DummyEagerGraph(object):
+  pass
+_DUMMY_EAGER_GRAPH = _DummyEagerGraph()
 
 # This boolean flag can be set to True to leave variable initialization
 # up to the user.
@@ -96,11 +105,11 @@ _LOCAL_DEVICES = None
 
 # This dictionary holds a mapping between a graph and variables to initialize
 # in the graph.
-_GRAPH_VARIABLES = {}
+_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
 
 # This dictionary holds a mapping between a graph and TF optimizers created in
 # the graph.
-_GRAPH_TF_OPTIMIZERS = {}
+_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
 
 
 @tf_export('keras.backend.backend')
@@ -359,10 +368,10 @@ def learning_phase():
       Learning phase (scalar integer tensor or Python integer).
   """
   if context.executing_eagerly():
-    if 'eager' not in _GRAPH_LEARNING_PHASES:
+    if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
       # Fallback to inference mode as default.
       return 0
-    return _GRAPH_LEARNING_PHASES['eager']
+    return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
 
   graph = ops.get_default_graph()
   if graph not in _GRAPH_LEARNING_PHASES:
@@ -386,7 +395,7 @@ def set_learning_phase(value):
   if value not in {0, 1}:
     raise ValueError('Expected learning phase to be 0 or 1.')
   if context.executing_eagerly():
-    _GRAPH_LEARNING_PHASES['eager'] = value
+    _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
   else:
     _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = value
 
@@ -415,7 +424,7 @@ def learning_phase_scope(value):
   finally:
     # Restore learning phase to initial value.
     if context.executing_eagerly():
-      _GRAPH_LEARNING_PHASES['eager'] = previous_value
+      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
     else:
       _GRAPH_LEARNING_PHASES[ops.get_default_graph()] = previous_value
 
diff --git a/tensorflow/python/util/memory.py b/tensorflow/python/util/memory.py
new file mode 100644
index 00000000000..e78f6d509a4
--- /dev/null
+++ b/tensorflow/python/util/memory.py
@@ -0,0 +1,45 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Functions related to Python memory management."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+# TODO(b/115366440): Delete this function when a custom OrderedDict is added
+def dismantle_ordered_dict(ordered_dict):
+  """Remove reference cycle in OrderedDict `ordered_dict`.
+
+  Helpful for making sure the garbage collector doesn't need to run after
+  using an OrderedDict.
+
+  Args:
+    ordered_dict: A `OrderedDict` object to destroy. This object is unusable
+      after this function runs.
+  """
+  # OrderedDict, makes a simple reference loop
+  # and hides it in an __attribute in some Python versions. We don't need to
+  # throw an error if we can't find it, but if we do find it we can break the
+  # loop to avoid creating work for the garbage collector.
+  problematic_cycle = ordered_dict.__dict__.get("_OrderedDict__root", None)  # pylint: disable=protected-access
+  if problematic_cycle:
+    try:
+      del problematic_cycle[0][:]
+    except TypeError:
+      # This is probably not one of the problematic Python versions. Continue
+      # with the rest of our cleanup.
+      pass