Merge pull request #34979 from robieta/cherrypicks_Q614H

[r2.1 Cherrypick] Remove name-based Variable handling in keras Lambda layers, and add detailed exceptions and warnings for unsafe corner cases.
This commit is contained in:
Goldie Gadde 2019-12-10 09:23:12 -08:00 committed by GitHub
commit ca587c0975
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 171 additions and 61 deletions

View File

@ -20,11 +20,13 @@ from __future__ import print_function
import copy
import sys
import textwrap
import types as python_types
import warnings
import numpy as np
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -47,6 +49,8 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import standard_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import keras_export
@ -690,7 +694,7 @@ class Lambda(Layer):
can be used when constructing `Sequential` and Functional API
models. `Lambda` layers are best suited for simple operations or
quick experimentation. For more advanced usecases, follow
[this guide](https://www.tensorflow.org/alpha/guide/keras/custom_layers_and_models)
[this guide](https://www.tensorflow.org/guide/keras/custom_layers_and_models)
for subclassing `tf.keras.layers.Layer`.
The main reason to subclass `tf.keras.layers.Layer` instead of using a
@ -721,30 +725,34 @@ class Lambda(Layer):
model.add(Lambda(antirectifier))
```
Variables can be created within a `Lambda` layer. Like with
other layers, these variables will be created only once and reused
if the `Lambda` layer is called on new inputs. If creating more
than one variable in a given `Lambda` instance, be sure to use
a different name for each variable. Note that calling sublayers
from within a `Lambda` is not supported.
Variables:
While it is possible to use Variables with Lambda layers, this practice is
discouraged as it can easily lead to bugs. For instance, consider the
following layer:
Example of variable creation:
```python
scale = tf.Variable(1.)
scale_layer = tf.keras.layers.Lambda(lambda x: x * scale)
```
```python
def linear_transform(x):
v1 = tf.Variable(1., name='multiplier')
v2 = tf.Variable(0., name='bias')
return x*v1 + v2
Because scale_layer does not directly track the `scale` variable, it will
not appear in `scale_layer.trainable_weights` and will therefore not be
trained if `scale_layer` is used in a Model.
linear_layer = Lambda(linear_transform)
model.add(linear_layer)
model.add(keras.layers.Dense(10, activation='relu'))
model.add(linear_layer) # Reuses existing Variables
```
A better pattern is to write a subclassed Layer:
Note that creating two instances of `Lambda` using the same function
will *not* share Variables between the two instances. Each instance of
`Lambda` will create and manage its own weights.
```python
class ScaleLayer(tf.keras.layers.Layer):
def __init__(self):
super(ScaleLayer, self).__init__()
self.scale = tf.Variable(1.)
def call(self, inputs):
return inputs * self.scale
```
In general, Lambda layers can be convenient for simple stateless
computation, but anything more complex should use a subclass Layer instead.
Arguments:
function: The function to be evaluated. Takes input tensor as first
@ -769,22 +777,24 @@ class Lambda(Layer):
Output shape: Specified by `output_shape` argument
"""
@trackable.no_automatic_dependency_tracking
def __init__(self, function, output_shape=None, mask=None, arguments=None,
**kwargs):
super(Lambda, self).__init__(**kwargs)
self.arguments = arguments or {}
self.function = function
self.arguments = arguments if arguments else {}
if mask is not None:
self.supports_masking = True
self.mask = mask
self._supports_ragged_inputs = True
self._output_shape = output_shape
self._variable_dict = {}
# These attributes are inherited from `Layer`.
self._trainable_weights = []
self._non_trainable_weights = []
function_args = tf_inspect.getfullargspec(self.function).args
# Warning on every invocation will be quite irksome in Eager mode.
self._already_warned = False
function_args = tf_inspect.getfullargspec(function).args
self._fn_expects_training_arg = 'training' in function_args
self._fn_expects_mask_arg = 'mask' in function_args
@ -818,26 +828,69 @@ class Lambda(Layer):
return nest.map_structure(_add_batch, output_shapes)
def call(self, inputs, mask=None, training=None):
arguments = self.arguments
# We must copy for thread safety, but it only needs to be a shallow copy.
kwargs = {k: v for k, v in self.arguments.items()}
if self._fn_expects_mask_arg:
arguments['mask'] = mask
kwargs['mask'] = mask
if self._fn_expects_training_arg:
arguments['training'] = training
with variable_scope.variable_creator_scope(self._variable_creator):
return self.function(inputs, **arguments)
kwargs['training'] = training
def _variable_creator(self, next_creator, **kwargs):
name = kwargs['name']
if name in self._variable_dict:
return self._variable_dict[name]
var = next_creator(**kwargs)
self._variable_dict[name] = var
if var.trainable:
self._trainable_weights.append(var)
else:
self._non_trainable_weights.append(var)
K.track_variable(var)
return var
created_variables = []
def _variable_creator(next_creator, **kwargs):
var = next_creator(**kwargs)
created_variables.append(var)
return var
with backprop.GradientTape(watch_accessed_variables=True) as tape,\
variable_scope.variable_creator_scope(_variable_creator):
result = self.function(inputs, **kwargs)
self._check_variables(created_variables, tape.watched_variables())
return result
def _check_variables(self, created_variables, accessed_variables):
if not created_variables and not accessed_variables:
# In the common case that a Lambda layer does not touch a Variable, we
# don't want to incur the runtime cost of assembling any state used for
# checking only to immediately discard it.
return
tracked_weights = set(v.experimental_ref() for v in self.weights)
untracked_new_vars = [v for v in created_variables
if v.experimental_ref() not in tracked_weights]
if untracked_new_vars:
variable_str = '\n'.join([' {}'.format(i) for i in untracked_new_vars])
error_str = textwrap.dedent(
'''
The following Variables were created within a Lambda layer ({name})
but are not tracked by said layer:
{variable_str}
The layer cannot safely ensure proper Variable reuse across multiple
calls, and consquently this behavior is disallowed for safety. Lambda
layers are not well suited to stateful computation; instead, writing a
subclassed Layer is the recommend way to define layers with
Variables.'''
).format(name=self.name, variable_str=variable_str)
raise ValueError(error_str)
untracked_used_vars = [v for v in accessed_variables
if v.experimental_ref() not in tracked_weights]
if untracked_used_vars and not self._already_warned:
variable_str = '\n'.join([' {}'.format(i) for i in untracked_used_vars])
self._warn(textwrap.dedent(
'''
The following Variables were used a Lambda layer's call ({name}), but
are not present in its tracked objects:
{variable_str}
It is possible that this is intended behavior, but it is more likely
an omission. This is a strong indication that this layer should be
formulated as a subclassed Layer rather than a Lambda layer.'''
).format(name=self.name, variable_str=variable_str))
self._already_warned = True
def _warn(self, msg):
# This method will be overridden in a unit test to raise an error, because
# self.assertWarns is not universally implemented.
return tf_logging.warn(msg)
def compute_mask(self, inputs, mask=None):
if callable(self.mask):

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import textwrap
import numpy as np
from tensorflow.python import keras
@ -225,17 +227,6 @@ class LambdaLayerTest(keras_parameterized.TestCase):
self.assertAllEqual(layer._output_shape, (1, 1))
self.assertAllEqual(layer.mask(1, True), True)
def test_lambda_with_variable(self):
def fn(x):
return x * variables.Variable(2., name='multiplier')
layer = keras.layers.Lambda(fn)
for _ in range(10):
layer(np.ones((10, 10), 'float32'))
self.assertLen(layer.trainable_weights, 1)
self.assertEqual(layer.trainable_weights[0].name, 'lambda/multiplier:0')
def test_lambda_with_training_arg(self):
def fn(x, training=True):
@ -283,19 +274,25 @@ class LambdaLayerTest(keras_parameterized.TestCase):
expected_out = ragged_factory_ops.constant([[2.0], [3.0, 4.0]])
self.assertAllClose(out, expected_out)
class TestStatefulLambda(keras_parameterized.TestCase):
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_with_all_model_types
def test_lambda_with_variable_in_model(self):
def lambda_fn(x):
# Variable will only get created once.
v = variables.Variable(1., trainable=True)
v = variables.Variable(1., trainable=True)
def lambda_fn(x, v):
return x * v
model = testing_utils.get_model_from_layers(
[keras.layers.Lambda(lambda_fn)], input_shape=(10,))
# While it is generally not advised to mix Variables with Lambda layers, if
# the variables are explicitly set as attributes then they are still
# tracked. This is consistent with the base Layer behavior.
layer = keras.layers.Lambda(lambda_fn, arguments={'v': v})
self.assertLen(layer.trainable_weights, 0)
layer.v = v
self.assertLen(layer.trainable_weights, 1)
model = testing_utils.get_model_from_layers([layer], input_shape=(10,))
model.compile(
keras.optimizer_v2.gradient_descent.SGD(0.1),
'mae',
@ -306,6 +303,66 @@ class TestStatefulLambda(keras_parameterized.TestCase):
self.assertLen(model.trainable_weights, 1)
self.assertAllClose(keras.backend.get_value(model.trainable_weights[0]), 2.)
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_with_all_model_types
def test_creation_inside_lambda(self):
def lambda_fn(x):
scale = variables.Variable(1., trainable=True, name='scale')
shift = variables.Variable(1., trainable=True, name='shift')
return x * scale + shift
expected_error = textwrap.dedent(r'''
( )?The following Variables were created within a Lambda layer \(shift_and_scale\)
( )?but are not tracked by said layer:
( )? <tf.Variable \'.*shift_and_scale/scale:0\'.+
( )? <tf.Variable \'.*shift_and_scale/shift:0\'.+
( )?The layer cannot safely ensure proper Variable reuse.+''')
with self.assertRaisesRegexp(ValueError, expected_error):
layer = keras.layers.Lambda(lambda_fn, name='shift_and_scale')
model = testing_utils.get_model_from_layers([layer], input_shape=(1,))
model(array_ops.ones((4, 1)))
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_with_all_model_types
def test_transitive_variable_creation(self):
dense = keras.layers.Dense(1, use_bias=False, kernel_initializer='ones')
def bad_lambda_fn(x):
return dense(x + 1) # Dense layer is built on first call
expected_error = textwrap.dedent(r'''
( )?The following Variables were created within a Lambda layer \(bias_dense\)
( )?but are not tracked by said layer:
( )? <tf.Variable \'.*bias_dense/dense/kernel:0\'.+
( )?The layer cannot safely ensure proper Variable reuse.+''')
with self.assertRaisesRegexp(ValueError, expected_error):
layer = keras.layers.Lambda(bad_lambda_fn, name='bias_dense')
model = testing_utils.get_model_from_layers([layer], input_shape=(1,))
model(array_ops.ones((4, 1)))
@keras_parameterized.run_all_keras_modes
@keras_parameterized.run_with_all_model_types
def test_warns_on_variable_capture(self):
v = variables.Variable(1., trainable=True)
def lambda_fn(x):
return x * v
expected_warning = textwrap.dedent(r'''
( )?The following Variables were used a Lambda layer\'s call \(lambda\), but
( )?are not present in its tracked objects:
( )? <tf.Variable \'.*Variable:0\'.+
( )?It is possible that this is intended behavior.+''')
layer = keras.layers.Lambda(lambda_fn)
def patched_warn(msg):
raise ValueError(msg)
layer._warn = patched_warn
with self.assertRaisesRegexp(ValueError, expected_warning):
model = testing_utils.get_model_from_layers([layer], input_shape=(1,))
model(array_ops.ones((4, 1)))
@keras_parameterized.run_all_keras_modes
class CoreLayersTest(keras_parameterized.TestCase):