From 2ba59dab2c22a592cb47660ecdb12463e457139c Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Wed, 17 Jun 2020 10:56:38 -0700 Subject: [PATCH] Simplify Layer.add_udpate in v2 and update version_selector to use v1 inside a tf.compat.v1.wrap_function. No longer track unused Layer.updates in v2. PiperOrigin-RevId: 316921838 Change-Id: I4698a0c925528594f402f824705d66b8a1ae7b72 --- tensorflow/python/keras/engine/base_layer.py | 53 +++---------------- .../python/keras/engine/sequential_test.py | 33 +++++------- .../python/keras/layers/normalization_test.py | 19 ++----- .../python/keras/layers/wrappers_test.py | 3 -- .../python/keras/utils/version_utils.py | 33 ++++++++---- 5 files changed, 50 insertions(+), 91 deletions(-) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index a0ee25417c0..5ddce951491 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -1733,54 +1733,15 @@ class Layer(module.Module, version_utils.LayerVersionSelector): inputs: Deprecated, will be automatically inferred. """ call_context = base_layer_utils.call_context() - - if (ds_context.has_strategy() and - ds_context.in_cross_replica_context() and - # When saving the model, the distribution strategy context should be - # ignored, following the default path for adding updates. - not call_context.saving): - # Updates don't need to be run in a cross-replica context. + # No need to run updates during Functional API construction. + if call_context.in_keras_graph: return - updates = generic_utils.to_list(updates) - - # All updates can be run immediately in Eager or in a tf.function. - if base_layer_utils.is_in_eager_or_tf_function(): - if not call_context.frozen: - for update in updates: - if callable(update): - update() - return - - def process_update(x): - """Standardize update ops. - - Arguments: - x: Tensor, op, or callable. - - Returns: - An update op. - """ - if callable(x): - update = lambda: process_update(x()) - if not ops.executing_eagerly_outside_functions(): - # In V1 mode, call the callable right away and process. This is needed - # for TPU strategy. - return update() - elif isinstance(x, ops.Operation): - update = x - elif hasattr(x, 'op'): - update = x.op - else: - update = ops.convert_to_tensor_v2(x) - return update - - updates = [process_update(x) for x in updates] - # Non-callable Updates are run automatically inside `call` in V2, so - # they do not need to be tracked later. - if ops.executing_eagerly_outside_functions() and call_context.in_call: - updates = [u for u in updates if callable(u)] - self._updates.extend(updates) + # Callable updates are disabled by setting `trainable=False`. + if not call_context.frozen: + for update in nest.flatten(updates): + if callable(update): + update() def set_weights(self, weights): """Sets the weights of the layer, from Numpy arrays. diff --git a/tensorflow/python/keras/engine/sequential_test.py b/tensorflow/python/keras/engine/sequential_test.py index 9589d24fc57..773ce003656 100644 --- a/tensorflow/python/keras/engine/sequential_test.py +++ b/tensorflow/python/keras/engine/sequential_test.py @@ -231,33 +231,28 @@ class TestSequential(keras_parameterized.TestCase): inner_model.trainable = True self.assertEqual(len(model.trainable_weights), 4) + @keras_parameterized.run_all_keras_modes def test_sequential_update_disabling(self): val_a = np.random.random((10, 4)) val_out = np.random.random((10, 4)) - with self.cached_session(): - model = keras.models.Sequential() - model.add(keras.layers.BatchNormalization(input_shape=(4,))) - assert model.updates + model = keras.models.Sequential() + model.add(keras.layers.BatchNormalization(input_shape=(4,))) - model.trainable = False - assert not model.updates + model.trainable = False + model.compile('sgd', 'mse') - model.compile('sgd', 'mse') - assert not model.updates + x1 = model.predict(val_a) + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + self.assertAllClose(x1, x2, atol=1e-7) - x1 = model.predict(val_a) - model.train_on_batch(val_a, val_out) - x2 = model.predict(val_a) - self.assertAllClose(x1, x2, atol=1e-7) + model.trainable = True + model.compile('sgd', 'mse') - model.trainable = True - model.compile('sgd', 'mse') - assert model.updates - - model.train_on_batch(val_a, val_out) - x2 = model.predict(val_a) - assert np.abs(np.sum(x1 - x2)) > 1e-5 + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + assert np.abs(np.sum(x1 - x2)) > 1e-5 @keras_parameterized.run_all_keras_modes def test_sequential_deferred_build_serialization(self): diff --git a/tensorflow/python/keras/layers/normalization_test.py b/tensorflow/python/keras/layers/normalization_test.py index ef43bcf5d22..39992f7580a 100644 --- a/tensorflow/python/keras/layers/normalization_test.py +++ b/tensorflow/python/keras/layers/normalization_test.py @@ -325,18 +325,18 @@ class BatchNormalizationV2Test(keras_parameterized.TestCase): norm(inp) def test_updates_in_wrap_function(self): - layer = normalization.BatchNormalization() def my_func(): + layer = normalization.BatchNormalization() x = array_ops.ones((10, 1)) - return layer(x, training=True) + y = layer(x, training=True) + # Updates should be tracked in a `wrap_function`. + self.assertLen(layer.updates, 2) + return y wrapped_fn = wrap_function.wrap_function(my_func, []) wrapped_fn() - # Updates should be tracked in a `wrap_function`. - self.assertLen(layer.updates, 2) - @keras_parameterized.run_all_keras_modes def test_basic_batchnorm_v2_none_shape_and_virtual_batch_size(self): # Test case for GitHub issue for 32380 @@ -392,15 +392,11 @@ class NormalizationLayersGraphModeOnlyTest( model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') model.train_on_batch(x, x) - self.assertLen(bn.updates, 4) - # Test model-level reuse x3 = keras.layers.Input(shape=(10,)) y3 = model(x3) new_model = keras.models.Model(x3, y3, name='new_model') - self.assertLen(new_model.updates, 6) - self.assertLen(model.updates, 6) new_model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') new_model.train_on_batch(x, x) @@ -415,10 +411,7 @@ class NormalizationLayersGraphModeOnlyTest( model = keras.models.Model(a, b) model.trainable = False - assert not model.updates - model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') - assert not model.updates x1 = model.predict(val_a) model.train_on_batch(val_a, val_out) @@ -427,7 +420,6 @@ class NormalizationLayersGraphModeOnlyTest( model.trainable = True model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') - assert model.updates model.train_on_batch(val_a, val_out) x2 = model.predict(val_a) @@ -435,7 +427,6 @@ class NormalizationLayersGraphModeOnlyTest( layer.trainable = False model.compile(gradient_descent.GradientDescentOptimizer(0.01), 'mse') - assert not model.updates x1 = model.predict(val_a) model.train_on_batch(val_a, val_out) diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index a73177fff12..5ee794dd1ef 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -234,13 +234,10 @@ class TimeDistributedTest(keras_parameterized.TestCase): x = keras.layers.Input(shape=(3, 2)) layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization()) _ = layer(x) - self.assertEqual(len(layer.updates), 2) self.assertEqual(len(layer.trainable_weights), 2) layer.trainable = False - assert not layer.updates assert not layer.trainable_weights layer.trainable = True - assert len(layer.updates) == 2 assert len(layer.trainable_weights) == 2 def test_TimeDistributed_with_masked_embedding_and_unspecified_shape(self): diff --git a/tensorflow/python/keras/utils/version_utils.py b/tensorflow/python/keras/utils/version_utils.py index 551a07d2422..d3796dcbf92 100644 --- a/tensorflow/python/keras/utils/version_utils.py +++ b/tensorflow/python/keras/utils/version_utils.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.util import lazy_loader @@ -51,8 +52,8 @@ class ModelVersionSelector(object): """Chooses between Keras v1 and v2 Model class.""" def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument - eager_enabled = ops.executing_eagerly_outside_functions() - cls = swap_class(cls, training.Model, training_v1.Model, eager_enabled) + use_v2 = should_use_v2() + cls = swap_class(cls, training.Model, training_v1.Model, use_v2) # pylint: disable=self-cls-assignment return super(ModelVersionSelector, cls).__new__(cls) @@ -60,8 +61,8 @@ class LayerVersionSelector(object): """Chooses between Keras v1 and v2 Layer class.""" def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument - eager_enabled = ops.executing_eagerly_outside_functions() - cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, eager_enabled) + use_v2 = should_use_v2() + cls = swap_class(cls, base_layer.Layer, base_layer_v1.Layer, use_v2) # pylint: disable=self-cls-assignment return super(LayerVersionSelector, cls).__new__(cls) @@ -69,10 +70,10 @@ class TensorBoardVersionSelector(object): """Chooses between Keras v1 and v2 TensorBoard callback class.""" def __new__(cls, *args, **kwargs): # pylint: disable=unused-argument - eager_enabled = ops.executing_eagerly_outside_functions() + use_v2 = should_use_v2() start_cls = cls cls = swap_class(start_cls, callbacks.TensorBoard, callbacks_v1.TensorBoard, - eager_enabled) + use_v2) if start_cls == callbacks_v1.TensorBoard and cls == callbacks.TensorBoard: # Since the v2 class is not a subclass of the v1 class, __init__ has to # be called manually. @@ -80,19 +81,33 @@ class TensorBoardVersionSelector(object): return super(TensorBoardVersionSelector, cls).__new__(cls) -def swap_class(cls, v2_cls, v1_cls, eager_enabled): +def should_use_v2(): + """Determine if v1 or v2 version should be used.""" + if context.executing_eagerly(): + return True + elif ops.executing_eagerly_outside_functions(): + # Check for a v1 `wrap_function` FuncGraph. + # Code inside a `wrap_function` is treated like v1 code. + graph = ops.get_default_graph() + if (getattr(graph, "name", False) and + graph.name.startswith("wrapped_function")): + return False + return True + + +def swap_class(cls, v2_cls, v1_cls, use_v2): """Swaps in v2_cls or v1_cls depending on graph mode.""" if cls == object: return cls if cls in (v2_cls, v1_cls): - if eager_enabled: + if use_v2: return v2_cls return v1_cls # Recursively search superclasses to swap in the right Keras class. cls.__bases__ = tuple( - swap_class(base, v2_cls, v1_cls, eager_enabled) for base in cls.__bases__) + swap_class(base, v2_cls, v1_cls, use_v2) for base in cls.__bases__) return cls