Extend Keras Lambda layers to work with functions of any signature rather than only functions that take one argument.

Any *args and **kwargs passed when calling the lambda layer will be forwarded directly to the underlying lambda.

PiperOrigin-RevId: 311789009
Change-Id: Ic072d2252038330cc944d7f565f14806753d7436
This commit is contained in:
A. Unique TensorFlower 2020-05-15 13:15:33 -07:00 committed by TensorFlower Gardener
parent 2bbf57217f
commit b3bf8bd856
4 changed files with 22 additions and 45 deletions

View File

@ -53,7 +53,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export
@ -738,8 +738,7 @@ class Lambda(Layer):
models. `Lambda` layers are best suited for simple operations or
quick experimentation. For more advanced use cases, follow
[this guide](https://www.tensorflow.org/guide/keras/custom_layers_and_models)
for subclassing `tf.keras.layers.Layer`. (Do not subclass
`tf.keras.layers.Lamba`.)
for subclassing `tf.keras.layers.Layer`.
The main reason to subclass `tf.keras.layers.Layer` instead of using a
`Lambda` layer is saving and inspecting a Model. `Lambda` layers
@ -799,7 +798,8 @@ class Lambda(Layer):
computation, but anything more complex should use a subclass Layer instead.
Arguments:
function: The function to evaluate when the layer is called.
function: The function to be evaluated. Takes input tensor as first
argument.
output_shape: Expected output shape from function. This argument can be
inferred if not explicitly provided. Can be a tuple or function. If a
tuple, it only specifies the first dimension onward;
@ -812,8 +812,8 @@ class Lambda(Layer):
mask: Either None (indicating no masking) or a callable with the same
signature as the `compute_mask` layer method, or a tensor that will be
returned as output mask regardless of what the input is.
arguments: Optional dictionary of keyword arguments to pass by default to
the function when those arguments are not passed to the layer call.
arguments: Optional dictionary of keyword arguments to be passed to the
function.
Input shape: Arbitrary. Use the keyword argument input_shape (tuple of
integers, does not include the samples axis) when using this layer as the
first layer in a model.
@ -823,16 +823,11 @@ class Lambda(Layer):
@trackable.no_automatic_dependency_tracking
def __init__(self, function, output_shape=None, mask=None, arguments=None,
**kwargs):
super(Lambda, self).__init__(**kwargs)
self.arguments = arguments or {}
self.function = function
# Decorate the function to produce this layer's call method
def _call_wrapper(*args, **kwargs):
return self._call_wrapper(*args, **kwargs)
self.call = tf_decorator.make_decorator(function, _call_wrapper)
super(Lambda, self).__init__(**kwargs)
if mask is not None:
self.supports_masking = True
self.mask = mask
@ -841,8 +836,9 @@ class Lambda(Layer):
# Warning on every invocation will be quite irksome in Eager mode.
self._already_warned = False
self._expects_training_arg = 'training' in self._call_fn_args
self._expects_mask_arg = 'mask' in self._call_fn_args
function_args = tf_inspect.getfullargspec(function).args
self._fn_expects_training_arg = 'training' in function_args
self._fn_expects_mask_arg = 'mask' in function_args
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
@ -873,22 +869,23 @@ class Lambda(Layer):
output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False)
return nest.map_structure(_add_batch, output_shapes)
def _call_wrapper(self, *args, **kwargs):
def call(self, inputs, mask=None, training=None):
# We must copy for thread safety, but it only needs to be a shallow copy.
call_kwargs = {k: v for k, v in self.arguments.items()}
# override default kwargs with the args passed to the layer call
call_kwargs.update(kwargs)
kwargs = {k: v for k, v in self.arguments.items()}
if self._fn_expects_mask_arg:
kwargs['mask'] = mask
if self._fn_expects_training_arg:
kwargs['training'] = training
created_variables = []
def _variable_creator(next_creator, **creator_kwargs):
var = next_creator(**creator_kwargs)
def _variable_creator(next_creator, **kwargs):
var = next_creator(**kwargs)
created_variables.append(var)
return var
with backprop.GradientTape(watch_accessed_variables=True) as tape,\
variable_scope.variable_creator_scope(_variable_creator):
result = self.function(*args, **call_kwargs)
result = self.function(inputs, **kwargs)
self._check_variables(created_variables, tape.watched_variables())
return result

View File

@ -139,26 +139,6 @@ class LambdaLayerTest(keras_parameterized.TestCase):
out = ld([x1, x2])
self.assertAllEqual(out.shape, [3, 2])
def test_lambda_multiple_args(self):
ld = keras.layers.Lambda(lambda x, y: x[0] + y)
x1 = np.ones([3, 2], np.float32)
x2 = np.ones([3, 5], np.float32)
expected_result = x1 * 2
self.assertAllEqual(ld([x1, x2], x1), expected_result)
self.assertAllEqual(ld([x1, x2], y=x1), expected_result)
self.assertAllEqual(ld(x=[x1, x2], y=x1), expected_result)
def test_lambda_constructor_args_and_multiple_args(self):
x1 = np.ones([3, 2], np.float32)
x2 = np.ones([3, 5], np.float32)
ld = keras.layers.Lambda(lambda x, y: x[0] + y, arguments={'y': x1*2})
self.assertAllEqual(ld([x1, x2]), x1 * 3)
self.assertAllEqual(ld([x1, x2], y=x1), x1 * 2)
self.assertAllEqual(ld(x=[x1, x2]), x1 * 3)
self.assertAllEqual(ld(x=[x1, x2], y=x1), x1 * 2)
def test_lambda_output_shape(self):
l = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
l(keras.backend.variable(np.ones((1, 1))))

View File

@ -145,7 +145,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "compute_mask"

View File

@ -145,7 +145,7 @@ tf_class {
}
member_method {
name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "compute_mask"