Updates Keras layer `__call__` to always set `training`, with the following priority order:
# Training mode for `Layer.__call__` is set via (in order of priority): # (1) The `training` argument passed to this `Layer.__call__`, if it is not None # (2) The training mode of an outer `Layer.__call__`. # (3) The default mode set by `tf.keras.backed.set_learning_phase` (if set). # (4) Any non-None default value for `training` specified in the `call` # signature # (5) False (treating the layer as if it's in inference) Previously (4) and (5) were missing, leading to crashes for layers that do not provide a default argument for `training`. Note that (4) is fragile to reflection issues, and may get confused by decorators. PiperOrigin-RevId: 317709904 Change-Id: I58039a4d9e5106bcb27f4cfbf65e6762f1b40807
This commit is contained in:
parent
780c0a29fe
commit
aac1dd5788
|
@ -943,10 +943,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
kwargs['mask'] = input_masks
|
||||
|
||||
# Training mode for `Layer.call` is set via (in order of priority):
|
||||
# (1) The `training` argument passed to this `Layer.call`.
|
||||
# (1) The `training` argument passed to this `Layer.call`, if it is not None
|
||||
# (2) The training mode of an outer `Layer.call`.
|
||||
# (3) The default mode set by `tf.keras.backed.set_learning_phase` (if set).
|
||||
training_mode = self._set_training_mode(args, kwargs, call_context)
|
||||
# (3) The default mode set by `tf.keras.backend.set_learning_phase` (if set)
|
||||
# (4) Any non-None default value for `training` specified in the call
|
||||
# signature
|
||||
# (5) False (treating the layer as if it's in inference)
|
||||
args, kwargs, training_mode = self._set_training_mode(
|
||||
args, kwargs, call_context)
|
||||
|
||||
# Losses are cleared for all sublayers on the outermost `Layer.call`.
|
||||
# Losses are not cleared on inner `Layer.call`s, because sublayers can be
|
||||
|
@ -1020,7 +1024,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
# propagate `training` value from this layer's calling layer.
|
||||
training_value = None
|
||||
training_arg_passed_by_framework = False
|
||||
# Priority 1: `training` was explicitly passed.
|
||||
# Priority 1: `training` was explicitly passed a non-None value.
|
||||
if self._call_arg_was_passed('training', args, kwargs):
|
||||
training_value = self._get_call_arg_value('training', args, kwargs)
|
||||
if not self._expects_training_arg:
|
||||
|
@ -1030,17 +1034,23 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
# Priority 2: `training` was passed to a parent layer.
|
||||
if call_context.training is not None:
|
||||
training_value = call_context.training
|
||||
# Priority 3a: `learning_phase()` has been set.
|
||||
# Priority 3: `learning_phase()` has been set.
|
||||
elif backend.global_learning_phase_is_set():
|
||||
training_value = backend.learning_phase()
|
||||
|
||||
if self._expects_training_arg and training_value is not None:
|
||||
# Force the training_value to be bool type which matches to the contract
|
||||
# for layer/model call args.
|
||||
if tensor_util.is_tensor(training_value):
|
||||
training_value = math_ops.cast(training_value, dtypes.bool)
|
||||
else:
|
||||
training_value = bool(training_value)
|
||||
# Priority 4: trace layer with the default training argument specified
|
||||
# in the `call` signature (or in inference mode if the `call` signature
|
||||
# specifies no non-None default).
|
||||
else:
|
||||
training_value = self._default_training_arg
|
||||
# In cases (2), (3), (4) the training argument is passed automatically
|
||||
# by the framework, and will not be hard-coded into the model.
|
||||
if self._expects_training_arg:
|
||||
args, kwargs = self._set_call_arg_value('training', training_value,
|
||||
args, kwargs)
|
||||
training_arg_passed_by_framework = True
|
||||
|
@ -1150,6 +1160,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
# (1) `training` was passed to this `Layer.call`.
|
||||
if self._call_arg_was_passed('training', args, kwargs):
|
||||
training_mode = self._get_call_arg_value('training', args, kwargs)
|
||||
# If no `training` arg was passed, or `None` was explicitly passed,
|
||||
# the framework will make a decision about the training mode is.
|
||||
if training_mode is None:
|
||||
call_ctx_training = call_context.training
|
||||
# (2) `training` mode is inferred from an outer `Layer.call`.
|
||||
|
@ -1165,10 +1177,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
training_mode = math_ops.cast(training_mode, dtypes.bool)
|
||||
else:
|
||||
training_mode = bool(training_mode)
|
||||
# (4) We default to using `call`'s default value for `training`,
|
||||
# or treating the layer as if it is in inference if no non-None default
|
||||
# is specified in the `call` signature.
|
||||
else:
|
||||
training_mode = self._default_training_arg
|
||||
|
||||
# For case (2) or (3), `training` arg is passed by framework.
|
||||
if training_mode is not None:
|
||||
kwargs['training'] = training_mode
|
||||
# For case (2), (3), (4) `training` arg is passed by framework.
|
||||
args, kwargs = self._set_call_arg_value('training', training_mode, args,
|
||||
kwargs)
|
||||
else:
|
||||
if 'training' in kwargs:
|
||||
# `training` was passed to this `Layer` but is not needed for
|
||||
|
@ -1178,7 +1195,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
# Grab the current `training` mode from any outer `Layer.call`.
|
||||
training_mode = call_context.training
|
||||
|
||||
return training_mode
|
||||
return args, kwargs, training_mode
|
||||
|
||||
def _autographed_call(self):
|
||||
# Wrapping `call` function in autograph to allow for dynamic control
|
||||
|
@ -2529,7 +2546,7 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
if len(args) > arg_pos:
|
||||
args = list(args)
|
||||
args[arg_pos] = new_value
|
||||
return args, kwargs
|
||||
return tuple(args), kwargs
|
||||
if new_value is None and pop_kwarg_if_none:
|
||||
kwargs.pop(arg_name, None)
|
||||
else:
|
||||
|
@ -2873,6 +2890,10 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
call_fn_args = self._call_fn_args
|
||||
self._expects_training_arg = ('training' in call_fn_args or
|
||||
self._call_accepts_kwargs)
|
||||
# The default training arg will be any (non-None) default specified in the
|
||||
# method signature, or `False` if no non-None default is specified.
|
||||
self._default_training_arg = self._call_fn_arg_defaults.get(
|
||||
'training') or False
|
||||
self._expects_mask_arg = ('mask' in call_fn_args or
|
||||
self._call_accepts_kwargs)
|
||||
|
||||
|
@ -2892,6 +2913,19 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
return all_args[1:]
|
||||
return all_args
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def _call_fn_arg_defaults(self):
|
||||
call_fn_args = self._call_fn_args
|
||||
call_fn_defaults = self._call_full_argspec.defaults or []
|
||||
defaults = dict()
|
||||
|
||||
# The call arg defaults are an n-tuple of the last n elements of the args
|
||||
# list. (n = # of elements that have a default argument)
|
||||
for i in range(-1 * len(call_fn_defaults), 0):
|
||||
defaults[call_fn_args[i]] = call_fn_defaults[i]
|
||||
return defaults
|
||||
|
||||
@property
|
||||
@tracking.cached_per_instance
|
||||
def _call_fn_arg_positions(self):
|
||||
|
|
|
@ -629,6 +629,96 @@ class BaseLayerTest(keras_parameterized.TestCase):
|
|||
self.assertTrue(layer.built)
|
||||
self.assertEqual([None, 3], layer._build_input_shape.as_list())
|
||||
|
||||
@combinations.generate(combinations.combine(mode=['eager']))
|
||||
def custom_layer_training_arg(self):
|
||||
class CustomLayerNoTrainingArg(base_layer.Layer):
|
||||
|
||||
def __init__(self, nested_layer=None):
|
||||
self._nested_layer = nested_layer or array_ops.identity
|
||||
|
||||
def call(self, inputs):
|
||||
return self._nested_layer(inputs)
|
||||
|
||||
class CustomLayerDefaultTrainingMissing(base_layer.Layer):
|
||||
|
||||
def __init__(self, nested_layer=None):
|
||||
self._nested_layer = nested_layer or array_ops.identity
|
||||
|
||||
def call(self, inputs, training):
|
||||
if training:
|
||||
return self._nested_layer(inputs)
|
||||
else:
|
||||
return self._nested_layer(inputs) * 0.5
|
||||
|
||||
class CustomLayerDefaultTrainingFalse(base_layer.Layer):
|
||||
|
||||
def __init__(self, nested_layer=None):
|
||||
self._nested_layer = nested_layer or array_ops.identity
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
if training:
|
||||
return self._nested_layer(inputs)
|
||||
else:
|
||||
return self._nested_layer(inputs) * 0.5
|
||||
|
||||
class CustomLayerDefaultTrainingTrue(base_layer.Layer):
|
||||
|
||||
def __init__(self, nested_layer=None):
|
||||
self._nested_layer = nested_layer or array_ops.identity
|
||||
|
||||
def call(self, inputs, training=True):
|
||||
if training:
|
||||
return self._nested_layer(inputs)
|
||||
else:
|
||||
return self._nested_layer(inputs) * 0.5
|
||||
|
||||
x = array_ops.ones(shape=(1, 1))
|
||||
|
||||
# If the layer signature doesn't specify a default training arg,
|
||||
# run it in inference mode when to training arg is passed
|
||||
# to __call__
|
||||
layer = CustomLayerDefaultTrainingMissing()
|
||||
self.assertAllEqual(layer(x), x * 0.5)
|
||||
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
||||
self.assertAllEqual(layer(x, training=True), x)
|
||||
|
||||
# If the layer signature specifies `False` as the default training arg,
|
||||
# run it in inference mode when no training arg is passed
|
||||
# to __call__
|
||||
layer = CustomLayerDefaultTrainingFalse()
|
||||
self.assertAllEqual(layer(x), x * 0.5)
|
||||
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
||||
self.assertAllEqual(layer(x, training=True), x)
|
||||
|
||||
# If the layer signature specifies `True` as the default training arg,
|
||||
# explicitly run it in training mode when no training arg is passed
|
||||
# to __call__
|
||||
layer = CustomLayerDefaultTrainingTrue()
|
||||
self.assertAllEqual(layer(x), x)
|
||||
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
||||
self.assertAllEqual(layer(x, training=True), x)
|
||||
|
||||
# Outer layers/models should set the training context implicitly for all
|
||||
# nested layers, respecting whatever mode the outer layer was run with.
|
||||
layer = CustomLayerDefaultTrainingTrue(CustomLayerDefaultTrainingFalse())
|
||||
self.assertAllEqual(layer(x), x)
|
||||
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
||||
self.assertAllEqual(layer(x, training=True), x)
|
||||
|
||||
layer = CustomLayerDefaultTrainingFalse(CustomLayerDefaultTrainingTrue())
|
||||
self.assertAllEqual(layer(x), x * 0.25)
|
||||
self.assertAllEqual(layer(x, training=False), x * 0.25)
|
||||
self.assertAllEqual(layer(x, training=True), x)
|
||||
|
||||
# If the outer layer `call` doesn't take a training argument at all,
|
||||
# it'll set the nested scope as inference when no training arg is passed in.
|
||||
# If a training arg is passed in it won't use it directly in `call`, but
|
||||
# it will set the nested training mode.
|
||||
layer = CustomLayerNoTrainingArg(CustomLayerDefaultTrainingTrue())
|
||||
self.assertAllEqual(layer(x), x * 0.5)
|
||||
self.assertAllEqual(layer(x, training=False), x * 0.5)
|
||||
self.assertAllEqual(layer(x, training=True), x)
|
||||
|
||||
def test_activity_regularizer_string(self):
|
||||
|
||||
class MyLayer(base_layer.Layer):
|
||||
|
@ -1387,6 +1477,7 @@ class DTypeTest(keras_parameterized.TestCase):
|
|||
class IdentityLayerWithArgs(base_layer.Layer):
|
||||
|
||||
def call(self, inputs, *args, **kwargs):
|
||||
kwargs.pop('training', None)
|
||||
return nest.flatten([inputs, args, kwargs])
|
||||
|
||||
layer = IdentityLayerWithArgs(dtype='float64')
|
||||
|
|
|
@ -2036,43 +2036,73 @@ class CacheCorrectnessTest(keras_parameterized.TestCase):
|
|||
|
||||
def test_training_passed_during_construction(self):
|
||||
|
||||
def _call(inputs, training):
|
||||
if training is None:
|
||||
return inputs * -1.0
|
||||
elif training:
|
||||
return inputs
|
||||
else:
|
||||
return inputs * 0.0
|
||||
|
||||
class MyLayer(base_layer.Layer):
|
||||
|
||||
def call(self, x, training=None):
|
||||
if training is None:
|
||||
return x * -1.0
|
||||
elif training:
|
||||
return x
|
||||
else:
|
||||
return x * 0.0
|
||||
def call(self, inputs, training=True):
|
||||
return _call(inputs, training)
|
||||
|
||||
my_layer = MyLayer()
|
||||
x = np.ones((1, 10))
|
||||
|
||||
# Hard-coded `true` value passed during construction is respected.
|
||||
inputs = input_layer_lib.Input(10)
|
||||
outputs = my_layer(inputs, training=True)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||
self.assertAllEqual(network(x, training=False), _call(x, True))
|
||||
self.assertAllEqual(network(x), _call(x, True))
|
||||
|
||||
# Hard-coded value passed during construction is respected.
|
||||
self.assertAllEqual(network(x, training=False), x)
|
||||
|
||||
# Hard-coded `false` value passed during construction is respected.
|
||||
inputs = input_layer_lib.Input(10)
|
||||
outputs = my_layer(inputs, training=False)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
self.assertAllEqual(network(x, training=True), _call(x, False))
|
||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||
self.assertAllEqual(network(x), _call(x, False))
|
||||
|
||||
network(x, training=True)
|
||||
# Hard-coded value passed during construction is respected.
|
||||
self.assertAllEqual(network(x, training=True), x * 0.0)
|
||||
if context.executing_eagerly():
|
||||
# In v2, construction still works when no `training` is specified
|
||||
# When no value passed during construction, it uses the runtime value.
|
||||
inputs = input_layer_lib.Input(10)
|
||||
outputs = my_layer(inputs)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||
self.assertAllEqual(network(x), _call(x, False))
|
||||
|
||||
# `None` value passed positionally during construction is ignored at runtime
|
||||
inputs = input_layer_lib.Input(10)
|
||||
outputs = my_layer(inputs, None)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||
if context.executing_eagerly():
|
||||
self.assertAllEqual(network(x), _call(x, False))
|
||||
else:
|
||||
# in v1 training would have defaulted to using the `None` inside the layer
|
||||
# if training is not passed at runtime
|
||||
self.assertAllEqual(network(x), _call(x, None))
|
||||
|
||||
# `None` value passed as kwarg during construction is ignored at runtime.
|
||||
inputs = input_layer_lib.Input(10)
|
||||
outputs = my_layer(inputs, training=None)
|
||||
network = functional.Functional(inputs, outputs)
|
||||
|
||||
# `None` value passed during construction is overridden.
|
||||
self.assertAllEqual(network(x, training=True), x)
|
||||
# `None` value passed during construction is overridden.
|
||||
self.assertAllEqual(network(x, training=False), x * 0.0)
|
||||
|
||||
self.assertAllEqual(network(x, training=True), _call(x, True))
|
||||
self.assertAllEqual(network(x, training=False), _call(x, False))
|
||||
if context.executing_eagerly():
|
||||
self.assertAllEqual(network(x), _call(x, False))
|
||||
else:
|
||||
# in v1 training would have defaulted to using the `None` inside the layer
|
||||
# if training is not passed at runtime
|
||||
self.assertAllEqual(network(x), _call(x, None))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
|
|
@ -27,6 +27,7 @@ from __future__ import print_function
|
|||
|
||||
from tensorflow.python.keras.layers import recurrent
|
||||
from tensorflow.python.ops import rnn_cell_wrapper_impl
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
|
@ -41,6 +42,10 @@ class _RNNCellWrapperV2(recurrent.AbstractRNNCell):
|
|||
def __init__(self, cell, *args, **kwargs):
|
||||
super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
|
||||
self.cell = cell
|
||||
cell_call_spec = tf_inspect.getfullargspec(cell.call)
|
||||
self._expects_training_arg = ("training" in cell_call_spec.args) or (
|
||||
cell_call_spec.varkw is not None
|
||||
)
|
||||
|
||||
def call(self, inputs, state, **kwargs):
|
||||
"""Runs the RNN cell step computation.
|
||||
|
|
|
@ -186,7 +186,7 @@ class CustomTrainingLoopTest(keras_parameterized.TestCase):
|
|||
|
||||
def train_step(x):
|
||||
no_learning_phase_out = model(x)
|
||||
self.assertIsNone(model.layer.training)
|
||||
self.assertFalse(model.layer.training)
|
||||
with keras.backend.learning_phase_scope(0):
|
||||
inf_learning_phase_out = model(x)
|
||||
self.assertEqual(model.layer.training, 0)
|
||||
|
|
Loading…
Reference in New Issue