Updates Keras layer `__call__` to always set `training`, with the following priority order:

# Training mode for `Layer.__call__` is set via (in order of priority):
    # (1) The `training` argument passed to this `Layer.__call__`, if it is not None
    # (2) The training mode of an outer `Layer.__call__`.
    # (3) The default mode set by `tf.keras.backed.set_learning_phase` (if set).
    # (4) Any non-None default value for `training` specified in the `call`
    #  signature
    # (5) False (treating the layer as if it's in inference)

Previously (4) and (5) were missing, leading to crashes for layers that do not provide a default argument for `training`.

Note that (4) is fragile to reflection issues, and may get confused by decorators.

PiperOrigin-RevId: 317709904
Change-Id: I58039a4d9e5106bcb27f4cfbf65e6762f1b40807
This commit is contained in:
Tomer Kaftan 2020-06-22 12:24:15 -07:00 committed by TensorFlower Gardener
parent 780c0a29fe
commit aac1dd5788
5 changed files with 192 additions and 32 deletions

View File

@ -943,10 +943,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
kwargs['mask'] = input_masks
# Training mode for `Layer.call` is set via (in order of priority):
# (1) The `training` argument passed to this `Layer.call`.
# (1) The `training` argument passed to this `Layer.call`, if it is not None
# (2) The training mode of an outer `Layer.call`.
# (3) The default mode set by `tf.keras.backed.set_learning_phase` (if set).
training_mode = self._set_training_mode(args, kwargs, call_context)
# (3) The default mode set by `tf.keras.backend.set_learning_phase` (if set)
# (4) Any non-None default value for `training` specified in the call
# signature
# (5) False (treating the layer as if it's in inference)
args, kwargs, training_mode = self._set_training_mode(
args, kwargs, call_context)
# Losses are cleared for all sublayers on the outermost `Layer.call`.
# Losses are not cleared on inner `Layer.call`s, because sublayers can be
@ -1020,7 +1024,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# propagate `training` value from this layer's calling layer.
training_value = None
training_arg_passed_by_framework = False
# Priority 1: `training` was explicitly passed.
# Priority 1: `training` was explicitly passed a non-None value.
if self._call_arg_was_passed('training', args, kwargs):
training_value = self._get_call_arg_value('training', args, kwargs)
if not self._expects_training_arg:
@ -1030,17 +1034,23 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# Priority 2: `training` was passed to a parent layer.
if call_context.training is not None:
training_value = call_context.training
# Priority 3a: `learning_phase()` has been set.
# Priority 3: `learning_phase()` has been set.
elif backend.global_learning_phase_is_set():
training_value = backend.learning_phase()
if self._expects_training_arg and training_value is not None:
# Force the training_value to be bool type which matches to the contract
# for layer/model call args.
if tensor_util.is_tensor(training_value):
training_value = math_ops.cast(training_value, dtypes.bool)
else:
training_value = bool(training_value)
# Priority 4: trace layer with the default training argument specified
# in the `call` signature (or in inference mode if the `call` signature
# specifies no non-None default).
else:
training_value = self._default_training_arg
# In cases (2), (3), (4) the training argument is passed automatically
# by the framework, and will not be hard-coded into the model.
if self._expects_training_arg:
args, kwargs = self._set_call_arg_value('training', training_value,
args, kwargs)
training_arg_passed_by_framework = True
@ -1150,6 +1160,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# (1) `training` was passed to this `Layer.call`.
if self._call_arg_was_passed('training', args, kwargs):
training_mode = self._get_call_arg_value('training', args, kwargs)
# If no `training` arg was passed, or `None` was explicitly passed,
# the framework will make a decision about the training mode is.
if training_mode is None:
call_ctx_training = call_context.training
# (2) `training` mode is inferred from an outer `Layer.call`.
@ -1165,10 +1177,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
training_mode = math_ops.cast(training_mode, dtypes.bool)
else:
training_mode = bool(training_mode)
# (4) We default to using `call`'s default value for `training`,
# or treating the layer as if it is in inference if no non-None default
# is specified in the `call` signature.
else:
training_mode = self._default_training_arg
# For case (2) or (3), `training` arg is passed by framework.
if training_mode is not None:
kwargs['training'] = training_mode
# For case (2), (3), (4) `training` arg is passed by framework.
args, kwargs = self._set_call_arg_value('training', training_mode, args,
kwargs)
else:
if 'training' in kwargs:
# `training` was passed to this `Layer` but is not needed for
@ -1178,7 +1195,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
# Grab the current `training` mode from any outer `Layer.call`.
training_mode = call_context.training
return training_mode
return args, kwargs, training_mode
def _autographed_call(self):
# Wrapping `call` function in autograph to allow for dynamic control
@ -2529,7 +2546,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
if len(args) > arg_pos:
args = list(args)
args[arg_pos] = new_value
return args, kwargs
return tuple(args), kwargs
if new_value is None and pop_kwarg_if_none:
kwargs.pop(arg_name, None)
else:
@ -2873,6 +2890,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
call_fn_args = self._call_fn_args
self._expects_training_arg = ('training' in call_fn_args or
self._call_accepts_kwargs)
# The default training arg will be any (non-None) default specified in the
# method signature, or `False` if no non-None default is specified.
self._default_training_arg = self._call_fn_arg_defaults.get(
'training') or False
self._expects_mask_arg = ('mask' in call_fn_args or
self._call_accepts_kwargs)
@ -2892,6 +2913,19 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
return all_args[1:]
return all_args
@property
@tracking.cached_per_instance
def _call_fn_arg_defaults(self):
call_fn_args = self._call_fn_args
call_fn_defaults = self._call_full_argspec.defaults or []
defaults = dict()
# The call arg defaults are an n-tuple of the last n elements of the args
# list. (n = # of elements that have a default argument)
for i in range(-1 * len(call_fn_defaults), 0):
defaults[call_fn_args[i]] = call_fn_defaults[i]
return defaults
@property
@tracking.cached_per_instance
def _call_fn_arg_positions(self):

View File

@ -629,6 +629,96 @@ class BaseLayerTest(keras_parameterized.TestCase):
self.assertTrue(layer.built)
self.assertEqual([None, 3], layer._build_input_shape.as_list())
@combinations.generate(combinations.combine(mode=['eager']))
def custom_layer_training_arg(self):
class CustomLayerNoTrainingArg(base_layer.Layer):
def __init__(self, nested_layer=None):
self._nested_layer = nested_layer or array_ops.identity
def call(self, inputs):
return self._nested_layer(inputs)
class CustomLayerDefaultTrainingMissing(base_layer.Layer):
def __init__(self, nested_layer=None):
self._nested_layer = nested_layer or array_ops.identity
def call(self, inputs, training):
if training:
return self._nested_layer(inputs)
else:
return self._nested_layer(inputs) * 0.5
class CustomLayerDefaultTrainingFalse(base_layer.Layer):
def __init__(self, nested_layer=None):
self._nested_layer = nested_layer or array_ops.identity
def call(self, inputs, training=False):
if training:
return self._nested_layer(inputs)
else:
return self._nested_layer(inputs) * 0.5
class CustomLayerDefaultTrainingTrue(base_layer.Layer):
def __init__(self, nested_layer=None):
self._nested_layer = nested_layer or array_ops.identity
def call(self, inputs, training=True):
if training:
return self._nested_layer(inputs)
else:
return self._nested_layer(inputs) * 0.5
x = array_ops.ones(shape=(1, 1))
# If the layer signature doesn't specify a default training arg,
# run it in inference mode when to training arg is passed
# to __call__
layer = CustomLayerDefaultTrainingMissing()
self.assertAllEqual(layer(x), x * 0.5)
self.assertAllEqual(layer(x, training=False), x * 0.5)
self.assertAllEqual(layer(x, training=True), x)
# If the layer signature specifies `False` as the default training arg,
# run it in inference mode when no training arg is passed
# to __call__
layer = CustomLayerDefaultTrainingFalse()
self.assertAllEqual(layer(x), x * 0.5)
self.assertAllEqual(layer(x, training=False), x * 0.5)
self.assertAllEqual(layer(x, training=True), x)
# If the layer signature specifies `True` as the default training arg,
# explicitly run it in training mode when no training arg is passed
# to __call__
layer = CustomLayerDefaultTrainingTrue()
self.assertAllEqual(layer(x), x)
self.assertAllEqual(layer(x, training=False), x * 0.5)
self.assertAllEqual(layer(x, training=True), x)
# Outer layers/models should set the training context implicitly for all
# nested layers, respecting whatever mode the outer layer was run with.
layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse())
self.assertAllEqual(layer(x), x)
self.assertAllEqual(layer(x, training=False), x * 0.25)
self.assertAllEqual(layer(x, training=True), x)
layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue())
self.assertAllEqual(layer(x), x * 0.25)
self.assertAllEqual(layer(x, training=False), x * 0.25)
self.assertAllEqual(layer(x, training=True), x)
# If the outer layer `call` doesn't take a training argument at all,
# it'll set the nested scope as inference when no training arg is passed in.
# If a training arg is passed in it won't use it directly in `call`, but
# it will set the nested training mode.
layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue())
self.assertAllEqual(layer(x), x * 0.5)
self.assertAllEqual(layer(x, training=False), x * 0.5)
self.assertAllEqual(layer(x, training=True), x)
def test_activity_regularizer_string(self):
class MyLayer(base_layer.Layer):
@ -1387,6 +1477,7 @@ class DTypeTest(keras_parameterized.TestCase):
class IdentityLayerWithArgs(base_layer.Layer):
def call(self, inputs, *args, **kwargs):
kwargs.pop('training', None)
return nest.flatten([inputs, args, kwargs])
layer = IdentityLayerWithArgs(dtype='float64')

View File

@ -2036,43 +2036,73 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
def test_training_passed_during_construction(self):
def _call(inputs, training):
if training is None:
return inputs * -1.0
elif training:
return inputs
else:
return inputs * 0.0
class MyLayer(base_layer.Layer):
def call(self, x, training=None):
if training is None:
return x * -1.0
elif training:
return x
else:
return x * 0.0
def call(self, inputs, training=True):
return _call(inputs, training)
my_layer = MyLayer()
x = np.ones((1, 10))
# Hard-coded `true` value passed during construction is respected.
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=True)
network = functional.Functional(inputs, outputs)
self.assertAllEqual(network(x, training=True), _call(x, True))
self.assertAllEqual(network(x, training=False), _call(x, True))
self.assertAllEqual(network(x), _call(x, True))
# Hard-coded value passed during construction is respected.
self.assertAllEqual(network(x, training=False), x)
# Hard-coded `false` value passed during construction is respected.
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=False)
network = functional.Functional(inputs, outputs)
self.assertAllEqual(network(x, training=True), _call(x, False))
self.assertAllEqual(network(x, training=False), _call(x, False))
self.assertAllEqual(network(x), _call(x, False))
network(x, training=True)
# Hard-coded value passed during construction is respected.
self.assertAllEqual(network(x, training=True), x * 0.0)
if context.executing_eagerly():
# In v2, construction still works when no `training` is specified
# When no value passed during construction, it uses the runtime value.
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs)
network = functional.Functional(inputs, outputs)
self.assertAllEqual(network(x, training=True), _call(x, True))
self.assertAllEqual(network(x, training=False), _call(x, False))
self.assertAllEqual(network(x), _call(x, False))
# `None` value passed positionally during construction is ignored at runtime
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, None)
network = functional.Functional(inputs, outputs)
self.assertAllEqual(network(x, training=True), _call(x, True))
self.assertAllEqual(network(x, training=False), _call(x, False))
if context.executing_eagerly():
self.assertAllEqual(network(x), _call(x, False))
else:
# in v1 training would have defaulted to using the `None` inside the layer
# if training is not passed at runtime
self.assertAllEqual(network(x), _call(x, None))
# `None` value passed as kwarg during construction is ignored at runtime.
inputs = input_layer_lib.Input(10)
outputs = my_layer(inputs, training=None)
network = functional.Functional(inputs, outputs)
# `None` value passed during construction is overridden.
self.assertAllEqual(network(x, training=True), x)
# `None` value passed during construction is overridden.
self.assertAllEqual(network(x, training=False), x * 0.0)
self.assertAllEqual(network(x, training=True), _call(x, True))
self.assertAllEqual(network(x, training=False), _call(x, False))
if context.executing_eagerly():
self.assertAllEqual(network(x), _call(x, False))
else:
# in v1 training would have defaulted to using the `None` inside the layer
# if training is not passed at runtime
self.assertAllEqual(network(x), _call(x, None))
if __name__ == '__main__':
test.main()

View File

@ -27,6 +27,7 @@ from __future__ import print_function
from tensorflow.python.keras.layers import recurrent
from tensorflow.python.ops import rnn_cell_wrapper_impl
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@ -41,6 +42,10 @@ class _RNNCellWrapperV2(recurrent.AbstractRNNCell):
def __init__(self, cell, *args, **kwargs):
super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
self.cell = cell
cell_call_spec = tf_inspect.getfullargspec(cell.call)
self._expects_training_arg = ("training" in cell_call_spec.args) or (
cell_call_spec.varkw is not None
)
def call(self, inputs, state, **kwargs):
"""Runs the RNN cell step computation.

View File

@ -186,7 +186,7 @@ class CustomTrainingLoopTest(keras_parameterized.TestCase):
def train_step(x):
no_learning_phase_out = model(x)
self.assertIsNone(model.layer.training)
self.assertFalse(model.layer.training)
with keras.backend.learning_phase_scope(0):
inf_learning_phase_out = model(x)
self.assertEqual(model.layer.training, 0)