Fix issue where loss scaling was not applied in some cases.
When cloning=False was passed to Model.compile(), loss scaling was previously not applied. In general, instances of LossScalingOptimizer do not scale the loss if tf.GradientTape/tf.gradients is used instead of calling optimizer.compute_gradients. In the future, we should probably find a way to warn or error out if the user creates a LossScaleOptimizer, then uses GradientTape without scaling the loss and gradients. Also rename scale_grads to unscale_grads. PiperOrigin-RevId: 247548487
This commit is contained in:
parent
fc61ca2d4f
commit
4224db98ef
@ -197,6 +197,7 @@ py_library(
|
||||
"//tensorflow/python/eager:monitoring",
|
||||
"//tensorflow/python/keras/distribute",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:autocast_variable",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:loss_scale_optimizer",
|
||||
"//tensorflow/python/keras/mixed_precision/experimental:policy",
|
||||
"//tensorflow/python/module",
|
||||
"//tensorflow/python/training/tracking:data_structures",
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||
from tensorflow.python.keras.utils import losses_utils
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -234,13 +235,24 @@ def _process_single_batch(model,
|
||||
if total_loss is None:
|
||||
raise ValueError('The model cannot be run '
|
||||
'because it has no loss to optimize.')
|
||||
if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
|
||||
# TODO(reedwm): Make loss_scale public instead of accessing private
|
||||
# _loss_scale attribute.
|
||||
loss_scale = model.optimizer._loss_scale()
|
||||
scaled_total_loss = loss_scale_optimizer.scale_loss(total_loss,
|
||||
loss_scale)
|
||||
else:
|
||||
loss_scale = None
|
||||
scaled_total_loss = total_loss
|
||||
if training:
|
||||
if not model.trainable_weights:
|
||||
logging.warning('The list of trainable weights is empty. Make sure that'
|
||||
' you are not setting model.trainable to False before '
|
||||
'compiling the model.')
|
||||
else:
|
||||
grads = tape.gradient(total_loss, model.trainable_weights)
|
||||
grads = tape.gradient(scaled_total_loss, model.trainable_weights)
|
||||
if loss_scale is not None:
|
||||
grads = loss_scale_optimizer.unscale_grads(grads, loss_scale)
|
||||
model.optimizer.apply_gradients(zip(grads,
|
||||
model.trainable_weights))
|
||||
return outs, total_loss, output_losses, masks
|
||||
|
@ -293,9 +293,14 @@ class KerasModelTest(test.TestCase, parameterized.TestCase):
|
||||
'testcase_name': 'regularizer',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'use_regularizer': True
|
||||
}, {
|
||||
'testcase_name': 'nocloning',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'cloning': False
|
||||
})
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_model(self, strategy_fn, use_operator=False, use_regularizer=False):
|
||||
def test_model(self, strategy_fn, use_operator=False, use_regularizer=False,
|
||||
cloning=True):
|
||||
regularizer = IdentityRegularizer() if use_regularizer else None
|
||||
with strategy_fn().scope():
|
||||
with policy.policy_scope('infer_float32_vars'):
|
||||
@ -314,7 +319,7 @@ class KerasModelTest(test.TestCase, parameterized.TestCase):
|
||||
# the variable will not change. So this tests the learning rate not
|
||||
# applied to a float16 value, but instead the float32 variable.
|
||||
opt = gradient_descent.SGD(2 ** -14)
|
||||
model.compile(opt, loss=loss_fn)
|
||||
model.compile(opt, loss=loss_fn, cloning=cloning)
|
||||
|
||||
self.assertEqual(backend.eval(layer.v), 1)
|
||||
x = np.ones((2, 1))
|
||||
@ -329,6 +334,53 @@ class KerasModelTest(test.TestCase, parameterized.TestCase):
|
||||
expected -= 2 ** -14
|
||||
self.assertEqual(backend.eval(layer.v), expected)
|
||||
|
||||
@parameterized.named_parameters({
|
||||
'testcase_name': 'base',
|
||||
'strategy_fn': default_strategy_fn
|
||||
}, {
|
||||
'testcase_name': 'distribute',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
}, {
|
||||
'testcase_name': 'nocloning',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'cloning': False,
|
||||
})
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_fixed_loss_scaling(self, strategy_fn, cloning=True):
|
||||
# Note: We do not test mixed precision in this method, only loss scaling.
|
||||
loss_scale = 8.
|
||||
batch_size = 4
|
||||
with strategy_fn().scope():
|
||||
x = layers.Input(shape=(1,), batch_size=batch_size)
|
||||
layer = AddLayer()
|
||||
y = layer(x)
|
||||
|
||||
# The gradient of 'y' at this point is 1. With loss scaling, the gradient
|
||||
# is 'loss_scale'. We divide by the batch size since the loss is averaged
|
||||
# across batch elements.
|
||||
expected_gradient = loss_scale / batch_size
|
||||
identity_with_grad_check_fn = (
|
||||
mp_test_util.create_identity_with_grad_check_fn([expected_gradient]))
|
||||
y = core.Lambda(identity_with_grad_check_fn)(y)
|
||||
model = models.Model(inputs=x, outputs=y)
|
||||
|
||||
def loss_fn(y_true, y_pred):
|
||||
del y_true
|
||||
return math_ops.reduce_mean(y_pred)
|
||||
|
||||
opt = gradient_descent.SGD(1.)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
model.compile(opt, loss=loss_fn, cloning=cloning)
|
||||
|
||||
self.assertEqual(backend.eval(layer.v), 1)
|
||||
x = np.ones((batch_size, 1))
|
||||
y = np.ones((batch_size, 1))
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(batch_size)
|
||||
model.fit(dataset)
|
||||
# Variable starts at 1, and should have gradient of 1 subtracted from it.
|
||||
expected = 0
|
||||
self.assertEqual(backend.eval(layer.v), expected)
|
||||
|
||||
@parameterized.named_parameters({
|
||||
'testcase_name': 'base',
|
||||
'strategy_fn': default_strategy_fn
|
||||
@ -413,9 +465,13 @@ class KerasModelTest(test.TestCase, parameterized.TestCase):
|
||||
}, {
|
||||
'testcase_name': 'distribute',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
}, {
|
||||
'testcase_name': 'nocloning',
|
||||
'strategy_fn': create_mirrored_strategy,
|
||||
'cloning': False,
|
||||
})
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def test_dynamic_loss_scaling(self, strategy_fn):
|
||||
def test_dynamic_loss_scaling(self, strategy_fn, cloning=True):
|
||||
strategy = strategy_fn()
|
||||
initial_loss_scale = 2.
|
||||
batch_size = 4
|
||||
@ -449,12 +505,12 @@ class KerasModelTest(test.TestCase, parameterized.TestCase):
|
||||
loss_scale = loss_scale_module.DynamicLossScale(
|
||||
initial_loss_scale=initial_loss_scale, increment_period=2)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||
model.compile(opt, loss=loss_fn)
|
||||
model.compile(opt, loss=loss_fn, cloning=cloning)
|
||||
|
||||
self.assertEqual(backend.eval(layer.v), 1)
|
||||
x = np.ones((2, 1))
|
||||
y = np.ones((2, 1))
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
|
||||
x = np.ones((batch_size, 1))
|
||||
y = np.ones((batch_size, 1))
|
||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(batch_size)
|
||||
model.fit(dataset)
|
||||
# The variables starts with 1 and has a gradient of 1, so will go down by 1
|
||||
# each step.
|
||||
|
@ -41,6 +41,20 @@ class _UnwrapPreventer(object):
|
||||
self.value = value
|
||||
|
||||
|
||||
def scale_loss(loss, loss_scale):
|
||||
"""Scales the loss by the loss scale."""
|
||||
if callable(loss):
|
||||
return lambda: loss() * loss_scale
|
||||
else:
|
||||
return loss * loss_scale
|
||||
|
||||
|
||||
def unscale_grads(grads, loss_scale):
|
||||
"""Unscales the gradients by the loss scale."""
|
||||
loss_scale_reciprocal = 1. / loss_scale
|
||||
return [g * loss_scale_reciprocal if g is not None else None for g in grads]
|
||||
|
||||
|
||||
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
|
||||
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
"""An optimizer that applies loss scaling.
|
||||
@ -101,31 +115,18 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
self._track_trackable(self._loss_scale, 'loss_scale')
|
||||
|
||||
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
||||
loss = self._scale_loss(loss)
|
||||
loss = scale_loss(loss, self._loss_scale())
|
||||
grads_and_vars = self._optimizer._compute_gradients(loss, var_list, # pylint: disable=protected-access
|
||||
grad_loss)
|
||||
grads = [g for g, _ in grads_and_vars]
|
||||
variables = [v for _, v in grads_and_vars]
|
||||
scaled_grads = self._scale_grads(grads)
|
||||
return list(zip(scaled_grads, variables))
|
||||
unscaled_grads = unscale_grads(grads, self._loss_scale())
|
||||
return list(zip(unscaled_grads, variables))
|
||||
|
||||
def get_gradients(self, loss, params):
|
||||
loss = self._scale_loss(loss)
|
||||
loss = scale_loss(loss, self._loss_scale())
|
||||
grads = self._optimizer.get_gradients(loss, params)
|
||||
return self._scale_grads(grads)
|
||||
|
||||
def _scale_loss(self, loss):
|
||||
# The loss is callable for `_compute_gradients`, but not `get_gradients`.
|
||||
loss_scale = self._loss_scale()
|
||||
if callable(loss):
|
||||
return lambda: loss() * loss_scale
|
||||
else:
|
||||
return loss * loss_scale
|
||||
|
||||
def _scale_grads(self, grads):
|
||||
loss_scale = self._loss_scale()
|
||||
loss_scale_reciprocal = 1 / loss_scale
|
||||
return [None if g is None else g * loss_scale_reciprocal for g in grads]
|
||||
return unscale_grads(grads, self._loss_scale())
|
||||
|
||||
def apply_gradients(self, grads_and_vars, name=None):
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
|
@ -119,8 +119,8 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
|
||||
|
||||
grads = [g for g, _ in grads_and_vars]
|
||||
variables = [v for _, v in grads_and_vars]
|
||||
scaled_grads = self._scale_grads(grads)
|
||||
return list(zip(scaled_grads, variables))
|
||||
unscaled_grads = self._unscale_grads(grads)
|
||||
return list(zip(unscaled_grads, variables))
|
||||
|
||||
def _scale_loss(self, loss):
|
||||
loss_scale = self._loss_scale()
|
||||
@ -128,7 +128,7 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
|
||||
return lambda: loss() * loss_scale
|
||||
return loss * loss_scale
|
||||
|
||||
def _scale_grads(self, grads):
|
||||
def _unscale_grads(self, grads):
|
||||
loss_scale = self._loss_scale()
|
||||
loss_scale_reciprical = 1 / loss_scale
|
||||
return [
|
||||
|
Loading…
Reference in New Issue
Block a user