Unify logic for caching values depending on the graph or eager context.
The CachedValue class: - creates new cached tensors in different contexts - returns the value at the current context or closest parent graph This class is used by: - all of the backend caches (learning phases, freezable variables, graph variables, and graph tf optimizers) - DropoutRNNCellMixin mask tensor This resolves some bugs when saving RNN layers to SavedModel, but there are still issues with saving so I'll address the rest in a different CL. PiperOrigin-RevId: 297662393 Change-Id: I5f93f4e8458e8b137a22787907e9f565414b3b8b
This commit is contained in:
parent
b312ab962f
commit
260a840659
@ -98,15 +98,6 @@ _CURRENT_SCRATCH_GRAPH = None
|
||||
# used by Keras. It can be set manually via `set_session(sess)`.
|
||||
_SESSION = threading.local()
|
||||
|
||||
# This dictionary holds a mapping {graph: learning_phase}.
|
||||
# A learning phase is a bool tensor used to run Keras models in
|
||||
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
|
||||
_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
|
||||
|
||||
# This dictionary holds a mapping {graph: set_of_freezable_variables}.
|
||||
# Each set tracks objects created via `freezable_variable` in the graph.
|
||||
_FREEZABLE_VARS = weakref.WeakKeyDictionary()
|
||||
|
||||
|
||||
# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
|
||||
# We keep a separate reference to it to make sure it does not get removed from
|
||||
@ -154,14 +145,6 @@ _MANUAL_VAR_INIT = False
|
||||
# We assume our devices don't change henceforth.
|
||||
_LOCAL_DEVICES = None
|
||||
|
||||
# This dictionary holds a mapping between a graph and variables to initialize
|
||||
# in the graph.
|
||||
_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
|
||||
|
||||
# This dictionary holds a mapping between a graph and TF optimizers created in
|
||||
# the graph.
|
||||
_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
|
||||
|
||||
# The below functions are kept accessible from backend for compatibility.
|
||||
epsilon = backend_config.epsilon
|
||||
floatx = backend_config.floatx
|
||||
@ -293,11 +276,9 @@ def clear_session():
|
||||
_SESSION.session = None
|
||||
graph = get_graph()
|
||||
with graph.as_default():
|
||||
with name_scope(''):
|
||||
phase = array_ops.placeholder_with_default(
|
||||
False, shape=(), name='keras_learning_phase')
|
||||
_GRAPH_LEARNING_PHASES = {}
|
||||
_GRAPH_LEARNING_PHASES[graph] = phase
|
||||
_GRAPH_LEARNING_PHASES.clear()
|
||||
# Create the learning phase placeholder in graph using the default factory.
|
||||
_GRAPH_LEARNING_PHASES.setdefault(graph)
|
||||
_GRAPH_VARIABLES.pop(graph, None)
|
||||
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
|
||||
_FREEZABLE_VARS.pop(graph, None)
|
||||
@ -336,24 +317,18 @@ def learning_phase():
|
||||
# Don't enter an init_scope for the learning phase if eager execution
|
||||
# is enabled but we're inside the Keras workspace graph.
|
||||
learning_phase = symbolic_learning_phase()
|
||||
_mark_func_graph_as_unsaveable(graph, learning_phase)
|
||||
return learning_phase
|
||||
with ops.init_scope():
|
||||
# We always check & set the learning phase inside the init_scope,
|
||||
# otherwise the wrong default_graph will be used to look up the learning
|
||||
# phase inside of functions & defuns.
|
||||
#
|
||||
# This is because functions & defuns (both in graph & in eager mode)
|
||||
# will always execute non-eagerly using a function-specific default
|
||||
# subgraph.
|
||||
if context.executing_eagerly():
|
||||
if _DUMMY_EAGER_GRAPH.key not in _GRAPH_LEARNING_PHASES:
|
||||
# Fallback to inference mode as default.
|
||||
return 0
|
||||
return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
|
||||
learning_phase = symbolic_learning_phase()
|
||||
_mark_func_graph_as_unsaveable(graph, learning_phase)
|
||||
return learning_phase
|
||||
else:
|
||||
with ops.init_scope():
|
||||
# We always check & set the learning phase inside the init_scope,
|
||||
# otherwise the wrong default_graph will be used to look up the learning
|
||||
# phase inside of functions & defuns.
|
||||
#
|
||||
# This is because functions & defuns (both in graph & in eager mode)
|
||||
# will always execute non-eagerly using a function-specific default
|
||||
# subgraph.
|
||||
learning_phase = _GRAPH_LEARNING_PHASES[None]
|
||||
_mark_func_graph_as_unsaveable(graph, learning_phase)
|
||||
return learning_phase
|
||||
|
||||
|
||||
def global_learning_phase_is_set():
|
||||
@ -382,14 +357,18 @@ def _mark_func_graph_as_unsaveable(graph, learning_phase):
|
||||
def symbolic_learning_phase():
|
||||
graph = get_graph()
|
||||
with graph.as_default():
|
||||
if graph not in _GRAPH_LEARNING_PHASES:
|
||||
with name_scope(''):
|
||||
phase = array_ops.placeholder_with_default(
|
||||
False, shape=(), name='keras_learning_phase')
|
||||
_GRAPH_LEARNING_PHASES[graph] = phase
|
||||
return _GRAPH_LEARNING_PHASES[graph]
|
||||
|
||||
|
||||
def _default_learning_phase():
|
||||
if context.executing_eagerly():
|
||||
return 0
|
||||
else:
|
||||
with name_scope(''):
|
||||
return array_ops.placeholder_with_default(
|
||||
False, shape=(), name='keras_learning_phase')
|
||||
|
||||
|
||||
@keras_export('keras.backend.set_learning_phase')
|
||||
def set_learning_phase(value):
|
||||
"""Sets the learning phase to a fixed value.
|
||||
@ -876,8 +855,7 @@ def track_tf_optimizer(tf_optimizer):
|
||||
"""Tracks the given TF optimizer for initialization of its variables."""
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
graph = get_graph()
|
||||
optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
|
||||
optimizers = _GRAPH_TF_OPTIMIZERS[None]
|
||||
optimizers.add(tf_optimizer)
|
||||
|
||||
|
||||
@ -886,8 +864,6 @@ def track_variable(v):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
graph = v.graph if hasattr(v, 'graph') else get_graph()
|
||||
if graph not in _GRAPH_VARIABLES:
|
||||
_GRAPH_VARIABLES[graph] = object_identity.ObjectIdentityWeakSet()
|
||||
_GRAPH_VARIABLES[graph].add(v)
|
||||
|
||||
|
||||
@ -943,8 +919,8 @@ def unique_object_name(name,
|
||||
def _get_variables(graph=None):
|
||||
"""Returns variables corresponding to the given graph for initialization."""
|
||||
assert not context.executing_eagerly()
|
||||
variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
|
||||
for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
|
||||
variables = _GRAPH_VARIABLES[graph]
|
||||
for opt in _GRAPH_TF_OPTIMIZERS[graph]:
|
||||
variables.update(opt.optimizer.variables())
|
||||
return variables
|
||||
|
||||
@ -1174,8 +1150,6 @@ def freezable_variable(value, shape=None, name=None):
|
||||
x.get_value = get_value
|
||||
|
||||
global _FREEZABLE_VARS
|
||||
if graph not in _FREEZABLE_VARS:
|
||||
_FREEZABLE_VARS[graph] = object_identity.ObjectIdentityWeakSet()
|
||||
_FREEZABLE_VARS[graph].add(x)
|
||||
return x
|
||||
|
||||
@ -6094,3 +6068,132 @@ def maybe_convert_to_ragged(is_ragged_input, output, nested_row_lengths):
|
||||
return output
|
||||
|
||||
return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
|
||||
|
||||
|
||||
class ContextValueCache(weakref.WeakKeyDictionary):
|
||||
"""Container that caches (possibly tensor) values based on the context.
|
||||
|
||||
This class is similar to defaultdict, where values may be produced by the
|
||||
default factory specified during initialization. This class also has a default
|
||||
value for the key (when key is `None`) -- the key is set to the the current
|
||||
graph or eager context. The default factories for key and value are only used
|
||||
in `__getitem__` and `setdefault`. The `.get()` behavior remains the same.
|
||||
|
||||
This object will return the value of the current graph or closest parent graph
|
||||
if the current graph is a function. This is to reflect the fact that if a
|
||||
tensor is created in eager/graph, child functions may capture that tensor.
|
||||
|
||||
The default factory method may accept keyword arguments (unlike defaultdict,
|
||||
which only accepts callables with 0 arguments). To pass keyword arguments to
|
||||
`default_factory`, use the `setdefault` method instead of `__getitem__`.
|
||||
|
||||
An example of how this class can be used in different contexts:
|
||||
|
||||
```
|
||||
cache = ContextValueCache(int)
|
||||
|
||||
# Eager mode
|
||||
cache[None] += 2
|
||||
cache[None] += 4
|
||||
assert cache[None] == 6
|
||||
|
||||
# Graph mode
|
||||
with tf.Graph().as_default() as g:
|
||||
cache[None] += 5
|
||||
cache[g] += 3
|
||||
assert cache[g] == 8
|
||||
```
|
||||
|
||||
Example of a default factory with arguments:
|
||||
|
||||
```
|
||||
cache = ContextValueCache(lambda x: x + 1)
|
||||
g = tf.get_default_graph()
|
||||
|
||||
# Example with keyword argument.
|
||||
value = cache.setdefault(key=g, kwargs={'x': 3})
|
||||
assert cache[g] == 4
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, default_factory):
|
||||
self.default_factory = default_factory
|
||||
weakref.WeakKeyDictionary.__init__(self)
|
||||
|
||||
def _key(self):
|
||||
if context.executing_eagerly():
|
||||
return _DUMMY_EAGER_GRAPH.key
|
||||
else:
|
||||
return ops.get_default_graph()
|
||||
|
||||
def _get_parent_graph(self, graph):
|
||||
"""Returns the parent graph or dummy eager object."""
|
||||
# TODO(b/149317164): Currently FuncGraphs use ops.get_default_graph() as the
|
||||
# outer graph. This results in outer_graph always being a Graph,
|
||||
# even in eager mode (get_default_graph will create a new Graph if there
|
||||
# isn't a default graph). Because of this bug, we have to specially set the
|
||||
# key when eager execution is enabled.
|
||||
parent_graph = graph.outer_graph
|
||||
if (not isinstance(parent_graph, func_graph.FuncGraph) and
|
||||
ops.executing_eagerly_outside_functions()):
|
||||
return _DUMMY_EAGER_GRAPH.key
|
||||
return parent_graph
|
||||
|
||||
def _get_recursive(self, key):
|
||||
"""Gets the value at key or the closest parent graph."""
|
||||
value = self.get(key)
|
||||
if value is not None:
|
||||
return value
|
||||
|
||||
# Since FuncGraphs are able to capture tensors and variables from their
|
||||
# parent graphs, recursively search to see if there is a value stored for
|
||||
# one of the parent graphs.
|
||||
if isinstance(key, func_graph.FuncGraph):
|
||||
return self._get_recursive(self._get_parent_graph(key))
|
||||
return None
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Gets the value at key (or current context), or sets default value.
|
||||
|
||||
Args:
|
||||
key: May be `None` or `Graph`object. When `None`, the key is set to the
|
||||
current context.
|
||||
|
||||
Returns:
|
||||
Either the cached or default value.
|
||||
"""
|
||||
if key is None:
|
||||
key = self._key()
|
||||
|
||||
value = self._get_recursive(key)
|
||||
if value is None:
|
||||
value = self[key] = self.default_factory() # pylint:disable=not-callable
|
||||
return value
|
||||
|
||||
def setdefault(self, key=None, default=None, kwargs=None):
|
||||
"""Sets the default value if key is not in dict, and returns the value."""
|
||||
if key is None:
|
||||
key = self._key()
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if default is None and key not in self:
|
||||
default = self.default_factory(**kwargs)
|
||||
return weakref.WeakKeyDictionary.setdefault(self, key, default)
|
||||
|
||||
# This dictionary holds a mapping {graph: learning_phase}. In eager mode, a
|
||||
# dummy object is used.
|
||||
# A learning phase is a bool tensor used to run Keras models in
|
||||
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
|
||||
_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
|
||||
|
||||
# This dictionary holds a mapping {graph: set_of_freezable_variables}.
|
||||
# Each set tracks objects created via `freezable_variable` in the graph.
|
||||
_FREEZABLE_VARS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
||||
|
||||
# This dictionary holds a mapping between a graph and variables to initialize
|
||||
# in the graph.
|
||||
_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
||||
|
||||
# This dictionary holds a mapping between a graph and TF optimizers created in
|
||||
# the graph.
|
||||
_GRAPH_TF_OPTIMIZERS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gc
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import scipy.sparse
|
||||
@ -24,6 +26,7 @@ import scipy.sparse
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python import keras
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import config
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
@ -2144,5 +2147,52 @@ class ControlOpsTests(test.TestCase):
|
||||
keras.backend.switch(keras.backend.equal(x, x), false_func, true_func)
|
||||
|
||||
|
||||
class ContextValueCacheTest(test.TestCase):
|
||||
|
||||
def test_cache(self):
|
||||
cache = keras.backend.ContextValueCache(list)
|
||||
graph1 = ops.Graph()
|
||||
graph2 = ops.Graph()
|
||||
|
||||
cache[graph1].append(1)
|
||||
with graph1.as_default():
|
||||
cache[None].append(2)
|
||||
|
||||
with graph2.as_default():
|
||||
cache[None].append(3)
|
||||
cache[graph2].append(4)
|
||||
|
||||
self.assertAllEqual(cache[graph1], [1, 2])
|
||||
self.assertAllEqual(cache[graph2], [3, 4])
|
||||
|
||||
with context.eager_mode():
|
||||
cache[None].append(5)
|
||||
cache[None].append(6)
|
||||
self.assertAllEqual(cache[None], [5, 6])
|
||||
|
||||
self.assertLen(cache, 3)
|
||||
|
||||
del graph1
|
||||
gc.collect()
|
||||
self.assertLen(cache, 2)
|
||||
|
||||
def test_cache_in_parent_graph(self):
|
||||
cache = keras.backend.ContextValueCache(int)
|
||||
cache.setdefault(None, keras.backend.constant(5))
|
||||
|
||||
with ops.Graph().as_default() as g:
|
||||
# g is not a child graph of the default test context, so the recursive
|
||||
# lookup will create a new default value.
|
||||
self.assertAllEqual(cache[g], 0)
|
||||
|
||||
@def_function.function
|
||||
def fn():
|
||||
# The function graph is a child of the default test context, so
|
||||
# __getitem__ will return the previously saved value.
|
||||
return cache[ops.get_default_graph()]
|
||||
|
||||
self.assertEqual(self.evaluate(fn()), 5)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -1088,10 +1088,9 @@ class DropoutRNNCellMixin(object):
|
||||
# RNN could be created with `unroll=True`. In that case, the `cell.call()`
|
||||
# function will be invoked multiple times, and we want to ensure same mask
|
||||
# is used every time.
|
||||
self._dropout_mask = None
|
||||
self._recurrent_dropout_mask = None
|
||||
self._eager_dropout_mask = None
|
||||
self._eager_recurrent_dropout_mask = None
|
||||
self._dropout_mask_cache = K.ContextValueCache(self._create_dropout_mask)
|
||||
self._recurrent_dropout_mask_cache = K.ContextValueCache(
|
||||
self._create_recurrent_dropout_mask)
|
||||
super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
|
||||
|
||||
def reset_dropout_mask(self):
|
||||
@ -1103,8 +1102,7 @@ class DropoutRNNCellMixin(object):
|
||||
be cached between batches. Otherwise it will introduce unreasonable bias
|
||||
against certain index of data within the batch.
|
||||
"""
|
||||
self._dropout_mask = None
|
||||
self._eager_dropout_mask = None
|
||||
self._dropout_mask_cache.clear()
|
||||
|
||||
def reset_recurrent_dropout_mask(self):
|
||||
"""Reset the cached recurrent dropout masks if any.
|
||||
@ -1115,8 +1113,21 @@ class DropoutRNNCellMixin(object):
|
||||
be cached between batches. Otherwise it will introduce unreasonable bias
|
||||
against certain index of data within the batch.
|
||||
"""
|
||||
self._recurrent_dropout_mask = None
|
||||
self._eager_recurrent_dropout_mask = None
|
||||
self._recurrent_dropout_mask_cache.clear()
|
||||
|
||||
def _create_dropout_mask(self, inputs, training, count=1):
|
||||
return _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.dropout,
|
||||
training=training,
|
||||
count=count)
|
||||
|
||||
def _create_recurrent_dropout_mask(self, inputs, training, count=1):
|
||||
return _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.recurrent_dropout,
|
||||
training=training,
|
||||
count=count)
|
||||
|
||||
def get_dropout_mask_for_cell(self, inputs, training, count=1):
|
||||
"""Get the dropout mask for RNN cell's input.
|
||||
@ -1136,23 +1147,8 @@ class DropoutRNNCellMixin(object):
|
||||
"""
|
||||
if self.dropout == 0:
|
||||
return None
|
||||
if (not context.executing_eagerly() and self._dropout_mask is None
|
||||
or context.executing_eagerly() and self._eager_dropout_mask is None):
|
||||
# Generate new mask and cache it based on context.
|
||||
dp_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.dropout,
|
||||
training=training,
|
||||
count=count)
|
||||
if context.executing_eagerly():
|
||||
self._eager_dropout_mask = dp_mask
|
||||
else:
|
||||
self._dropout_mask = dp_mask
|
||||
else:
|
||||
# Reuse the existing mask.
|
||||
dp_mask = (self._eager_dropout_mask
|
||||
if context.executing_eagerly() else self._dropout_mask)
|
||||
return dp_mask
|
||||
init_kwargs = dict(inputs=inputs, training=training, count=count)
|
||||
return self._dropout_mask_cache.setdefault(kwargs=init_kwargs)
|
||||
|
||||
def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
|
||||
"""Get the recurrent dropout mask for RNN cell.
|
||||
@ -1172,25 +1168,8 @@ class DropoutRNNCellMixin(object):
|
||||
"""
|
||||
if self.recurrent_dropout == 0:
|
||||
return None
|
||||
if (not context.executing_eagerly() and self._recurrent_dropout_mask is None
|
||||
or context.executing_eagerly()
|
||||
and self._eager_recurrent_dropout_mask is None):
|
||||
# Generate new mask and cache it based on context.
|
||||
rec_dp_mask = _generate_dropout_mask(
|
||||
array_ops.ones_like(inputs),
|
||||
self.recurrent_dropout,
|
||||
training=training,
|
||||
count=count)
|
||||
if context.executing_eagerly():
|
||||
self._eager_recurrent_dropout_mask = rec_dp_mask
|
||||
else:
|
||||
self._recurrent_dropout_mask = rec_dp_mask
|
||||
else:
|
||||
# Reuse the existing mask.
|
||||
rec_dp_mask = (self._eager_recurrent_dropout_mask
|
||||
if context.executing_eagerly()
|
||||
else self._recurrent_dropout_mask)
|
||||
return rec_dp_mask
|
||||
init_kwargs = dict(inputs=inputs, training=training, count=count)
|
||||
return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)
|
||||
|
||||
|
||||
@keras_export('keras.layers.SimpleRNNCell')
|
||||
|
Loading…
x
Reference in New Issue
Block a user