Extend Keras Lambda layers to work with functions of any signature rather than only functions that take one argument.
Any *args and **kwargs passed when calling the lambda layer will be forwarded directly to the underlying lambda. PiperOrigin-RevId: 311789009 Change-Id: Ic072d2252038330cc944d7f565f14806753d7436
This commit is contained in:
parent
2bbf57217f
commit
b3bf8bd856
@ -53,7 +53,7 @@ from tensorflow.python.ops import variable_scope
|
|||||||
from tensorflow.python.platform import tf_logging
|
from tensorflow.python.platform import tf_logging
|
||||||
from tensorflow.python.training.tracking import base as trackable
|
from tensorflow.python.training.tracking import base as trackable
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
from tensorflow.python.util import tf_decorator
|
from tensorflow.python.util import tf_inspect
|
||||||
from tensorflow.python.util.tf_export import keras_export
|
from tensorflow.python.util.tf_export import keras_export
|
||||||
|
|
||||||
|
|
||||||
@ -738,8 +738,7 @@ class Lambda(Layer):
|
|||||||
models. `Lambda` layers are best suited for simple operations or
|
models. `Lambda` layers are best suited for simple operations or
|
||||||
quick experimentation. For more advanced use cases, follow
|
quick experimentation. For more advanced use cases, follow
|
||||||
[this guide](https://www.tensorflow.org/guide/keras/custom_layers_and_models)
|
[this guide](https://www.tensorflow.org/guide/keras/custom_layers_and_models)
|
||||||
for subclassing `tf.keras.layers.Layer`. (Do not subclass
|
for subclassing `tf.keras.layers.Layer`.
|
||||||
`tf.keras.layers.Lamba`.)
|
|
||||||
|
|
||||||
The main reason to subclass `tf.keras.layers.Layer` instead of using a
|
The main reason to subclass `tf.keras.layers.Layer` instead of using a
|
||||||
`Lambda` layer is saving and inspecting a Model. `Lambda` layers
|
`Lambda` layer is saving and inspecting a Model. `Lambda` layers
|
||||||
@ -799,7 +798,8 @@ class Lambda(Layer):
|
|||||||
computation, but anything more complex should use a subclass Layer instead.
|
computation, but anything more complex should use a subclass Layer instead.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
function: The function to evaluate when the layer is called.
|
function: The function to be evaluated. Takes input tensor as first
|
||||||
|
argument.
|
||||||
output_shape: Expected output shape from function. This argument can be
|
output_shape: Expected output shape from function. This argument can be
|
||||||
inferred if not explicitly provided. Can be a tuple or function. If a
|
inferred if not explicitly provided. Can be a tuple or function. If a
|
||||||
tuple, it only specifies the first dimension onward;
|
tuple, it only specifies the first dimension onward;
|
||||||
@ -812,8 +812,8 @@ class Lambda(Layer):
|
|||||||
mask: Either None (indicating no masking) or a callable with the same
|
mask: Either None (indicating no masking) or a callable with the same
|
||||||
signature as the `compute_mask` layer method, or a tensor that will be
|
signature as the `compute_mask` layer method, or a tensor that will be
|
||||||
returned as output mask regardless of what the input is.
|
returned as output mask regardless of what the input is.
|
||||||
arguments: Optional dictionary of keyword arguments to pass by default to
|
arguments: Optional dictionary of keyword arguments to be passed to the
|
||||||
the function when those arguments are not passed to the layer call.
|
function.
|
||||||
Input shape: Arbitrary. Use the keyword argument input_shape (tuple of
|
Input shape: Arbitrary. Use the keyword argument input_shape (tuple of
|
||||||
integers, does not include the samples axis) when using this layer as the
|
integers, does not include the samples axis) when using this layer as the
|
||||||
first layer in a model.
|
first layer in a model.
|
||||||
@ -823,16 +823,11 @@ class Lambda(Layer):
|
|||||||
@trackable.no_automatic_dependency_tracking
|
@trackable.no_automatic_dependency_tracking
|
||||||
def __init__(self, function, output_shape=None, mask=None, arguments=None,
|
def __init__(self, function, output_shape=None, mask=None, arguments=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
super(Lambda, self).__init__(**kwargs)
|
||||||
|
|
||||||
self.arguments = arguments or {}
|
self.arguments = arguments or {}
|
||||||
self.function = function
|
self.function = function
|
||||||
|
|
||||||
# Decorate the function to produce this layer's call method
|
|
||||||
def _call_wrapper(*args, **kwargs):
|
|
||||||
return self._call_wrapper(*args, **kwargs)
|
|
||||||
self.call = tf_decorator.make_decorator(function, _call_wrapper)
|
|
||||||
|
|
||||||
super(Lambda, self).__init__(**kwargs)
|
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
self.supports_masking = True
|
self.supports_masking = True
|
||||||
self.mask = mask
|
self.mask = mask
|
||||||
@ -841,8 +836,9 @@ class Lambda(Layer):
|
|||||||
# Warning on every invocation will be quite irksome in Eager mode.
|
# Warning on every invocation will be quite irksome in Eager mode.
|
||||||
self._already_warned = False
|
self._already_warned = False
|
||||||
|
|
||||||
self._expects_training_arg = 'training' in self._call_fn_args
|
function_args = tf_inspect.getfullargspec(function).args
|
||||||
self._expects_mask_arg = 'mask' in self._call_fn_args
|
self._fn_expects_training_arg = 'training' in function_args
|
||||||
|
self._fn_expects_mask_arg = 'mask' in function_args
|
||||||
|
|
||||||
@tf_utils.shape_type_conversion
|
@tf_utils.shape_type_conversion
|
||||||
def compute_output_shape(self, input_shape):
|
def compute_output_shape(self, input_shape):
|
||||||
@ -873,22 +869,23 @@ class Lambda(Layer):
|
|||||||
output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False)
|
output_shapes = tf_utils.convert_shapes(self._output_shape, to_tuples=False)
|
||||||
return nest.map_structure(_add_batch, output_shapes)
|
return nest.map_structure(_add_batch, output_shapes)
|
||||||
|
|
||||||
def _call_wrapper(self, *args, **kwargs):
|
def call(self, inputs, mask=None, training=None):
|
||||||
# We must copy for thread safety, but it only needs to be a shallow copy.
|
# We must copy for thread safety, but it only needs to be a shallow copy.
|
||||||
call_kwargs = {k: v for k, v in self.arguments.items()}
|
kwargs = {k: v for k, v in self.arguments.items()}
|
||||||
|
if self._fn_expects_mask_arg:
|
||||||
# override default kwargs with the args passed to the layer call
|
kwargs['mask'] = mask
|
||||||
call_kwargs.update(kwargs)
|
if self._fn_expects_training_arg:
|
||||||
|
kwargs['training'] = training
|
||||||
|
|
||||||
created_variables = []
|
created_variables = []
|
||||||
def _variable_creator(next_creator, **creator_kwargs):
|
def _variable_creator(next_creator, **kwargs):
|
||||||
var = next_creator(**creator_kwargs)
|
var = next_creator(**kwargs)
|
||||||
created_variables.append(var)
|
created_variables.append(var)
|
||||||
return var
|
return var
|
||||||
|
|
||||||
with backprop.GradientTape(watch_accessed_variables=True) as tape,\
|
with backprop.GradientTape(watch_accessed_variables=True) as tape,\
|
||||||
variable_scope.variable_creator_scope(_variable_creator):
|
variable_scope.variable_creator_scope(_variable_creator):
|
||||||
result = self.function(*args, **call_kwargs)
|
result = self.function(inputs, **kwargs)
|
||||||
self._check_variables(created_variables, tape.watched_variables())
|
self._check_variables(created_variables, tape.watched_variables())
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
@ -139,26 +139,6 @@ class LambdaLayerTest(keras_parameterized.TestCase):
|
|||||||
out = ld([x1, x2])
|
out = ld([x1, x2])
|
||||||
self.assertAllEqual(out.shape, [3, 2])
|
self.assertAllEqual(out.shape, [3, 2])
|
||||||
|
|
||||||
def test_lambda_multiple_args(self):
|
|
||||||
ld = keras.layers.Lambda(lambda x, y: x[0] + y)
|
|
||||||
x1 = np.ones([3, 2], np.float32)
|
|
||||||
x2 = np.ones([3, 5], np.float32)
|
|
||||||
|
|
||||||
expected_result = x1 * 2
|
|
||||||
self.assertAllEqual(ld([x1, x2], x1), expected_result)
|
|
||||||
self.assertAllEqual(ld([x1, x2], y=x1), expected_result)
|
|
||||||
self.assertAllEqual(ld(x=[x1, x2], y=x1), expected_result)
|
|
||||||
|
|
||||||
def test_lambda_constructor_args_and_multiple_args(self):
|
|
||||||
x1 = np.ones([3, 2], np.float32)
|
|
||||||
x2 = np.ones([3, 5], np.float32)
|
|
||||||
ld = keras.layers.Lambda(lambda x, y: x[0] + y, arguments={'y': x1*2})
|
|
||||||
|
|
||||||
self.assertAllEqual(ld([x1, x2]), x1 * 3)
|
|
||||||
self.assertAllEqual(ld([x1, x2], y=x1), x1 * 2)
|
|
||||||
self.assertAllEqual(ld(x=[x1, x2]), x1 * 3)
|
|
||||||
self.assertAllEqual(ld(x=[x1, x2], y=x1), x1 * 2)
|
|
||||||
|
|
||||||
def test_lambda_output_shape(self):
|
def test_lambda_output_shape(self):
|
||||||
l = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
|
l = keras.layers.Lambda(lambda x: x + 1, output_shape=(1, 1))
|
||||||
l(keras.backend.variable(np.ones((1, 1))))
|
l(keras.backend.variable(np.ones((1, 1))))
|
||||||
|
@ -145,7 +145,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "call"
|
name: "call"
|
||||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "compute_mask"
|
name: "compute_mask"
|
||||||
|
@ -145,7 +145,7 @@ tf_class {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "call"
|
name: "call"
|
||||||
argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
|
argspec: "args=[\'self\', \'inputs\', \'mask\', \'training\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "compute_mask"
|
name: "compute_mask"
|
||||||
|
Loading…
Reference in New Issue
Block a user