From 5b2c54de04f4c1c50a7b1f25585447a907da10e9 Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Mon, 18 Mar 2019 10:45:34 -0700 Subject: [PATCH] Add option for functions in `Lambda` layers to have `training` argument. PiperOrigin-RevId: 239016189 --- tensorflow/python/keras/engine/base_layer.py | 5 +---- tensorflow/python/keras/layers/core.py | 11 +++++++++-- tensorflow/python/keras/layers/core_test.py | 13 +++++++++++++ .../golden/v1/tensorflow.keras.layers.-lambda.pbtxt | 2 +- .../golden/v2/tensorflow.keras.layers.-lambda.pbtxt | 2 +- 5 files changed, 25 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index cf70c28e75e..f6686d6216f 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -193,10 +193,7 @@ class Layer(trackable.Trackable): self._outbound_nodes = [] call_argspec = tf_inspect.getfullargspec(self.call) - if 'training' in call_argspec.args: - self._expects_training_arg = True - else: - self._expects_training_arg = False + self._expects_training_arg = 'training' in call_argspec.args # Whether the `call` method can be used to build a TF graph without issues. self._dynamic = dynamic diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 11f78e8b2e4..477d6d9730c 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -48,6 +48,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import keras_export @@ -738,6 +739,10 @@ class Lambda(Layer): self._trainable_weights = [] self._non_trainable_weights = [] + function_args = tf_inspect.getfullargspec(self.function).args + self._fn_expects_training_arg = 'training' in function_args + self._fn_expects_mask_arg = 'mask' in function_args + @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): if self._output_shape is None: @@ -767,10 +772,12 @@ class Lambda(Layer): output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False) return nest.map_structure(_add_batch, output_shapes) - def call(self, inputs, mask=None): + def call(self, inputs, mask=None, training=None): arguments = self.arguments - if generic_utils.has_arg(self.function, 'mask'): + if self._fn_expects_mask_arg: arguments['mask'] = mask + if self._fn_expects_training_arg: + arguments['training'] = training with variable_scope.variable_creator_scope(self._variable_creator): return self.function(inputs, **arguments) diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index 92ddaa9ee96..9f818a54da6 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -204,6 +204,19 @@ class LambdaLayerTest(keras_parameterized.TestCase): self.assertLen(layer.trainable_weights, 1) self.assertEqual(layer.trainable_weights[0].name, 'lambda/multiplier:0') + def test_lambda_with_training_arg(self): + + def fn(x, training=True): + return keras.backend.in_train_phase(x, 2 * x, training=training) + + layer = keras.layers.Lambda(fn) + x = keras.backend.ones(()) + train_out = layer(x, training=True) + eval_out = layer(x, training=False) + + self.assertEqual(keras.backend.get_value(train_out), 1.) + self.assertEqual(keras.backend.get_value(eval_out), 2.) + class TestStatefulLambda(keras_parameterized.TestCase): diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt index 88f1f8b06d1..b833a4216af 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt @@ -118,7 +118,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt index 88f1f8b06d1..b833a4216af 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt @@ -118,7 +118,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "compute_mask"