diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index fbec5382a08..97eb0447a69 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -943,10 +943,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector): kwargs['mask'] = input_masks # Training mode for `Layer.call` is set via (in order of priority): - # (1) The `training` argument passed to this `Layer.call`. + # (1) The `training` argument passed to this `Layer.call`, if it is not None # (2) The training mode of an outer `Layer.call`. - # (3) The default mode set by `tf.keras.backed.set_learning_phase` (if set). - training_mode = self._set_training_mode(args, kwargs, call_context) + # (3) The default mode set by `tf.keras.backend.set_learning_phase` (if set) + # (4) Any non-None default value for `training` specified in the call + # signature + # (5) False (treating the layer as if it's in inference) + args, kwargs, training_mode = self._set_training_mode( + args, kwargs, call_context) # Losses are cleared for all sublayers on the outermost `Layer.call`. # Losses are not cleared on inner `Layer.call`s, because sublayers can be @@ -1020,7 +1024,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # propagate `training` value from this layer's calling layer. training_value = None training_arg_passed_by_framework = False - # Priority 1: `training` was explicitly passed. + # Priority 1: `training` was explicitly passed a non-None value. if self._call_arg_was_passed('training', args, kwargs): training_value = self._get_call_arg_value('training', args, kwargs) if not self._expects_training_arg: @@ -1030,17 +1034,23 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # Priority 2: `training` was passed to a parent layer. if call_context.training is not None: training_value = call_context.training - # Priority 3a: `learning_phase()` has been set. + # Priority 3: `learning_phase()` has been set. elif backend.global_learning_phase_is_set(): training_value = backend.learning_phase() - - if self._expects_training_arg and training_value is not None: # Force the training_value to be bool type which matches to the contract # for layer/model call args. if tensor_util.is_tensor(training_value): training_value = math_ops.cast(training_value, dtypes.bool) else: training_value = bool(training_value) + # Priority 4: trace layer with the default training argument specified + # in the `call` signature (or in inference mode if the `call` signature + # specifies no non-None default). + else: + training_value = self._default_training_arg + # In cases (2), (3), (4) the training argument is passed automatically + # by the framework, and will not be hard-coded into the model. + if self._expects_training_arg: args, kwargs = self._set_call_arg_value('training', training_value, args, kwargs) training_arg_passed_by_framework = True @@ -1150,6 +1160,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # (1) `training` was passed to this `Layer.call`. if self._call_arg_was_passed('training', args, kwargs): training_mode = self._get_call_arg_value('training', args, kwargs) + # If no `training` arg was passed, or `None` was explicitly passed, + # the framework will make a decision about the training mode is. if training_mode is None: call_ctx_training = call_context.training # (2) `training` mode is inferred from an outer `Layer.call`. @@ -1165,10 +1177,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector): training_mode = math_ops.cast(training_mode, dtypes.bool) else: training_mode = bool(training_mode) + # (4) We default to using `call`'s default value for `training`, + # or treating the layer as if it is in inference if no non-None default + # is specified in the `call` signature. + else: + training_mode = self._default_training_arg - # For case (2) or (3), `training` arg is passed by framework. - if training_mode is not None: - kwargs['training'] = training_mode + # For case (2), (3), (4) `training` arg is passed by framework. + args, kwargs = self._set_call_arg_value('training', training_mode, args, + kwargs) else: if 'training' in kwargs: # `training` was passed to this `Layer` but is not needed for @@ -1178,7 +1195,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): # Grab the current `training` mode from any outer `Layer.call`. training_mode = call_context.training - return training_mode + return args, kwargs, training_mode def _autographed_call(self): # Wrapping `call` function in autograph to allow for dynamic control @@ -2529,7 +2546,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector): if len(args) > arg_pos: args = list(args) args[arg_pos] = new_value - return args, kwargs + return tuple(args), kwargs if new_value is None and pop_kwarg_if_none: kwargs.pop(arg_name, None) else: @@ -2873,6 +2890,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector): call_fn_args = self._call_fn_args self._expects_training_arg = ('training' in call_fn_args or self._call_accepts_kwargs) + # The default training arg will be any (non-None) default specified in the + # method signature, or `False` if no non-None default is specified. + self._default_training_arg = self._call_fn_arg_defaults.get( + 'training') or False self._expects_mask_arg = ('mask' in call_fn_args or self._call_accepts_kwargs) @@ -2892,6 +2913,19 @@ class Layer(module.Module, version_utils.LayerVersionSelector): return all_args[1:] return all_args + @property + @tracking.cached_per_instance + def _call_fn_arg_defaults(self): + call_fn_args = self._call_fn_args + call_fn_defaults = self._call_full_argspec.defaults or [] + defaults = dict() + + # The call arg defaults are an n-tuple of the last n elements of the args + # list. (n = # of elements that have a default argument) + for i in range(-1 * len(call_fn_defaults), 0): + defaults[call_fn_args[i]] = call_fn_defaults[i] + return defaults + @property @tracking.cached_per_instance def _call_fn_arg_positions(self): diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index b861d7e4b5b..58a0799329a 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -629,6 +629,96 @@ class BaseLayerTest(keras_parameterized.TestCase): self.assertTrue(layer.built) self.assertEqual([None, 3], layer._build_input_shape.as_list()) + @combinations.generate(combinations.combine(mode=['eager'])) + def custom_layer_training_arg(self): + class CustomLayerNoTrainingArg(base_layer.Layer): + + def __init__(self, nested_layer=None): + self._nested_layer = nested_layer or array_ops.identity + + def call(self, inputs): + return self._nested_layer(inputs) + + class CustomLayerDefaultTrainingMissing(base_layer.Layer): + + def __init__(self, nested_layer=None): + self._nested_layer = nested_layer or array_ops.identity + + def call(self, inputs, training): + if training: + return self._nested_layer(inputs) + else: + return self._nested_layer(inputs) * 0.5 + + class CustomLayerDefaultTrainingFalse(base_layer.Layer): + + def __init__(self, nested_layer=None): + self._nested_layer = nested_layer or array_ops.identity + + def call(self, inputs, training=False): + if training: + return self._nested_layer(inputs) + else: + return self._nested_layer(inputs) * 0.5 + + class CustomLayerDefaultTrainingTrue(base_layer.Layer): + + def __init__(self, nested_layer=None): + self._nested_layer = nested_layer or array_ops.identity + + def call(self, inputs, training=True): + if training: + return self._nested_layer(inputs) + else: + return self._nested_layer(inputs) * 0.5 + + x = array_ops.ones(shape=(1, 1)) + + # If the layer signature doesn't specify a default training arg, + # run it in inference mode when to training arg is passed + # to __call__ + layer = CustomLayerDefaultTrainingMissing() + self.assertAllEqual(layer(x), x * 0.5) + self.assertAllEqual(layer(x, training=False), x * 0.5) + self.assertAllEqual(layer(x, training=True), x) + + # If the layer signature specifies `False` as the default training arg, + # run it in inference mode when no training arg is passed + # to __call__ + layer = CustomLayerDefaultTrainingFalse() + self.assertAllEqual(layer(x), x * 0.5) + self.assertAllEqual(layer(x, training=False), x * 0.5) + self.assertAllEqual(layer(x, training=True), x) + + # If the layer signature specifies `True` as the default training arg, + # explicitly run it in training mode when no training arg is passed + # to __call__ + layer = CustomLayerDefaultTrainingTrue() + self.assertAllEqual(layer(x), x) + self.assertAllEqual(layer(x, training=False), x * 0.5) + self.assertAllEqual(layer(x, training=True), x) + + # Outer layers/models should set the training context implicitly for all + # nested layers, respecting whatever mode the outer layer was run with. + layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse()) + self.assertAllEqual(layer(x), x) + self.assertAllEqual(layer(x, training=False), x * 0.25) + self.assertAllEqual(layer(x, training=True), x) + + layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue()) + self.assertAllEqual(layer(x), x * 0.25) + self.assertAllEqual(layer(x, training=False), x * 0.25) + self.assertAllEqual(layer(x, training=True), x) + + # If the outer layer `call` doesn't take a training argument at all, + # it'll set the nested scope as inference when no training arg is passed in. + # If a training arg is passed in it won't use it directly in `call`, but + # it will set the nested training mode. + layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue()) + self.assertAllEqual(layer(x), x * 0.5) + self.assertAllEqual(layer(x, training=False), x * 0.5) + self.assertAllEqual(layer(x, training=True), x) + def test_activity_regularizer_string(self): class MyLayer(base_layer.Layer): @@ -1387,6 +1477,7 @@ class DTypeTest(keras_parameterized.TestCase): class IdentityLayerWithArgs(base_layer.Layer): def call(self, inputs, *args, **kwargs): + kwargs.pop('training', None) return nest.flatten([inputs, args, kwargs]) layer = IdentityLayerWithArgs(dtype='float64') diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index a7e314d4a49..3c14411deb9 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -2036,43 +2036,73 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): def test_training_passed_during_construction(self): + def _call(inputs, training): + if training is None: + return inputs * -1.0 + elif training: + return inputs + else: + return inputs * 0.0 + class MyLayer(base_layer.Layer): - def call(self, x, training=None): - if training is None: - return x * -1.0 - elif training: - return x - else: - return x * 0.0 + def call(self, inputs, training=True): + return _call(inputs, training) my_layer = MyLayer() x = np.ones((1, 10)) + # Hard-coded `true` value passed during construction is respected. inputs = input_layer_lib.Input(10) outputs = my_layer(inputs, training=True) network = functional.Functional(inputs, outputs) + self.assertAllEqual(network(x, training=True), _call(x, True)) + self.assertAllEqual(network(x, training=False), _call(x, True)) + self.assertAllEqual(network(x), _call(x, True)) - # Hard-coded value passed during construction is respected. - self.assertAllEqual(network(x, training=False), x) - + # Hard-coded `false` value passed during construction is respected. inputs = input_layer_lib.Input(10) outputs = my_layer(inputs, training=False) network = functional.Functional(inputs, outputs) + self.assertAllEqual(network(x, training=True), _call(x, False)) + self.assertAllEqual(network(x, training=False), _call(x, False)) + self.assertAllEqual(network(x), _call(x, False)) - network(x, training=True) - # Hard-coded value passed during construction is respected. - self.assertAllEqual(network(x, training=True), x * 0.0) + if context.executing_eagerly(): + # In v2, construction still works when no `training` is specified + # When no value passed during construction, it uses the runtime value. + inputs = input_layer_lib.Input(10) + outputs = my_layer(inputs) + network = functional.Functional(inputs, outputs) + self.assertAllEqual(network(x, training=True), _call(x, True)) + self.assertAllEqual(network(x, training=False), _call(x, False)) + self.assertAllEqual(network(x), _call(x, False)) + # `None` value passed positionally during construction is ignored at runtime + inputs = input_layer_lib.Input(10) + outputs = my_layer(inputs, None) + network = functional.Functional(inputs, outputs) + self.assertAllEqual(network(x, training=True), _call(x, True)) + self.assertAllEqual(network(x, training=False), _call(x, False)) + if context.executing_eagerly(): + self.assertAllEqual(network(x), _call(x, False)) + else: + # in v1 training would have defaulted to using the `None` inside the layer + # if training is not passed at runtime + self.assertAllEqual(network(x), _call(x, None)) + + # `None` value passed as kwarg during construction is ignored at runtime. inputs = input_layer_lib.Input(10) outputs = my_layer(inputs, training=None) network = functional.Functional(inputs, outputs) - - # `None` value passed during construction is overridden. - self.assertAllEqual(network(x, training=True), x) - # `None` value passed during construction is overridden. - self.assertAllEqual(network(x, training=False), x * 0.0) - + self.assertAllEqual(network(x, training=True), _call(x, True)) + self.assertAllEqual(network(x, training=False), _call(x, False)) + if context.executing_eagerly(): + self.assertAllEqual(network(x), _call(x, False)) + else: + # in v1 training would have defaulted to using the `None` inside the layer + # if training is not passed at runtime + self.assertAllEqual(network(x), _call(x, None)) if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py index 4356244b292..d387a375aa2 100644 --- a/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py +++ b/tensorflow/python/keras/layers/rnn_cell_wrapper_v2.py @@ -27,6 +27,7 @@ from __future__ import print_function from tensorflow.python.keras.layers import recurrent from tensorflow.python.ops import rnn_cell_wrapper_impl +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import tf_export @@ -41,6 +42,10 @@ class _RNNCellWrapperV2(recurrent.AbstractRNNCell): def __init__(self, cell, *args, **kwargs): super(_RNNCellWrapperV2, self).__init__(*args, **kwargs) self.cell = cell + cell_call_spec = tf_inspect.getfullargspec(cell.call) + self._expects_training_arg = ("training" in cell_call_spec.args) or ( + cell_call_spec.varkw is not None + ) def call(self, inputs, state, **kwargs): """Runs the RNN cell step computation. diff --git a/tensorflow/python/keras/tests/custom_training_loop_test.py b/tensorflow/python/keras/tests/custom_training_loop_test.py index 5b3310b2b40..6291933ac99 100644 --- a/tensorflow/python/keras/tests/custom_training_loop_test.py +++ b/tensorflow/python/keras/tests/custom_training_loop_test.py @@ -186,7 +186,7 @@ class CustomTrainingLoopTest(keras_parameterized.TestCase): def train_step(x): no_learning_phase_out = model(x) - self.assertIsNone(model.layer.training) + self.assertFalse(model.layer.training) with keras.backend.learning_phase_scope(0): inf_learning_phase_out = model(x) self.assertEqual(model.layer.training, 0)