From 89b1e717c04d9196ac5e53312aa8f57c030eae6a Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Fri, 7 Jun 2019 00:07:50 -0700 Subject: [PATCH] Propagate `training` value from parent layer to child layer as the default in custom training loops. `training` arugment will be passed to a Layer's `call` in this order of priority: 1) A value explicitly passed to `__call__`. 2) The value passed to the parent Layer (the layer that calls this Layer). 3) The currently active `learning_phase` value, if it has been set by the user or if using `fit`/`evalute`/`predict`. 4) The default value in the Layer's `call`. PiperOrigin-RevId: 252000169 --- tensorflow/python/keras/backend.py | 30 +++--- .../python/keras/custom_training_loop_test.py | 97 +++++++++++++++++++ tensorflow/python/keras/engine/base_layer.py | 71 +++++++++----- .../python/keras/engine/base_layer_utils.py | 8 +- tensorflow/python/keras/engine/network.py | 6 +- tensorflow/python/keras/engine/training.py | 12 ++- .../python/keras/engine/training_generator.py | 7 +- .../python/keras/layers/recurrent_test.py | 2 +- .../python/keras/layers/wrappers_test.py | 2 +- tensorflow/python/keras/saving/saved_model.py | 2 +- 10 files changed, 185 insertions(+), 52 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 04fec4c84cc..8f407809479 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -293,6 +293,10 @@ def learning_phase(): return symbolic_learning_phase() +def global_learning_phase_is_set(): + return _DUMMY_EAGER_GRAPH in _GRAPH_LEARNING_PHASES + + def symbolic_learning_phase(): graph = get_graph() with graph.as_default(): @@ -325,18 +329,6 @@ def set_learning_phase(value): _GRAPH_LEARNING_PHASES[get_graph()] = value -def set_eager_learning_phase(value): - """Internal utility that sets the learning phase in eager execution only. - - Arguments: - value: Learning phase value, either 0 or 1 (integers). - """ - global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned - assert value in {0, 1} - assert context.executing_eagerly() - _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value - - @keras_export('keras.backend.learning_phase_scope') @tf_contextlib.contextmanager def learning_phase_scope(value): @@ -381,9 +373,10 @@ def learning_phase_scope(value): elif graph in _GRAPH_LEARNING_PHASES: del _GRAPH_LEARNING_PHASES[graph] + @tf_contextlib.contextmanager def eager_learning_phase_scope(value): - """Internal scope that sets the learning phase in eager execution only. + """Internal scope that sets the learning phase in eager / tf.function only. Arguments: value: Learning phase value, either 0 or 1 (integers). @@ -397,13 +390,18 @@ def eager_learning_phase_scope(value): global _GRAPH_LEARNING_PHASES # pylint: disable=global-variable-not-assigned assert value in {0, 1} assert ops.executing_eagerly_outside_functions() - previous_value = learning_phase() + global_learning_phase_was_set = global_learning_phase_is_set() + if global_learning_phase_was_set: + previous_value = learning_phase() try: _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value yield finally: - # Restore learning phase to initial value. - _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value + # Restore learning phase to initial value or unset. + if global_learning_phase_was_set: + _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value + else: + del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] def _current_graph(op_input_list): diff --git a/tensorflow/python/keras/custom_training_loop_test.py b/tensorflow/python/keras/custom_training_loop_test.py index d2b82c8a55f..5b3310b2b40 100644 --- a/tensorflow/python/keras/custom_training_loop_test.py +++ b/tensorflow/python/keras/custom_training_loop_test.py @@ -58,6 +58,16 @@ class LayerWithMetrics(keras.layers.Layer): return inputs +class LayerWithTrainingArg(keras.layers.Layer): + + def call(self, inputs, training=None): + self.training = training + if training: + return inputs + else: + return 0. * inputs + + def add_loss_step(defun): optimizer = keras.optimizer_v2.adam.Adam() model = testing_utils.get_model_from_layers([LayerWithLosses()], @@ -142,6 +152,93 @@ class CustomTrainingLoopTest(keras_parameterized.TestCase): fn_result = train_step(defun=True) self.assertAllClose(eager_result, fn_result) + @parameterized.named_parameters(('eager', False), ('defun', True)) + def test_training_arg_propagation(self, defun): + + model = testing_utils.get_model_from_layers([LayerWithTrainingArg()], + input_shape=(1,)) + + def train_step(x): + return model(x), model(x, training=False), model(x, training=True) + + if defun: + train_step = def_function.function(train_step) + + x = array_ops.ones((1, 1)) + results = train_step(x) + self.assertAllClose(results[0], array_ops.zeros((1, 1))) + self.assertAllClose(results[1], array_ops.zeros((1, 1))) + self.assertAllClose(results[2], array_ops.ones((1, 1))) + + @parameterized.named_parameters(('eager', False), ('defun', True)) + def test_learning_phase_propagation(self, defun): + + class MyModel(keras.layers.Layer): + + def __init__(self): + super(MyModel, self).__init__() + self.layer = LayerWithTrainingArg() + + def call(self, inputs): + return self.layer(inputs) + + model = MyModel() + + def train_step(x): + no_learning_phase_out = model(x) + self.assertIsNone(model.layer.training) + with keras.backend.learning_phase_scope(0): + inf_learning_phase_out = model(x) + self.assertEqual(model.layer.training, 0) + with keras.backend.learning_phase_scope(1): + train_learning_phase_out = model(x) + self.assertEqual(model.layer.training, 1) + return [ + no_learning_phase_out, inf_learning_phase_out, + train_learning_phase_out + ] + + if defun: + train_step = def_function.function(train_step) + + x = array_ops.ones((1, 1)) + results = train_step(x) + self.assertAllClose(results[0], array_ops.zeros((1, 1))) + self.assertAllClose(results[1], array_ops.zeros((1, 1))) + self.assertAllClose(results[2], array_ops.ones((1, 1))) + + @parameterized.named_parameters(('eager', False), ('defun', True)) + def test_training_arg_priorities(self, defun): + + class MyModel(keras.layers.Layer): + + def __init__(self): + super(MyModel, self).__init__() + self.layer = LayerWithTrainingArg() + + def call(self, inputs, training=False): + return self.layer(inputs) + + model = MyModel() + + def train_step(x): + explicit_out = model(x, training=True) + default_out = model(x) + with keras.backend.learning_phase_scope(1): + parent_out = model(x, training=False) + lr_out = model(x) + return [explicit_out, default_out, parent_out, lr_out] + + if defun: + train_step = def_function.function(train_step) + + x = array_ops.ones((1, 1)) + results = train_step(x) + self.assertAllClose(results[0], array_ops.ones((1, 1))) + self.assertAllClose(results[1], array_ops.zeros((1, 1))) + self.assertAllClose(results[2], array_ops.zeros((1, 1))) + self.assertAllClose(results[3], array_ops.ones((1, 1))) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 274f32d8800..d33a38d1ebf 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -212,8 +212,10 @@ class Layer(module.Module): self._outbound_nodes = [] call_fn_args = self._call_fn_args - self._expects_training_arg = 'training' in call_fn_args - self._expects_mask_arg = 'mask' in call_fn_args + self._expects_training_arg = ('training' in call_fn_args or + self._call_accepts_kwargs) + self._expects_mask_arg = ('mask' in call_fn_args or + self._call_accepts_kwargs) # Whether the `call` method can be used to build a TF graph without issues. self._dynamic = dynamic @@ -563,8 +565,15 @@ class Layer(module.Module): Raises: ValueError: if the layer's `call` method returns None (an invalid value). """ + call_context = base_layer_utils.call_context() input_list = nest.flatten(inputs) + # We will attempt to build a TF graph if & only if all inputs are symbolic. + # This is always the case in graph mode. It can also be the case in eager + # mode when all inputs can be traced back to `keras.Input()` (when building + # models using the functional API). + build_graph = tf_utils.are_all_symbolic_tensors(input_list) + # Accept NumPy and scalar inputs by converting to Tensors. if any(isinstance(x, (np.ndarray, float, int)) for x in input_list): def _convert_non_tensor(x): @@ -585,11 +594,31 @@ class Layer(module.Module): not self._call_arg_was_passed('mask', args, kwargs)): kwargs['mask'] = input_masks - # We will attempt to build a TF graph if & only if all inputs are symbolic. - # This is always the case in graph mode. It can also be the case in eager - # mode when all inputs can be traced back to `keras.Input()` (when building - # models using the functional API). - build_graph = tf_utils.are_all_symbolic_tensors(input_list) + # If `training` argument was not explicitly passed, propagate `training` + # value from this layer's calling layer. + training_arg_passed_by_framework = False + # Priority 1: `training` was explicitly passed. + if self._call_arg_was_passed('training', args, kwargs): + training_value = self._get_call_arg_value('training', args, kwargs) + if not self._expects_training_arg: + kwargs.pop('training') + else: + training_value = None + # Priority 2: `training` was passed to a parent layer. + if call_context.training is not None: + training_value = call_context.training + # Priority 3a: `learning_phase()` has been set. + elif backend.global_learning_phase_is_set(): + training_value = backend.learning_phase() + # Priority 3b: Pass the `learning_phase()` if in the Keras FuncGraph. + elif build_graph: + with backend.get_graph().as_default(): + if base_layer_utils.is_in_keras_graph(): + training_value = backend.learning_phase() + + if self._expects_training_arg and training_value is not None: + kwargs['training'] = training_value + training_arg_passed_by_framework = True # Only create Keras history if at least one tensor originates from a # `keras.Input`. Otherwise this Layer may be being used outside the Keras @@ -599,12 +628,12 @@ class Layer(module.Module): # Clear eager losses on top level model call. # We are clearing the losses only on the top level model call and not on - # every layer/mode call because layer/model may be reused. + # every layer/model call because layer/model may be reused. if (base_layer_utils.is_in_eager_or_tf_function() and - not base_layer_utils.call_context().in_call): + not call_context.in_call): self._clear_losses() - with base_layer_utils.call_context().enter(self, inputs, build_graph): + with call_context.enter(self, inputs, build_graph, training_value): # Check input assumptions set after layer building, e.g. input shape. if build_graph: # Symbolic execution on symbolic tensors. We will attempt to build @@ -636,21 +665,6 @@ class Layer(module.Module): else: call_fn = self.call - # Explicitly pass the learning phase placeholder to `call` if - # the `training` argument was left unspecified by the user. - # This behavior is restricted to the managed Keras FuncGraph. - # TODO(omalleyt): Reconcile this with new `trainable` behavior - # when available. - learning_phase_passed_by_framework = False - if (base_layer_utils.is_in_keras_graph() and - self._expects_training_arg): - training_arg = None - if self._call_arg_was_passed('training', args, kwargs): - training_arg = self._get_call_arg_value('training', args, kwargs) - if training_arg is None: - learning_phase_passed_by_framework = True - kwargs['training'] = backend.learning_phase() - if not self.dynamic: try: with base_layer_utils.autocast_context_manager( @@ -692,7 +706,7 @@ class Layer(module.Module): 'Tensor or a list of Tensors, not None ' '(layer: ' + self.name + ').') if base_layer_utils.have_all_keras_metadata(inputs): - if learning_phase_passed_by_framework: + if training_arg_passed_by_framework: kwargs.pop('training') inputs, outputs = self._set_connectivity_metadata_( inputs, outputs, args, kwargs) @@ -2131,6 +2145,11 @@ class Layer(module.Module): return all_args[1:] return all_args + @property + @tracking.cached_per_instance + def _call_accepts_kwargs(self): + return tf_inspect.getfullargspec(self.call).varkw is not None + @property @tracking.cached_per_instance def _should_compute_mask(self): diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index 83dfa5ca46f..f0f0e495743 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -378,6 +378,7 @@ class CallContext(object): frozen: Whether currently executing inside a `Layer` with `trainable` set to `False`. in_call: Whether currently inside the `call` of a Layer. + training: Whether currently executing in training or inference mode. in_keras_graph: Whether executing inside the Keras Graph. """ @@ -386,25 +387,29 @@ class CallContext(object): self.inputs = None self.frozen = False self.in_call = False + self.training = None self._in_keras_graph = False @tf_contextlib.contextmanager - def enter(self, layer, inputs, build_graph): + def enter(self, layer, inputs, build_graph, training): """Push a Layer and its inputs and state onto the current call context.""" prev_layer = self.layer prev_inputs = self.inputs prev_frozen = self.frozen prev_in_call = self.in_call + prev_training = self.training prev_in_keras_graph = self._in_keras_graph self.layer = layer self.inputs = inputs self.frozen = self.frozen or not layer.trainable self.in_call = True + self.training = training self._in_keras_graph = ( self._in_keras_graph or (build_graph and getattr(backend.get_graph(), 'name', None) == 'keras_graph')) + try: yield finally: @@ -412,6 +417,7 @@ class CallContext(object): self.inputs = prev_inputs self.frozen = prev_frozen self.in_call = prev_in_call + self.training = prev_training self._in_keras_graph = prev_in_keras_graph @property diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py index 16dc11116a9..ea928986c1a 100644 --- a/tensorflow/python/keras/engine/network.py +++ b/tensorflow/python/keras/engine/network.py @@ -372,8 +372,10 @@ class Network(base_layer.Layer): def _init_subclassed_network(self, name=None, **kwargs): self._base_init(name=name, **kwargs) self._is_graph_network = False - self._expects_training_arg = 'training' in self._call_fn_args - self._expects_mask_arg = 'mask' in self._call_fn_args + self._expects_training_arg = ('training' in self._call_fn_args or + self._call_accepts_kwargs) + self._expects_mask_arg = ('mask' in self._call_fn_args or + self._call_accepts_kwargs) call_argspec = tf_inspect.getfullargspec(self.call) self._call_convention = self._determine_call_convention(call_argspec) self.outputs = [] diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py index 02df39ff418..35870e46d88 100644 --- a/tensorflow/python/keras/engine/training.py +++ b/tensorflow/python/keras/engine/training.py @@ -2554,7 +2554,17 @@ class Model(network.Network): inputs = self._set_input_attrs(inputs) if outputs is None: - kwargs = {'training': training} if self._expects_training_arg else {} + kwargs = {} + if self._expects_training_arg: + # In V2 mode, feeding `training=None` is not allowed because any value + # explicitly passed by the user is respected, even `None`, and in this + # case if the user has not passed a value in V2 we need to replace + # `None` with the `learning_phase()`. In V1, `training=None` is needed + # so that `Dropout` and `BatchNormalization` replace `None` values with + # the `learning_phase()` in their `call`. + if (training is not None or + not ops.executing_eagerly_outside_functions()): + kwargs['training'] = training try: outputs = self(inputs, **kwargs) except NotImplementedError: diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py index ce976b1847d..facbcd08e8c 100644 --- a/tensorflow/python/keras/engine/training_generator.py +++ b/tensorflow/python/keras/engine/training_generator.py @@ -188,8 +188,9 @@ def model_iteration(model, should_set_learning_phase = context.executing_eagerly() and model.run_eagerly if should_set_learning_phase: - old_learning_phase = backend.learning_phase() - backend.set_eager_learning_phase(1 if mode == ModeKeys.TRAIN else 0) + learning_phase_scope = backend.eager_learning_phase_scope( + 1 if mode == ModeKeys.TRAIN else 0) + learning_phase_scope.__enter__() callbacks.model.stop_training = False callbacks._call_begin_hook(mode) @@ -341,7 +342,7 @@ def model_iteration(model, enqueuer.stop() if should_set_learning_phase: - backend.set_eager_learning_phase(old_learning_phase) + learning_phase_scope.__exit__(None, None, None) if mode == ModeKeys.TRAIN: return model.history diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py index 43ddb9d84e0..55233120a09 100644 --- a/tensorflow/python/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/layers/recurrent_test.py @@ -610,7 +610,7 @@ class RNNTest(keras_parameterized.TestCase): update_2 = state_ops.assign_add(cells[0].kernel, array_ops.ones_like(cells[0].kernel)) # TODO(b/128682878): Remove when RNNCells are __call__'d. - with base_layer_utils.call_context().enter(layer, x, True): + with base_layer_utils.call_context().enter(layer, x, True, None): cells[0].add_update(update_1, inputs=x) cells[0].add_update(update_2) self.assertEqual(len(layer.updates), 2) diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index 466b6af2af8..c11211807bd 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -611,7 +611,7 @@ class BidirectionalTest(test.TestCase, parameterized.TestCase): assert not layer.get_updates_for(None) assert not layer.get_updates_for(x) # TODO(b/128684069): Remove when Wrapper sublayers are __call__'d. - with base_layer_utils.call_context().enter(layer, x, True): + with base_layer_utils.call_context().enter(layer, x, True, None): layer.forward_layer.add_update(x_reachable_update, inputs=x) layer.forward_layer.add_update(1, inputs=None) layer.backward_layer.add_update(x_reachable_update, inputs=x) diff --git a/tensorflow/python/keras/saving/saved_model.py b/tensorflow/python/keras/saving/saved_model.py index 92e53882b57..441cb02b555 100644 --- a/tensorflow/python/keras/saving/saved_model.py +++ b/tensorflow/python/keras/saving/saved_model.py @@ -843,7 +843,7 @@ def _wrap_layer_functions(layer, serialization_cache, # Manually trigger traces before restoring the overwritten functions. The # functions are traced within the layer call context to ensure that layer # functions (e.g. add_loss) behave as though running in graph mode. - with base_layer_utils.call_context().enter(layer, None, build_graph=True): + with base_layer_utils.call_context().enter(layer, None, True, None): for fn in fns.values(): if fn is not None and fn.input_signature is not None: fn.get_concrete_function()