Simplify Layer.add_udpate in v2 and update version_selector to use v1 inside a
tf.compat.v1.wrap_function. No longer track unused Layer.updates in v2. PiperOrigin-RevId: 316921838 Change-Id: I4698a0c925528594f402f824705d66b8a1ae7b72
This commit is contained in:
parent
4747be646b
commit
2ba59dab2c
|
@ -1733,54 +1733,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||||
inputs: Deprecated, will be automatically inferred.
|
inputs: Deprecated, will be automatically inferred.
|
||||||
"""
|
"""
|
||||||
call_context = base_layer_utils.call_context()
|
call_context = base_layer_utils.call_context()
|
||||||
|
# No need to run updates during Functional API construction.
|
||||||
if (ds_context.has_strategy() and
|
if call_context.in_keras_graph:
|
||||||
ds_context.in_cross_replica_context() and
|
|
||||||
# When saving the model, the distribution strategy context should be
|
|
||||||
# ignored, following the default path for adding updates.
|
|
||||||
not call_context.saving):
|
|
||||||
# Updates don't need to be run in a cross-replica context.
|
|
||||||
return
|
return
|
||||||
|
|
||||||
updates = generic_utils.to_list(updates)
|
# Callable updates are disabled by setting `trainable=False`.
|
||||||
|
if not call_context.frozen:
|
||||||
# All updates can be run immediately in Eager or in a tf.function.
|
for update in nest.flatten(updates):
|
||||||
if base_layer_utils.is_in_eager_or_tf_function():
|
if callable(update):
|
||||||
if not call_context.frozen:
|
update()
|
||||||
for update in updates:
|
|
||||||
if callable(update):
|
|
||||||
update()
|
|
||||||
return
|
|
||||||
|
|
||||||
def process_update(x):
|
|
||||||
"""Standardize update ops.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
x: Tensor, op, or callable.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An update op.
|
|
||||||
"""
|
|
||||||
if callable(x):
|
|
||||||
update = lambda: process_update(x())
|
|
||||||
if not ops.executing_eagerly_outside_functions():
|
|
||||||
# In V1 mode, call the callable right away and process. This is needed
|
|
||||||
# for TPU strategy.
|
|
||||||
return update()
|
|
||||||
elif isinstance(x, ops.Operation):
|
|
||||||
update = x
|
|
||||||
elif hasattr(x, 'op'):
|
|
||||||
update = x.op
|
|
||||||
else:
|
|
||||||
update = ops.convert_to_tensor_v2(x)
|
|
||||||
return update
|
|
||||||
|
|
||||||
updates = [process_update(x) for x in updates]
|
|
||||||
# Non-callable Updates are run automatically inside `call` in V2, so
|
|
||||||
# they do not need to be tracked later.
|
|
||||||
if ops.executing_eagerly_outside_functions() and call_context.in_call:
|
|
||||||
updates = [u for u in updates if callable(u)]
|
|
||||||
self._updates.extend(updates)
|
|
||||||
|
|
||||||
def set_weights(self, weights):
|
def set_weights(self, weights):
|
||||||
"""Sets the weights of the layer, from Numpy arrays.
|
"""Sets the weights of the layer, from Numpy arrays.
|
||||||
|
|
|
@ -231,33 +231,28 @@ class TestSequential(keras_parameterized.TestCase):
|
||||||
inner_model.trainable = True
|
inner_model.trainable = True
|
||||||
self.assertEqual(len(model.trainable_weights), 4)
|
self.assertEqual(len(model.trainable_weights), 4)
|
||||||
|
|
||||||
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_sequential_update_disabling(self):
|
def test_sequential_update_disabling(self):
|
||||||
val_a = np.random.random((10, 4))
|
val_a = np.random.random((10, 4))
|
||||||
val_out = np.random.random((10, 4))
|
val_out = np.random.random((10, 4))
|
||||||
|
|
||||||
with self.cached_session():
|
model = keras.models.Sequential()
|
||||||
model = keras.models.Sequential()
|
model.add(keras.layers.BatchNormalization(input_shape=(4,)))
|
||||||
model.add(keras.layers.BatchNormalization(input_shape=(4,)))
|
|
||||||
assert model.updates
|
|
||||||
|
|
||||||
model.trainable = False
|
model.trainable = False
|
||||||
assert not model.updates
|
model.compile('sgd', 'mse')
|
||||||
|
|
||||||
model.compile('sgd', 'mse')
|
x1 = model.predict(val_a)
|
||||||
assert not model.updates
|
model.train_on_batch(val_a, val_out)
|
||||||
|
x2 = model.predict(val_a)
|
||||||
|
self.assertAllClose(x1, x2, atol=1e-7)
|
||||||
|
|
||||||
x1 = model.predict(val_a)
|
model.trainable = True
|
||||||
model.train_on_batch(val_a, val_out)
|
model.compile('sgd', 'mse')
|
||||||
x2 = model.predict(val_a)
|
|
||||||
self.assertAllClose(x1, x2, atol=1e-7)
|
|
||||||
|
|
||||||
model.trainable = True
|
model.train_on_batch(val_a, val_out)
|
||||||
model.compile('sgd', 'mse')
|
x2 = model.predict(val_a)
|
||||||
assert model.updates
|
assert np.abs(np.sum(x1 - x2)) > 1e-5
|
||||||
|
|
||||||
model.train_on_batch(val_a, val_out)
|
|
||||||
x2 = model.predict(val_a)
|
|
||||||
assert np.abs(np.sum(x1 - x2)) > 1e-5
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_sequential_deferred_build_serialization(self):
|
def test_sequential_deferred_build_serialization(self):
|
||||||
|
|
|
@ -325,18 +325,18 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
|
||||||
norm(inp)
|
norm(inp)
|
||||||
|
|
||||||
def test_updates_in_wrap_function(self):
|
def test_updates_in_wrap_function(self):
|
||||||
layer = normalization.BatchNormalization()
|
|
||||||
|
|
||||||
def my_func():
|
def my_func():
|
||||||
|
layer = normalization.BatchNormalization()
|
||||||
x = array_ops.ones((10, 1))
|
x = array_ops.ones((10, 1))
|
||||||
return layer(x, training=True)
|
y = layer(x, training=True)
|
||||||
|
# Updates should be tracked in a `wrap_function`.
|
||||||
|
self.assertLen(layer.updates, 2)
|
||||||
|
return y
|
||||||
|
|
||||||
wrapped_fn = wrap_function.wrap_function(my_func, [])
|
wrapped_fn = wrap_function.wrap_function(my_func, [])
|
||||||
wrapped_fn()
|
wrapped_fn()
|
||||||
|
|
||||||
# Updates should be tracked in a `wrap_function`.
|
|
||||||
self.assertLen(layer.updates, 2)
|
|
||||||
|
|
||||||
@keras_parameterized.run_all_keras_modes
|
@keras_parameterized.run_all_keras_modes
|
||||||
def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self):
|
def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self):
|
||||||
# Test case for GitHub issue for 32380
|
# Test case for GitHub issue for 32380
|
||||||
|
@ -392,15 +392,11 @@ class NormalizationLayersGraphModeOnlyTest(
|
||||||
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||||
model.train_on_batch(x, x)
|
model.train_on_batch(x, x)
|
||||||
|
|
||||||
self.assertLen(bn.updates, 4)
|
|
||||||
|
|
||||||
# Test model-level reuse
|
# Test model-level reuse
|
||||||
x3 = keras.layers.Input(shape=(10,))
|
x3 = keras.layers.Input(shape=(10,))
|
||||||
y3 = model(x3)
|
y3 = model(x3)
|
||||||
new_model = keras.models.Model(x3, y3, name='new_model')
|
new_model = keras.models.Model(x3, y3, name='new_model')
|
||||||
|
|
||||||
self.assertLen(new_model.updates, 6)
|
|
||||||
self.assertLen(model.updates, 6)
|
|
||||||
new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||||
new_model.train_on_batch(x, x)
|
new_model.train_on_batch(x, x)
|
||||||
|
|
||||||
|
@ -415,10 +411,7 @@ class NormalizationLayersGraphModeOnlyTest(
|
||||||
model = keras.models.Model(a, b)
|
model = keras.models.Model(a, b)
|
||||||
|
|
||||||
model.trainable = False
|
model.trainable = False
|
||||||
assert not model.updates
|
|
||||||
|
|
||||||
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||||
assert not model.updates
|
|
||||||
|
|
||||||
x1 = model.predict(val_a)
|
x1 = model.predict(val_a)
|
||||||
model.train_on_batch(val_a, val_out)
|
model.train_on_batch(val_a, val_out)
|
||||||
|
@ -427,7 +420,6 @@ class NormalizationLayersGraphModeOnlyTest(
|
||||||
|
|
||||||
model.trainable = True
|
model.trainable = True
|
||||||
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||||
assert model.updates
|
|
||||||
|
|
||||||
model.train_on_batch(val_a, val_out)
|
model.train_on_batch(val_a, val_out)
|
||||||
x2 = model.predict(val_a)
|
x2 = model.predict(val_a)
|
||||||
|
@ -435,7 +427,6 @@ class NormalizationLayersGraphModeOnlyTest(
|
||||||
|
|
||||||
layer.trainable = False
|
layer.trainable = False
|
||||||
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||||
assert not model.updates
|
|
||||||
|
|
||||||
x1 = model.predict(val_a)
|
x1 = model.predict(val_a)
|
||||||
model.train_on_batch(val_a, val_out)
|
model.train_on_batch(val_a, val_out)
|
||||||
|
|
|
@ -234,13 +234,10 @@ class TimeDistributedTest(keras_parameterized.TestCase):
|
||||||
x = keras.layers.Input(shape=(3, 2))
|
x = keras.layers.Input(shape=(3, 2))
|
||||||
layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization())
|
layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization())
|
||||||
_ = layer(x)
|
_ = layer(x)
|
||||||
self.assertEqual(len(layer.updates), 2)
|
|
||||||
self.assertEqual(len(layer.trainable_weights), 2)
|
self.assertEqual(len(layer.trainable_weights), 2)
|
||||||
layer.trainable = False
|
layer.trainable = False
|
||||||
assert not layer.updates
|
|
||||||
assert not layer.trainable_weights
|
assert not layer.trainable_weights
|
||||||
layer.trainable = True
|
layer.trainable = True
|
||||||
assert len(layer.updates) == 2
|
|
||||||
assert len(layer.trainable_weights) == 2
|
assert len(layer.trainable_weights) == 2
|
||||||
|
|
||||||
def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self):
|
def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self):
|
||||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.util import lazy_loader
|
from tensorflow.python.util import lazy_loader
|
||||||
|
|
||||||
|
@ -51,8 +52,8 @@ class ModelVersionSelector(object):
|
||||||
"""Chooses between Keras v1 and v2 Model class."""
|
"""Chooses between Keras v1 and v2 Model class."""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||||
eager_enabled = ops.executing_eagerly_outside_functions()
|
use_v2 = should_use_v2()
|
||||||
cls = swap_class(cls, training.Model, training_v1.Model, eager_enabled)
|
cls = swap_class(cls, training.Model, training_v1.Model, use_v2) # pylint: disable=self-cls-assignment
|
||||||
return super(ModelVersionSelector, cls).__new__(cls)
|
return super(ModelVersionSelector, cls).__new__(cls)
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,8 +61,8 @@ class LayerVersionSelector(object):
|
||||||
"""Chooses between Keras v1 and v2 Layer class."""
|
"""Chooses between Keras v1 and v2 Layer class."""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||||
eager_enabled = ops.executing_eagerly_outside_functions()
|
use_v2 = should_use_v2()
|
||||||
cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, eager_enabled)
|
cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, use_v2) # pylint: disable=self-cls-assignment
|
||||||
return super(LayerVersionSelector, cls).__new__(cls)
|
return super(LayerVersionSelector, cls).__new__(cls)
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,10 +70,10 @@ class TensorBoardVersionSelector(object):
|
||||||
"""Chooses between Keras v1 and v2 TensorBoard callback class."""
|
"""Chooses between Keras v1 and v2 TensorBoard callback class."""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||||
eager_enabled = ops.executing_eagerly_outside_functions()
|
use_v2 = should_use_v2()
|
||||||
start_cls = cls
|
start_cls = cls
|
||||||
cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard,
|
cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard,
|
||||||
eager_enabled)
|
use_v2)
|
||||||
if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard:
|
if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard:
|
||||||
# Since the v2 class is not a subclass of the v1 class, __init__ has to
|
# Since the v2 class is not a subclass of the v1 class, __init__ has to
|
||||||
# be called manually.
|
# be called manually.
|
||||||
|
@ -80,19 +81,33 @@ class TensorBoardVersionSelector(object):
|
||||||
return super(TensorBoardVersionSelector, cls).__new__(cls)
|
return super(TensorBoardVersionSelector, cls).__new__(cls)
|
||||||
|
|
||||||
|
|
||||||
def swap_class(cls, v2_cls, v1_cls, eager_enabled):
|
def should_use_v2():
|
||||||
|
"""Determine if v1 or v2 version should be used."""
|
||||||
|
if context.executing_eagerly():
|
||||||
|
return True
|
||||||
|
elif ops.executing_eagerly_outside_functions():
|
||||||
|
# Check for a v1 `wrap_function` FuncGraph.
|
||||||
|
# Code inside a `wrap_function` is treated like v1 code.
|
||||||
|
graph = ops.get_default_graph()
|
||||||
|
if (getattr(graph, "name", False) and
|
||||||
|
graph.name.startswith("wrapped_function")):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def swap_class(cls, v2_cls, v1_cls, use_v2):
|
||||||
"""Swaps in v2_cls or v1_cls depending on graph mode."""
|
"""Swaps in v2_cls or v1_cls depending on graph mode."""
|
||||||
if cls == object:
|
if cls == object:
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
if cls in (v2_cls, v1_cls):
|
if cls in (v2_cls, v1_cls):
|
||||||
if eager_enabled:
|
if use_v2:
|
||||||
return v2_cls
|
return v2_cls
|
||||||
return v1_cls
|
return v1_cls
|
||||||
|
|
||||||
# Recursively search superclasses to swap in the right Keras class.
|
# Recursively search superclasses to swap in the right Keras class.
|
||||||
cls.__bases__ = tuple(
|
cls.__bases__ = tuple(
|
||||||
swap_class(base, v2_cls, v1_cls, eager_enabled) for base in cls.__bases__)
|
swap_class(base, v2_cls, v1_cls, use_v2) for base in cls.__bases__)
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue