Fix a critical breakage in `training` argument default value in inference for
layers with a default of `training=True` called in e.g. a Sequential container. PiperOrigin-RevId: 318145694 Change-Id: I1af5286824e3a45e1a7d1b8a4fadd7ec223895dc
This commit is contained in:
parent
890eae3e88
commit
91bdadc08e
|
@ -2891,9 +2891,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
self._expects_training_arg = ('training' in call_fn_args or
|
||||
self._call_accepts_kwargs)
|
||||
# The default training arg will be any (non-None) default specified in the
|
||||
# method signature, or `False` if no non-None default is specified.
|
||||
# method signature, or None if no value is specified.
|
||||
self._default_training_arg = self._call_fn_arg_defaults.get(
|
||||
'training') or False
|
||||
'training')
|
||||
self._expects_mask_arg = ('mask' in call_fn_args or
|
||||
self._call_accepts_kwargs)
|
||||
|
||||
|
|
|
@ -650,6 +650,17 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||
else:
|
||||
return self._nested_layer(inputs) * 0.5
|
||||
|
||||
class CustomLayerDefaultTrainingNone(base_layer.Layer):
|
||||
|
||||
def __init__(self, nested_layer=None):
|
||||
self._nested_layer = nested_layer or array_ops.identity
|
||||
|
||||
def call(self, inputs, training=None):
|
||||
if training:
|
||||
return self._nested_layer(inputs)
|
||||
else:
|
||||
return self._nested_layer(inputs) * 0.5
|
||||
|
||||
class CustomLayerDefaultTrainingFalse(base_layer.Layer):
|
||||
|
||||
def __init__(self, nested_layer=None):
|
||||
|
@ -701,21 +712,30 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||
# Outer layers/models should set the training context implicitly for all
|
||||
# nested layers, respecting whatever mode the outer layer was run with.
|
||||
layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse())
|
||||
self.assertAllEqual(layer(x), x)
|
||||
# No outer value passed: use local defaults
|
||||
self.assertAllEqual(layer(x), x * 0.25) # Use local default False
|
||||
# Outer value passed: override local defaults
|
||||
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
||||
self.assertAllEqual(layer(x, training=True), x)
|
||||
|
||||
layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue())
|
||||
self.assertAllEqual(layer(x), x * 0.25)
|
||||
# No outer value passed: use local defaults
|
||||
self.assertAllEqual(layer(x), x) # Use local default True
|
||||
# Outer value passed: override local defaults
|
||||
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
||||
self.assertAllEqual(layer(x, training=True), x)
|
||||
|
||||
# If the outer layer `call` doesn't take a training argument at all,
|
||||
# it'll set the nested scope as inference when no training arg is passed in.
|
||||
# it'll set the nested scope as None when no training arg is passed in.
|
||||
# If a training arg is passed in it won't use it directly in `call`, but
|
||||
# it will set the nested training mode.
|
||||
layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue())
|
||||
self.assertAllEqual(layer(x), x * 0.5)
|
||||
self.assertAllEqual(layer(x), x) # Use local default True
|
||||
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
||||
self.assertAllEqual(layer(x, training=True), x)
|
||||
|
||||
layer = CustomLayerDefaultTrainingNone(CustomLayerDefaultTrainingTrue())
|
||||
self.assertAllEqual(layer(x), x) # Use local default True
|
||||
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
||||
self.assertAllEqual(layer(x, training=True), x)
|
||||
|
||||
|
|
|
@ -2116,13 +2116,13 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
|||
|
||||
if context.executing_eagerly():
|
||||
# In v2, construction still works when no `training` is specified
|
||||
# When no value passed during construction, it uses the runtime value.
|
||||
# When no value passed during construction, it uses the local default.
|
||||
inputs = input_layer_lib.Input(10)
|
||||
outputs = my_layer(inputs)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||
self.assertAllEqual(network(x), _call(x, False))
|
||||
self.assertAllEqual(network(x), _call(x, True)) # Use local default
|
||||
|
||||
# `None` value passed positionally during construction is ignored at runtime
|
||||
inputs = input_layer_lib.Input(10)
|
||||
|
@ -2131,7 +2131,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
|||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||
if context.executing_eagerly():
|
||||
self.assertAllEqual(network(x), _call(x, False))
|
||||
self.assertAllEqual(network(x), _call(x, True)) # Use local default
|
||||
else:
|
||||
# in v1 training would have defaulted to using the `None` inside the layer
|
||||
# if training is not passed at runtime
|
||||
|
@ -2144,7 +2144,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
|||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||
if context.executing_eagerly():
|
||||
self.assertAllEqual(network(x), _call(x, False))
|
||||
self.assertAllEqual(network(x), _call(x, True)) # Use local default
|
||||
else:
|
||||
# in v1 training would have defaulted to using the `None` inside the layer
|
||||
# if training is not passed at runtime
|
||||
|
|
|
@ -26,6 +26,7 @@ from tensorflow.python.framework import errors
|
|||
from tensorflow.python.framework import test_util as tf_test_util
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import sequential
|
||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||
from tensorflow.python.ops import gen_stateful_random_ops
|
||||
|
@ -1273,5 +1274,38 @@ class RandomWidthTest(keras_parameterized.TestCase):
|
|||
self.assertEqual(layer_1.name, layer.name)
|
||||
|
||||
|
||||
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||
class LearningPhaseTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_plain_call(self):
|
||||
layer = image_preprocessing.RandomWidth(.5, seed=123)
|
||||
shape = (12, 12, 3)
|
||||
img = np.random.random((12,) + shape)
|
||||
out = layer(img) # Default to training=True
|
||||
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||
|
||||
out = layer(img, training=True)
|
||||
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||
|
||||
out = layer(img, training=False)
|
||||
self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||
|
||||
def test_call_in_container(self):
|
||||
layer1 = image_preprocessing.RandomWidth(.5, seed=123)
|
||||
layer2 = image_preprocessing.RandomHeight(.5, seed=123)
|
||||
seq = sequential.Sequential([layer1, layer2])
|
||||
|
||||
shape = (12, 12, 3)
|
||||
img = np.random.random((12,) + shape)
|
||||
out = seq(img) # Default to training=True
|
||||
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||
|
||||
out = seq(img, training=True)
|
||||
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||
|
||||
out = seq(img, training=False)
|
||||
self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue