Support trainable=False
for updates with run_eagerly=True
.
`add_update` can now be passed a zero-arg callable in order to support turning off the update when setting `trainable=False` on a Layer of a Model compiled with `run_eagerly=True`. PiperOrigin-RevId: 242704465
This commit is contained in:
parent
6257855a22
commit
fceded9809
@ -1169,7 +1169,7 @@ tf_py_test(
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
],
|
||||
shard_count = 16,
|
||||
shard_count = 20,
|
||||
tags = ["notsan"],
|
||||
)
|
||||
|
||||
|
@ -564,7 +564,7 @@ class Layer(trackable.Trackable):
|
||||
not base_layer_utils.is_in_call_context()):
|
||||
self._clear_losses()
|
||||
|
||||
with base_layer_utils.call_context():
|
||||
with base_layer_utils.call_context(self):
|
||||
# Check input assumptions set after layer building, e.g. input shape.
|
||||
if build_graph:
|
||||
# Symbolic execution on symbolic tensors. We will attempt to build
|
||||
@ -598,6 +598,8 @@ class Layer(trackable.Trackable):
|
||||
# Explicitly pass the learning phase placeholder to `call` if
|
||||
# the `training` argument was left unspecified by the user.
|
||||
# This behavior is restricted to the managed Keras FuncGraph.
|
||||
# TODO(omalleyt): Reconcile this with new `trainable` behavior
|
||||
# when available.
|
||||
learning_phase_passed_by_framework = False
|
||||
if (self._expects_training_arg and
|
||||
not base_layer_utils.training_arg_passed_to_call(
|
||||
@ -669,6 +671,7 @@ class Layer(trackable.Trackable):
|
||||
self._initial_weights is not None):
|
||||
self.set_weights(self._initial_weights)
|
||||
del self._initial_weights
|
||||
|
||||
return outputs
|
||||
|
||||
@property
|
||||
@ -974,7 +977,10 @@ class Layer(trackable.Trackable):
|
||||
execution).
|
||||
|
||||
Arguments:
|
||||
updates: Update op, or list/tuple of update ops.
|
||||
updates: Update op, or list/tuple of update ops, or zero-arg callable
|
||||
that returns an update op. A zero-arg callable should be passed in
|
||||
order to disable running the updates by setting `trainable=False`
|
||||
on this Layer, when executing in Eager mode.
|
||||
inputs: If anything other than None is passed, it signals the updates
|
||||
are conditional on some of the layer's inputs,
|
||||
and thus they should only be run where these inputs are available.
|
||||
@ -984,10 +990,20 @@ class Layer(trackable.Trackable):
|
||||
have is available at runtime.
|
||||
A step counter might fall into this category.
|
||||
"""
|
||||
updates = generic_utils.to_list(updates)
|
||||
|
||||
if context.executing_eagerly():
|
||||
# Don't run callable updates if currently executing inside the `call`
|
||||
# of a Layer/Model with `trainable=False`.
|
||||
if not base_layer_utils.is_in_frozen_context():
|
||||
for update in updates:
|
||||
if callable(update):
|
||||
update()
|
||||
return # Updates already applied when in eager mode.
|
||||
|
||||
def process_update(x):
|
||||
if callable(x):
|
||||
x = x()
|
||||
if isinstance(x, ops.Operation):
|
||||
return x
|
||||
elif hasattr(x, 'op'):
|
||||
@ -995,7 +1011,6 @@ class Layer(trackable.Trackable):
|
||||
else:
|
||||
return ops.convert_to_tensor(x)
|
||||
|
||||
updates = generic_utils.to_list(updates)
|
||||
updates = [process_update(x) for x in updates]
|
||||
self._updates += updates
|
||||
if inputs is None:
|
||||
|
@ -339,6 +339,17 @@ def is_in_call_context():
|
||||
return getattr(_call_context, 'in_call', False)
|
||||
|
||||
|
||||
def is_in_frozen_context():
|
||||
"""Returns if currently executing inside a `call` of a frozen Layer.
|
||||
|
||||
A Layer is considered frozen if `layer.trainable=False`.
|
||||
|
||||
Returns:
|
||||
Whether currently inside the `call` of a frozen Layer.
|
||||
"""
|
||||
return getattr(_call_context, 'frozen', False)
|
||||
|
||||
|
||||
def uses_keras_history(tensors):
|
||||
"""Check if at least one Tensor originates from a `keras.Input`.
|
||||
|
||||
@ -397,14 +408,18 @@ def mark_checked(tensors):
|
||||
|
||||
|
||||
@tf_contextlib.contextmanager
|
||||
def call_context():
|
||||
def call_context(layer):
|
||||
"""Scope that marks when we are currently inside a Layer/Model's `call`."""
|
||||
was_in_call = is_in_call_context()
|
||||
was_frozen = is_in_frozen_context()
|
||||
_call_context.in_call = True
|
||||
if not layer.trainable:
|
||||
_call_context.frozen = True
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_call_context.in_call = was_in_call
|
||||
_call_context.frozen = was_frozen
|
||||
|
||||
|
||||
def training_arg_passed_to_call(argspec, args, kwargs):
|
||||
|
@ -662,43 +662,43 @@ class TrainingTest(keras_parameterized.TestCase):
|
||||
metrics=['accuracy'],
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
|
||||
@keras_parameterized.run_all_keras_modes
|
||||
def test_that_trainable_disables_updates(self):
|
||||
val_a = np.random.random((10, 4))
|
||||
val_out = np.random.random((10, 4))
|
||||
|
||||
with self.cached_session():
|
||||
a = keras.layers.Input(shape=(4,))
|
||||
layer = keras.layers.BatchNormalization(input_shape=(4,))
|
||||
b = layer(a)
|
||||
model = keras.Model(a, b)
|
||||
a = keras.layers.Input(shape=(4,))
|
||||
layer = keras.layers.BatchNormalization(input_shape=(4,))
|
||||
b = layer(a)
|
||||
model = keras.Model(a, b)
|
||||
|
||||
model.trainable = False
|
||||
assert not model.updates
|
||||
model.trainable = False
|
||||
assert not model.updates
|
||||
|
||||
model.compile('sgd', 'mse')
|
||||
assert not model.updates
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
assert not model.updates
|
||||
|
||||
x1 = model.predict(val_a)
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
self.assertAllClose(x1, x2, atol=1e-7)
|
||||
x1 = model.predict(val_a)
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
self.assertAllClose(x1, x2, atol=1e-7)
|
||||
|
||||
model.trainable = True
|
||||
model.compile('sgd', 'mse')
|
||||
assert model.updates
|
||||
model.trainable = True
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
assert model.updates
|
||||
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
assert np.abs(np.sum(x1 - x2)) > 1e-5
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
assert np.abs(np.sum(x1 - x2)) > 1e-5
|
||||
|
||||
layer.trainable = False
|
||||
model.compile('sgd', 'mse')
|
||||
assert not model.updates
|
||||
layer.trainable = False
|
||||
model.compile('sgd', 'mse', run_eagerly=testing_utils.should_run_eagerly())
|
||||
assert not model.updates
|
||||
|
||||
x1 = model.predict(val_a)
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
self.assertAllClose(x1, x2, atol=1e-7)
|
||||
x1 = model.predict(val_a)
|
||||
model.train_on_batch(val_a, val_out)
|
||||
x2 = model.predict(val_a)
|
||||
self.assertAllClose(x1, x2, atol=1e-7)
|
||||
|
||||
def test_logs_passed_to_callbacks(self):
|
||||
with self.cached_session():
|
||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
@ -480,17 +479,25 @@ class BatchNormalizationBase(Layer):
|
||||
if training_value or training_value is None:
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
strategy = distribution_strategy_context.get_strategy()
|
||||
mean_update = strategy.extended.update(
|
||||
self.moving_mean, self._assign_moving_average,
|
||||
(mean, self.momentum))
|
||||
variance_update = strategy.extended.update(
|
||||
self.moving_variance, self._assign_moving_average,
|
||||
(variance, self.momentum))
|
||||
|
||||
def mean_update():
|
||||
return strategy.extended.update(self.moving_mean,
|
||||
self._assign_moving_average,
|
||||
(mean, self.momentum))
|
||||
|
||||
def variance_update():
|
||||
return strategy.extended.update(self.moving_variance,
|
||||
self._assign_moving_average,
|
||||
(variance, self.momentum))
|
||||
else:
|
||||
mean_update = self._assign_moving_average(self.moving_mean, mean,
|
||||
momentum)
|
||||
variance_update = self._assign_moving_average(self.moving_variance,
|
||||
variance, momentum)
|
||||
|
||||
def mean_update():
|
||||
return self._assign_moving_average(self.moving_mean, mean, momentum)
|
||||
|
||||
def variance_update():
|
||||
return self._assign_moving_average(self.moving_variance, variance,
|
||||
momentum)
|
||||
|
||||
self.add_update(mean_update, inputs=True)
|
||||
self.add_update(variance_update, inputs=True)
|
||||
|
||||
@ -569,7 +576,6 @@ class BatchNormalizationBase(Layer):
|
||||
if training is None:
|
||||
training = K.learning_phase()
|
||||
|
||||
in_eager_mode = context.executing_eagerly()
|
||||
if self.virtual_batch_size is not None:
|
||||
# Virtual batches (aka ghost batches) can be simulated by reshaping the
|
||||
# Tensor and reusing the existing batch norm implementation
|
||||
@ -676,39 +682,38 @@ class BatchNormalizationBase(Layer):
|
||||
|
||||
def _do_update(var, value):
|
||||
"""Compute the updates for mean and variance."""
|
||||
if in_eager_mode and not self.trainable:
|
||||
return
|
||||
return strategy.extended.update(
|
||||
var, self._assign_moving_average, (value, self.momentum),
|
||||
group=False)
|
||||
# We need to unwrap the moving_mean or moving_variance in the case of
|
||||
# training being false to match the output of true_fn and false_fn
|
||||
# in the smart cond.
|
||||
mean_update = tf_utils.smart_cond(
|
||||
training,
|
||||
lambda: _do_update(self.moving_mean, new_mean),
|
||||
lambda: strategy.unwrap(self.moving_mean))
|
||||
variance_update = tf_utils.smart_cond(
|
||||
training,
|
||||
lambda: _do_update(self.moving_variance, new_variance),
|
||||
lambda: strategy.unwrap(self.moving_variance))
|
||||
def mean_update():
|
||||
true_branch = lambda: _do_update(self.moving_mean, new_mean)
|
||||
false_branch = lambda: strategy.unwrap(self.moving_mean)
|
||||
return tf_utils.smart_cond(training, true_branch, false_branch)
|
||||
|
||||
def variance_update():
|
||||
return tf_utils.smart_cond(
|
||||
training, lambda: _do_update(self.moving_variance, new_variance),
|
||||
lambda: strategy.unwrap(self.moving_variance))
|
||||
else:
|
||||
def _do_update(var, value):
|
||||
"""Compute the updates for mean and variance."""
|
||||
if in_eager_mode and not self.trainable:
|
||||
return
|
||||
return self._assign_moving_average(var, value, self.momentum)
|
||||
mean_update = tf_utils.smart_cond(
|
||||
training,
|
||||
lambda: _do_update(self.moving_mean, new_mean),
|
||||
lambda: self.moving_mean)
|
||||
variance_update = tf_utils.smart_cond(
|
||||
training,
|
||||
lambda: _do_update(self.moving_variance, new_variance),
|
||||
lambda: self.moving_variance)
|
||||
if not context.executing_eagerly():
|
||||
self.add_update(mean_update, inputs=True)
|
||||
self.add_update(variance_update, inputs=True)
|
||||
|
||||
def mean_update():
|
||||
true_branch = lambda: _do_update(self.moving_mean, new_mean)
|
||||
false_branch = lambda: self.moving_mean
|
||||
return tf_utils.smart_cond(training, true_branch, false_branch)
|
||||
|
||||
def variance_update():
|
||||
true_branch = lambda: _do_update(self.moving_variance, new_variance)
|
||||
false_branch = lambda: self.moving_variance
|
||||
return tf_utils.smart_cond(training, true_branch, false_branch)
|
||||
|
||||
self.add_update(mean_update, inputs=True)
|
||||
self.add_update(variance_update, inputs=True)
|
||||
|
||||
else:
|
||||
mean, variance = self.moving_mean, self.moving_variance
|
||||
|
Loading…
Reference in New Issue
Block a user