Propagate training value from parent layer to child layer as the default in

custom training loops.

`training` arugment will be passed to a Layer's `call` in this order of priority:

1) A value explicitly passed to `__call__`.
2) The value passed to the parent Layer (the layer that calls this Layer).
3) The currently active `learning_phase` value, if it has been set by the user
or if using `fit`/`evalute`/`predict`.
4) The default value in the Layer's `call`.

PiperOrigin-RevId: 252000169
This commit is contained in:
Thomas O'Malley 2019-06-07 00:07:50 -07:00 committed by TensorFlower Gardener
parent 09d478868f
commit 89b1e717c0
10 changed files with 185 additions and 52 deletions

View File

@ -293,6 +293,10 @@ def learning_phase():
return symbolic_learning_phase()
def global_learning_phase_is_set():
return _DUMMY_EAGER_GRAPH in _GRAPH_LEARNING_PHASES
def symbolic_learning_phase():
graph = get_graph()
with graph.as_default():
@ -325,18 +329,6 @@ def set_learning_phase(value):
_GRAPH_LEARNING_PHASES[get_graph()] = value
def set_eager_learning_phase(value):
"""Internal utility that sets the learning phase in eager execution only.
Arguments:
value: Learning phase value, either 0 or 1 (integers).
"""
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
assert value in {0, 1}
assert context.executing_eagerly()
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
@keras_export('keras.backend.learning_phase_scope')
@tf_contextlib.contextmanager
def learning_phase_scope(value):
@ -381,9 +373,10 @@ def learning_phase_scope(value):
elif graph in _GRAPH_LEARNING_PHASES:
del _GRAPH_LEARNING_PHASES[graph]
@tf_contextlib.contextmanager
def eager_learning_phase_scope(value):
"""Internal scope that sets the learning phase in eager execution only.
"""Internal scope that sets the learning phase in eager / tf.function only.
Arguments:
value: Learning phase value, either 0 or 1 (integers).
@ -397,13 +390,18 @@ def eager_learning_phase_scope(value):
global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned
assert value in {0, 1}
assert ops.executing_eagerly_outside_functions()
previous_value = learning_phase()
global_learning_phase_was_set = global_learning_phase_is_set()
if global_learning_phase_was_set:
previous_value = learning_phase()
try:
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
yield
finally:
# Restore learning phase to initial value.
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
# Restore learning phase to initial value or unset.
if global_learning_phase_was_set:
_GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
else:
del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
def _current_graph(op_input_list):

View File

@ -58,6 +58,16 @@ class LayerWithMetrics(keras.layers.Layer):
return inputs
class LayerWithTrainingArg(keras.layers.Layer):
def call(self, inputs, training=None):
self.training = training
if training:
return inputs
else:
return 0. * inputs
def add_loss_step(defun):
optimizer = keras.optimizer_v2.adam.Adam()
model = testing_utils.get_model_from_layers([LayerWithLosses()],
@ -142,6 +152,93 @@ class CustomTrainingLoopTest(keras_parameterized.TestCase):
fn_result = train_step(defun=True)
self.assertAllClose(eager_result, fn_result)
@parameterized.named_parameters(('eager', False), ('defun', True))
def test_training_arg_propagation(self, defun):
model = testing_utils.get_model_from_layers([LayerWithTrainingArg()],
input_shape=(1,))
def train_step(x):
return model(x), model(x, training=False), model(x, training=True)
if defun:
train_step = def_function.function(train_step)
x = array_ops.ones((1, 1))
results = train_step(x)
self.assertAllClose(results[0], array_ops.zeros((1, 1)))
self.assertAllClose(results[1], array_ops.zeros((1, 1)))
self.assertAllClose(results[2], array_ops.ones((1, 1)))
@parameterized.named_parameters(('eager', False), ('defun', True))
def test_learning_phase_propagation(self, defun):
class MyModel(keras.layers.Layer):
def __init__(self):
super(MyModel, self).__init__()
self.layer = LayerWithTrainingArg()
def call(self, inputs):
return self.layer(inputs)
model = MyModel()
def train_step(x):
no_learning_phase_out = model(x)
self.assertIsNone(model.layer.training)
with keras.backend.learning_phase_scope(0):
inf_learning_phase_out = model(x)
self.assertEqual(model.layer.training, 0)
with keras.backend.learning_phase_scope(1):
train_learning_phase_out = model(x)
self.assertEqual(model.layer.training, 1)
return [
no_learning_phase_out, inf_learning_phase_out,
train_learning_phase_out
]
if defun:
train_step = def_function.function(train_step)
x = array_ops.ones((1, 1))
results = train_step(x)
self.assertAllClose(results[0], array_ops.zeros((1, 1)))
self.assertAllClose(results[1], array_ops.zeros((1, 1)))
self.assertAllClose(results[2], array_ops.ones((1, 1)))
@parameterized.named_parameters(('eager', False), ('defun', True))
def test_training_arg_priorities(self, defun):
class MyModel(keras.layers.Layer):
def __init__(self):
super(MyModel, self).__init__()
self.layer = LayerWithTrainingArg()
def call(self, inputs, training=False):
return self.layer(inputs)
model = MyModel()
def train_step(x):
explicit_out = model(x, training=True)
default_out = model(x)
with keras.backend.learning_phase_scope(1):
parent_out = model(x, training=False)
lr_out = model(x)
return [explicit_out, default_out, parent_out, lr_out]
if defun:
train_step = def_function.function(train_step)
x = array_ops.ones((1, 1))
results = train_step(x)
self.assertAllClose(results[0], array_ops.ones((1, 1)))
self.assertAllClose(results[1], array_ops.zeros((1, 1)))
self.assertAllClose(results[2], array_ops.zeros((1, 1)))
self.assertAllClose(results[3], array_ops.ones((1, 1)))
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -212,8 +212,10 @@ class Layer(module.Module):
self._outbound_nodes = []
call_fn_args = self._call_fn_args
self._expects_training_arg = 'training' in call_fn_args
self._expects_mask_arg = 'mask' in call_fn_args
self._expects_training_arg = ('training' in call_fn_args or
self._call_accepts_kwargs)
self._expects_mask_arg = ('mask' in call_fn_args or
self._call_accepts_kwargs)
# Whether the `call` method can be used to build a TF graph without issues.
self._dynamic = dynamic
@ -563,8 +565,15 @@ class Layer(module.Module):
Raises:
ValueError: if the layer's `call` method returns None (an invalid value).
"""
call_context = base_layer_utils.call_context()
input_list = nest.flatten(inputs)
# We will attempt to build a TF graph if & only if all inputs are symbolic.
# This is always the case in graph mode. It can also be the case in eager
# mode when all inputs can be traced back to `keras.Input()` (when building
# models using the functional API).
build_graph = tf_utils.are_all_symbolic_tensors(input_list)
# Accept NumPy and scalar inputs by converting to Tensors.
if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
def _convert_non_tensor(x):
@ -585,11 +594,31 @@ class Layer(module.Module):
not self._call_arg_was_passed('mask', args, kwargs)):
kwargs['mask'] = input_masks
# We will attempt to build a TF graph if & only if all inputs are symbolic.
# This is always the case in graph mode. It can also be the case in eager
# mode when all inputs can be traced back to `keras.Input()` (when building
# models using the functional API).
build_graph = tf_utils.are_all_symbolic_tensors(input_list)
# If `training` argument was not explicitly passed, propagate `training`
# value from this layer's calling layer.
training_arg_passed_by_framework = False
# Priority 1: `training` was explicitly passed.
if self._call_arg_was_passed('training', args, kwargs):
training_value = self._get_call_arg_value('training', args, kwargs)
if not self._expects_training_arg:
kwargs.pop('training')
else:
training_value = None
# Priority 2: `training` was passed to a parent layer.
if call_context.training is not None:
training_value = call_context.training
# Priority 3a: `learning_phase()` has been set.
elif backend.global_learning_phase_is_set():
training_value = backend.learning_phase()
# Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph.
elif build_graph:
with backend.get_graph().as_default():
if base_layer_utils.is_in_keras_graph():
training_value = backend.learning_phase()
if self._expects_training_arg and training_value is not None:
kwargs['training'] = training_value
training_arg_passed_by_framework = True
# Only create Keras history if at least one tensor originates from a
# `keras.Input`. Otherwise this Layer may be being used outside the Keras
@ -599,12 +628,12 @@ class Layer(module.Module):
# Clear eager losses on top level model call.
# We are clearing the losses only on the top level model call and not on
# every layer/mode call because layer/model may be reused.
# every layer/model call because layer/model may be reused.
if (base_layer_utils.is_in_eager_or_tf_function() and
not base_layer_utils.call_context().in_call):
not call_context.in_call):
self._clear_losses()
with base_layer_utils.call_context().enter(self, inputs, build_graph):
with call_context.enter(self, inputs, build_graph, training_value):
# Check input assumptions set after layer building, e.g. input shape.
if build_graph:
# Symbolic execution on symbolic tensors. We will attempt to build
@ -636,21 +665,6 @@ class Layer(module.Module):
else:
call_fn = self.call
# Explicitly pass the learning phase placeholder to `call` if
# the `training` argument was left unspecified by the user.
# This behavior is restricted to the managed Keras FuncGraph.
# TODO(omalleyt): Reconcile this with new `trainable` behavior
# when available.
learning_phase_passed_by_framework = False
if (base_layer_utils.is_in_keras_graph() and
self._expects_training_arg):
training_arg = None
if self._call_arg_was_passed('training', args, kwargs):
training_arg = self._get_call_arg_value('training', args, kwargs)
if training_arg is None:
learning_phase_passed_by_framework = True
kwargs['training'] = backend.learning_phase()
if not self.dynamic:
try:
with base_layer_utils.autocast_context_manager(
@ -692,7 +706,7 @@ class Layer(module.Module):
'Tensor or a list of Tensors, not None '
'(layer: ' + self.name + ').')
if base_layer_utils.have_all_keras_metadata(inputs):
if learning_phase_passed_by_framework:
if training_arg_passed_by_framework:
kwargs.pop('training')
inputs, outputs = self._set_connectivity_metadata_(
inputs, outputs, args, kwargs)
@ -2131,6 +2145,11 @@ class Layer(module.Module):
return all_args[1:]
return all_args
@property
@tracking.cached_per_instance
def _call_accepts_kwargs(self):
return tf_inspect.getfullargspec(self.call).varkw is not None
@property
@tracking.cached_per_instance
def _should_compute_mask(self):

View File

@ -378,6 +378,7 @@ class CallContext(object):
frozen: Whether currently executing inside a `Layer` with `trainable` set to
`False`.
in_call: Whether currently inside the `call` of a Layer.
training: Whether currently executing in training or inference mode.
in_keras_graph: Whether executing inside the Keras Graph.
"""
@ -386,25 +387,29 @@ class CallContext(object):
self.inputs = None
self.frozen = False
self.in_call = False
self.training = None
self._in_keras_graph = False
@tf_contextlib.contextmanager
def enter(self, layer, inputs, build_graph):
def enter(self, layer, inputs, build_graph, training):
"""Push a Layer and its inputs and state onto the current call context."""
prev_layer = self.layer
prev_inputs = self.inputs
prev_frozen = self.frozen
prev_in_call = self.in_call
prev_training = self.training
prev_in_keras_graph = self._in_keras_graph
self.layer = layer
self.inputs = inputs
self.frozen = self.frozen or not layer.trainable
self.in_call = True
self.training = training
self._in_keras_graph = (
self._in_keras_graph or
(build_graph and
getattr(backend.get_graph(), 'name', None) == 'keras_graph'))
try:
yield
finally:
@ -412,6 +417,7 @@ class CallContext(object):
self.inputs = prev_inputs
self.frozen = prev_frozen
self.in_call = prev_in_call
self.training = prev_training
self._in_keras_graph = prev_in_keras_graph
@property

View File

@ -372,8 +372,10 @@ class Network(base_layer.Layer):
def _init_subclassed_network(self, name=None, **kwargs):
self._base_init(name=name, **kwargs)
self._is_graph_network = False
self._expects_training_arg = 'training' in self._call_fn_args
self._expects_mask_arg = 'mask' in self._call_fn_args
self._expects_training_arg = ('training' in self._call_fn_args or
self._call_accepts_kwargs)
self._expects_mask_arg = ('mask' in self._call_fn_args or
self._call_accepts_kwargs)
call_argspec = tf_inspect.getfullargspec(self.call)
self._call_convention = self._determine_call_convention(call_argspec)
self.outputs = []

View File

@ -2554,7 +2554,17 @@ class Model(network.Network):
inputs = self._set_input_attrs(inputs)
if outputs is None:
kwargs = {'training': training} if self._expects_training_arg else {}
kwargs = {}
if self._expects_training_arg:
# In V2 mode, feeding `training=None` is not allowed because any value
# explicitly passed by the user is respected, even `None`, and in this
# case if the user has not passed a value in V2 we need to replace
# `None` with the `learning_phase()`. In V1, `training=None` is needed
# so that `Dropout` and `BatchNormalization` replace `None` values with
# the `learning_phase()` in their `call`.
if (training is not None or
not ops.executing_eagerly_outside_functions()):
kwargs['training'] = training
try:
outputs = self(inputs, **kwargs)
except NotImplementedError:

View File

@ -188,8 +188,9 @@ def model_iteration(model,
should_set_learning_phase = context.executing_eagerly() and model.run_eagerly
if should_set_learning_phase:
old_learning_phase = backend.learning_phase()
backend.set_eager_learning_phase(1 if mode == ModeKeys.TRAIN else 0)
learning_phase_scope = backend.eager_learning_phase_scope(
1 if mode == ModeKeys.TRAIN else 0)
learning_phase_scope.__enter__()
callbacks.model.stop_training = False
callbacks._call_begin_hook(mode)
@ -341,7 +342,7 @@ def model_iteration(model,
enqueuer.stop()
if should_set_learning_phase:
backend.set_eager_learning_phase(old_learning_phase)
learning_phase_scope.__exit__(None, None, None)
if mode == ModeKeys.TRAIN:
return model.history

View File

@ -610,7 +610,7 @@ class RNNTest(keras_parameterized.TestCase):
update_2 = state_ops.assign_add(cells[0].kernel,
array_ops.ones_like(cells[0].kernel))
# TODO(b/128682878): Remove when RNNCells are __call__'d.
with base_layer_utils.call_context().enter(layer, x, True):
with base_layer_utils.call_context().enter(layer, x, True, None):
cells[0].add_update(update_1, inputs=x)
cells[0].add_update(update_2)
self.assertEqual(len(layer.updates), 2)

View File

@ -611,7 +611,7 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase):
assert not layer.get_updates_for(None)
assert not layer.get_updates_for(x)
# TODO(b/128684069): Remove when Wrapper sublayers are __call__'d.
with base_layer_utils.call_context().enter(layer, x, True):
with base_layer_utils.call_context().enter(layer, x, True, None):
layer.forward_layer.add_update(x_reachable_update, inputs=x)
layer.forward_layer.add_update(1, inputs=None)
layer.backward_layer.add_update(x_reachable_update, inputs=x)

View File

@ -843,7 +843,7 @@ def _wrap_layer_functions(layer, serialization_cache,
# Manually trigger traces before restoring the overwritten functions. The
# functions are traced within the layer call context to ensure that layer
# functions (e.g. add_loss) behave as though running in graph mode.
with base_layer_utils.call_context().enter(layer, None, build_graph=True):
with base_layer_utils.call_context().enter(layer, None, True, None):
for fn in fns.values():
if fn is not None and fn.input_signature is not None:
fn.get_concrete_function()