diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 05f97256d76..f83ed74c2f8 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -98,15 +98,6 @@ _CURRENT_SCRATCH_GRAPH = None # used by Keras. It can be set manually via `set_session(sess)`. _SESSION = threading.local() -# This dictionary holds a mapping {graph: learning_phase}. -# A learning phase is a bool tensor used to run Keras models in -# either train mode (learning_phase == 1) or test mode (learning_phase == 0). -_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary() - -# This dictionary holds a mapping {graph: set_of_freezable_variables}. -# Each set tracks objects created via `freezable_variable` in the graph. -_FREEZABLE_VARS = weakref.WeakKeyDictionary() - # _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES. # We keep a separate reference to it to make sure it does not get removed from @@ -154,14 +145,6 @@ _MANUAL_VAR_INIT = False # We assume our devices don't change henceforth. _LOCAL_DEVICES = None -# This dictionary holds a mapping between a graph and variables to initialize -# in the graph. -_GRAPH_VARIABLES = weakref.WeakKeyDictionary() - -# This dictionary holds a mapping between a graph and TF optimizers created in -# the graph. -_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary() - # The below functions are kept accessible from backend for compatibility. epsilon = backend_config.epsilon floatx = backend_config.floatx @@ -293,11 +276,9 @@ def clear_session(): _SESSION.session = None graph = get_graph() with graph.as_default(): - with name_scope(''): - phase = array_ops.placeholder_with_default( - False, shape=(), name='keras_learning_phase') - _GRAPH_LEARNING_PHASES = {} - _GRAPH_LEARNING_PHASES[graph] = phase + _GRAPH_LEARNING_PHASES.clear() + # Create the learning phase placeholder in graph using the default factory. + _GRAPH_LEARNING_PHASES.setdefault(graph) _GRAPH_VARIABLES.pop(graph, None) _GRAPH_TF_OPTIMIZERS.pop(graph, None) _FREEZABLE_VARS.pop(graph, None) @@ -336,24 +317,18 @@ def learning_phase(): # Don't enter an init_scope for the learning phase if eager execution # is enabled but we're inside the Keras workspace graph. learning_phase = symbolic_learning_phase() - _mark_func_graph_as_unsaveable(graph, learning_phase) - return learning_phase - with ops.init_scope(): - # We always check & set the learning phase inside the init_scope, - # otherwise the wrong default_graph will be used to look up the learning - # phase inside of functions & defuns. - # - # This is because functions & defuns (both in graph & in eager mode) - # will always execute non-eagerly using a function-specific default - # subgraph. - if context.executing_eagerly(): - if _DUMMY_EAGER_GRAPH.key not in _GRAPH_LEARNING_PHASES: - # Fallback to inference mode as default. - return 0 - return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key] - learning_phase = symbolic_learning_phase() - _mark_func_graph_as_unsaveable(graph, learning_phase) - return learning_phase + else: + with ops.init_scope(): + # We always check & set the learning phase inside the init_scope, + # otherwise the wrong default_graph will be used to look up the learning + # phase inside of functions & defuns. + # + # This is because functions & defuns (both in graph & in eager mode) + # will always execute non-eagerly using a function-specific default + # subgraph. + learning_phase = _GRAPH_LEARNING_PHASES[None] + _mark_func_graph_as_unsaveable(graph, learning_phase) + return learning_phase def global_learning_phase_is_set(): @@ -382,14 +357,18 @@ def _mark_func_graph_as_unsaveable(graph, learning_phase): def symbolic_learning_phase(): graph = get_graph() with graph.as_default(): - if graph not in _GRAPH_LEARNING_PHASES: - with name_scope(''): - phase = array_ops.placeholder_with_default( - False, shape=(), name='keras_learning_phase') - _GRAPH_LEARNING_PHASES[graph] = phase return _GRAPH_LEARNING_PHASES[graph] +def _default_learning_phase(): + if context.executing_eagerly(): + return 0 + else: + with name_scope(''): + return array_ops.placeholder_with_default( + False, shape=(), name='keras_learning_phase') + + @keras_export('keras.backend.set_learning_phase') def set_learning_phase(value): """Sets the learning phase to a fixed value. @@ -876,8 +855,7 @@ def track_tf_optimizer(tf_optimizer): """Tracks the given TF optimizer for initialization of its variables.""" if context.executing_eagerly(): return - graph = get_graph() - optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet()) + optimizers = _GRAPH_TF_OPTIMIZERS[None] optimizers.add(tf_optimizer) @@ -886,8 +864,6 @@ def track_variable(v): if context.executing_eagerly(): return graph = v.graph if hasattr(v, 'graph') else get_graph() - if graph not in _GRAPH_VARIABLES: - _GRAPH_VARIABLES[graph] = object_identity.ObjectIdentityWeakSet() _GRAPH_VARIABLES[graph].add(v) @@ -943,8 +919,8 @@ def unique_object_name(name, def _get_variables(graph=None): """Returns variables corresponding to the given graph for initialization.""" assert not context.executing_eagerly() - variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet()) - for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()): + variables = _GRAPH_VARIABLES[graph] + for opt in _GRAPH_TF_OPTIMIZERS[graph]: variables.update(opt.optimizer.variables()) return variables @@ -1174,8 +1150,6 @@ def freezable_variable(value, shape=None, name=None): x.get_value = get_value global _FREEZABLE_VARS - if graph not in _FREEZABLE_VARS: - _FREEZABLE_VARS[graph] = object_identity.ObjectIdentityWeakSet() _FREEZABLE_VARS[graph].add(x) return x @@ -6094,3 +6068,132 @@ def maybe_convert_to_ragged(is_ragged_input, output, nested_row_lengths): return output return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths) + + +class ContextValueCache(weakref.WeakKeyDictionary): + """Container that caches (possibly tensor) values based on the context. + + This class is similar to defaultdict, where values may be produced by the + default factory specified during initialization. This class also has a default + value for the key (when key is `None`) -- the key is set to the the current + graph or eager context. The default factories for key and value are only used + in `__getitem__` and `setdefault`. The `.get()` behavior remains the same. + + This object will return the value of the current graph or closest parent graph + if the current graph is a function. This is to reflect the fact that if a + tensor is created in eager/graph, child functions may capture that tensor. + + The default factory method may accept keyword arguments (unlike defaultdict, + which only accepts callables with 0 arguments). To pass keyword arguments to + `default_factory`, use the `setdefault` method instead of `__getitem__`. + + An example of how this class can be used in different contexts: + + ``` + cache = ContextValueCache(int) + + # Eager mode + cache[None] += 2 + cache[None] += 4 + assert cache[None] == 6 + + # Graph mode + with tf.Graph().as_default() as g: + cache[None] += 5 + cache[g] += 3 + assert cache[g] == 8 + ``` + + Example of a default factory with arguments: + + ``` + cache = ContextValueCache(lambda x: x + 1) + g = tf.get_default_graph() + + # Example with keyword argument. + value = cache.setdefault(key=g, kwargs={'x': 3}) + assert cache[g] == 4 + ``` + """ + + def __init__(self, default_factory): + self.default_factory = default_factory + weakref.WeakKeyDictionary.__init__(self) + + def _key(self): + if context.executing_eagerly(): + return _DUMMY_EAGER_GRAPH.key + else: + return ops.get_default_graph() + + def _get_parent_graph(self, graph): + """Returns the parent graph or dummy eager object.""" + # TODO(b/149317164): Currently FuncGraphs use ops.get_default_graph() as the + # outer graph. This results in outer_graph always being a Graph, + # even in eager mode (get_default_graph will create a new Graph if there + # isn't a default graph). Because of this bug, we have to specially set the + # key when eager execution is enabled. + parent_graph = graph.outer_graph + if (not isinstance(parent_graph, func_graph.FuncGraph) and + ops.executing_eagerly_outside_functions()): + return _DUMMY_EAGER_GRAPH.key + return parent_graph + + def _get_recursive(self, key): + """Gets the value at key or the closest parent graph.""" + value = self.get(key) + if value is not None: + return value + + # Since FuncGraphs are able to capture tensors and variables from their + # parent graphs, recursively search to see if there is a value stored for + # one of the parent graphs. + if isinstance(key, func_graph.FuncGraph): + return self._get_recursive(self._get_parent_graph(key)) + return None + + def __getitem__(self, key): + """Gets the value at key (or current context), or sets default value. + + Args: + key: May be `None` or `Graph`object. When `None`, the key is set to the + current context. + + Returns: + Either the cached or default value. + """ + if key is None: + key = self._key() + + value = self._get_recursive(key) + if value is None: + value = self[key] = self.default_factory() # pylint:disable=not-callable + return value + + def setdefault(self, key=None, default=None, kwargs=None): + """Sets the default value if key is not in dict, and returns the value.""" + if key is None: + key = self._key() + kwargs = kwargs or {} + + if default is None and key not in self: + default = self.default_factory(**kwargs) + return weakref.WeakKeyDictionary.setdefault(self, key, default) + +# This dictionary holds a mapping {graph: learning_phase}. In eager mode, a +# dummy object is used. +# A learning phase is a bool tensor used to run Keras models in +# either train mode (learning_phase == 1) or test mode (learning_phase == 0). +_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase) + +# This dictionary holds a mapping {graph: set_of_freezable_variables}. +# Each set tracks objects created via `freezable_variable` in the graph. +_FREEZABLE_VARS = ContextValueCache(object_identity.ObjectIdentityWeakSet) + +# This dictionary holds a mapping between a graph and variables to initialize +# in the graph. +_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet) + +# This dictionary holds a mapping between a graph and TF optimizers created in +# the graph. +_GRAPH_TF_OPTIMIZERS = ContextValueCache(object_identity.ObjectIdentityWeakSet) diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py index fe7d0ea4d67..5ab4f32f684 100644 --- a/tensorflow/python/keras/backend_test.py +++ b/tensorflow/python/keras/backend_test.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import gc + from absl.testing import parameterized import numpy as np import scipy.sparse @@ -24,6 +26,7 @@ import scipy.sparse from tensorflow.core.protobuf import config_pb2 from tensorflow.python import keras from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import config from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops @@ -2144,5 +2147,52 @@ class ControlOpsTests(test.TestCase): keras.backend.switch(keras.backend.equal(x, x), false_func, true_func) +class ContextValueCacheTest(test.TestCase): + + def test_cache(self): + cache = keras.backend.ContextValueCache(list) + graph1 = ops.Graph() + graph2 = ops.Graph() + + cache[graph1].append(1) + with graph1.as_default(): + cache[None].append(2) + + with graph2.as_default(): + cache[None].append(3) + cache[graph2].append(4) + + self.assertAllEqual(cache[graph1], [1, 2]) + self.assertAllEqual(cache[graph2], [3, 4]) + + with context.eager_mode(): + cache[None].append(5) + cache[None].append(6) + self.assertAllEqual(cache[None], [5, 6]) + + self.assertLen(cache, 3) + + del graph1 + gc.collect() + self.assertLen(cache, 2) + + def test_cache_in_parent_graph(self): + cache = keras.backend.ContextValueCache(int) + cache.setdefault(None, keras.backend.constant(5)) + + with ops.Graph().as_default() as g: + # g is not a child graph of the default test context, so the recursive + # lookup will create a new default value. + self.assertAllEqual(cache[g], 0) + + @def_function.function + def fn(): + # The function graph is a child of the default test context, so + # __getitem__ will return the previously saved value. + return cache[ops.get_default_graph()] + + self.assertEqual(self.evaluate(fn()), 5) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py index e30527a73ca..1b52b477170 100644 --- a/tensorflow/python/keras/layers/recurrent.py +++ b/tensorflow/python/keras/layers/recurrent.py @@ -1088,10 +1088,9 @@ class DropoutRNNCellMixin(object): # RNN could be created with `unroll=True`. In that case, the `cell.call()` # function will be invoked multiple times, and we want to ensure same mask # is used every time. - self._dropout_mask = None - self._recurrent_dropout_mask = None - self._eager_dropout_mask = None - self._eager_recurrent_dropout_mask = None + self._dropout_mask_cache = K.ContextValueCache(self._create_dropout_mask) + self._recurrent_dropout_mask_cache = K.ContextValueCache( + self._create_recurrent_dropout_mask) super(DropoutRNNCellMixin, self).__init__(*args, **kwargs) def reset_dropout_mask(self): @@ -1103,8 +1102,7 @@ class DropoutRNNCellMixin(object): be cached between batches. Otherwise it will introduce unreasonable bias against certain index of data within the batch. """ - self._dropout_mask = None - self._eager_dropout_mask = None + self._dropout_mask_cache.clear() def reset_recurrent_dropout_mask(self): """Reset the cached recurrent dropout masks if any. @@ -1115,8 +1113,21 @@ class DropoutRNNCellMixin(object): be cached between batches. Otherwise it will introduce unreasonable bias against certain index of data within the batch. """ - self._recurrent_dropout_mask = None - self._eager_recurrent_dropout_mask = None + self._recurrent_dropout_mask_cache.clear() + + def _create_dropout_mask(self, inputs, training, count=1): + return _generate_dropout_mask( + array_ops.ones_like(inputs), + self.dropout, + training=training, + count=count) + + def _create_recurrent_dropout_mask(self, inputs, training, count=1): + return _generate_dropout_mask( + array_ops.ones_like(inputs), + self.recurrent_dropout, + training=training, + count=count) def get_dropout_mask_for_cell(self, inputs, training, count=1): """Get the dropout mask for RNN cell's input. @@ -1136,23 +1147,8 @@ class DropoutRNNCellMixin(object): """ if self.dropout == 0: return None - if (not context.executing_eagerly() and self._dropout_mask is None - or context.executing_eagerly() and self._eager_dropout_mask is None): - # Generate new mask and cache it based on context. - dp_mask = _generate_dropout_mask( - array_ops.ones_like(inputs), - self.dropout, - training=training, - count=count) - if context.executing_eagerly(): - self._eager_dropout_mask = dp_mask - else: - self._dropout_mask = dp_mask - else: - # Reuse the existing mask. - dp_mask = (self._eager_dropout_mask - if context.executing_eagerly() else self._dropout_mask) - return dp_mask + init_kwargs = dict(inputs=inputs, training=training, count=count) + return self._dropout_mask_cache.setdefault(kwargs=init_kwargs) def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1): """Get the recurrent dropout mask for RNN cell. @@ -1172,25 +1168,8 @@ class DropoutRNNCellMixin(object): """ if self.recurrent_dropout == 0: return None - if (not context.executing_eagerly() and self._recurrent_dropout_mask is None - or context.executing_eagerly() - and self._eager_recurrent_dropout_mask is None): - # Generate new mask and cache it based on context. - rec_dp_mask = _generate_dropout_mask( - array_ops.ones_like(inputs), - self.recurrent_dropout, - training=training, - count=count) - if context.executing_eagerly(): - self._eager_recurrent_dropout_mask = rec_dp_mask - else: - self._recurrent_dropout_mask = rec_dp_mask - else: - # Reuse the existing mask. - rec_dp_mask = (self._eager_recurrent_dropout_mask - if context.executing_eagerly() - else self._recurrent_dropout_mask) - return rec_dp_mask + init_kwargs = dict(inputs=inputs, training=training, count=count) + return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs) @keras_export('keras.layers.SimpleRNNCell')