Simplify Layer.add_udpate in v2 and update version_selector to use v1 inside a
tf.compat.v1.wrap_function. No longer track unused Layer.updates in v2. PiperOrigin-RevId: 316921838 Change-Id: I4698a0c925528594f402f824705d66b8a1ae7b72
This commit is contained in:
parent
4747be646b
commit
2ba59dab2c
|
@ -1733,54 +1733,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
|||
inputs: Deprecated, will be automatically inferred.
|
||||
"""
|
||||
call_context = base_layer_utils.call_context()
|
||||
|
||||
if (ds_context.has_strategy() and
|
||||
ds_context.in_cross_replica_context() and
|
||||
# When saving the model, the distribution strategy context should be
|
||||
# ignored, following the default path for adding updates.
|
||||
not call_context.saving):
|
||||
# Updates don't need to be run in a cross-replica context.
|
||||
# No need to run updates during Functional API construction.
|
||||
if call_context.in_keras_graph:
|
||||
return
|
||||
|
||||
updates = generic_utils.to_list(updates)
|
||||
|
||||
# All updates can be run immediately in Eager or in a tf.function.
|
||||
if base_layer_utils.is_in_eager_or_tf_function():
|
||||
if not call_context.frozen:
|
||||
for update in updates:
|
||||
if callable(update):
|
||||
update()
|
||||
return
|
||||
|
||||
def process_update(x):
|
||||
"""Standardize update ops.
|
||||
|
||||
Arguments:
|
||||
x: Tensor, op, or callable.
|
||||
|
||||
Returns:
|
||||
An update op.
|
||||
"""
|
||||
if callable(x):
|
||||
update = lambda: process_update(x())
|
||||
if not ops.executing_eagerly_outside_functions():
|
||||
# In V1 mode, call the callable right away and process. This is needed
|
||||
# for TPU strategy.
|
||||
return update()
|
||||
elif isinstance(x, ops.Operation):
|
||||
update = x
|
||||
elif hasattr(x, 'op'):
|
||||
update = x.op
|
||||
else:
|
||||
update = ops.convert_to_tensor_v2(x)
|
||||
return update
|
||||
|
||||
updates = [process_update(x) for x in updates]
|
||||
# Non-callable Updates are run automatically inside `call` in V2, so
|
||||
# they do not need to be tracked later.
|
||||
if ops.executing_eagerly_outside_functions() and call_context.in_call:
|
||||
updates = [u for u in updates if callable(u)]
|
||||
self._updates.extend(updates)
|
||||
# Callable updates are disabled by setting `trainable=False`.
|
||||
if not call_context.frozen:
|
||||
for update in nest.flatten(updates):
|
||||
if callable(update):
|
||||
update()
|
||||
|
||||
def set_weights(self, weights):
|
||||
"""Sets the weights of the layer, from Numpy arrays.
|
||||
|
|
|
@ -231,33 +231,28 @@ class TestSequential(keras_parameterized.TestCase):
|
|||
inner_model.trainable = True
|
||||
self.assertEqual(len(model.trainable_weights), 4)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_sequential_update_disabling(self):
|
||||
val_a = np.random.random((10, 4))
|
||||
val_out = np.random.random((10, 4))
|
||||
|
||||
with self.cached_session():
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.BatchNormalization(input_shape=(4,)))
|
||||
assert model.updates
|
||||
model = keras.models.Sequential()
|
||||
model.add(keras.layers.BatchNormalization(input_shape=(4,)))
|
||||
|
||||
model.trainable = False
|
||||
assert not model.updates
|
||||
model.trainable = False
|
||||
model.compile('sgd', 'mse')
|
||||
|
||||
model.compile('sgd', 'mse')
|
||||
assert not model.updates
|
||||
x1 = model.predict(val_a)
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
self.assertAllClose(x1, x2, atol=1e-7)
|
||||
|
||||
x1 = model.predict(val_a)
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
self.assertAllClose(x1, x2, atol=1e-7)
|
||||
model.trainable = True
|
||||
model.compile('sgd', 'mse')
|
||||
|
||||
model.trainable = True
|
||||
model.compile('sgd', 'mse')
|
||||
assert model.updates
|
||||
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
assert np.abs(np.sum(x1 - x2)) > 1e-5
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
assert np.abs(np.sum(x1 - x2)) > 1e-5
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_sequential_deferred_build_serialization(self):
|
||||
|
|
|
@ -325,18 +325,18 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase):
|
|||
norm(inp)
|
||||
|
||||
def test_updates_in_wrap_function(self):
|
||||
layer = normalization.BatchNormalization()
|
||||
|
||||
def my_func():
|
||||
layer = normalization.BatchNormalization()
|
||||
x = array_ops.ones((10, 1))
|
||||
return layer(x, training=True)
|
||||
y = layer(x, training=True)
|
||||
# Updates should be tracked in a `wrap_function`.
|
||||
self.assertLen(layer.updates, 2)
|
||||
return y
|
||||
|
||||
wrapped_fn = wrap_function.wrap_function(my_func, [])
|
||||
wrapped_fn()
|
||||
|
||||
# Updates should be tracked in a `wrap_function`.
|
||||
self.assertLen(layer.updates, 2)
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self):
|
||||
# Test case for GitHub issue for 32380
|
||||
|
@ -392,15 +392,11 @@ class NormalizationLayersGraphModeOnlyTest(
|
|||
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||
model.train_on_batch(x, x)
|
||||
|
||||
self.assertLen(bn.updates, 4)
|
||||
|
||||
# Test model-level reuse
|
||||
x3 = keras.layers.Input(shape=(10,))
|
||||
y3 = model(x3)
|
||||
new_model = keras.models.Model(x3, y3, name='new_model')
|
||||
|
||||
self.assertLen(new_model.updates, 6)
|
||||
self.assertLen(model.updates, 6)
|
||||
new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||
new_model.train_on_batch(x, x)
|
||||
|
||||
|
@ -415,10 +411,7 @@ class NormalizationLayersGraphModeOnlyTest(
|
|||
model = keras.models.Model(a, b)
|
||||
|
||||
model.trainable = False
|
||||
assert not model.updates
|
||||
|
||||
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||
assert not model.updates
|
||||
|
||||
x1 = model.predict(val_a)
|
||||
model.train_on_batch(val_a, val_out)
|
||||
|
@ -427,7 +420,6 @@ class NormalizationLayersGraphModeOnlyTest(
|
|||
|
||||
model.trainable = True
|
||||
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||
assert model.updates
|
||||
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
|
@ -435,7 +427,6 @@ class NormalizationLayersGraphModeOnlyTest(
|
|||
|
||||
layer.trainable = False
|
||||
model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse')
|
||||
assert not model.updates
|
||||
|
||||
x1 = model.predict(val_a)
|
||||
model.train_on_batch(val_a, val_out)
|
||||
|
|
|
@ -234,13 +234,10 @@ class TimeDistributedTest(keras_parameterized.TestCase):
|
|||
x = keras.layers.Input(shape=(3, 2))
|
||||
layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization())
|
||||
_ = layer(x)
|
||||
self.assertEqual(len(layer.updates), 2)
|
||||
self.assertEqual(len(layer.trainable_weights), 2)
|
||||
layer.trainable = False
|
||||
assert not layer.updates
|
||||
assert not layer.trainable_weights
|
||||
layer.trainable = True
|
||||
assert len(layer.updates) == 2
|
||||
assert len(layer.trainable_weights) == 2
|
||||
|
||||
def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self):
|
||||
|
|
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
|||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.util import lazy_loader
|
||||
|
||||
|
@ -51,8 +52,8 @@ class ModelVersionSelector(object):
|
|||
"""Chooses between Keras v1 and v2 Model class."""
|
||||
|
||||
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
eager_enabled = ops.executing_eagerly_outside_functions()
|
||||
cls = swap_class(cls, training.Model, training_v1.Model, eager_enabled)
|
||||
use_v2 = should_use_v2()
|
||||
cls = swap_class(cls, training.Model, training_v1.Model, use_v2) # pylint: disable=self-cls-assignment
|
||||
return super(ModelVersionSelector, cls).__new__(cls)
|
||||
|
||||
|
||||
|
@ -60,8 +61,8 @@ class LayerVersionSelector(object):
|
|||
"""Chooses between Keras v1 and v2 Layer class."""
|
||||
|
||||
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
eager_enabled = ops.executing_eagerly_outside_functions()
|
||||
cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, eager_enabled)
|
||||
use_v2 = should_use_v2()
|
||||
cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, use_v2) # pylint: disable=self-cls-assignment
|
||||
return super(LayerVersionSelector, cls).__new__(cls)
|
||||
|
||||
|
||||
|
@ -69,10 +70,10 @@ class TensorBoardVersionSelector(object):
|
|||
"""Chooses between Keras v1 and v2 TensorBoard callback class."""
|
||||
|
||||
def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||
eager_enabled = ops.executing_eagerly_outside_functions()
|
||||
use_v2 = should_use_v2()
|
||||
start_cls = cls
|
||||
cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard,
|
||||
eager_enabled)
|
||||
use_v2)
|
||||
if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard:
|
||||
# Since the v2 class is not a subclass of the v1 class, __init__ has to
|
||||
# be called manually.
|
||||
|
@ -80,19 +81,33 @@ class TensorBoardVersionSelector(object):
|
|||
return super(TensorBoardVersionSelector, cls).__new__(cls)
|
||||
|
||||
|
||||
def swap_class(cls, v2_cls, v1_cls, eager_enabled):
|
||||
def should_use_v2():
|
||||
"""Determine if v1 or v2 version should be used."""
|
||||
if context.executing_eagerly():
|
||||
return True
|
||||
elif ops.executing_eagerly_outside_functions():
|
||||
# Check for a v1 `wrap_function` FuncGraph.
|
||||
# Code inside a `wrap_function` is treated like v1 code.
|
||||
graph = ops.get_default_graph()
|
||||
if (getattr(graph, "name", False) and
|
||||
graph.name.startswith("wrapped_function")):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def swap_class(cls, v2_cls, v1_cls, use_v2):
|
||||
"""Swaps in v2_cls or v1_cls depending on graph mode."""
|
||||
if cls == object:
|
||||
return cls
|
||||
|
||||
if cls in (v2_cls, v1_cls):
|
||||
if eager_enabled:
|
||||
if use_v2:
|
||||
return v2_cls
|
||||
return v1_cls
|
||||
|
||||
# Recursively search superclasses to swap in the right Keras class.
|
||||
cls.__bases__ = tuple(
|
||||
swap_class(base, v2_cls, v1_cls, eager_enabled) for base in cls.__bases__)
|
||||
swap_class(base, v2_cls, v1_cls, use_v2) for base in cls.__bases__)
|
||||
return cls
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue