Add option for functions in Lambda
layers to have training
argument.
PiperOrigin-RevId: 239016189
This commit is contained in:
parent
47f78f1f5a
commit
5b2c54de04
@ -193,10 +193,7 @@ class Layer(trackable.Trackable):
|
||||
self._outbound_nodes = []
|
||||
|
||||
call_argspec = tf_inspect.getfullargspec(self.call)
|
||||
if 'training' in call_argspec.args:
|
||||
self._expects_training_arg = True
|
||||
else:
|
||||
self._expects_training_arg = False
|
||||
self._expects_training_arg = 'training' in call_argspec.args
|
||||
|
||||
# Whether the `call` method can be used to build a TF graph without issues.
|
||||
self._dynamic = dynamic
|
||||
|
@ -48,6 +48,7 @@ from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import standard_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_inspect
|
||||
from tensorflow.python.util.tf_export import keras_export
|
||||
|
||||
|
||||
@ -738,6 +739,10 @@ class Lambda(Layer):
|
||||
self._trainable_weights = []
|
||||
self._non_trainable_weights = []
|
||||
|
||||
function_args = tf_inspect.getfullargspec(self.function).args
|
||||
self._fn_expects_training_arg = 'training' in function_args
|
||||
self._fn_expects_mask_arg = 'mask' in function_args
|
||||
|
||||
@tf_utils.shape_type_conversion
|
||||
def compute_output_shape(self, input_shape):
|
||||
if self._output_shape is None:
|
||||
@ -767,10 +772,12 @@ class Lambda(Layer):
|
||||
output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False)
|
||||
return nest.map_structure(_add_batch, output_shapes)
|
||||
|
||||
def call(self, inputs, mask=None):
|
||||
def call(self, inputs, mask=None, training=None):
|
||||
arguments = self.arguments
|
||||
if generic_utils.has_arg(self.function, 'mask'):
|
||||
if self._fn_expects_mask_arg:
|
||||
arguments['mask'] = mask
|
||||
if self._fn_expects_training_arg:
|
||||
arguments['training'] = training
|
||||
with variable_scope.variable_creator_scope(self._variable_creator):
|
||||
return self.function(inputs, **arguments)
|
||||
|
||||
|
@ -204,6 +204,19 @@ class LambdaLayerTest(keras_parameterized.TestCase):
|
||||
self.assertLen(layer.trainable_weights, 1)
|
||||
self.assertEqual(layer.trainable_weights[0].name, 'lambda/multiplier:0')
|
||||
|
||||
def test_lambda_with_training_arg(self):
|
||||
|
||||
def fn(x, training=True):
|
||||
return keras.backend.in_train_phase(x, 2 * x, training=training)
|
||||
|
||||
layer = keras.layers.Lambda(fn)
|
||||
x = keras.backend.ones(())
|
||||
train_out = layer(x, training=True)
|
||||
eval_out = layer(x, training=False)
|
||||
|
||||
self.assertEqual(keras.backend.get_value(train_out), 1.)
|
||||
self.assertEqual(keras.backend.get_value(eval_out), 2.)
|
||||
|
||||
|
||||
class TestStatefulLambda(keras_parameterized.TestCase):
|
||||
|
||||
|
@ -118,7 +118,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
@ -118,7 +118,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "call"
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "compute_mask"
|
||||
|
Loading…
Reference in New Issue
Block a user