Unify logic for caching values depending on the graph or eager context.
The CachedValue class: - creates new cached tensors in different contexts - returns the value at the current context or closest parent graph This class is used by: - all of the backend caches (learning phases, freezable variables, graph variables, and graph tf optimizers) - DropoutRNNCellMixin mask tensor This resolves some bugs when saving RNN layers to SavedModel, but there are still issues with saving so I'll address the rest in a different CL. PiperOrigin-RevId: 297662393 Change-Id: I5f93f4e8458e8b137a22787907e9f565414b3b8b
This commit is contained in:
parent
b312ab962f
commit
260a840659
@ -98,15 +98,6 @@ _CURRENT_SCRATCH_GRAPH = None
|
|||||||
# used by Keras. It can be set manually via `set_session(sess)`.
|
# used by Keras. It can be set manually via `set_session(sess)`.
|
||||||
_SESSION = threading.local()
|
_SESSION = threading.local()
|
||||||
|
|
||||||
# This dictionary holds a mapping {graph: learning_phase}.
|
|
||||||
# A learning phase is a bool tensor used to run Keras models in
|
|
||||||
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
|
|
||||||
_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
|
|
||||||
|
|
||||||
# This dictionary holds a mapping {graph: set_of_freezable_variables}.
|
|
||||||
# Each set tracks objects created via `freezable_variable` in the graph.
|
|
||||||
_FREEZABLE_VARS = weakref.WeakKeyDictionary()
|
|
||||||
|
|
||||||
|
|
||||||
# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
|
# _DUMMY_EAGER_GRAPH.key is used as a key in _GRAPH_LEARNING_PHASES.
|
||||||
# We keep a separate reference to it to make sure it does not get removed from
|
# We keep a separate reference to it to make sure it does not get removed from
|
||||||
@ -154,14 +145,6 @@ _MANUAL_VAR_INIT = False
|
|||||||
# We assume our devices don't change henceforth.
|
# We assume our devices don't change henceforth.
|
||||||
_LOCAL_DEVICES = None
|
_LOCAL_DEVICES = None
|
||||||
|
|
||||||
# This dictionary holds a mapping between a graph and variables to initialize
|
|
||||||
# in the graph.
|
|
||||||
_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
|
|
||||||
|
|
||||||
# This dictionary holds a mapping between a graph and TF optimizers created in
|
|
||||||
# the graph.
|
|
||||||
_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
|
|
||||||
|
|
||||||
# The below functions are kept accessible from backend for compatibility.
|
# The below functions are kept accessible from backend for compatibility.
|
||||||
epsilon = backend_config.epsilon
|
epsilon = backend_config.epsilon
|
||||||
floatx = backend_config.floatx
|
floatx = backend_config.floatx
|
||||||
@ -293,11 +276,9 @@ def clear_session():
|
|||||||
_SESSION.session = None
|
_SESSION.session = None
|
||||||
graph = get_graph()
|
graph = get_graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
with name_scope(''):
|
_GRAPH_LEARNING_PHASES.clear()
|
||||||
phase = array_ops.placeholder_with_default(
|
# Create the learning phase placeholder in graph using the default factory.
|
||||||
False, shape=(), name='keras_learning_phase')
|
_GRAPH_LEARNING_PHASES.setdefault(graph)
|
||||||
_GRAPH_LEARNING_PHASES = {}
|
|
||||||
_GRAPH_LEARNING_PHASES[graph] = phase
|
|
||||||
_GRAPH_VARIABLES.pop(graph, None)
|
_GRAPH_VARIABLES.pop(graph, None)
|
||||||
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
|
_GRAPH_TF_OPTIMIZERS.pop(graph, None)
|
||||||
_FREEZABLE_VARS.pop(graph, None)
|
_FREEZABLE_VARS.pop(graph, None)
|
||||||
@ -336,24 +317,18 @@ def learning_phase():
|
|||||||
# Don't enter an init_scope for the learning phase if eager execution
|
# Don't enter an init_scope for the learning phase if eager execution
|
||||||
# is enabled but we're inside the Keras workspace graph.
|
# is enabled but we're inside the Keras workspace graph.
|
||||||
learning_phase = symbolic_learning_phase()
|
learning_phase = symbolic_learning_phase()
|
||||||
_mark_func_graph_as_unsaveable(graph, learning_phase)
|
else:
|
||||||
return learning_phase
|
with ops.init_scope():
|
||||||
with ops.init_scope():
|
# We always check & set the learning phase inside the init_scope,
|
||||||
# We always check & set the learning phase inside the init_scope,
|
# otherwise the wrong default_graph will be used to look up the learning
|
||||||
# otherwise the wrong default_graph will be used to look up the learning
|
# phase inside of functions & defuns.
|
||||||
# phase inside of functions & defuns.
|
#
|
||||||
#
|
# This is because functions & defuns (both in graph & in eager mode)
|
||||||
# This is because functions & defuns (both in graph & in eager mode)
|
# will always execute non-eagerly using a function-specific default
|
||||||
# will always execute non-eagerly using a function-specific default
|
# subgraph.
|
||||||
# subgraph.
|
learning_phase = _GRAPH_LEARNING_PHASES[None]
|
||||||
if context.executing_eagerly():
|
_mark_func_graph_as_unsaveable(graph, learning_phase)
|
||||||
if _DUMMY_EAGER_GRAPH.key not in _GRAPH_LEARNING_PHASES:
|
return learning_phase
|
||||||
# Fallback to inference mode as default.
|
|
||||||
return 0
|
|
||||||
return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH.key]
|
|
||||||
learning_phase = symbolic_learning_phase()
|
|
||||||
_mark_func_graph_as_unsaveable(graph, learning_phase)
|
|
||||||
return learning_phase
|
|
||||||
|
|
||||||
|
|
||||||
def global_learning_phase_is_set():
|
def global_learning_phase_is_set():
|
||||||
@ -382,14 +357,18 @@ def _mark_func_graph_as_unsaveable(graph, learning_phase):
|
|||||||
def symbolic_learning_phase():
|
def symbolic_learning_phase():
|
||||||
graph = get_graph()
|
graph = get_graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
if graph not in _GRAPH_LEARNING_PHASES:
|
|
||||||
with name_scope(''):
|
|
||||||
phase = array_ops.placeholder_with_default(
|
|
||||||
False, shape=(), name='keras_learning_phase')
|
|
||||||
_GRAPH_LEARNING_PHASES[graph] = phase
|
|
||||||
return _GRAPH_LEARNING_PHASES[graph]
|
return _GRAPH_LEARNING_PHASES[graph]
|
||||||
|
|
||||||
|
|
||||||
|
def _default_learning_phase():
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
with name_scope(''):
|
||||||
|
return array_ops.placeholder_with_default(
|
||||||
|
False, shape=(), name='keras_learning_phase')
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.backend.set_learning_phase')
|
@keras_export('keras.backend.set_learning_phase')
|
||||||
def set_learning_phase(value):
|
def set_learning_phase(value):
|
||||||
"""Sets the learning phase to a fixed value.
|
"""Sets the learning phase to a fixed value.
|
||||||
@ -876,8 +855,7 @@ def track_tf_optimizer(tf_optimizer):
|
|||||||
"""Tracks the given TF optimizer for initialization of its variables."""
|
"""Tracks the given TF optimizer for initialization of its variables."""
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return
|
return
|
||||||
graph = get_graph()
|
optimizers = _GRAPH_TF_OPTIMIZERS[None]
|
||||||
optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
|
|
||||||
optimizers.add(tf_optimizer)
|
optimizers.add(tf_optimizer)
|
||||||
|
|
||||||
|
|
||||||
@ -886,8 +864,6 @@ def track_variable(v):
|
|||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return
|
return
|
||||||
graph = v.graph if hasattr(v, 'graph') else get_graph()
|
graph = v.graph if hasattr(v, 'graph') else get_graph()
|
||||||
if graph not in _GRAPH_VARIABLES:
|
|
||||||
_GRAPH_VARIABLES[graph] = object_identity.ObjectIdentityWeakSet()
|
|
||||||
_GRAPH_VARIABLES[graph].add(v)
|
_GRAPH_VARIABLES[graph].add(v)
|
||||||
|
|
||||||
|
|
||||||
@ -943,8 +919,8 @@ def unique_object_name(name,
|
|||||||
def _get_variables(graph=None):
|
def _get_variables(graph=None):
|
||||||
"""Returns variables corresponding to the given graph for initialization."""
|
"""Returns variables corresponding to the given graph for initialization."""
|
||||||
assert not context.executing_eagerly()
|
assert not context.executing_eagerly()
|
||||||
variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
|
variables = _GRAPH_VARIABLES[graph]
|
||||||
for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
|
for opt in _GRAPH_TF_OPTIMIZERS[graph]:
|
||||||
variables.update(opt.optimizer.variables())
|
variables.update(opt.optimizer.variables())
|
||||||
return variables
|
return variables
|
||||||
|
|
||||||
@ -1174,8 +1150,6 @@ def freezable_variable(value, shape=None, name=None):
|
|||||||
x.get_value = get_value
|
x.get_value = get_value
|
||||||
|
|
||||||
global _FREEZABLE_VARS
|
global _FREEZABLE_VARS
|
||||||
if graph not in _FREEZABLE_VARS:
|
|
||||||
_FREEZABLE_VARS[graph] = object_identity.ObjectIdentityWeakSet()
|
|
||||||
_FREEZABLE_VARS[graph].add(x)
|
_FREEZABLE_VARS[graph].add(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -6094,3 +6068,132 @@ def maybe_convert_to_ragged(is_ragged_input, output, nested_row_lengths):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
|
return ragged_tensor.RaggedTensor.from_tensor(output, nested_row_lengths)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextValueCache(weakref.WeakKeyDictionary):
|
||||||
|
"""Container that caches (possibly tensor) values based on the context.
|
||||||
|
|
||||||
|
This class is similar to defaultdict, where values may be produced by the
|
||||||
|
default factory specified during initialization. This class also has a default
|
||||||
|
value for the key (when key is `None`) -- the key is set to the the current
|
||||||
|
graph or eager context. The default factories for key and value are only used
|
||||||
|
in `__getitem__` and `setdefault`. The `.get()` behavior remains the same.
|
||||||
|
|
||||||
|
This object will return the value of the current graph or closest parent graph
|
||||||
|
if the current graph is a function. This is to reflect the fact that if a
|
||||||
|
tensor is created in eager/graph, child functions may capture that tensor.
|
||||||
|
|
||||||
|
The default factory method may accept keyword arguments (unlike defaultdict,
|
||||||
|
which only accepts callables with 0 arguments). To pass keyword arguments to
|
||||||
|
`default_factory`, use the `setdefault` method instead of `__getitem__`.
|
||||||
|
|
||||||
|
An example of how this class can be used in different contexts:
|
||||||
|
|
||||||
|
```
|
||||||
|
cache = ContextValueCache(int)
|
||||||
|
|
||||||
|
# Eager mode
|
||||||
|
cache[None] += 2
|
||||||
|
cache[None] += 4
|
||||||
|
assert cache[None] == 6
|
||||||
|
|
||||||
|
# Graph mode
|
||||||
|
with tf.Graph().as_default() as g:
|
||||||
|
cache[None] += 5
|
||||||
|
cache[g] += 3
|
||||||
|
assert cache[g] == 8
|
||||||
|
```
|
||||||
|
|
||||||
|
Example of a default factory with arguments:
|
||||||
|
|
||||||
|
```
|
||||||
|
cache = ContextValueCache(lambda x: x + 1)
|
||||||
|
g = tf.get_default_graph()
|
||||||
|
|
||||||
|
# Example with keyword argument.
|
||||||
|
value = cache.setdefault(key=g, kwargs={'x': 3})
|
||||||
|
assert cache[g] == 4
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, default_factory):
|
||||||
|
self.default_factory = default_factory
|
||||||
|
weakref.WeakKeyDictionary.__init__(self)
|
||||||
|
|
||||||
|
def _key(self):
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return _DUMMY_EAGER_GRAPH.key
|
||||||
|
else:
|
||||||
|
return ops.get_default_graph()
|
||||||
|
|
||||||
|
def _get_parent_graph(self, graph):
|
||||||
|
"""Returns the parent graph or dummy eager object."""
|
||||||
|
# TODO(b/149317164): Currently FuncGraphs use ops.get_default_graph() as the
|
||||||
|
# outer graph. This results in outer_graph always being a Graph,
|
||||||
|
# even in eager mode (get_default_graph will create a new Graph if there
|
||||||
|
# isn't a default graph). Because of this bug, we have to specially set the
|
||||||
|
# key when eager execution is enabled.
|
||||||
|
parent_graph = graph.outer_graph
|
||||||
|
if (not isinstance(parent_graph, func_graph.FuncGraph) and
|
||||||
|
ops.executing_eagerly_outside_functions()):
|
||||||
|
return _DUMMY_EAGER_GRAPH.key
|
||||||
|
return parent_graph
|
||||||
|
|
||||||
|
def _get_recursive(self, key):
|
||||||
|
"""Gets the value at key or the closest parent graph."""
|
||||||
|
value = self.get(key)
|
||||||
|
if value is not None:
|
||||||
|
return value
|
||||||
|
|
||||||
|
# Since FuncGraphs are able to capture tensors and variables from their
|
||||||
|
# parent graphs, recursively search to see if there is a value stored for
|
||||||
|
# one of the parent graphs.
|
||||||
|
if isinstance(key, func_graph.FuncGraph):
|
||||||
|
return self._get_recursive(self._get_parent_graph(key))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
"""Gets the value at key (or current context), or sets default value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: May be `None` or `Graph`object. When `None`, the key is set to the
|
||||||
|
current context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Either the cached or default value.
|
||||||
|
"""
|
||||||
|
if key is None:
|
||||||
|
key = self._key()
|
||||||
|
|
||||||
|
value = self._get_recursive(key)
|
||||||
|
if value is None:
|
||||||
|
value = self[key] = self.default_factory() # pylint:disable=not-callable
|
||||||
|
return value
|
||||||
|
|
||||||
|
def setdefault(self, key=None, default=None, kwargs=None):
|
||||||
|
"""Sets the default value if key is not in dict, and returns the value."""
|
||||||
|
if key is None:
|
||||||
|
key = self._key()
|
||||||
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
|
if default is None and key not in self:
|
||||||
|
default = self.default_factory(**kwargs)
|
||||||
|
return weakref.WeakKeyDictionary.setdefault(self, key, default)
|
||||||
|
|
||||||
|
# This dictionary holds a mapping {graph: learning_phase}. In eager mode, a
|
||||||
|
# dummy object is used.
|
||||||
|
# A learning phase is a bool tensor used to run Keras models in
|
||||||
|
# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
|
||||||
|
_GRAPH_LEARNING_PHASES = ContextValueCache(_default_learning_phase)
|
||||||
|
|
||||||
|
# This dictionary holds a mapping {graph: set_of_freezable_variables}.
|
||||||
|
# Each set tracks objects created via `freezable_variable` in the graph.
|
||||||
|
_FREEZABLE_VARS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
||||||
|
|
||||||
|
# This dictionary holds a mapping between a graph and variables to initialize
|
||||||
|
# in the graph.
|
||||||
|
_GRAPH_VARIABLES = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
||||||
|
|
||||||
|
# This dictionary holds a mapping between a graph and TF optimizers created in
|
||||||
|
# the graph.
|
||||||
|
_GRAPH_TF_OPTIMIZERS = ContextValueCache(object_identity.ObjectIdentityWeakSet)
|
||||||
|
@ -17,6 +17,8 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import gc
|
||||||
|
|
||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.sparse
|
import scipy.sparse
|
||||||
@ -24,6 +26,7 @@ import scipy.sparse
|
|||||||
from tensorflow.core.protobuf import config_pb2
|
from tensorflow.core.protobuf import config_pb2
|
||||||
from tensorflow.python import keras
|
from tensorflow.python import keras
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import config
|
from tensorflow.python.framework import config
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
@ -2144,5 +2147,52 @@ class ControlOpsTests(test.TestCase):
|
|||||||
keras.backend.switch(keras.backend.equal(x, x), false_func, true_func)
|
keras.backend.switch(keras.backend.equal(x, x), false_func, true_func)
|
||||||
|
|
||||||
|
|
||||||
|
class ContextValueCacheTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_cache(self):
|
||||||
|
cache = keras.backend.ContextValueCache(list)
|
||||||
|
graph1 = ops.Graph()
|
||||||
|
graph2 = ops.Graph()
|
||||||
|
|
||||||
|
cache[graph1].append(1)
|
||||||
|
with graph1.as_default():
|
||||||
|
cache[None].append(2)
|
||||||
|
|
||||||
|
with graph2.as_default():
|
||||||
|
cache[None].append(3)
|
||||||
|
cache[graph2].append(4)
|
||||||
|
|
||||||
|
self.assertAllEqual(cache[graph1], [1, 2])
|
||||||
|
self.assertAllEqual(cache[graph2], [3, 4])
|
||||||
|
|
||||||
|
with context.eager_mode():
|
||||||
|
cache[None].append(5)
|
||||||
|
cache[None].append(6)
|
||||||
|
self.assertAllEqual(cache[None], [5, 6])
|
||||||
|
|
||||||
|
self.assertLen(cache, 3)
|
||||||
|
|
||||||
|
del graph1
|
||||||
|
gc.collect()
|
||||||
|
self.assertLen(cache, 2)
|
||||||
|
|
||||||
|
def test_cache_in_parent_graph(self):
|
||||||
|
cache = keras.backend.ContextValueCache(int)
|
||||||
|
cache.setdefault(None, keras.backend.constant(5))
|
||||||
|
|
||||||
|
with ops.Graph().as_default() as g:
|
||||||
|
# g is not a child graph of the default test context, so the recursive
|
||||||
|
# lookup will create a new default value.
|
||||||
|
self.assertAllEqual(cache[g], 0)
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def fn():
|
||||||
|
# The function graph is a child of the default test context, so
|
||||||
|
# __getitem__ will return the previously saved value.
|
||||||
|
return cache[ops.get_default_graph()]
|
||||||
|
|
||||||
|
self.assertEqual(self.evaluate(fn()), 5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -1088,10 +1088,9 @@ class DropoutRNNCellMixin(object):
|
|||||||
# RNN could be created with `unroll=True`. In that case, the `cell.call()`
|
# RNN could be created with `unroll=True`. In that case, the `cell.call()`
|
||||||
# function will be invoked multiple times, and we want to ensure same mask
|
# function will be invoked multiple times, and we want to ensure same mask
|
||||||
# is used every time.
|
# is used every time.
|
||||||
self._dropout_mask = None
|
self._dropout_mask_cache = K.ContextValueCache(self._create_dropout_mask)
|
||||||
self._recurrent_dropout_mask = None
|
self._recurrent_dropout_mask_cache = K.ContextValueCache(
|
||||||
self._eager_dropout_mask = None
|
self._create_recurrent_dropout_mask)
|
||||||
self._eager_recurrent_dropout_mask = None
|
|
||||||
super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
|
super(DropoutRNNCellMixin, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
def reset_dropout_mask(self):
|
def reset_dropout_mask(self):
|
||||||
@ -1103,8 +1102,7 @@ class DropoutRNNCellMixin(object):
|
|||||||
be cached between batches. Otherwise it will introduce unreasonable bias
|
be cached between batches. Otherwise it will introduce unreasonable bias
|
||||||
against certain index of data within the batch.
|
against certain index of data within the batch.
|
||||||
"""
|
"""
|
||||||
self._dropout_mask = None
|
self._dropout_mask_cache.clear()
|
||||||
self._eager_dropout_mask = None
|
|
||||||
|
|
||||||
def reset_recurrent_dropout_mask(self):
|
def reset_recurrent_dropout_mask(self):
|
||||||
"""Reset the cached recurrent dropout masks if any.
|
"""Reset the cached recurrent dropout masks if any.
|
||||||
@ -1115,8 +1113,21 @@ class DropoutRNNCellMixin(object):
|
|||||||
be cached between batches. Otherwise it will introduce unreasonable bias
|
be cached between batches. Otherwise it will introduce unreasonable bias
|
||||||
against certain index of data within the batch.
|
against certain index of data within the batch.
|
||||||
"""
|
"""
|
||||||
self._recurrent_dropout_mask = None
|
self._recurrent_dropout_mask_cache.clear()
|
||||||
self._eager_recurrent_dropout_mask = None
|
|
||||||
|
def _create_dropout_mask(self, inputs, training, count=1):
|
||||||
|
return _generate_dropout_mask(
|
||||||
|
array_ops.ones_like(inputs),
|
||||||
|
self.dropout,
|
||||||
|
training=training,
|
||||||
|
count=count)
|
||||||
|
|
||||||
|
def _create_recurrent_dropout_mask(self, inputs, training, count=1):
|
||||||
|
return _generate_dropout_mask(
|
||||||
|
array_ops.ones_like(inputs),
|
||||||
|
self.recurrent_dropout,
|
||||||
|
training=training,
|
||||||
|
count=count)
|
||||||
|
|
||||||
def get_dropout_mask_for_cell(self, inputs, training, count=1):
|
def get_dropout_mask_for_cell(self, inputs, training, count=1):
|
||||||
"""Get the dropout mask for RNN cell's input.
|
"""Get the dropout mask for RNN cell's input.
|
||||||
@ -1136,23 +1147,8 @@ class DropoutRNNCellMixin(object):
|
|||||||
"""
|
"""
|
||||||
if self.dropout == 0:
|
if self.dropout == 0:
|
||||||
return None
|
return None
|
||||||
if (not context.executing_eagerly() and self._dropout_mask is None
|
init_kwargs = dict(inputs=inputs, training=training, count=count)
|
||||||
or context.executing_eagerly() and self._eager_dropout_mask is None):
|
return self._dropout_mask_cache.setdefault(kwargs=init_kwargs)
|
||||||
# Generate new mask and cache it based on context.
|
|
||||||
dp_mask = _generate_dropout_mask(
|
|
||||||
array_ops.ones_like(inputs),
|
|
||||||
self.dropout,
|
|
||||||
training=training,
|
|
||||||
count=count)
|
|
||||||
if context.executing_eagerly():
|
|
||||||
self._eager_dropout_mask = dp_mask
|
|
||||||
else:
|
|
||||||
self._dropout_mask = dp_mask
|
|
||||||
else:
|
|
||||||
# Reuse the existing mask.
|
|
||||||
dp_mask = (self._eager_dropout_mask
|
|
||||||
if context.executing_eagerly() else self._dropout_mask)
|
|
||||||
return dp_mask
|
|
||||||
|
|
||||||
def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
|
def get_recurrent_dropout_mask_for_cell(self, inputs, training, count=1):
|
||||||
"""Get the recurrent dropout mask for RNN cell.
|
"""Get the recurrent dropout mask for RNN cell.
|
||||||
@ -1172,25 +1168,8 @@ class DropoutRNNCellMixin(object):
|
|||||||
"""
|
"""
|
||||||
if self.recurrent_dropout == 0:
|
if self.recurrent_dropout == 0:
|
||||||
return None
|
return None
|
||||||
if (not context.executing_eagerly() and self._recurrent_dropout_mask is None
|
init_kwargs = dict(inputs=inputs, training=training, count=count)
|
||||||
or context.executing_eagerly()
|
return self._recurrent_dropout_mask_cache.setdefault(kwargs=init_kwargs)
|
||||||
and self._eager_recurrent_dropout_mask is None):
|
|
||||||
# Generate new mask and cache it based on context.
|
|
||||||
rec_dp_mask = _generate_dropout_mask(
|
|
||||||
array_ops.ones_like(inputs),
|
|
||||||
self.recurrent_dropout,
|
|
||||||
training=training,
|
|
||||||
count=count)
|
|
||||||
if context.executing_eagerly():
|
|
||||||
self._eager_recurrent_dropout_mask = rec_dp_mask
|
|
||||||
else:
|
|
||||||
self._recurrent_dropout_mask = rec_dp_mask
|
|
||||||
else:
|
|
||||||
# Reuse the existing mask.
|
|
||||||
rec_dp_mask = (self._eager_recurrent_dropout_mask
|
|
||||||
if context.executing_eagerly()
|
|
||||||
else self._recurrent_dropout_mask)
|
|
||||||
return rec_dp_mask
|
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.layers.SimpleRNNCell')
|
@keras_export('keras.layers.SimpleRNNCell')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user