Unify logic for caching values depending on the graph or eager context.

The CachedValue class:
  - creates new cached tensors in different contexts
  - returns the value at the current context or closest parent graph

This class is used by:
  - all of the backend caches (learning phases, freezable variables, graph variables, and graph tf optimizers)
   - DropoutRNNCellMixin mask tensor

This resolves some bugs when saving RNN layers to SavedModel, but there are still issues with saving so I'll address the rest in a different CL.

PiperOrigin-RevId: 297662393
Change-Id: I5f93f4e8458e8b137a22787907e9f565414b3b8b
This commit is contained in:
Katherine Wu 2020-02-27 12:37:26 -08:00 committed by TensorFlower Gardener
parent b312ab962f
commit 260a840659
3 changed files with 229 additions and 97 deletions

View File

@ -98,15 +98,6 @@ _CURRENT_SCRATCH_GRAPH = None
# used by Keras. It can be set manually via `set_session(sess)`.
_SESSION = threading.local()
# This dictionary holds a mapping {graph: learning_phase}.
# A learning phase is a bool tensor used to run Keras models in
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
# This dictionary holds a mapping {graph: set_of_freezable_variables}.
# Each set tracks objects created via `freezable_variable` in the graph.
_FREEZABLE_VARS = weakref.WeakKeyDictionary()
# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
# We keep a separate reference to it to make sure it does not get removed from
@ -154,14 +145,6 @@ _MANUAL_VAR_INIT = False
# We assume our devices don't change henceforth.
_LOCAL_DEVICES = None
# This dictionary holds a mapping between a graph and variables to initialize
# in the graph.
_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
# This dictionary holds a mapping between a graph and TF optimizers created in
# the graph.
_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
# The below functions are kept accessible from backend for compatibility.
epsilon = backend_config.epsilon
floatx = backend_config.floatx
@ -293,11 +276,9 @@ def clear_session():
_SESSION.session = None
graph = get_graph()
with graph.as_default():
with name_scope(''):
phase = array_ops.placeholder_with_default(
False, shape=(), name='keras_learning_phase')
_GRAPH_LEARNING_PHASES = {}
_GRAPH_LEARNING_PHASES[graph] = phase
_GRAPH_LEARNING_PHASES.clear()
# Create the learning phase placeholder in graph using the default factory.
_GRAPH_LEARNING_PHASES.setdefault(graph)
_GRAPH_VARIABLES.pop(graph, None)
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
_FREEZABLE_VARS.pop(graph, None)
@ -336,24 +317,18 @@ def learning_phase():
# Don't enter an init_scope for the learning phase if eager execution
# is enabled but we're inside the Keras workspace graph.
learning_phase = symbolic_learning_phase()
_mark_func_graph_as_unsaveable(graph, learning_phase)
return learning_phase
with ops.init_scope():
# We always check & set the learning phase inside the init_scope,
# otherwise the wrong default_graph will be used to look up the learning
# phase inside of functions & defuns.
#
# This is because functions & defuns (both in graph & in eager mode)
# will always execute non-eagerly using a function-specific default
# subgraph.
if context.executing_eagerly():
if _DUMMY_EAGER_GRAPH.key not in _GRAPH_LEARNING_PHASES:
# Fallback to inference mode as default.
return 0
return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
learning_phase = symbolic_learning_phase()
_mark_func_graph_as_unsaveable(graph, learning_phase)
return learning_phase
else:
with ops.init_scope():
# We always check & set the learning phase inside the init_scope,
# otherwise the wrong default_graph will be used to look up the learning
# phase inside of functions & defuns.
#
# This is because functions & defuns (both in graph & in eager mode)
# will always execute non-eagerly using a function-specific default
# subgraph.
learning_phase = _GRAPH_LEARNING_PHASES[None]
_mark_func_graph_as_unsaveable(graph, learning_phase)
return learning_phase
def global_learning_phase_is_set():
@ -382,14 +357,18 @@ def _mark_func_graph_as_unsaveable(graph, learning_phase):
def symbolic_learning_phase():
graph = get_graph()
with graph.as_default():
if graph not in _GRAPH_LEARNING_PHASES:
with name_scope(''):
phase = array_ops.placeholder_with_default(
False, shape=(), name='keras_learning_phase')
_GRAPH_LEARNING_PHASES[graph] = phase
return _GRAPH_LEARNING_PHASES[graph]
def _default_learning_phase():
if context.executing_eagerly():
return 0
else:
with name_scope(''):
return array_ops.placeholder_with_default(
False, shape=(), name='keras_learning_phase')
@keras_export('keras.backend.set_learning_phase')
def set_learning_phase(value):
"""Sets the learning phase to a fixed value.
@ -876,8 +855,7 @@ def track_tf_optimizer(tf_optimizer):
"""Tracks the given TF optimizer for initialization of its variables."""
if context.executing_eagerly():
return
graph = get_graph()
optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
optimizers = _GRAPH_TF_OPTIMIZERS[None]
optimizers.add(tf_optimizer)
@ -886,8 +864,6 @@ def track_variable(v):
if context.executing_eagerly():
return
graph = v.graph if hasattr(v, 'graph') else get_graph()
if graph not in _GRAPH_VARIABLES:
_GRAPH_VARIABLES[graph] = object_identity.ObjectIdentityWeakSet()
_GRAPH_VARIABLES[graph].add(v)
@ -943,8 +919,8 @@ def unique_object_name(name,
def _get_variables(graph=None):
"""Returns variables corresponding to the given graph for initialization."""
assert not context.executing_eagerly()
variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
variables = _GRAPH_VARIABLES[graph]
for opt in _GRAPH_TF_OPTIMIZERS[graph]:
variables.update(opt.optimizer.variables())
return variables
@ -1174,8 +1150,6 @@ def freezable_variable(value, shape=None, name=None):
x.get_value = get_value
global _FREEZABLE_VARS
if graph not in _FREEZABLE_VARS:
_FREEZABLE_VARS[graph] = object_identity.ObjectIdentityWeakSet()
_FREEZABLE_VARS[graph].add(x)
return x
@ -6094,3 +6068,132 @@ def maybe_convert_to_ragged(is_ragged_input, output, nested_row_lengths):
return output
return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
class ContextValueCache(weakref.WeakKeyDictionary):
"""Container that caches (possibly tensor) values based on the context.
This class is similar to defaultdict, where values may be produced by the
default factory specified during initialization. This class also has a default
value for the key (when key is `None`) -- the key is set to the the current
graph or eager context. The default factories for key and value are only used
in `__getitem__` and `setdefault`. The `.get()` behavior remains the same.
This object will return the value of the current graph or closest parent graph
if the current graph is a function. This is to reflect the fact that if a
tensor is created in eager/graph, child functions may capture that tensor.
The default factory method may accept keyword arguments (unlike defaultdict,
which only accepts callables with 0 arguments). To pass keyword arguments to
`default_factory`, use the `setdefault` method instead of `__getitem__`.
An example of how this class can be used in different contexts:
```
cache = ContextValueCache(int)
# Eager mode
cache[None] += 2
cache[None] += 4
assert cache[None] == 6
# Graph mode
with tf.Graph().as_default() as g:
cache[None] += 5
cache[g] += 3
assert cache[g] == 8
```
Example of a default factory with arguments:
```
cache = ContextValueCache(lambda x: x + 1)
g = tf.get_default_graph()
# Example with keyword argument.
value = cache.setdefault(key=g, kwargs={'x': 3})
assert cache[g] == 4
```
"""
def __init__(self, default_factory):
self.default_factory = default_factory
weakref.WeakKeyDictionary.__init__(self)
def _key(self):
if context.executing_eagerly():
return _DUMMY_EAGER_GRAPH.key
else:
return ops.get_default_graph()
def _get_parent_graph(self, graph):
"""Returns the parent graph or dummy eager object."""
# TODO(b/149317164): Currently FuncGraphs use ops.get_default_graph() as the
# outer graph. This results in outer_graph always being a Graph,
# even in eager mode (get_default_graph will create a new Graph if there
# isn't a default graph). Because of this bug, we have to specially set the
# key when eager execution is enabled.
parent_graph = graph.outer_graph
if (not isinstance(parent_graph, func_graph.FuncGraph) and
ops.executing_eagerly_outside_functions()):
return _DUMMY_EAGER_GRAPH.key
return parent_graph
def _get_recursive(self, key):
"""Gets the value at key or the closest parent graph."""
value = self.get(key)
if value is not None:
return value
# Since FuncGraphs are able to capture tensors and variables from their
# parent graphs, recursively search to see if there is a value stored for
# one of the parent graphs.
if isinstance(key, func_graph.FuncGraph):
return self._get_recursive(self._get_parent_graph(key))
return None
def __getitem__(self, key):
"""Gets the value at key (or current context), or sets default value.
Args:
key: May be `None` or `Graph`object. When `None`, the key is set to the
current context.
Returns:
Either the cached or default value.
"""
if key is None:
key = self._key()
value = self._get_recursive(key)
if value is None:
value = self[key] = self.default_factory() # pylint:disable=not-callable
return value
def setdefault(self, key=None, default=None, kwargs=None):
"""Sets the default value if key is not in dict, and returns the value."""
if key is None:
key = self._key()
kwargs = kwargs or {}
if default is None and key not in self:
default = self.default_factory(**kwargs)
return weakref.WeakKeyDictionary.setdefault(self, key, default)
# This dictionary holds a mapping {graph: learning_phase}. In eager mode, a
# dummy object is used.
# A learning phase is a bool tensor used to run Keras models in
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
# This dictionary holds a mapping {graph: set_of_freezable_variables}.
# Each set tracks objects created via `freezable_variable` in the graph.
_FREEZABLE_VARS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
# This dictionary holds a mapping between a graph and variables to initialize
# in the graph.
_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
# This dictionary holds a mapping between a graph and TF optimizers created in
# the graph.
_GRAPH_TF_OPTIMIZERS = ContextValueCache(object_identity.ObjectIdentityWeakSet)

View File

@ -17,6 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gc
from absl.testing import parameterized
import numpy as np
import scipy.sparse
@ -24,6 +26,7 @@ import scipy.sparse
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
@ -2144,5 +2147,52 @@ class ControlOpsTests(test.TestCase):
keras.backend.switch(keras.backend.equal(x, x), false_func, true_func)
class ContextValueCacheTest(test.TestCase):
def test_cache(self):
cache = keras.backend.ContextValueCache(list)
graph1 = ops.Graph()
graph2 = ops.Graph()
cache[graph1].append(1)
with graph1.as_default():
cache[None].append(2)
with graph2.as_default():
cache[None].append(3)
cache[graph2].append(4)
self.assertAllEqual(cache[graph1], [1, 2])
self.assertAllEqual(cache[graph2], [3, 4])
with context.eager_mode():
cache[None].append(5)
cache[None].append(6)
self.assertAllEqual(cache[None], [5, 6])
self.assertLen(cache, 3)
del graph1
gc.collect()
self.assertLen(cache, 2)
def test_cache_in_parent_graph(self):
cache = keras.backend.ContextValueCache(int)
cache.setdefault(None, keras.backend.constant(5))
with ops.Graph().as_default() as g:
# g is not a child graph of the default test context, so the recursive
# lookup will create a new default value.
self.assertAllEqual(cache[g], 0)
@def_function.function
def fn():
# The function graph is a child of the default test context, so
# __getitem__ will return the previously saved value.
return cache[ops.get_default_graph()]
self.assertEqual(self.evaluate(fn()), 5)
if __name__ == '__main__':
test.main()

View File

@ -1088,10 +1088,9 @@ class DropoutRNNCellMixin(object):
# RNN could be created with `unroll=True`. In that case, the `cell.call()`
# function will be invoked multiple times, and we want to ensure same mask
# is used every time.
self._dropout_mask = None
self._recurrent_dropout_mask = None
self._eager_dropout_mask = None
self._eager_recurrent_dropout_mask = None
self._dropout_mask_cache = K.ContextValueCache(self._create_dropout_mask)
self._recurrent_dropout_mask_cache = K.ContextValueCache(
self._create_recurrent_dropout_mask)
super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
def reset_dropout_mask(self):
@ -1103,8 +1102,7 @@ class DropoutRNNCellMixin(object):
be cached between batches. Otherwise it will introduce unreasonable bias
against certain index of data within the batch.
"""
self._dropout_mask = None
self._eager_dropout_mask = None
self._dropout_mask_cache.clear()
def reset_recurrent_dropout_mask(self):
"""Reset the cached recurrent dropout masks if any.
@ -1115,8 +1113,21 @@ class DropoutRNNCellMixin(object):
be cached between batches. Otherwise it will introduce unreasonable bias
against certain index of data within the batch.
"""
self._recurrent_dropout_mask = None
self._eager_recurrent_dropout_mask = None
self._recurrent_dropout_mask_cache.clear()
def _create_dropout_mask(self, inputs, training, count=1):
return _generate_dropout_mask(
array_ops.ones_like(inputs),
self.dropout,
training=training,
count=count)
def _create_recurrent_dropout_mask(self, inputs, training, count=1):
return _generate_dropout_mask(
array_ops.ones_like(inputs),
self.recurrent_dropout,
training=training,
count=count)
def get_dropout_mask_for_cell(self, inputs, training, count=1):
"""Get the dropout mask for RNN cell's input.
@ -1136,23 +1147,8 @@ class DropoutRNNCellMixin(object):
"""
if self.dropout == 0:
return None
if (not context.executing_eagerly() and self._dropout_mask is None
or context.executing_eagerly() and self._eager_dropout_mask is None):
# Generate new mask and cache it based on context.
dp_mask = _generate_dropout_mask(
array_ops.ones_like(inputs),
self.dropout,
training=training,
count=count)
if context.executing_eagerly():
self._eager_dropout_mask = dp_mask
else:
self._dropout_mask = dp_mask
else:
# Reuse the existing mask.
dp_mask = (self._eager_dropout_mask
if context.executing_eagerly() else self._dropout_mask)
return dp_mask
init_kwargs = dict(inputs=inputs, training=training, count=count)
return self._dropout_mask_cache.setdefault(kwargs=init_kwargs)
def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
"""Get the recurrent dropout mask for RNN cell.
@ -1172,25 +1168,8 @@ class DropoutRNNCellMixin(object):
"""
if self.recurrent_dropout == 0:
return None
if (not context.executing_eagerly() and self._recurrent_dropout_mask is None
or context.executing_eagerly()
and self._eager_recurrent_dropout_mask is None):
# Generate new mask and cache it based on context.
rec_dp_mask = _generate_dropout_mask(
array_ops.ones_like(inputs),
self.recurrent_dropout,
training=training,
count=count)
if context.executing_eagerly():
self._eager_recurrent_dropout_mask = rec_dp_mask
else:
self._recurrent_dropout_mask = rec_dp_mask
else:
# Reuse the existing mask.
rec_dp_mask = (self._eager_recurrent_dropout_mask
if context.executing_eagerly()
else self._recurrent_dropout_mask)
return rec_dp_mask
init_kwargs = dict(inputs=inputs, training=training, count=count)
return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)
@keras_export('keras.layers.SimpleRNNCell')