diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index d1528c7ba59..db9c47eca17 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -53,7 +53,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import nest -from tensorflow.python.util import tf_decorator +from tensorflow.python.util import tf_inspect from tensorflow.python.util.tf_export import keras_export @@ -738,8 +738,7 @@ class Lambda(Layer): models. `Lambda` layers are best suited for simple operations or quick experimentation. For more advanced use cases, follow [this guide](https://www.tensorflow.org/guide/keras/custom_layers_and_models) - for subclassing `tf.keras.layers.Layer`. (Do not subclass - `tf.keras.layers.Lamba`.) + for subclassing `tf.keras.layers.Layer`. The main reason to subclass `tf.keras.layers.Layer` instead of using a `Lambda` layer is saving and inspecting a Model. `Lambda` layers @@ -799,7 +798,8 @@ class Lambda(Layer): computation, but anything more complex should use a subclass Layer instead. Arguments: - function: The function to evaluate when the layer is called. + function: The function to be evaluated. Takes input tensor as first + argument. output_shape: Expected output shape from function. This argument can be inferred if not explicitly provided. Can be a tuple or function. If a tuple, it only specifies the first dimension onward; @@ -812,8 +812,8 @@ class Lambda(Layer): mask: Either None (indicating no masking) or a callable with the same signature as the `compute_mask` layer method, or a tensor that will be returned as output mask regardless of what the input is. - arguments: Optional dictionary of keyword arguments to pass by default to - the function when those arguments are not passed to the layer call. + arguments: Optional dictionary of keyword arguments to be passed to the + function. Input shape: Arbitrary. Use the keyword argument input_shape (tuple of integers, does not include the samples axis) when using this layer as the first layer in a model. @@ -823,16 +823,11 @@ class Lambda(Layer): @trackable.no_automatic_dependency_tracking def __init__(self, function, output_shape=None, mask=None, arguments=None, **kwargs): + super(Lambda, self).__init__(**kwargs) + self.arguments = arguments or {} self.function = function - # Decorate the function to produce this layer's call method - def _call_wrapper(*args, **kwargs): - return self._call_wrapper(*args, **kwargs) - self.call = tf_decorator.make_decorator(function, _call_wrapper) - - super(Lambda, self).__init__(**kwargs) - if mask is not None: self.supports_masking = True self.mask = mask @@ -841,8 +836,9 @@ class Lambda(Layer): # Warning on every invocation will be quite irksome in Eager mode. self._already_warned = False - self._expects_training_arg = 'training' in self._call_fn_args - self._expects_mask_arg = 'mask' in self._call_fn_args + function_args = tf_inspect.getfullargspec(function).args + self._fn_expects_training_arg = 'training' in function_args + self._fn_expects_mask_arg = 'mask' in function_args @tf_utils.shape_type_conversion def compute_output_shape(self, input_shape): @@ -873,22 +869,23 @@ class Lambda(Layer): output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False) return nest.map_structure(_add_batch, output_shapes) - def _call_wrapper(self, *args, **kwargs): + def call(self, inputs, mask=None, training=None): # We must copy for thread safety, but it only needs to be a shallow copy. - call_kwargs = {k: v for k, v in self.arguments.items()} - - # override default kwargs with the args passed to the layer call - call_kwargs.update(kwargs) + kwargs = {k: v for k, v in self.arguments.items()} + if self._fn_expects_mask_arg: + kwargs['mask'] = mask + if self._fn_expects_training_arg: + kwargs['training'] = training created_variables = [] - def _variable_creator(next_creator, **creator_kwargs): - var = next_creator(**creator_kwargs) + def _variable_creator(next_creator, **kwargs): + var = next_creator(**kwargs) created_variables.append(var) return var with backprop.GradientTape(watch_accessed_variables=True) as tape,\ variable_scope.variable_creator_scope(_variable_creator): - result = self.function(*args, **call_kwargs) + result = self.function(inputs, **kwargs) self._check_variables(created_variables, tape.watched_variables()) return result diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py index aa1192e12fc..3daa187f1ce 100644 --- a/tensorflow/python/keras/layers/core_test.py +++ b/tensorflow/python/keras/layers/core_test.py @@ -139,26 +139,6 @@ class LambdaLayerTest(keras_parameterized.TestCase): out = ld([x1, x2]) self.assertAllEqual(out.shape, [3, 2]) - def test_lambda_multiple_args(self): - ld = keras.layers.Lambda(lambda x, y: x[0] + y) - x1 = np.ones([3, 2], np.float32) - x2 = np.ones([3, 5], np.float32) - - expected_result = x1 * 2 - self.assertAllEqual(ld([x1, x2], x1), expected_result) - self.assertAllEqual(ld([x1, x2], y=x1), expected_result) - self.assertAllEqual(ld(x=[x1, x2], y=x1), expected_result) - - def test_lambda_constructor_args_and_multiple_args(self): - x1 = np.ones([3, 2], np.float32) - x2 = np.ones([3, 5], np.float32) - ld = keras.layers.Lambda(lambda x, y: x[0] + y, arguments={'y': x1*2}) - - self.assertAllEqual(ld([x1, x2]), x1 * 3) - self.assertAllEqual(ld([x1, x2], y=x1), x1 * 2) - self.assertAllEqual(ld(x=[x1, x2]), x1 * 3) - self.assertAllEqual(ld(x=[x1, x2], y=x1), x1 * 2) - def test_lambda_output_shape(self): l = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1)) l(keras.backend.variable(np.ones((1, 1)))) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt index d4dbe96d1ba..22fa730112f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt @@ -145,7 +145,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt index d4dbe96d1ba..22fa730112f 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt @@ -145,7 +145,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" + argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "compute_mask"