Propagate training
value from parent layer to child layer as the default in
custom training loops. `training` arugment will be passed to a Layer's `call` in this order of priority: 1) A value explicitly passed to `__call__`. 2) The value passed to the parent Layer (the layer that calls this Layer). 3) The currently active `learning_phase` value, if it has been set by the user or if using `fit`/`evalute`/`predict`. 4) The default value in the Layer's `call`. PiperOrigin-RevId: 252000169
This commit is contained in:
parent
09d478868f
commit
89b1e717c0
@ -293,6 +293,10 @@ def learning_phase():
|
|||||||
return symbolic_learning_phase()
|
return symbolic_learning_phase()
|
||||||
|
|
||||||
|
|
||||||
|
def global_learning_phase_is_set():
|
||||||
|
return _DUMMY_EAGER_GRAPH in _GRAPH_LEARNING_PHASES
|
||||||
|
|
||||||
|
|
||||||
def symbolic_learning_phase():
|
def symbolic_learning_phase():
|
||||||
graph = get_graph()
|
graph = get_graph()
|
||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
@ -325,18 +329,6 @@ def set_learning_phase(value):
|
|||||||
_GRAPH_LEARNING_PHASES[get_graph()] = value
|
_GRAPH_LEARNING_PHASES[get_graph()] = value
|
||||||
|
|
||||||
|
|
||||||
def set_eager_learning_phase(value):
|
|
||||||
"""Internal utility that sets the learning phase in eager execution only.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
value: Learning phase value, either 0 or 1 (integers).
|
|
||||||
"""
|
|
||||||
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
|
|
||||||
assert value in {0, 1}
|
|
||||||
assert context.executing_eagerly()
|
|
||||||
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
|
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.backend.learning_phase_scope')
|
@keras_export('keras.backend.learning_phase_scope')
|
||||||
@tf_contextlib.contextmanager
|
@tf_contextlib.contextmanager
|
||||||
def learning_phase_scope(value):
|
def learning_phase_scope(value):
|
||||||
@ -381,9 +373,10 @@ def learning_phase_scope(value):
|
|||||||
elif graph in _GRAPH_LEARNING_PHASES:
|
elif graph in _GRAPH_LEARNING_PHASES:
|
||||||
del _GRAPH_LEARNING_PHASES[graph]
|
del _GRAPH_LEARNING_PHASES[graph]
|
||||||
|
|
||||||
|
|
||||||
@tf_contextlib.contextmanager
|
@tf_contextlib.contextmanager
|
||||||
def eager_learning_phase_scope(value):
|
def eager_learning_phase_scope(value):
|
||||||
"""Internal scope that sets the learning phase in eager execution only.
|
"""Internal scope that sets the learning phase in eager / tf.function only.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
value: Learning phase value, either 0 or 1 (integers).
|
value: Learning phase value, either 0 or 1 (integers).
|
||||||
@ -397,13 +390,18 @@ def eager_learning_phase_scope(value):
|
|||||||
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
|
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
|
||||||
assert value in {0, 1}
|
assert value in {0, 1}
|
||||||
assert ops.executing_eagerly_outside_functions()
|
assert ops.executing_eagerly_outside_functions()
|
||||||
previous_value = learning_phase()
|
global_learning_phase_was_set = global_learning_phase_is_set()
|
||||||
|
if global_learning_phase_was_set:
|
||||||
|
previous_value = learning_phase()
|
||||||
try:
|
try:
|
||||||
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
|
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
# Restore learning phase to initial value.
|
# Restore learning phase to initial value or unset.
|
||||||
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
|
if global_learning_phase_was_set:
|
||||||
|
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
|
||||||
|
else:
|
||||||
|
del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
|
||||||
|
|
||||||
|
|
||||||
def _current_graph(op_input_list):
|
def _current_graph(op_input_list):
|
||||||
|
@ -58,6 +58,16 @@ class LayerWithMetrics(keras.layers.Layer):
|
|||||||
return inputs
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
class LayerWithTrainingArg(keras.layers.Layer):
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
self.training = training
|
||||||
|
if training:
|
||||||
|
return inputs
|
||||||
|
else:
|
||||||
|
return 0. * inputs
|
||||||
|
|
||||||
|
|
||||||
def add_loss_step(defun):
|
def add_loss_step(defun):
|
||||||
optimizer = keras.optimizer_v2.adam.Adam()
|
optimizer = keras.optimizer_v2.adam.Adam()
|
||||||
model = testing_utils.get_model_from_layers([LayerWithLosses()],
|
model = testing_utils.get_model_from_layers([LayerWithLosses()],
|
||||||
@ -142,6 +152,93 @@ class CustomTrainingLoopTest(keras_parameterized.TestCase):
|
|||||||
fn_result = train_step(defun=True)
|
fn_result = train_step(defun=True)
|
||||||
self.assertAllClose(eager_result, fn_result)
|
self.assertAllClose(eager_result, fn_result)
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eager', False), ('defun', True))
|
||||||
|
def test_training_arg_propagation(self, defun):
|
||||||
|
|
||||||
|
model = testing_utils.get_model_from_layers([LayerWithTrainingArg()],
|
||||||
|
input_shape=(1,))
|
||||||
|
|
||||||
|
def train_step(x):
|
||||||
|
return model(x), model(x, training=False), model(x, training=True)
|
||||||
|
|
||||||
|
if defun:
|
||||||
|
train_step = def_function.function(train_step)
|
||||||
|
|
||||||
|
x = array_ops.ones((1, 1))
|
||||||
|
results = train_step(x)
|
||||||
|
self.assertAllClose(results[0], array_ops.zeros((1, 1)))
|
||||||
|
self.assertAllClose(results[1], array_ops.zeros((1, 1)))
|
||||||
|
self.assertAllClose(results[2], array_ops.ones((1, 1)))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eager', False), ('defun', True))
|
||||||
|
def test_learning_phase_propagation(self, defun):
|
||||||
|
|
||||||
|
class MyModel(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModel, self).__init__()
|
||||||
|
self.layer = LayerWithTrainingArg()
|
||||||
|
|
||||||
|
def call(self, inputs):
|
||||||
|
return self.layer(inputs)
|
||||||
|
|
||||||
|
model = MyModel()
|
||||||
|
|
||||||
|
def train_step(x):
|
||||||
|
no_learning_phase_out = model(x)
|
||||||
|
self.assertIsNone(model.layer.training)
|
||||||
|
with keras.backend.learning_phase_scope(0):
|
||||||
|
inf_learning_phase_out = model(x)
|
||||||
|
self.assertEqual(model.layer.training, 0)
|
||||||
|
with keras.backend.learning_phase_scope(1):
|
||||||
|
train_learning_phase_out = model(x)
|
||||||
|
self.assertEqual(model.layer.training, 1)
|
||||||
|
return [
|
||||||
|
no_learning_phase_out, inf_learning_phase_out,
|
||||||
|
train_learning_phase_out
|
||||||
|
]
|
||||||
|
|
||||||
|
if defun:
|
||||||
|
train_step = def_function.function(train_step)
|
||||||
|
|
||||||
|
x = array_ops.ones((1, 1))
|
||||||
|
results = train_step(x)
|
||||||
|
self.assertAllClose(results[0], array_ops.zeros((1, 1)))
|
||||||
|
self.assertAllClose(results[1], array_ops.zeros((1, 1)))
|
||||||
|
self.assertAllClose(results[2], array_ops.ones((1, 1)))
|
||||||
|
|
||||||
|
@parameterized.named_parameters(('eager', False), ('defun', True))
|
||||||
|
def test_training_arg_priorities(self, defun):
|
||||||
|
|
||||||
|
class MyModel(keras.layers.Layer):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(MyModel, self).__init__()
|
||||||
|
self.layer = LayerWithTrainingArg()
|
||||||
|
|
||||||
|
def call(self, inputs, training=False):
|
||||||
|
return self.layer(inputs)
|
||||||
|
|
||||||
|
model = MyModel()
|
||||||
|
|
||||||
|
def train_step(x):
|
||||||
|
explicit_out = model(x, training=True)
|
||||||
|
default_out = model(x)
|
||||||
|
with keras.backend.learning_phase_scope(1):
|
||||||
|
parent_out = model(x, training=False)
|
||||||
|
lr_out = model(x)
|
||||||
|
return [explicit_out, default_out, parent_out, lr_out]
|
||||||
|
|
||||||
|
if defun:
|
||||||
|
train_step = def_function.function(train_step)
|
||||||
|
|
||||||
|
x = array_ops.ones((1, 1))
|
||||||
|
results = train_step(x)
|
||||||
|
self.assertAllClose(results[0], array_ops.ones((1, 1)))
|
||||||
|
self.assertAllClose(results[1], array_ops.zeros((1, 1)))
|
||||||
|
self.assertAllClose(results[2], array_ops.zeros((1, 1)))
|
||||||
|
self.assertAllClose(results[3], array_ops.ones((1, 1)))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
@ -212,8 +212,10 @@ class Layer(module.Module):
|
|||||||
self._outbound_nodes = []
|
self._outbound_nodes = []
|
||||||
|
|
||||||
call_fn_args = self._call_fn_args
|
call_fn_args = self._call_fn_args
|
||||||
self._expects_training_arg = 'training' in call_fn_args
|
self._expects_training_arg = ('training' in call_fn_args or
|
||||||
self._expects_mask_arg = 'mask' in call_fn_args
|
self._call_accepts_kwargs)
|
||||||
|
self._expects_mask_arg = ('mask' in call_fn_args or
|
||||||
|
self._call_accepts_kwargs)
|
||||||
|
|
||||||
# Whether the `call` method can be used to build a TF graph without issues.
|
# Whether the `call` method can be used to build a TF graph without issues.
|
||||||
self._dynamic = dynamic
|
self._dynamic = dynamic
|
||||||
@ -563,8 +565,15 @@ class Layer(module.Module):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: if the layer's `call` method returns None (an invalid value).
|
ValueError: if the layer's `call` method returns None (an invalid value).
|
||||||
"""
|
"""
|
||||||
|
call_context = base_layer_utils.call_context()
|
||||||
input_list = nest.flatten(inputs)
|
input_list = nest.flatten(inputs)
|
||||||
|
|
||||||
|
# We will attempt to build a TF graph if & only if all inputs are symbolic.
|
||||||
|
# This is always the case in graph mode. It can also be the case in eager
|
||||||
|
# mode when all inputs can be traced back to `keras.Input()` (when building
|
||||||
|
# models using the functional API).
|
||||||
|
build_graph = tf_utils.are_all_symbolic_tensors(input_list)
|
||||||
|
|
||||||
# Accept NumPy and scalar inputs by converting to Tensors.
|
# Accept NumPy and scalar inputs by converting to Tensors.
|
||||||
if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
|
if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
|
||||||
def _convert_non_tensor(x):
|
def _convert_non_tensor(x):
|
||||||
@ -585,11 +594,31 @@ class Layer(module.Module):
|
|||||||
not self._call_arg_was_passed('mask', args, kwargs)):
|
not self._call_arg_was_passed('mask', args, kwargs)):
|
||||||
kwargs['mask'] = input_masks
|
kwargs['mask'] = input_masks
|
||||||
|
|
||||||
# We will attempt to build a TF graph if & only if all inputs are symbolic.
|
# If `training` argument was not explicitly passed, propagate `training`
|
||||||
# This is always the case in graph mode. It can also be the case in eager
|
# value from this layer's calling layer.
|
||||||
# mode when all inputs can be traced back to `keras.Input()` (when building
|
training_arg_passed_by_framework = False
|
||||||
# models using the functional API).
|
# Priority 1: `training` was explicitly passed.
|
||||||
build_graph = tf_utils.are_all_symbolic_tensors(input_list)
|
if self._call_arg_was_passed('training', args, kwargs):
|
||||||
|
training_value = self._get_call_arg_value('training', args, kwargs)
|
||||||
|
if not self._expects_training_arg:
|
||||||
|
kwargs.pop('training')
|
||||||
|
else:
|
||||||
|
training_value = None
|
||||||
|
# Priority 2: `training` was passed to a parent layer.
|
||||||
|
if call_context.training is not None:
|
||||||
|
training_value = call_context.training
|
||||||
|
# Priority 3a: `learning_phase()` has been set.
|
||||||
|
elif backend.global_learning_phase_is_set():
|
||||||
|
training_value = backend.learning_phase()
|
||||||
|
# Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph.
|
||||||
|
elif build_graph:
|
||||||
|
with backend.get_graph().as_default():
|
||||||
|
if base_layer_utils.is_in_keras_graph():
|
||||||
|
training_value = backend.learning_phase()
|
||||||
|
|
||||||
|
if self._expects_training_arg and training_value is not None:
|
||||||
|
kwargs['training'] = training_value
|
||||||
|
training_arg_passed_by_framework = True
|
||||||
|
|
||||||
# Only create Keras history if at least one tensor originates from a
|
# Only create Keras history if at least one tensor originates from a
|
||||||
# `keras.Input`. Otherwise this Layer may be being used outside the Keras
|
# `keras.Input`. Otherwise this Layer may be being used outside the Keras
|
||||||
@ -599,12 +628,12 @@ class Layer(module.Module):
|
|||||||
|
|
||||||
# Clear eager losses on top level model call.
|
# Clear eager losses on top level model call.
|
||||||
# We are clearing the losses only on the top level model call and not on
|
# We are clearing the losses only on the top level model call and not on
|
||||||
# every layer/mode call because layer/model may be reused.
|
# every layer/model call because layer/model may be reused.
|
||||||
if (base_layer_utils.is_in_eager_or_tf_function() and
|
if (base_layer_utils.is_in_eager_or_tf_function() and
|
||||||
not base_layer_utils.call_context().in_call):
|
not call_context.in_call):
|
||||||
self._clear_losses()
|
self._clear_losses()
|
||||||
|
|
||||||
with base_layer_utils.call_context().enter(self, inputs, build_graph):
|
with call_context.enter(self, inputs, build_graph, training_value):
|
||||||
# Check input assumptions set after layer building, e.g. input shape.
|
# Check input assumptions set after layer building, e.g. input shape.
|
||||||
if build_graph:
|
if build_graph:
|
||||||
# Symbolic execution on symbolic tensors. We will attempt to build
|
# Symbolic execution on symbolic tensors. We will attempt to build
|
||||||
@ -636,21 +665,6 @@ class Layer(module.Module):
|
|||||||
else:
|
else:
|
||||||
call_fn = self.call
|
call_fn = self.call
|
||||||
|
|
||||||
# Explicitly pass the learning phase placeholder to `call` if
|
|
||||||
# the `training` argument was left unspecified by the user.
|
|
||||||
# This behavior is restricted to the managed Keras FuncGraph.
|
|
||||||
# TODO(omalleyt): Reconcile this with new `trainable` behavior
|
|
||||||
# when available.
|
|
||||||
learning_phase_passed_by_framework = False
|
|
||||||
if (base_layer_utils.is_in_keras_graph() and
|
|
||||||
self._expects_training_arg):
|
|
||||||
training_arg = None
|
|
||||||
if self._call_arg_was_passed('training', args, kwargs):
|
|
||||||
training_arg = self._get_call_arg_value('training', args, kwargs)
|
|
||||||
if training_arg is None:
|
|
||||||
learning_phase_passed_by_framework = True
|
|
||||||
kwargs['training'] = backend.learning_phase()
|
|
||||||
|
|
||||||
if not self.dynamic:
|
if not self.dynamic:
|
||||||
try:
|
try:
|
||||||
with base_layer_utils.autocast_context_manager(
|
with base_layer_utils.autocast_context_manager(
|
||||||
@ -692,7 +706,7 @@ class Layer(module.Module):
|
|||||||
'Tensor or a list of Tensors, not None '
|
'Tensor or a list of Tensors, not None '
|
||||||
'(layer: ' + self.name + ').')
|
'(layer: ' + self.name + ').')
|
||||||
if base_layer_utils.have_all_keras_metadata(inputs):
|
if base_layer_utils.have_all_keras_metadata(inputs):
|
||||||
if learning_phase_passed_by_framework:
|
if training_arg_passed_by_framework:
|
||||||
kwargs.pop('training')
|
kwargs.pop('training')
|
||||||
inputs, outputs = self._set_connectivity_metadata_(
|
inputs, outputs = self._set_connectivity_metadata_(
|
||||||
inputs, outputs, args, kwargs)
|
inputs, outputs, args, kwargs)
|
||||||
@ -2131,6 +2145,11 @@ class Layer(module.Module):
|
|||||||
return all_args[1:]
|
return all_args[1:]
|
||||||
return all_args
|
return all_args
|
||||||
|
|
||||||
|
@property
|
||||||
|
@tracking.cached_per_instance
|
||||||
|
def _call_accepts_kwargs(self):
|
||||||
|
return tf_inspect.getfullargspec(self.call).varkw is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@tracking.cached_per_instance
|
@tracking.cached_per_instance
|
||||||
def _should_compute_mask(self):
|
def _should_compute_mask(self):
|
||||||
|
@ -378,6 +378,7 @@ class CallContext(object):
|
|||||||
frozen: Whether currently executing inside a `Layer` with `trainable` set to
|
frozen: Whether currently executing inside a `Layer` with `trainable` set to
|
||||||
`False`.
|
`False`.
|
||||||
in_call: Whether currently inside the `call` of a Layer.
|
in_call: Whether currently inside the `call` of a Layer.
|
||||||
|
training: Whether currently executing in training or inference mode.
|
||||||
in_keras_graph: Whether executing inside the Keras Graph.
|
in_keras_graph: Whether executing inside the Keras Graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -386,25 +387,29 @@ class CallContext(object):
|
|||||||
self.inputs = None
|
self.inputs = None
|
||||||
self.frozen = False
|
self.frozen = False
|
||||||
self.in_call = False
|
self.in_call = False
|
||||||
|
self.training = None
|
||||||
self._in_keras_graph = False
|
self._in_keras_graph = False
|
||||||
|
|
||||||
@tf_contextlib.contextmanager
|
@tf_contextlib.contextmanager
|
||||||
def enter(self, layer, inputs, build_graph):
|
def enter(self, layer, inputs, build_graph, training):
|
||||||
"""Push a Layer and its inputs and state onto the current call context."""
|
"""Push a Layer and its inputs and state onto the current call context."""
|
||||||
prev_layer = self.layer
|
prev_layer = self.layer
|
||||||
prev_inputs = self.inputs
|
prev_inputs = self.inputs
|
||||||
prev_frozen = self.frozen
|
prev_frozen = self.frozen
|
||||||
prev_in_call = self.in_call
|
prev_in_call = self.in_call
|
||||||
|
prev_training = self.training
|
||||||
prev_in_keras_graph = self._in_keras_graph
|
prev_in_keras_graph = self._in_keras_graph
|
||||||
|
|
||||||
self.layer = layer
|
self.layer = layer
|
||||||
self.inputs = inputs
|
self.inputs = inputs
|
||||||
self.frozen = self.frozen or not layer.trainable
|
self.frozen = self.frozen or not layer.trainable
|
||||||
self.in_call = True
|
self.in_call = True
|
||||||
|
self.training = training
|
||||||
self._in_keras_graph = (
|
self._in_keras_graph = (
|
||||||
self._in_keras_graph or
|
self._in_keras_graph or
|
||||||
(build_graph and
|
(build_graph and
|
||||||
getattr(backend.get_graph(), 'name', None) == 'keras_graph'))
|
getattr(backend.get_graph(), 'name', None) == 'keras_graph'))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
@ -412,6 +417,7 @@ class CallContext(object):
|
|||||||
self.inputs = prev_inputs
|
self.inputs = prev_inputs
|
||||||
self.frozen = prev_frozen
|
self.frozen = prev_frozen
|
||||||
self.in_call = prev_in_call
|
self.in_call = prev_in_call
|
||||||
|
self.training = prev_training
|
||||||
self._in_keras_graph = prev_in_keras_graph
|
self._in_keras_graph = prev_in_keras_graph
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -372,8 +372,10 @@ class Network(base_layer.Layer):
|
|||||||
def _init_subclassed_network(self, name=None, **kwargs):
|
def _init_subclassed_network(self, name=None, **kwargs):
|
||||||
self._base_init(name=name, **kwargs)
|
self._base_init(name=name, **kwargs)
|
||||||
self._is_graph_network = False
|
self._is_graph_network = False
|
||||||
self._expects_training_arg = 'training' in self._call_fn_args
|
self._expects_training_arg = ('training' in self._call_fn_args or
|
||||||
self._expects_mask_arg = 'mask' in self._call_fn_args
|
self._call_accepts_kwargs)
|
||||||
|
self._expects_mask_arg = ('mask' in self._call_fn_args or
|
||||||
|
self._call_accepts_kwargs)
|
||||||
call_argspec = tf_inspect.getfullargspec(self.call)
|
call_argspec = tf_inspect.getfullargspec(self.call)
|
||||||
self._call_convention = self._determine_call_convention(call_argspec)
|
self._call_convention = self._determine_call_convention(call_argspec)
|
||||||
self.outputs = []
|
self.outputs = []
|
||||||
|
@ -2554,7 +2554,17 @@ class Model(network.Network):
|
|||||||
inputs = self._set_input_attrs(inputs)
|
inputs = self._set_input_attrs(inputs)
|
||||||
|
|
||||||
if outputs is None:
|
if outputs is None:
|
||||||
kwargs = {'training': training} if self._expects_training_arg else {}
|
kwargs = {}
|
||||||
|
if self._expects_training_arg:
|
||||||
|
# In V2 mode, feeding `training=None` is not allowed because any value
|
||||||
|
# explicitly passed by the user is respected, even `None`, and in this
|
||||||
|
# case if the user has not passed a value in V2 we need to replace
|
||||||
|
# `None` with the `learning_phase()`. In V1, `training=None` is needed
|
||||||
|
# so that `Dropout` and `BatchNormalization` replace `None` values with
|
||||||
|
# the `learning_phase()` in their `call`.
|
||||||
|
if (training is not None or
|
||||||
|
not ops.executing_eagerly_outside_functions()):
|
||||||
|
kwargs['training'] = training
|
||||||
try:
|
try:
|
||||||
outputs = self(inputs, **kwargs)
|
outputs = self(inputs, **kwargs)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
|
@ -188,8 +188,9 @@ def model_iteration(model,
|
|||||||
|
|
||||||
should_set_learning_phase = context.executing_eagerly() and model.run_eagerly
|
should_set_learning_phase = context.executing_eagerly() and model.run_eagerly
|
||||||
if should_set_learning_phase:
|
if should_set_learning_phase:
|
||||||
old_learning_phase = backend.learning_phase()
|
learning_phase_scope = backend.eager_learning_phase_scope(
|
||||||
backend.set_eager_learning_phase(1 if mode == ModeKeys.TRAIN else 0)
|
1 if mode == ModeKeys.TRAIN else 0)
|
||||||
|
learning_phase_scope.__enter__()
|
||||||
|
|
||||||
callbacks.model.stop_training = False
|
callbacks.model.stop_training = False
|
||||||
callbacks._call_begin_hook(mode)
|
callbacks._call_begin_hook(mode)
|
||||||
@ -341,7 +342,7 @@ def model_iteration(model,
|
|||||||
enqueuer.stop()
|
enqueuer.stop()
|
||||||
|
|
||||||
if should_set_learning_phase:
|
if should_set_learning_phase:
|
||||||
backend.set_eager_learning_phase(old_learning_phase)
|
learning_phase_scope.__exit__(None, None, None)
|
||||||
|
|
||||||
if mode == ModeKeys.TRAIN:
|
if mode == ModeKeys.TRAIN:
|
||||||
return model.history
|
return model.history
|
||||||
|
@ -610,7 +610,7 @@ class RNNTest(keras_parameterized.TestCase):
|
|||||||
update_2 = state_ops.assign_add(cells[0].kernel,
|
update_2 = state_ops.assign_add(cells[0].kernel,
|
||||||
array_ops.ones_like(cells[0].kernel))
|
array_ops.ones_like(cells[0].kernel))
|
||||||
# TODO(b/128682878): Remove when RNNCells are __call__'d.
|
# TODO(b/128682878): Remove when RNNCells are __call__'d.
|
||||||
with base_layer_utils.call_context().enter(layer, x, True):
|
with base_layer_utils.call_context().enter(layer, x, True, None):
|
||||||
cells[0].add_update(update_1, inputs=x)
|
cells[0].add_update(update_1, inputs=x)
|
||||||
cells[0].add_update(update_2)
|
cells[0].add_update(update_2)
|
||||||
self.assertEqual(len(layer.updates), 2)
|
self.assertEqual(len(layer.updates), 2)
|
||||||
|
@ -611,7 +611,7 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
|
|||||||
assert not layer.get_updates_for(None)
|
assert not layer.get_updates_for(None)
|
||||||
assert not layer.get_updates_for(x)
|
assert not layer.get_updates_for(x)
|
||||||
# TODO(b/128684069): Remove when Wrapper sublayers are __call__'d.
|
# TODO(b/128684069): Remove when Wrapper sublayers are __call__'d.
|
||||||
with base_layer_utils.call_context().enter(layer, x, True):
|
with base_layer_utils.call_context().enter(layer, x, True, None):
|
||||||
layer.forward_layer.add_update(x_reachable_update, inputs=x)
|
layer.forward_layer.add_update(x_reachable_update, inputs=x)
|
||||||
layer.forward_layer.add_update(1, inputs=None)
|
layer.forward_layer.add_update(1, inputs=None)
|
||||||
layer.backward_layer.add_update(x_reachable_update, inputs=x)
|
layer.backward_layer.add_update(x_reachable_update, inputs=x)
|
||||||
|
@ -843,7 +843,7 @@ def _wrap_layer_functions(layer, serialization_cache,
|
|||||||
# Manually trigger traces before restoring the overwritten functions. The
|
# Manually trigger traces before restoring the overwritten functions. The
|
||||||
# functions are traced within the layer call context to ensure that layer
|
# functions are traced within the layer call context to ensure that layer
|
||||||
# functions (e.g. add_loss) behave as though running in graph mode.
|
# functions (e.g. add_loss) behave as though running in graph mode.
|
||||||
with base_layer_utils.call_context().enter(layer, None, build_graph=True):
|
with base_layer_utils.call_context().enter(layer, None, True, None):
|
||||||
for fn in fns.values():
|
for fn in fns.values():
|
||||||
if fn is not None and fn.input_signature is not None:
|
if fn is not None and fn.input_signature is not None:
|
||||||
fn.get_concrete_function()
|
fn.get_concrete_function()
|
||||||
|
Loading…
Reference in New Issue
Block a user