Simplify Layer.add_udpate in v2 and update version_selector to use v1 inside a

tf.compat.v1.wrap_function.

No longer track unused Layer.updates in v2.

PiperOrigin-RevId: 316921838
Change-Id: I4698a0c925528594f402f824705d66b8a1ae7b72
This commit is contained in:
Thomas O'Malley 2020-06-17 10:56:38 -07:00 committed by TensorFlower Gardener
parent 4747be646b
commit 2ba59dab2c
5 changed files with 50 additions and 91 deletions

View File

@ -1733,54 +1733,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
inputs: Deprecated, will be automatically inferred. inputs: Deprecated, will be automatically inferred.
""" """
call_context = base_layer_utils.call_context() call_context = base_layer_utils.call_context()
# No need to run updates during Functional API construction.
if (ds_context.has_strategy() and if call_context.in_keras_graph:
ds_context.in_cross_replica_context() and
# When saving the model, the distribution strategy context should be
# ignored, following the default path for adding updates.
not call_context.saving):
# Updates don't need to be run in a cross-replica context.
return return
updates = generic_utils.to_list(updates) # Callable updates are disabled by setting `trainable=False`.
if not call_context.frozen:
# All updates can be run immediately in Eager or in a tf.function. for update in nest.flatten(updates):
if base_layer_utils.is_in_eager_or_tf_function(): if callable(update):
if not call_context.frozen: update()
for update in updates:
if callable(update):
update()
return
def process_update(x):
"""Standardize update ops.
Arguments:
x: Tensor, op, or callable.
Returns:
An update op.
"""
if callable(x):
update = lambda: process_update(x())
if not ops.executing_eagerly_outside_functions():
# In V1 mode, call the callable right away and process. This is needed
# for TPU strategy.
return update()
elif isinstance(x, ops.Operation):
update = x
elif hasattr(x, 'op'):
update = x.op
else:
update = ops.convert_to_tensor_v2(x)
return update
updates = [process_update(x) for x in updates]
# Non-callable Updates are run automatically inside `call` in V2, so
# they do not need to be tracked later.
if ops.executing_eagerly_outside_functions() and call_context.in_call:
updates = [u for u in updates if callable(u)]
self._updates.extend(updates)
def set_weights(self, weights): def set_weights(self, weights):
"""Sets the weights of the layer, from Numpy arrays. """Sets the weights of the layer, from Numpy arrays.

View File

@ -231,33 +231,28 @@ class TestSequential(keras_parameterized.TestCase):
inner_model.trainable = True inner_model.trainable = True
self.assertEqual(len(model.trainable_weights), 4) self.assertEqual(len(model.trainable_weights), 4)
@keras_parameterized.run_all_keras_modes
def test_sequential_update_disabling(self): def test_sequential_update_disabling(self):
val_a = np.random.random((10, 4)) val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4)) val_out = np.random.random((10, 4))
with self.cached_session(): model = keras.models.Sequential()
model = keras.models.Sequential() model.add(keras.layers.BatchNormalization(input_shape=(4,)))
model.add(keras.layers.BatchNormalization(input_shape=(4,)))
assert model.updates
model.trainable = False model.trainable = False
assert not model.updates model.compile('sgd', 'mse')
model.compile('sgd', 'mse') x1 = model.predict(val_a)
assert not model.updates model.train_on_batch(val_a, val_out)
x2 = model.predict(val_a)
self.assertAllClose(x1, x2, atol=1e-7)
x1 = model.predict(val_a) model.trainable = True
model.train_on_batch(val_a, val_out) model.compile('sgd', 'mse')
x2 = model.predict(val_a)
self.assertAllClose(x1, x2, atol=1e-7)
model.trainable = True model.train_on_batch(val_a, val_out)
model.compile('sgd', 'mse') x2 = model.predict(val_a)
assert model.updates assert np.abs(np.sum(x1 - x2)) > 1e-5
model.train_on_batch(val_a, val_out)
x2 = model.predict(val_a)
assert np.abs(np.sum(x1 - x2)) > 1e-5
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
def test_sequential_deferred_build_serialization(self): def test_sequential_deferred_build_serialization(self):

View File

@ -325,18 +325,18 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
norm(inp) norm(inp)
def test_updates_in_wrap_function(self): def test_updates_in_wrap_function(self):
layer = normalization.BatchNormalization()
def my_func(): def my_func():
layer = normalization.BatchNormalization()
x = array_ops.ones((10, 1)) x = array_ops.ones((10, 1))
return layer(x, training=True) y = layer(x, training=True)
# Updates should be tracked in a `wrap_function`.
self.assertLen(layer.updates, 2)
return y
wrapped_fn = wrap_function.wrap_function(my_func, []) wrapped_fn = wrap_function.wrap_function(my_func, [])
wrapped_fn() wrapped_fn()
# Updates should be tracked in a `wrap_function`.
self.assertLen(layer.updates, 2)
@keras_parameterized.run_all_keras_modes @keras_parameterized.run_all_keras_modes
def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self): def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self):
# Test case for GitHub issue for 32380 # Test case for GitHub issue for 32380
@ -392,15 +392,11 @@ class NormalizationLayersGraphModeOnlyTest(
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
model.train_on_batch(x, x) model.train_on_batch(x, x)
self.assertLen(bn.updates, 4)
# Test model-level reuse # Test model-level reuse
x3 = keras.layers.Input(shape=(10,)) x3 = keras.layers.Input(shape=(10,))
y3 = model(x3) y3 = model(x3)
new_model = keras.models.Model(x3, y3, name='new_model') new_model = keras.models.Model(x3, y3, name='new_model')
self.assertLen(new_model.updates, 6)
self.assertLen(model.updates, 6)
new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
new_model.train_on_batch(x, x) new_model.train_on_batch(x, x)
@ -415,10 +411,7 @@ class NormalizationLayersGraphModeOnlyTest(
model = keras.models.Model(a, b) model = keras.models.Model(a, b)
model.trainable = False model.trainable = False
assert not model.updates
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
assert not model.updates
x1 = model.predict(val_a) x1 = model.predict(val_a)
model.train_on_batch(val_a, val_out) model.train_on_batch(val_a, val_out)
@ -427,7 +420,6 @@ class NormalizationLayersGraphModeOnlyTest(
model.trainable = True model.trainable = True
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
assert model.updates
model.train_on_batch(val_a, val_out) model.train_on_batch(val_a, val_out)
x2 = model.predict(val_a) x2 = model.predict(val_a)
@ -435,7 +427,6 @@ class NormalizationLayersGraphModeOnlyTest(
layer.trainable = False layer.trainable = False
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
assert not model.updates
x1 = model.predict(val_a) x1 = model.predict(val_a)
model.train_on_batch(val_a, val_out) model.train_on_batch(val_a, val_out)

View File

@ -234,13 +234,10 @@ class TimeDistributedTest(keras_parameterized.TestCase):
x = keras.layers.Input(shape=(3, 2)) x = keras.layers.Input(shape=(3, 2))
layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization()) layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization())
_ = layer(x) _ = layer(x)
self.assertEqual(len(layer.updates), 2)
self.assertEqual(len(layer.trainable_weights), 2) self.assertEqual(len(layer.trainable_weights), 2)
layer.trainable = False layer.trainable = False
assert not layer.updates
assert not layer.trainable_weights assert not layer.trainable_weights
layer.trainable = True layer.trainable = True
assert len(layer.updates) == 2
assert len(layer.trainable_weights) == 2 assert len(layer.trainable_weights) == 2
def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self): def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self):

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.util import lazy_loader from tensorflow.python.util import lazy_loader
@ -51,8 +52,8 @@ class ModelVersionSelector(object):
"""Chooses between Keras v1 and v2 Model class.""" """Chooses between Keras v1 and v2 Model class."""
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
eager_enabled = ops.executing_eagerly_outside_functions() use_v2 = should_use_v2()
cls = swap_class(cls, training.Model, training_v1.Model, eager_enabled) cls = swap_class(cls, training.Model, training_v1.Model, use_v2) # pylint: disable=self-cls-assignment
return super(ModelVersionSelector, cls).__new__(cls) return super(ModelVersionSelector, cls).__new__(cls)
@ -60,8 +61,8 @@ class LayerVersionSelector(object):
"""Chooses between Keras v1 and v2 Layer class.""" """Chooses between Keras v1 and v2 Layer class."""
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
eager_enabled = ops.executing_eagerly_outside_functions() use_v2 = should_use_v2()
cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, eager_enabled) cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, use_v2) # pylint: disable=self-cls-assignment
return super(LayerVersionSelector, cls).__new__(cls) return super(LayerVersionSelector, cls).__new__(cls)
@ -69,10 +70,10 @@ class TensorBoardVersionSelector(object):
"""Chooses between Keras v1 and v2 TensorBoard callback class.""" """Chooses between Keras v1 and v2 TensorBoard callback class."""
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
eager_enabled = ops.executing_eagerly_outside_functions() use_v2 = should_use_v2()
start_cls = cls start_cls = cls
cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard,
eager_enabled) use_v2)
if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard: if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard:
# Since the v2 class is not a subclass of the v1 class, __init__ has to # Since the v2 class is not a subclass of the v1 class, __init__ has to
# be called manually. # be called manually.
@ -80,19 +81,33 @@ class TensorBoardVersionSelector(object):
return super(TensorBoardVersionSelector, cls).__new__(cls) return super(TensorBoardVersionSelector, cls).__new__(cls)
def swap_class(cls, v2_cls, v1_cls, eager_enabled): def should_use_v2():
"""Determine if v1 or v2 version should be used."""
if context.executing_eagerly():
return True
elif ops.executing_eagerly_outside_functions():
# Check for a v1 `wrap_function` FuncGraph.
# Code inside a `wrap_function` is treated like v1 code.
graph = ops.get_default_graph()
if (getattr(graph, "name", False) and
graph.name.startswith("wrapped_function")):
return False
return True
def swap_class(cls, v2_cls, v1_cls, use_v2):
"""Swaps in v2_cls or v1_cls depending on graph mode.""" """Swaps in v2_cls or v1_cls depending on graph mode."""
if cls == object: if cls == object:
return cls return cls
if cls in (v2_cls, v1_cls): if cls in (v2_cls, v1_cls):
if eager_enabled: if use_v2:
return v2_cls return v2_cls
return v1_cls return v1_cls
# Recursively search superclasses to swap in the right Keras class. # Recursively search superclasses to swap in the right Keras class.
cls.__bases__ = tuple( cls.__bases__ = tuple(
swap_class(base, v2_cls, v1_cls, eager_enabled) for base in cls.__bases__) swap_class(base, v2_cls, v1_cls, use_v2) for base in cls.__bases__)
return cls return cls