Fix a critical breakage in `training` argument default value in inference for

layers with a default of `training=True` called in e.g. a Sequential container.

PiperOrigin-RevId: 318145694
Change-Id: I1af5286824e3a45e1a7d1b8a4fadd7ec223895dc
This commit is contained in:
Francois Chollet 2020-06-24 14:49:53 -07:00 committed by Geeta Chavan
parent 890eae3e88
commit 91bdadc08e
4 changed files with 64 additions and 10 deletions

View File

@ -2891,9 +2891,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
self._expects_training_arg = ('training' in call_fn_args or
self._call_accepts_kwargs)
# The default training arg will be any (non-None) default specified in the
# method signature, or `False` if no non-None default is specified.
# method signature, or None if no value is specified.
self._default_training_arg = self._call_fn_arg_defaults.get(
'training') or False
'training')
self._expects_mask_arg = ('mask' in call_fn_args or
self._call_accepts_kwargs)

View File

@ -650,6 +650,17 @@ class BaseLayerTest(keras_parameterized.TestCase):
else:
return self._nested_layer(inputs) * 0.5
class CustomLayerDefaultTrainingNone(base_layer.Layer):
def __init__(self, nested_layer=None):
self._nested_layer = nested_layer or array_ops.identity
def call(self, inputs, training=None):
if training:
return self._nested_layer(inputs)
else:
return self._nested_layer(inputs) * 0.5
class CustomLayerDefaultTrainingFalse(base_layer.Layer):
def __init__(self, nested_layer=None):
@ -701,21 +712,30 @@ class BaseLayerTest(keras_parameterized.TestCase):
# Outer layers/models should set the training context implicitly for all
# nested layers, respecting whatever mode the outer layer was run with.
layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse())
self.assertAllEqual(layer(x), x)
# No outer value passed: use local defaults
self.assertAllEqual(layer(x), x * 0.25) # Use local default False
# Outer value passed: override local defaults
self.assertAllEqual(layer(x, training=False), x * 0.25)
self.assertAllEqual(layer(x, training=True), x)
layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue())
self.assertAllEqual(layer(x), x * 0.25)
# No outer value passed: use local defaults
self.assertAllEqual(layer(x), x) # Use local default True
# Outer value passed: override local defaults
self.assertAllEqual(layer(x, training=False), x * 0.25)
self.assertAllEqual(layer(x, training=True), x)
# If the outer layer `call` doesn't take a training argument at all,
# it'll set the nested scope as inference when no training arg is passed in.
# it'll set the nested scope as None when no training arg is passed in.
# If a training arg is passed in it won't use it directly in `call`, but
# it will set the nested training mode.
layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue())
self.assertAllEqual(layer(x), x * 0.5)
self.assertAllEqual(layer(x), x) # Use local default True
self.assertAllEqual(layer(x, training=False), x * 0.5)
self.assertAllEqual(layer(x, training=True), x)
layer = CustomLayerDefaultTrainingNone(CustomLayerDefaultTrainingTrue())
self.assertAllEqual(layer(x), x) # Use local default True
self.assertAllEqual(layer(x, training=False), x * 0.5)
self.assertAllEqual(layer(x, training=True), x)

View File

@ -2116,13 +2116,13 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
if context.executing_eagerly():
# In v2, construction still works when no `training` is specified
# When no value passed during construction, it uses the runtime value.
# When no value passed during construction, it uses the local default.
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs)
network = functional.Functional(inputs, outputs)
self.assertAllEqual(network(x, training=True), _call(x, True))
self.assertAllEqual(network(x, training=False), _call(x, False))
self.assertAllEqual(network(x), _call(x, False))
self.assertAllEqual(network(x), _call(x, True)) # Use local default
# `None` value passed positionally during construction is ignored at runtime
inputs = input_layer_lib.Input(10)
@ -2131,7 +2131,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
self.assertAllEqual(network(x, training=True), _call(x, True))
self.assertAllEqual(network(x, training=False), _call(x, False))
if context.executing_eagerly():
self.assertAllEqual(network(x), _call(x, False))
self.assertAllEqual(network(x), _call(x, True)) # Use local default
else:
# in v1 training would have defaulted to using the `None` inside the layer
# if training is not passed at runtime
@ -2144,7 +2144,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
self.assertAllEqual(network(x, training=True), _call(x, True))
self.assertAllEqual(network(x, training=False), _call(x, False))
if context.executing_eagerly():
self.assertAllEqual(network(x), _call(x, False))
self.assertAllEqual(network(x), _call(x, True)) # Use local default
else:
# in v1 training would have defaulted to using the `None` inside the layer
# if training is not passed at runtime

View File

@ -26,6 +26,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import sequential
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
from tensorflow.python.ops import gen_stateful_random_ops
@ -1273,5 +1274,38 @@ class RandomWidthTest(keras_parameterized.TestCase):
self.assertEqual(layer_1.name, layer.name)
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class LearningPhaseTest(keras_parameterized.TestCase):
def test_plain_call(self):
layer = image_preprocessing.RandomWidth(.5, seed=123)
shape = (12, 12, 3)
img = np.random.random((12,) + shape)
out = layer(img) # Default to training=True
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
out = layer(img, training=True)
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
out = layer(img, training=False)
self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape)
def test_call_in_container(self):
layer1 = image_preprocessing.RandomWidth(.5, seed=123)
layer2 = image_preprocessing.RandomHeight(.5, seed=123)
seq = sequential.Sequential([layer1, layer2])
shape = (12, 12, 3)
img = np.random.random((12,) + shape)
out = seq(img) # Default to training=True
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
out = seq(img, training=True)
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
out = seq(img, training=False)
self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape)
if __name__ == '__main__':
test.main()