Support trainable=False for updates with run_eagerly=True.

`add_update` can now be passed a zero-arg callable in order to support turning
off the update when setting `trainable=False` on a Layer of a Model compiled with
`run_eagerly=True`.

PiperOrigin-RevId: 242704465
This commit is contained in:
Thomas O'Malley 2019-04-09 11:21:29 -07:00 committed by TensorFlower Gardener
parent 6257855a22
commit fceded9809
5 changed files with 101 additions and 66 deletions

View File

@ -1169,7 +1169,7 @@ tf_py_test(
"//third_party/py/numpy", "//third_party/py/numpy",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
], ],
shard_count = 16, shard_count = 20,
tags = ["notsan"], tags = ["notsan"],
) )

View File

@ -564,7 +564,7 @@ class Layer(trackable.Trackable):
not base_layer_utils.is_in_call_context()): not base_layer_utils.is_in_call_context()):
self._clear_losses() self._clear_losses()
with base_layer_utils.call_context(): with base_layer_utils.call_context(self):
# Check input assumptions set after layer building, e.g. input shape. # Check input assumptions set after layer building, e.g. input shape.
if build_graph: if build_graph:
# Symbolic execution on symbolic tensors. We will attempt to build # Symbolic execution on symbolic tensors. We will attempt to build
@ -598,6 +598,8 @@ class Layer(trackable.Trackable):
# Explicitly pass the learning phase placeholder to `call` if # Explicitly pass the learning phase placeholder to `call` if
# the `training` argument was left unspecified by the user. # the `training` argument was left unspecified by the user.
# This behavior is restricted to the managed Keras FuncGraph. # This behavior is restricted to the managed Keras FuncGraph.
# TODO(omalleyt): Reconcile this with new `trainable` behavior
# when available.
learning_phase_passed_by_framework = False learning_phase_passed_by_framework = False
if (self._expects_training_arg and if (self._expects_training_arg and
not base_layer_utils.training_arg_passed_to_call( not base_layer_utils.training_arg_passed_to_call(
@ -669,6 +671,7 @@ class Layer(trackable.Trackable):
self._initial_weights is not None): self._initial_weights is not None):
self.set_weights(self._initial_weights) self.set_weights(self._initial_weights)
del self._initial_weights del self._initial_weights
return outputs return outputs
@property @property
@ -974,7 +977,10 @@ class Layer(trackable.Trackable):
execution). execution).
Arguments: Arguments:
updates: Update op, or list/tuple of update ops. updates: Update op, or list/tuple of update ops, or zero-arg callable
that returns an update op. A zero-arg callable should be passed in
order to disable running the updates by setting `trainable=False`
on this Layer, when executing in Eager mode.
inputs: If anything other than None is passed, it signals the updates inputs: If anything other than None is passed, it signals the updates
are conditional on some of the layer's inputs, are conditional on some of the layer's inputs,
and thus they should only be run where these inputs are available. and thus they should only be run where these inputs are available.
@ -984,10 +990,20 @@ class Layer(trackable.Trackable):
have is available at runtime. have is available at runtime.
A step counter might fall into this category. A step counter might fall into this category.
""" """
updates = generic_utils.to_list(updates)
if context.executing_eagerly(): if context.executing_eagerly():
# Don't run callable updates if currently executing inside the `call`
# of a Layer/Model with `trainable=False`.
if not base_layer_utils.is_in_frozen_context():
for update in updates:
if callable(update):
update()
return # Updates already applied when in eager mode. return # Updates already applied when in eager mode.
def process_update(x): def process_update(x):
if callable(x):
x = x()
if isinstance(x, ops.Operation): if isinstance(x, ops.Operation):
return x return x
elif hasattr(x, 'op'): elif hasattr(x, 'op'):
@ -995,7 +1011,6 @@ class Layer(trackable.Trackable):
else: else:
return ops.convert_to_tensor(x) return ops.convert_to_tensor(x)
updates = generic_utils.to_list(updates)
updates = [process_update(x) for x in updates] updates = [process_update(x) for x in updates]
self._updates += updates self._updates += updates
if inputs is None: if inputs is None:

View File

@ -339,6 +339,17 @@ def is_in_call_context():
return getattr(_call_context, 'in_call', False) return getattr(_call_context, 'in_call', False)
def is_in_frozen_context():
"""Returns if currently executing inside a `call` of a frozen Layer.
A Layer is considered frozen if `layer.trainable=False`.
Returns:
Whether currently inside the `call` of a frozen Layer.
"""
return getattr(_call_context, 'frozen', False)
def uses_keras_history(tensors): def uses_keras_history(tensors):
"""Check if at least one Tensor originates from a `keras.Input`. """Check if at least one Tensor originates from a `keras.Input`.
@ -397,14 +408,18 @@ def mark_checked(tensors):
@tf_contextlib.contextmanager @tf_contextlib.contextmanager
def call_context(): def call_context(layer):
"""Scope that marks when we are currently inside a Layer/Model's `call`.""" """Scope that marks when we are currently inside a Layer/Model's `call`."""
was_in_call = is_in_call_context() was_in_call = is_in_call_context()
was_frozen = is_in_frozen_context()
_call_context.in_call = True _call_context.in_call = True
if not layer.trainable:
_call_context.frozen = True
try: try:
yield yield
finally: finally:
_call_context.in_call = was_in_call _call_context.in_call = was_in_call
_call_context.frozen = was_frozen
def training_arg_passed_to_call(argspec, args, kwargs): def training_arg_passed_to_call(argspec, args, kwargs):

View File

@ -662,11 +662,11 @@ class TrainingTest(keras_parameterized.TestCase):
metrics=['accuracy'], metrics=['accuracy'],
run_eagerly=testing_utils.should_run_eagerly()) run_eagerly=testing_utils.should_run_eagerly())
@keras_parameterized.run_all_keras_modes
def test_that_trainable_disables_updates(self): def test_that_trainable_disables_updates(self):
val_a = np.random.random((10, 4)) val_a = np.random.random((10, 4))
val_out = np.random.random((10, 4)) val_out = np.random.random((10, 4))
with self.cached_session():
a = keras.layers.Input(shape=(4,)) a = keras.layers.Input(shape=(4,))
layer = keras.layers.BatchNormalization(input_shape=(4,)) layer = keras.layers.BatchNormalization(input_shape=(4,))
b = layer(a) b = layer(a)
@ -675,7 +675,7 @@ class TrainingTest(keras_parameterized.TestCase):
model.trainable = False model.trainable = False
assert not model.updates assert not model.updates
model.compile('sgd', 'mse') model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
assert not model.updates assert not model.updates
x1 = model.predict(val_a) x1 = model.predict(val_a)
@ -684,7 +684,7 @@ class TrainingTest(keras_parameterized.TestCase):
self.assertAllClose(x1, x2, atol=1e-7) self.assertAllClose(x1, x2, atol=1e-7)
model.trainable = True model.trainable = True
model.compile('sgd', 'mse') model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
assert model.updates assert model.updates
model.train_on_batch(val_a, val_out) model.train_on_batch(val_a, val_out)
@ -692,7 +692,7 @@ class TrainingTest(keras_parameterized.TestCase):
assert np.abs(np.sum(x1 - x2)) > 1e-5 assert np.abs(np.sum(x1 - x2)) > 1e-5
layer.trainable = False layer.trainable = False
model.compile('sgd', 'mse') model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
assert not model.updates assert not model.updates
x1 = model.predict(val_a) x1 = model.predict(val_a)

View File

@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
@ -480,17 +479,25 @@ class BatchNormalizationBase(Layer):
if training_value or training_value is None: if training_value or training_value is None:
if distribution_strategy_context.in_cross_replica_context(): if distribution_strategy_context.in_cross_replica_context():
strategy = distribution_strategy_context.get_strategy() strategy = distribution_strategy_context.get_strategy()
mean_update = strategy.extended.update(
self.moving_mean, self._assign_moving_average, def mean_update():
return strategy.extended.update(self.moving_mean,
self._assign_moving_average,
(mean, self.momentum)) (mean, self.momentum))
variance_update = strategy.extended.update(
self.moving_variance, self._assign_moving_average, def variance_update():
return strategy.extended.update(self.moving_variance,
self._assign_moving_average,
(variance, self.momentum)) (variance, self.momentum))
else: else:
mean_update = self._assign_moving_average(self.moving_mean, mean,
def mean_update():
return self._assign_moving_average(self.moving_mean, mean, momentum)
def variance_update():
return self._assign_moving_average(self.moving_variance, variance,
momentum) momentum)
variance_update = self._assign_moving_average(self.moving_variance,
variance, momentum)
self.add_update(mean_update, inputs=True) self.add_update(mean_update, inputs=True)
self.add_update(variance_update, inputs=True) self.add_update(variance_update, inputs=True)
@ -569,7 +576,6 @@ class BatchNormalizationBase(Layer):
if training is None: if training is None:
training = K.learning_phase() training = K.learning_phase()
in_eager_mode = context.executing_eagerly()
if self.virtual_batch_size is not None: if self.virtual_batch_size is not None:
# Virtual batches (aka ghost batches) can be simulated by reshaping the # Virtual batches (aka ghost batches) can be simulated by reshaping the
# Tensor and reusing the existing batch norm implementation # Tensor and reusing the existing batch norm implementation
@ -676,37 +682,36 @@ class BatchNormalizationBase(Layer):
def _do_update(var, value): def _do_update(var, value):
"""Compute the updates for mean and variance.""" """Compute the updates for mean and variance."""
if in_eager_mode and not self.trainable:
return
return strategy.extended.update( return strategy.extended.update(
var, self._assign_moving_average, (value, self.momentum), var, self._assign_moving_average, (value, self.momentum),
group=False) group=False)
# We need to unwrap the moving_mean or moving_variance in the case of # We need to unwrap the moving_mean or moving_variance in the case of
# training being false to match the output of true_fn and false_fn # training being false to match the output of true_fn and false_fn
# in the smart cond. # in the smart cond.
mean_update = tf_utils.smart_cond( def mean_update():
training, true_branch = lambda: _do_update(self.moving_mean, new_mean)
lambda: _do_update(self.moving_mean, new_mean), false_branch = lambda: strategy.unwrap(self.moving_mean)
lambda: strategy.unwrap(self.moving_mean)) return tf_utils.smart_cond(training, true_branch, false_branch)
variance_update = tf_utils.smart_cond(
training, def variance_update():
lambda: _do_update(self.moving_variance, new_variance), return tf_utils.smart_cond(
training, lambda: _do_update(self.moving_variance, new_variance),
lambda: strategy.unwrap(self.moving_variance)) lambda: strategy.unwrap(self.moving_variance))
else: else:
def _do_update(var, value): def _do_update(var, value):
"""Compute the updates for mean and variance.""" """Compute the updates for mean and variance."""
if in_eager_mode and not self.trainable:
return
return self._assign_moving_average(var, value, self.momentum) return self._assign_moving_average(var, value, self.momentum)
mean_update = tf_utils.smart_cond(
training, def mean_update():
lambda: _do_update(self.moving_mean, new_mean), true_branch = lambda: _do_update(self.moving_mean, new_mean)
lambda: self.moving_mean) false_branch = lambda: self.moving_mean
variance_update = tf_utils.smart_cond( return tf_utils.smart_cond(training, true_branch, false_branch)
training,
lambda: _do_update(self.moving_variance, new_variance), def variance_update():
lambda: self.moving_variance) true_branch = lambda: _do_update(self.moving_variance, new_variance)
if not context.executing_eagerly(): false_branch = lambda: self.moving_variance
return tf_utils.smart_cond(training, true_branch, false_branch)
self.add_update(mean_update, inputs=True) self.add_update(mean_update, inputs=True)
self.add_update(variance_update, inputs=True) self.add_update(variance_update, inputs=True)