From fceded98095d7ea7c4fba5442ac64d4fa505702f Mon Sep 17 00:00:00 2001 From: Thomas O'Malley Date: Tue, 9 Apr 2019 11:21:29 -0700 Subject: [PATCH] Support `trainable=False` for updates with `run_eagerly=True`. `add_update` can now be passed a zero-arg callable in order to support turning off the update when setting `trainable=False` on a Layer of a Model compiled with `run_eagerly=True`. PiperOrigin-RevId: 242704465 --- tensorflow/python/keras/BUILD | 2 +- tensorflow/python/keras/engine/base_layer.py | 21 +++++- .../python/keras/engine/base_layer_utils.py | 17 ++++- .../python/keras/engine/training_test.py | 52 ++++++------- .../python/keras/layers/normalization.py | 75 ++++++++++--------- 5 files changed, 101 insertions(+), 66 deletions(-) diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index 20723b5987c..5a2e399cd39 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -1169,7 +1169,7 @@ tf_py_test( "//third_party/py/numpy", "//tensorflow/python:client_testlib", ], - shard_count = 16, + shard_count = 20, tags = ["notsan"], ) diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index fbbb843c38c..899b457de24 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -564,7 +564,7 @@ class Layer(trackable.Trackable): not base_layer_utils.is_in_call_context()): self._clear_losses() - with base_layer_utils.call_context(): + with base_layer_utils.call_context(self): # Check input assumptions set after layer building, e.g. input shape. if build_graph: # Symbolic execution on symbolic tensors. We will attempt to build @@ -598,6 +598,8 @@ class Layer(trackable.Trackable): # Explicitly pass the learning phase placeholder to `call` if # the `training` argument was left unspecified by the user. # This behavior is restricted to the managed Keras FuncGraph. + # TODO(omalleyt): Reconcile this with new `trainable` behavior + # when available. learning_phase_passed_by_framework = False if (self._expects_training_arg and not base_layer_utils.training_arg_passed_to_call( @@ -669,6 +671,7 @@ class Layer(trackable.Trackable): self._initial_weights is not None): self.set_weights(self._initial_weights) del self._initial_weights + return outputs @property @@ -974,7 +977,10 @@ class Layer(trackable.Trackable): execution). Arguments: - updates: Update op, or list/tuple of update ops. + updates: Update op, or list/tuple of update ops, or zero-arg callable + that returns an update op. A zero-arg callable should be passed in + order to disable running the updates by setting `trainable=False` + on this Layer, when executing in Eager mode. inputs: If anything other than None is passed, it signals the updates are conditional on some of the layer's inputs, and thus they should only be run where these inputs are available. @@ -984,10 +990,20 @@ class Layer(trackable.Trackable): have is available at runtime. A step counter might fall into this category. """ + updates = generic_utils.to_list(updates) + if context.executing_eagerly(): + # Don't run callable updates if currently executing inside the `call` + # of a Layer/Model with `trainable=False`. + if not base_layer_utils.is_in_frozen_context(): + for update in updates: + if callable(update): + update() return # Updates already applied when in eager mode. def process_update(x): + if callable(x): + x = x() if isinstance(x, ops.Operation): return x elif hasattr(x, 'op'): @@ -995,7 +1011,6 @@ class Layer(trackable.Trackable): else: return ops.convert_to_tensor(x) - updates = generic_utils.to_list(updates) updates = [process_update(x) for x in updates] self._updates += updates if inputs is None: diff --git a/tensorflow/python/keras/engine/base_layer_utils.py b/tensorflow/python/keras/engine/base_layer_utils.py index e3d1aa30c74..61622bfa93c 100644 --- a/tensorflow/python/keras/engine/base_layer_utils.py +++ b/tensorflow/python/keras/engine/base_layer_utils.py @@ -339,6 +339,17 @@ def is_in_call_context(): return getattr(_call_context, 'in_call', False) +def is_in_frozen_context(): + """Returns if currently executing inside a `call` of a frozen Layer. + + A Layer is considered frozen if `layer.trainable=False`. + + Returns: + Whether currently inside the `call` of a frozen Layer. + """ + return getattr(_call_context, 'frozen', False) + + def uses_keras_history(tensors): """Check if at least one Tensor originates from a `keras.Input`. @@ -397,14 +408,18 @@ def mark_checked(tensors): @tf_contextlib.contextmanager -def call_context(): +def call_context(layer): """Scope that marks when we are currently inside a Layer/Model's `call`.""" was_in_call = is_in_call_context() + was_frozen = is_in_frozen_context() _call_context.in_call = True + if not layer.trainable: + _call_context.frozen = True try: yield finally: _call_context.in_call = was_in_call + _call_context.frozen = was_frozen def training_arg_passed_to_call(argspec, args, kwargs): diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py index e4b87fa239a..5e7d828cce5 100644 --- a/tensorflow/python/keras/engine/training_test.py +++ b/tensorflow/python/keras/engine/training_test.py @@ -662,43 +662,43 @@ class TrainingTest(keras_parameterized.TestCase): metrics=['accuracy'], run_eagerly=testing_utils.should_run_eagerly()) + @keras_parameterized.run_all_keras_modes def test_that_trainable_disables_updates(self): val_a = np.random.random((10, 4)) val_out = np.random.random((10, 4)) - with self.cached_session(): - a = keras.layers.Input(shape=(4,)) - layer = keras.layers.BatchNormalization(input_shape=(4,)) - b = layer(a) - model = keras.Model(a, b) + a = keras.layers.Input(shape=(4,)) + layer = keras.layers.BatchNormalization(input_shape=(4,)) + b = layer(a) + model = keras.Model(a, b) - model.trainable = False - assert not model.updates + model.trainable = False + assert not model.updates - model.compile('sgd', 'mse') - assert not model.updates + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + assert not model.updates - x1 = model.predict(val_a) - model.train_on_batch(val_a, val_out) - x2 = model.predict(val_a) - self.assertAllClose(x1, x2, atol=1e-7) + x1 = model.predict(val_a) + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + self.assertAllClose(x1, x2, atol=1e-7) - model.trainable = True - model.compile('sgd', 'mse') - assert model.updates + model.trainable = True + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + assert model.updates - model.train_on_batch(val_a, val_out) - x2 = model.predict(val_a) - assert np.abs(np.sum(x1 - x2)) > 1e-5 + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + assert np.abs(np.sum(x1 - x2)) > 1e-5 - layer.trainable = False - model.compile('sgd', 'mse') - assert not model.updates + layer.trainable = False + model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly()) + assert not model.updates - x1 = model.predict(val_a) - model.train_on_batch(val_a, val_out) - x2 = model.predict(val_a) - self.assertAllClose(x1, x2, atol=1e-7) + x1 = model.predict(val_a) + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + self.assertAllClose(x1, x2, atol=1e-7) def test_logs_passed_to_callbacks(self): with self.cached_session(): diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 1bd69e0162b..d27dc010b01 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.distribute import distribution_strategy_context -from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -480,17 +479,25 @@ class BatchNormalizationBase(Layer): if training_value or training_value is None: if distribution_strategy_context.in_cross_replica_context(): strategy = distribution_strategy_context.get_strategy() - mean_update = strategy.extended.update( - self.moving_mean, self._assign_moving_average, - (mean, self.momentum)) - variance_update = strategy.extended.update( - self.moving_variance, self._assign_moving_average, - (variance, self.momentum)) + + def mean_update(): + return strategy.extended.update(self.moving_mean, + self._assign_moving_average, + (mean, self.momentum)) + + def variance_update(): + return strategy.extended.update(self.moving_variance, + self._assign_moving_average, + (variance, self.momentum)) else: - mean_update = self._assign_moving_average(self.moving_mean, mean, - momentum) - variance_update = self._assign_moving_average(self.moving_variance, - variance, momentum) + + def mean_update(): + return self._assign_moving_average(self.moving_mean, mean, momentum) + + def variance_update(): + return self._assign_moving_average(self.moving_variance, variance, + momentum) + self.add_update(mean_update, inputs=True) self.add_update(variance_update, inputs=True) @@ -569,7 +576,6 @@ class BatchNormalizationBase(Layer): if training is None: training = K.learning_phase() - in_eager_mode = context.executing_eagerly() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation @@ -676,39 +682,38 @@ class BatchNormalizationBase(Layer): def _do_update(var, value): """Compute the updates for mean and variance.""" - if in_eager_mode and not self.trainable: - return return strategy.extended.update( var, self._assign_moving_average, (value, self.momentum), group=False) # We need to unwrap the moving_mean or moving_variance in the case of # training being false to match the output of true_fn and false_fn # in the smart cond. - mean_update = tf_utils.smart_cond( - training, - lambda: _do_update(self.moving_mean, new_mean), - lambda: strategy.unwrap(self.moving_mean)) - variance_update = tf_utils.smart_cond( - training, - lambda: _do_update(self.moving_variance, new_variance), - lambda: strategy.unwrap(self.moving_variance)) + def mean_update(): + true_branch = lambda: _do_update(self.moving_mean, new_mean) + false_branch = lambda: strategy.unwrap(self.moving_mean) + return tf_utils.smart_cond(training, true_branch, false_branch) + + def variance_update(): + return tf_utils.smart_cond( + training, lambda: _do_update(self.moving_variance, new_variance), + lambda: strategy.unwrap(self.moving_variance)) else: def _do_update(var, value): """Compute the updates for mean and variance.""" - if in_eager_mode and not self.trainable: - return return self._assign_moving_average(var, value, self.momentum) - mean_update = tf_utils.smart_cond( - training, - lambda: _do_update(self.moving_mean, new_mean), - lambda: self.moving_mean) - variance_update = tf_utils.smart_cond( - training, - lambda: _do_update(self.moving_variance, new_variance), - lambda: self.moving_variance) - if not context.executing_eagerly(): - self.add_update(mean_update, inputs=True) - self.add_update(variance_update, inputs=True) + + def mean_update(): + true_branch = lambda: _do_update(self.moving_mean, new_mean) + false_branch = lambda: self.moving_mean + return tf_utils.smart_cond(training, true_branch, false_branch) + + def variance_update(): + true_branch = lambda: _do_update(self.moving_variance, new_variance) + false_branch = lambda: self.moving_variance + return tf_utils.smart_cond(training, true_branch, false_branch) + + self.add_update(mean_update, inputs=True) + self.add_update(variance_update, inputs=True) else: mean, variance = self.moving_mean, self.moving_variance