Fix a critical breakage in `training` argument default value in inference for
layers with a default of `training=True` called in e.g. a Sequential container. PiperOrigin-RevId: 318145694 Change-Id: I1af5286824e3a45e1a7d1b8a4fadd7ec223895dc
This commit is contained in:
parent
890eae3e88
commit
91bdadc08e
|
@ -2891,9 +2891,9 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||||
self._expects_training_arg = ('training' in call_fn_args or
|
self._expects_training_arg = ('training' in call_fn_args or
|
||||||
self._call_accepts_kwargs)
|
self._call_accepts_kwargs)
|
||||||
# The default training arg will be any (non-None) default specified in the
|
# The default training arg will be any (non-None) default specified in the
|
||||||
# method signature, or `False` if no non-None default is specified.
|
# method signature, or None if no value is specified.
|
||||||
self._default_training_arg = self._call_fn_arg_defaults.get(
|
self._default_training_arg = self._call_fn_arg_defaults.get(
|
||||||
'training') or False
|
'training')
|
||||||
self._expects_mask_arg = ('mask' in call_fn_args or
|
self._expects_mask_arg = ('mask' in call_fn_args or
|
||||||
self._call_accepts_kwargs)
|
self._call_accepts_kwargs)
|
||||||
|
|
||||||
|
|
|
@ -650,6 +650,17 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||||
else:
|
else:
|
||||||
return self._nested_layer(inputs) * 0.5
|
return self._nested_layer(inputs) * 0.5
|
||||||
|
|
||||||
|
class CustomLayerDefaultTrainingNone(base_layer.Layer):
|
||||||
|
|
||||||
|
def __init__(self, nested_layer=None):
|
||||||
|
self._nested_layer = nested_layer or array_ops.identity
|
||||||
|
|
||||||
|
def call(self, inputs, training=None):
|
||||||
|
if training:
|
||||||
|
return self._nested_layer(inputs)
|
||||||
|
else:
|
||||||
|
return self._nested_layer(inputs) * 0.5
|
||||||
|
|
||||||
class CustomLayerDefaultTrainingFalse(base_layer.Layer):
|
class CustomLayerDefaultTrainingFalse(base_layer.Layer):
|
||||||
|
|
||||||
def __init__(self, nested_layer=None):
|
def __init__(self, nested_layer=None):
|
||||||
|
@ -701,21 +712,30 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
||||||
# Outer layers/models should set the training context implicitly for all
|
# Outer layers/models should set the training context implicitly for all
|
||||||
# nested layers, respecting whatever mode the outer layer was run with.
|
# nested layers, respecting whatever mode the outer layer was run with.
|
||||||
layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse())
|
layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse())
|
||||||
self.assertAllEqual(layer(x), x)
|
# No outer value passed: use local defaults
|
||||||
|
self.assertAllEqual(layer(x), x * 0.25) # Use local default False
|
||||||
|
# Outer value passed: override local defaults
|
||||||
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
||||||
self.assertAllEqual(layer(x, training=True), x)
|
self.assertAllEqual(layer(x, training=True), x)
|
||||||
|
|
||||||
layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue())
|
layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue())
|
||||||
self.assertAllEqual(layer(x), x * 0.25)
|
# No outer value passed: use local defaults
|
||||||
|
self.assertAllEqual(layer(x), x) # Use local default True
|
||||||
|
# Outer value passed: override local defaults
|
||||||
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
||||||
self.assertAllEqual(layer(x, training=True), x)
|
self.assertAllEqual(layer(x, training=True), x)
|
||||||
|
|
||||||
# If the outer layer `call` doesn't take a training argument at all,
|
# If the outer layer `call` doesn't take a training argument at all,
|
||||||
# it'll set the nested scope as inference when no training arg is passed in.
|
# it'll set the nested scope as None when no training arg is passed in.
|
||||||
# If a training arg is passed in it won't use it directly in `call`, but
|
# If a training arg is passed in it won't use it directly in `call`, but
|
||||||
# it will set the nested training mode.
|
# it will set the nested training mode.
|
||||||
layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue())
|
layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue())
|
||||||
self.assertAllEqual(layer(x), x * 0.5)
|
self.assertAllEqual(layer(x), x) # Use local default True
|
||||||
|
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
||||||
|
self.assertAllEqual(layer(x, training=True), x)
|
||||||
|
|
||||||
|
layer = CustomLayerDefaultTrainingNone(CustomLayerDefaultTrainingTrue())
|
||||||
|
self.assertAllEqual(layer(x), x) # Use local default True
|
||||||
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
||||||
self.assertAllEqual(layer(x, training=True), x)
|
self.assertAllEqual(layer(x, training=True), x)
|
||||||
|
|
||||||
|
|
|
@ -2116,13 +2116,13 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
# In v2, construction still works when no `training` is specified
|
# In v2, construction still works when no `training` is specified
|
||||||
# When no value passed during construction, it uses the runtime value.
|
# When no value passed during construction, it uses the local default.
|
||||||
inputs = input_layer_lib.Input(10)
|
inputs = input_layer_lib.Input(10)
|
||||||
outputs = my_layer(inputs)
|
outputs = my_layer(inputs)
|
||||||
network = functional.Functional(inputs, outputs)
|
network = functional.Functional(inputs, outputs)
|
||||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||||
self.assertAllEqual(network(x), _call(x, False))
|
self.assertAllEqual(network(x), _call(x, True)) # Use local default
|
||||||
|
|
||||||
# `None` value passed positionally during construction is ignored at runtime
|
# `None` value passed positionally during construction is ignored at runtime
|
||||||
inputs = input_layer_lib.Input(10)
|
inputs = input_layer_lib.Input(10)
|
||||||
|
@ -2131,7 +2131,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
self.assertAllEqual(network(x), _call(x, False))
|
self.assertAllEqual(network(x), _call(x, True)) # Use local default
|
||||||
else:
|
else:
|
||||||
# in v1 training would have defaulted to using the `None` inside the layer
|
# in v1 training would have defaulted to using the `None` inside the layer
|
||||||
# if training is not passed at runtime
|
# if training is not passed at runtime
|
||||||
|
@ -2144,7 +2144,7 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
||||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
self.assertAllEqual(network(x), _call(x, False))
|
self.assertAllEqual(network(x), _call(x, True)) # Use local default
|
||||||
else:
|
else:
|
||||||
# in v1 training would have defaulted to using the `None` inside the layer
|
# in v1 training would have defaulted to using the `None` inside the layer
|
||||||
# if training is not passed at runtime
|
# if training is not passed at runtime
|
||||||
|
|
|
@ -26,6 +26,7 @@ from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import test_util as tf_test_util
|
from tensorflow.python.framework import test_util as tf_test_util
|
||||||
from tensorflow.python.keras import keras_parameterized
|
from tensorflow.python.keras import keras_parameterized
|
||||||
from tensorflow.python.keras import testing_utils
|
from tensorflow.python.keras import testing_utils
|
||||||
|
from tensorflow.python.keras.engine import sequential
|
||||||
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
from tensorflow.python.keras.layers.preprocessing import image_preprocessing
|
||||||
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
from tensorflow.python.keras.utils.generic_utils import CustomObjectScope
|
||||||
from tensorflow.python.ops import gen_stateful_random_ops
|
from tensorflow.python.ops import gen_stateful_random_ops
|
||||||
|
@ -1273,5 +1274,38 @@ class RandomWidthTest(keras_parameterized.TestCase):
|
||||||
self.assertEqual(layer_1.name, layer.name)
|
self.assertEqual(layer_1.name, layer.name)
|
||||||
|
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
|
||||||
|
class LearningPhaseTest(keras_parameterized.TestCase):
|
||||||
|
|
||||||
|
def test_plain_call(self):
|
||||||
|
layer = image_preprocessing.RandomWidth(.5, seed=123)
|
||||||
|
shape = (12, 12, 3)
|
||||||
|
img = np.random.random((12,) + shape)
|
||||||
|
out = layer(img) # Default to training=True
|
||||||
|
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
out = layer(img, training=True)
|
||||||
|
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
out = layer(img, training=False)
|
||||||
|
self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
def test_call_in_container(self):
|
||||||
|
layer1 = image_preprocessing.RandomWidth(.5, seed=123)
|
||||||
|
layer2 = image_preprocessing.RandomHeight(.5, seed=123)
|
||||||
|
seq = sequential.Sequential([layer1, layer2])
|
||||||
|
|
||||||
|
shape = (12, 12, 3)
|
||||||
|
img = np.random.random((12,) + shape)
|
||||||
|
out = seq(img) # Default to training=True
|
||||||
|
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
out = seq(img, training=True)
|
||||||
|
self.assertNotEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
out = seq(img, training=False)
|
||||||
|
self.assertEqual(tuple(int(i) for i in out.shape[1:]), shape)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
|
Loading…
Reference in New Issue