Extend Keras Lambda layers to work with functions of any signature rather than only functions that take one argument.

Any *args and **kwargs passed when calling the lambda layer will be forwarded directly to the underlying lambda.

PiperOrigin-RevId: 311789009
Change-Id: Ic072d2252038330cc944d7f565f14806753d7436
This commit is contained in:
A. Unique TensorFlower 2020-05-15 13:15:33 -07:00 committed by TensorFlower Gardener
parent 2bbf57217f
commit b3bf8bd856
4 changed files with 22 additions and 45 deletions

View File

@ -53,7 +53,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export from tensorflow.python.util.tf_export import keras_export
@ -738,8 +738,7 @@ class Lambda(Layer):
models. `Lambda` layers are best suited for simple operations or models. `Lambda` layers are best suited for simple operations or
quick experimentation. For more advanced use cases, follow quick experimentation. For more advanced use cases, follow
[this guide](https://www.tensorflow.org/guide/keras/custom_layers_and_models) [this guide](https://www.tensorflow.org/guide/keras/custom_layers_and_models)
for subclassing `tf.keras.layers.Layer`. (Do not subclass for subclassing `tf.keras.layers.Layer`.
`tf.keras.layers.Lamba`.)
The main reason to subclass `tf.keras.layers.Layer` instead of using a The main reason to subclass `tf.keras.layers.Layer` instead of using a
`Lambda` layer is saving and inspecting a Model. `Lambda` layers `Lambda` layer is saving and inspecting a Model. `Lambda` layers
@ -799,7 +798,8 @@ class Lambda(Layer):
computation, but anything more complex should use a subclass Layer instead. computation, but anything more complex should use a subclass Layer instead.
Arguments: Arguments:
function: The function to evaluate when the layer is called. function: The function to be evaluated. Takes input tensor as first
argument.
output_shape: Expected output shape from function. This argument can be output_shape: Expected output shape from function. This argument can be
inferred if not explicitly provided. Can be a tuple or function. If a inferred if not explicitly provided. Can be a tuple or function. If a
tuple, it only specifies the first dimension onward; tuple, it only specifies the first dimension onward;
@ -812,8 +812,8 @@ class Lambda(Layer):
mask: Either None (indicating no masking) or a callable with the same mask: Either None (indicating no masking) or a callable with the same
signature as the `compute_mask` layer method, or a tensor that will be signature as the `compute_mask` layer method, or a tensor that will be
returned as output mask regardless of what the input is. returned as output mask regardless of what the input is.
arguments: Optional dictionary of keyword arguments to pass by default to arguments: Optional dictionary of keyword arguments to be passed to the
the function when those arguments are not passed to the layer call. function.
Input shape: Arbitrary. Use the keyword argument input_shape (tuple of Input shape: Arbitrary. Use the keyword argument input_shape (tuple of
integers, does not include the samples axis) when using this layer as the integers, does not include the samples axis) when using this layer as the
first layer in a model. first layer in a model.
@ -823,16 +823,11 @@ class Lambda(Layer):
@trackable.no_automatic_dependency_tracking @trackable.no_automatic_dependency_tracking
def __init__(self, function, output_shape=None, mask=None, arguments=None, def __init__(self, function, output_shape=None, mask=None, arguments=None,
**kwargs): **kwargs):
super(Lambda, self).__init__(**kwargs)
self.arguments = arguments or {} self.arguments = arguments or {}
self.function = function self.function = function
# Decorate the function to produce this layer's call method
def _call_wrapper(*args, **kwargs):
return self._call_wrapper(*args, **kwargs)
self.call = tf_decorator.make_decorator(function, _call_wrapper)
super(Lambda, self).__init__(**kwargs)
if mask is not None: if mask is not None:
self.supports_masking = True self.supports_masking = True
self.mask = mask self.mask = mask
@ -841,8 +836,9 @@ class Lambda(Layer):
# Warning on every invocation will be quite irksome in Eager mode. # Warning on every invocation will be quite irksome in Eager mode.
self._already_warned = False self._already_warned = False
self._expects_training_arg = 'training' in self._call_fn_args function_args = tf_inspect.getfullargspec(function).args
self._expects_mask_arg = 'mask' in self._call_fn_args self._fn_expects_training_arg = 'training' in function_args
self._fn_expects_mask_arg = 'mask' in function_args
@tf_utils.shape_type_conversion @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape): def compute_output_shape(self, input_shape):
@ -873,22 +869,23 @@ class Lambda(Layer):
output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False) output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False)
return nest.map_structure(_add_batch, output_shapes) return nest.map_structure(_add_batch, output_shapes)
def _call_wrapper(self, *args, **kwargs): def call(self, inputs, mask=None, training=None):
# We must copy for thread safety, but it only needs to be a shallow copy. # We must copy for thread safety, but it only needs to be a shallow copy.
call_kwargs = {k: v for k, v in self.arguments.items()} kwargs = {k: v for k, v in self.arguments.items()}
if self._fn_expects_mask_arg:
# override default kwargs with the args passed to the layer call kwargs['mask'] = mask
call_kwargs.update(kwargs) if self._fn_expects_training_arg:
kwargs['training'] = training
created_variables = [] created_variables = []
def _variable_creator(next_creator, **creator_kwargs): def _variable_creator(next_creator, **kwargs):
var = next_creator(**creator_kwargs) var = next_creator(**kwargs)
created_variables.append(var) created_variables.append(var)
return var return var
with backprop.GradientTape(watch_accessed_variables=True) as tape,\ with backprop.GradientTape(watch_accessed_variables=True) as tape,\
variable_scope.variable_creator_scope(_variable_creator): variable_scope.variable_creator_scope(_variable_creator):
result = self.function(*args, **call_kwargs) result = self.function(inputs, **kwargs)
self._check_variables(created_variables, tape.watched_variables()) self._check_variables(created_variables, tape.watched_variables())
return result return result

View File

@ -139,26 +139,6 @@ class LambdaLayerTest(keras_parameterized.TestCase):
out = ld([x1, x2]) out = ld([x1, x2])
self.assertAllEqual(out.shape, [3, 2]) self.assertAllEqual(out.shape, [3, 2])
def test_lambda_multiple_args(self):
ld = keras.layers.Lambda(lambda x, y: x[0] + y)
x1 = np.ones([3, 2], np.float32)
x2 = np.ones([3, 5], np.float32)
expected_result = x1 * 2
self.assertAllEqual(ld([x1, x2], x1), expected_result)
self.assertAllEqual(ld([x1, x2], y=x1), expected_result)
self.assertAllEqual(ld(x=[x1, x2], y=x1), expected_result)
def test_lambda_constructor_args_and_multiple_args(self):
x1 = np.ones([3, 2], np.float32)
x2 = np.ones([3, 5], np.float32)
ld = keras.layers.Lambda(lambda x, y: x[0] + y, arguments={'y': x1*2})
self.assertAllEqual(ld([x1, x2]), x1 * 3)
self.assertAllEqual(ld([x1, x2], y=x1), x1 * 2)
self.assertAllEqual(ld(x=[x1, x2]), x1 * 3)
self.assertAllEqual(ld(x=[x1, x2], y=x1), x1 * 2)
def test_lambda_output_shape(self): def test_lambda_output_shape(self):
l = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1)) l = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
l(keras.backend.variable(np.ones((1, 1)))) l(keras.backend.variable(np.ones((1, 1))))

View File

@ -145,7 +145,7 @@ tf_class {
} }
member_method { member_method {
name: "call" name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
} }
member_method { member_method {
name: "compute_mask" name: "compute_mask"

View File

@ -145,7 +145,7 @@ tf_class {
} }
member_method { member_method {
name: "call" name: "call"
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
} }
member_method { member_method {
name: "compute_mask" name: "compute_mask"