Add option for functions in Lambda layers to have training argument.

PiperOrigin-RevId: 239016189
This commit is contained in:
Thomas O'Malley 2019-03-18 10:45:34 -07:00 committed by TensorFlower Gardener
parent 47f78f1f5a
commit 5b2c54de04
5 changed files with 25 additions and 8 deletions

View File

@ -193,10 +193,7 @@ class Layer(trackable.Trackable):
self._outbound_nodes = [] self._outbound_nodes = []
call_argspec = tf_inspect.getfullargspec(self.call) call_argspec = tf_inspect.getfullargspec(self.call)
if 'training' in call_argspec.args: self._expects_training_arg = 'training' in call_argspec.args
self._expects_training_arg = True
else:
self._expects_training_arg = False
# Whether the `call` method can be used to build a TF graph without issues. # Whether the `call` method can be used to build a TF graph without issues.
self._dynamic = dynamic self._dynamic = dynamic

View File

@ -48,6 +48,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import standard_ops from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -738,6 +739,10 @@ class Lambda(Layer):
self._trainable_weights = [] self._trainable_weights = []
self._non_trainable_weights = [] self._non_trainable_weights = []
function_args = tf_inspect.getfullargspec(self.function).args
self._fn_expects_training_arg = 'training' in function_args
self._fn_expects_mask_arg = 'mask' in function_args
@tf_utils.shape_type_conversion @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape): def compute_output_shape(self, input_shape):
if self._output_shape is None: if self._output_shape is None:
@ -767,10 +772,12 @@ class Lambda(Layer):
output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False) output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False)
return nest.map_structure(_add_batch, output_shapes) return nest.map_structure(_add_batch, output_shapes)
def call(self, inputs, mask=None): def call(self, inputs, mask=None, training=None):
arguments = self.arguments arguments = self.arguments
if generic_utils.has_arg(self.function, 'mask'): if self._fn_expects_mask_arg:
arguments['mask'] = mask arguments['mask'] = mask
if self._fn_expects_training_arg:
arguments['training'] = training
with variable_scope.variable_creator_scope(self._variable_creator): with variable_scope.variable_creator_scope(self._variable_creator):
return self.function(inputs, **arguments) return self.function(inputs, **arguments)

View File

@ -204,6 +204,19 @@ class LambdaLayerTest(keras_parameterized.TestCase):
self.assertLen(layer.trainable_weights, 1) self.assertLen(layer.trainable_weights, 1)
self.assertEqual(layer.trainable_weights[0].name, 'lambda/multiplier:0') self.assertEqual(layer.trainable_weights[0].name, 'lambda/multiplier:0')
def test_lambda_with_training_arg(self):
def fn(x, training=True):
return keras.backend.in_train_phase(x, 2 * x, training=training)
layer = keras.layers.Lambda(fn)
x = keras.backend.ones(())
train_out = layer(x, training=True)
eval_out = layer(x, training=False)
self.assertEqual(keras.backend.get_value(train_out), 1.)
self.assertEqual(keras.backend.get_value(eval_out), 2.)
class TestStatefulLambda(keras_parameterized.TestCase): class TestStatefulLambda(keras_parameterized.TestCase):

View File

@ -118,7 +118,7 @@ tf_class {
} }
member_method { member_method {
name: "call" name: "call"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
} }
member_method { member_method {
name: "compute_mask" name: "compute_mask"

View File

@ -118,7 +118,7 @@ tf_class {
} }
member_method { member_method {
name: "call" name: "call"
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
} }
member_method { member_method {
name: "compute_mask" name: "compute_mask"