From 91bdadc08e958165262bdaf3e00005194c466715 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 24 Jun 2020 14:49:53 -0700 Subject: [PATCH] Fix a critical breakage in `training` argument default value in inference for layers with a default of `training=True` called in e.g. a Sequential container. PiperOrigin-RevId: 318145694 Change-Id: I1af5286824e3a45e1a7d1b8a4fadd7ec223895dc --- tensorflow/python/keras/engine/base_layer.py | 4 +-- .../python/keras/engine/base_layer_test.py | 28 ++++++++++++--- .../python/keras/engine/functional_test.py | 8 ++--- .../preprocessing/image_preprocessing_test.py | 34 +++++++++++++++++++ 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 97eb0447a69..1cd28a7a6e4 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -2891,9 +2891,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector): self._expects_training_arg = ('training' in call_fn_args or self._call_accepts_kwargs) # The default training arg will be any (non-None) default specified in the - # method signature, or `False` if no non-None default is specified. + # method signature, or None if no value is specified. self._default_training_arg = self._call_fn_arg_defaults.get( - 'training') or False + 'training') self._expects_mask_arg = ('mask' in call_fn_args or self._call_accepts_kwargs) diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py index 58a0799329a..559e927d603 100644 --- a/tensorflow/python/keras/engine/base_layer_test.py +++ b/tensorflow/python/keras/engine/base_layer_test.py @@ -650,6 +650,17 @@ class BaseLayerTest(keras_parameterized.TestCase): else: return self._nested_layer(inputs) * 0.5 + class CustomLayerDefaultTrainingNone(base_layer.Layer): + + def __init__(self, nested_layer=None): + self._nested_layer = nested_layer or array_ops.identity + + def call(self, inputs, training=None): + if training: + return self._nested_layer(inputs) + else: + return self._nested_layer(inputs) * 0.5 + class CustomLayerDefaultTrainingFalse(base_layer.Layer): def __init__(self, nested_layer=None): @@ -701,21 +712,30 @@ class BaseLayerTest(keras_parameterized.TestCase): # Outer layers/models should set the training context implicitly for all # nested layers, respecting whatever mode the outer layer was run with. layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse()) - self.assertAllEqual(layer(x), x) + # No outer value passed: use local defaults + self.assertAllEqual(layer(x), x * 0.25) # Use local default False + # Outer value passed: override local defaults self.assertAllEqual(layer(x, training=False), x * 0.25) self.assertAllEqual(layer(x, training=True), x) layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue()) - self.assertAllEqual(layer(x), x * 0.25) + # No outer value passed: use local defaults + self.assertAllEqual(layer(x), x) # Use local default True + # Outer value passed: override local defaults self.assertAllEqual(layer(x, training=False), x * 0.25) self.assertAllEqual(layer(x, training=True), x) # If the outer layer `call` doesn't take a training argument at all, - # it'll set the nested scope as inference when no training arg is passed in. + # it'll set the nested scope as None when no training arg is passed in. # If a training arg is passed in it won't use it directly in `call`, but # it will set the nested training mode. layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue()) - self.assertAllEqual(layer(x), x * 0.5) + self.assertAllEqual(layer(x), x) # Use local default True + self.assertAllEqual(layer(x, training=False), x * 0.5) + self.assertAllEqual(layer(x, training=True), x) + + layer = CustomLayerDefaultTrainingNone(CustomLayerDefaultTrainingTrue()) + self.assertAllEqual(layer(x), x) # Use local default True self.assertAllEqual(layer(x, training=False), x * 0.5) self.assertAllEqual(layer(x, training=True), x) diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index 0e82d95d3de..24b0e147b97 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -2116,13 +2116,13 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): if context.executing_eagerly(): # In v2, construction still works when no `training` is specified - # When no value passed during construction, it uses the runtime value. + # When no value passed during construction, it uses the local default. inputs = input_layer_lib.Input(10) outputs = my_layer(inputs) network = functional.Functional(inputs, outputs) self.assertAllEqual(network(x, training=True), _call(x, True)) self.assertAllEqual(network(x, training=False), _call(x, False)) - self.assertAllEqual(network(x), _call(x, False)) + self.assertAllEqual(network(x), _call(x, True)) # Use local default # `None` value passed positionally during construction is ignored at runtime inputs = input_layer_lib.Input(10) @@ -2131,7 +2131,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): self.assertAllEqual(network(x, training=True), _call(x, True)) self.assertAllEqual(network(x, training=False), _call(x, False)) if context.executing_eagerly(): - self.assertAllEqual(network(x), _call(x, False)) + self.assertAllEqual(network(x), _call(x, True)) # Use local default else: # in v1 training would have defaulted to using the `None` inside the layer # if training is not passed at runtime @@ -2144,7 +2144,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase): self.assertAllEqual(network(x, training=True), _call(x, True)) self.assertAllEqual(network(x, training=False), _call(x, False)) if context.executing_eagerly(): - self.assertAllEqual(network(x), _call(x, False)) + self.assertAllEqual(network(x), _call(x, True)) # Use local default else: # in v1 training would have defaulted to using the `None` inside the layer # if training is not passed at runtime diff --git a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py index 5cb7cec5b7b..f5210589b82 100644 --- a/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py +++ b/tensorflow/python/keras/layers/preprocessing/image_preprocessing_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import testing_utils +from tensorflow.python.keras.engine import sequential from tensorflow.python.keras.layers.preprocessing import image_preprocessing from tensorflow.python.keras.utils.generic_utils import CustomObjectScope from tensorflow.python.ops import gen_stateful_random_ops @@ -1273,5 +1274,38 @@ class RandomWidthTest(keras_parameterized.TestCase): self.assertEqual(layer_1.name, layer.name) +@keras_parameterized.run_all_keras_modes(always_skip_v1=True) +class LearningPhaseTest(keras_parameterized.TestCase): + + def test_plain_call(self): + layer = image_preprocessing.RandomWidth(.5, seed=123) + shape = (12, 12, 3) + img = np.random.random((12,) + shape) + out = layer(img) # Default to training=True + self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape) + + out = layer(img, training=True) + self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape) + + out = layer(img, training=False) + self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape) + + def test_call_in_container(self): + layer1 = image_preprocessing.RandomWidth(.5, seed=123) + layer2 = image_preprocessing.RandomHeight(.5, seed=123) + seq = sequential.Sequential([layer1, layer2]) + + shape = (12, 12, 3) + img = np.random.random((12,) + shape) + out = seq(img) # Default to training=True + self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape) + + out = seq(img, training=True) + self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape) + + out = seq(img, training=False) + self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape) + + if __name__ == '__main__': test.main()