Support trainable=False for updates with run_eagerly=True.

`add_update` can now be passed a zero-arg callable in order to support turning
off the update when setting `trainable=False` on a Layer of a Model compiled with
`run_eagerly=True`.

PiperOrigin-RevId: 242704465
This commit is contained in:
Thomas O'Malley 2019-04-09 11:21:29 -07:00 committed by TensorFlower Gardener
parent 6257855a22
commit fceded9809
5 changed files with 101 additions and 66 deletions

View File

@ -1169,7 +1169,7 @@ tf_py_test(
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
],
shard_count = 16,
shard_count = 20,
tags = ["notsan"],
)

View File

@ -564,7 +564,7 @@ class Layer(trackable.Trackable):
not base_layer_utils.is_in_call_context()):
self._clear_losses()
with base_layer_utils.call_context():
with base_layer_utils.call_context(self):
# Check input assumptions set after layer building, e.g. input shape.
if build_graph:
# Symbolic execution on symbolic tensors. We will attempt to build
@ -598,6 +598,8 @@ class Layer(trackable.Trackable):
# Explicitly pass the learning phase placeholder to `call` if
# the `training` argument was left unspecified by the user.
# This behavior is restricted to the managed Keras FuncGraph.
# TODO(omalleyt): Reconcile this with new `trainable` behavior
# when available.
learning_phase_passed_by_framework = False
if (self._expects_training_arg and
not base_layer_utils.training_arg_passed_to_call(
@ -669,6 +671,7 @@ class Layer(trackable.Trackable):
self._initial_weights is not None):
self.set_weights(self._initial_weights)
del self._initial_weights
return outputs
@property
@ -974,7 +977,10 @@ class Layer(trackable.Trackable):
execution).
Arguments:
updates: Update op, or list/tuple of update ops.
updates: Update op, or list/tuple of update ops, or zero-arg callable
that returns an update op. A zero-arg callable should be passed in
order to disable running the updates by setting `trainable=False`
on this Layer, when executing in Eager mode.
inputs: If anything other than None is passed, it signals the updates
are conditional on some of the layer's inputs,
and thus they should only be run where these inputs are available.
@ -984,10 +990,20 @@ class Layer(trackable.Trackable):
have is available at runtime.
A step counter might fall into this category.
"""
updates = generic_utils.to_list(updates)
if context.executing_eagerly():
# Don't run callable updates if currently executing inside the `call`
# of a Layer/Model with `trainable=False`.
if not base_layer_utils.is_in_frozen_context():
for update in updates:
if callable(update):
update()
return # Updates already applied when in eager mode.
def process_update(x):
if callable(x):
x = x()
if isinstance(x, ops.Operation):
return x
elif hasattr(x, 'op'):
@ -995,7 +1011,6 @@ class Layer(trackable.Trackable):
else:
return ops.convert_to_tensor(x)
updates = generic_utils.to_list(updates)
updates = [process_update(x) for x in updates]
self._updates += updates
if inputs is None:

View File

@ -339,6 +339,17 @@ def is_in_call_context():
return getattr(_call_context, 'in_call', False)
def is_in_frozen_context():
"""Returns if currently executing inside a `call` of a frozen Layer.
A Layer is considered frozen if `layer.trainable=False`.
Returns:
Whether currently inside the `call` of a frozen Layer.
"""
return getattr(_call_context, 'frozen', False)
def uses_keras_history(tensors):
"""Check if at least one Tensor originates from a `keras.Input`.
@ -397,14 +408,18 @@ def mark_checked(tensors):
@tf_contextlib.contextmanager
def call_context():
def call_context(layer):
"""Scope that marks when we are currently inside a Layer/Model's `call`."""
was_in_call = is_in_call_context()
was_frozen = is_in_frozen_context()
_call_context.in_call = True
if not layer.trainable:
_call_context.frozen = True
try:
yield
finally:
_call_context.in_call = was_in_call
_call_context.frozen = was_frozen
def training_arg_passed_to_call(argspec, args, kwargs):

View File

@ -662,43 +662,43 @@ class TrainingTest(keras_parameterized.TestCase):
metrics=['accuracy'],
run_eagerly=testing_utils.should_run_eagerly())
@keras_parameterized.run_all_keras_modes
def test_that_trainable_disables_updates(self):
val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4))
with self.cached_session():
a = keras.layers.Input(shape=(4,))
layer = keras.layers.BatchNormalization(input_shape=(4,))
b = layer(a)
model = keras.Model(a, b)
a = keras.layers.Input(shape=(4,))
layer = keras.layers.BatchNormalization(input_shape=(4,))
b = layer(a)
model = keras.Model(a, b)
model.trainable = False
assert not model.updates
model.trainable = False
assert not model.updates
model.compile('sgd', 'mse')
assert not model.updates
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
assert not model.updates
x1 = model.predict(val_a)
model.train_on_batch(val_a, val_out)
x2 = model.predict(val_a)
self.assertAllClose(x1, x2, atol=1e-7)
x1 = model.predict(val_a)
model.train_on_batch(val_a, val_out)
x2 = model.predict(val_a)
self.assertAllClose(x1, x2, atol=1e-7)
model.trainable = True
model.compile('sgd', 'mse')
assert model.updates
model.trainable = True
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
assert model.updates
model.train_on_batch(val_a, val_out)
x2 = model.predict(val_a)
assert np.abs(np.sum(x1 - x2)) > 1e-5
model.train_on_batch(val_a, val_out)
x2 = model.predict(val_a)
assert np.abs(np.sum(x1 - x2)) > 1e-5
layer.trainable = False
model.compile('sgd', 'mse')
assert not model.updates
layer.trainable = False
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
assert not model.updates
x1 = model.predict(val_a)
model.train_on_batch(val_a, val_out)
x2 = model.predict(val_a)
self.assertAllClose(x1, x2, atol=1e-7)
x1 = model.predict(val_a)
model.train_on_batch(val_a, val_out)
x2 = model.predict(val_a)
self.assertAllClose(x1, x2, atol=1e-7)
def test_logs_passed_to_callbacks(self):
with self.cached_session():

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@ -480,17 +479,25 @@ class BatchNormalizationBase(Layer):
if training_value or training_value is None:
if distribution_strategy_context.in_cross_replica_context():
strategy = distribution_strategy_context.get_strategy()
mean_update = strategy.extended.update(
self.moving_mean, self._assign_moving_average,
(mean, self.momentum))
variance_update = strategy.extended.update(
self.moving_variance, self._assign_moving_average,
(variance, self.momentum))
def mean_update():
return strategy.extended.update(self.moving_mean,
self._assign_moving_average,
(mean, self.momentum))
def variance_update():
return strategy.extended.update(self.moving_variance,
self._assign_moving_average,
(variance, self.momentum))
else:
mean_update = self._assign_moving_average(self.moving_mean, mean,
momentum)
variance_update = self._assign_moving_average(self.moving_variance,
variance, momentum)
def mean_update():
return self._assign_moving_average(self.moving_mean, mean, momentum)
def variance_update():
return self._assign_moving_average(self.moving_variance, variance,
momentum)
self.add_update(mean_update, inputs=True)
self.add_update(variance_update, inputs=True)
@ -569,7 +576,6 @@ class BatchNormalizationBase(Layer):
if training is None:
training = K.learning_phase()
in_eager_mode = context.executing_eagerly()
if self.virtual_batch_size is not None:
# Virtual batches (aka ghost batches) can be simulated by reshaping the
# Tensor and reusing the existing batch norm implementation
@ -676,39 +682,38 @@ class BatchNormalizationBase(Layer):
def _do_update(var, value):
"""Compute the updates for mean and variance."""
if in_eager_mode and not self.trainable:
return
return strategy.extended.update(
var, self._assign_moving_average, (value, self.momentum),
group=False)
# We need to unwrap the moving_mean or moving_variance in the case of
# training being false to match the output of true_fn and false_fn
# in the smart cond.
mean_update = tf_utils.smart_cond(
training,
lambda: _do_update(self.moving_mean, new_mean),
lambda: strategy.unwrap(self.moving_mean))
variance_update = tf_utils.smart_cond(
training,
lambda: _do_update(self.moving_variance, new_variance),
lambda: strategy.unwrap(self.moving_variance))
def mean_update():
true_branch = lambda: _do_update(self.moving_mean, new_mean)
false_branch = lambda: strategy.unwrap(self.moving_mean)
return tf_utils.smart_cond(training, true_branch, false_branch)
def variance_update():
return tf_utils.smart_cond(
training, lambda: _do_update(self.moving_variance, new_variance),
lambda: strategy.unwrap(self.moving_variance))
else:
def _do_update(var, value):
"""Compute the updates for mean and variance."""
if in_eager_mode and not self.trainable:
return
return self._assign_moving_average(var, value, self.momentum)
mean_update = tf_utils.smart_cond(
training,
lambda: _do_update(self.moving_mean, new_mean),
lambda: self.moving_mean)
variance_update = tf_utils.smart_cond(
training,
lambda: _do_update(self.moving_variance, new_variance),
lambda: self.moving_variance)
if not context.executing_eagerly():
self.add_update(mean_update, inputs=True)
self.add_update(variance_update, inputs=True)
def mean_update():
true_branch = lambda: _do_update(self.moving_mean, new_mean)
false_branch = lambda: self.moving_mean
return tf_utils.smart_cond(training, true_branch, false_branch)
def variance_update():
true_branch = lambda: _do_update(self.moving_variance, new_variance)
false_branch = lambda: self.moving_variance
return tf_utils.smart_cond(training, true_branch, false_branch)
self.add_update(mean_update, inputs=True)
self.add_update(variance_update, inputs=True)
else:
mean, variance = self.moving_mean, self.moving_variance