Fix issue where loss scaling was not applied in some cases.
When cloning=False was passed to Model.compile(), loss scaling was previously not applied. In general, instances of LossScalingOptimizer do not scale the loss if tf.GradientTape/tf.gradients is used instead of calling optimizer.compute_gradients. In the future, we should probably find a way to warn or error out if the user creates a LossScaleOptimizer, then uses GradientTape without scaling the loss and gradients. Also rename scale_grads to unscale_grads. PiperOrigin-RevId: 247548487
This commit is contained in:
parent
fc61ca2d4f
commit
4224db98ef
@ -197,6 +197,7 @@ py_library(
|
|||||||
"//tensorflow/python/eager:monitoring",
|
"//tensorflow/python/eager:monitoring",
|
||||||
"//tensorflow/python/keras/distribute",
|
"//tensorflow/python/keras/distribute",
|
||||||
"//tensorflow/python/keras/mixed_precision/experimental:autocast_variable",
|
"//tensorflow/python/keras/mixed_precision/experimental:autocast_variable",
|
||||||
|
"//tensorflow/python/keras/mixed_precision/experimental:loss_scale_optimizer",
|
||||||
"//tensorflow/python/keras/mixed_precision/experimental:policy",
|
"//tensorflow/python/keras/mixed_precision/experimental:policy",
|
||||||
"//tensorflow/python/module",
|
"//tensorflow/python/module",
|
||||||
"//tensorflow/python/training/tracking:data_structures",
|
"//tensorflow/python/training/tracking:data_structures",
|
||||||
|
@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.keras import backend
|
from tensorflow.python.keras import backend
|
||||||
from tensorflow.python.keras.engine import training_utils
|
from tensorflow.python.keras.engine import training_utils
|
||||||
|
from tensorflow.python.keras.mixed_precision.experimental import loss_scale_optimizer
|
||||||
from tensorflow.python.keras.utils import losses_utils
|
from tensorflow.python.keras.utils import losses_utils
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
@ -234,13 +235,24 @@ def _process_single_batch(model,
|
|||||||
if total_loss is None:
|
if total_loss is None:
|
||||||
raise ValueError('The model cannot be run '
|
raise ValueError('The model cannot be run '
|
||||||
'because it has no loss to optimize.')
|
'because it has no loss to optimize.')
|
||||||
|
if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
|
||||||
|
# TODO(reedwm): Make loss_scale public instead of accessing private
|
||||||
|
# _loss_scale attribute.
|
||||||
|
loss_scale = model.optimizer._loss_scale()
|
||||||
|
scaled_total_loss = loss_scale_optimizer.scale_loss(total_loss,
|
||||||
|
loss_scale)
|
||||||
|
else:
|
||||||
|
loss_scale = None
|
||||||
|
scaled_total_loss = total_loss
|
||||||
if training:
|
if training:
|
||||||
if not model.trainable_weights:
|
if not model.trainable_weights:
|
||||||
logging.warning('The list of trainable weights is empty. Make sure that'
|
logging.warning('The list of trainable weights is empty. Make sure that'
|
||||||
' you are not setting model.trainable to False before '
|
' you are not setting model.trainable to False before '
|
||||||
'compiling the model.')
|
'compiling the model.')
|
||||||
else:
|
else:
|
||||||
grads = tape.gradient(total_loss, model.trainable_weights)
|
grads = tape.gradient(scaled_total_loss, model.trainable_weights)
|
||||||
|
if loss_scale is not None:
|
||||||
|
grads = loss_scale_optimizer.unscale_grads(grads, loss_scale)
|
||||||
model.optimizer.apply_gradients(zip(grads,
|
model.optimizer.apply_gradients(zip(grads,
|
||||||
model.trainable_weights))
|
model.trainable_weights))
|
||||||
return outs, total_loss, output_losses, masks
|
return outs, total_loss, output_losses, masks
|
||||||
|
@ -293,9 +293,14 @@ class KerasModelTest(test.TestCase, parameterized.TestCase):
|
|||||||
'testcase_name': 'regularizer',
|
'testcase_name': 'regularizer',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
'use_regularizer': True
|
'use_regularizer': True
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'nocloning',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'cloning': False
|
||||||
})
|
})
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_model(self, strategy_fn, use_operator=False, use_regularizer=False):
|
def test_model(self, strategy_fn, use_operator=False, use_regularizer=False,
|
||||||
|
cloning=True):
|
||||||
regularizer = IdentityRegularizer() if use_regularizer else None
|
regularizer = IdentityRegularizer() if use_regularizer else None
|
||||||
with strategy_fn().scope():
|
with strategy_fn().scope():
|
||||||
with policy.policy_scope('infer_float32_vars'):
|
with policy.policy_scope('infer_float32_vars'):
|
||||||
@ -314,7 +319,7 @@ class KerasModelTest(test.TestCase, parameterized.TestCase):
|
|||||||
# the variable will not change. So this tests the learning rate not
|
# the variable will not change. So this tests the learning rate not
|
||||||
# applied to a float16 value, but instead the float32 variable.
|
# applied to a float16 value, but instead the float32 variable.
|
||||||
opt = gradient_descent.SGD(2 ** -14)
|
opt = gradient_descent.SGD(2 ** -14)
|
||||||
model.compile(opt, loss=loss_fn)
|
model.compile(opt, loss=loss_fn, cloning=cloning)
|
||||||
|
|
||||||
self.assertEqual(backend.eval(layer.v), 1)
|
self.assertEqual(backend.eval(layer.v), 1)
|
||||||
x = np.ones((2, 1))
|
x = np.ones((2, 1))
|
||||||
@ -329,6 +334,53 @@ class KerasModelTest(test.TestCase, parameterized.TestCase):
|
|||||||
expected -= 2 ** -14
|
expected -= 2 ** -14
|
||||||
self.assertEqual(backend.eval(layer.v), expected)
|
self.assertEqual(backend.eval(layer.v), expected)
|
||||||
|
|
||||||
|
@parameterized.named_parameters({
|
||||||
|
'testcase_name': 'base',
|
||||||
|
'strategy_fn': default_strategy_fn
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'distribute',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'nocloning',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'cloning': False,
|
||||||
|
})
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_fixed_loss_scaling(self, strategy_fn, cloning=True):
|
||||||
|
# Note: We do not test mixed precision in this method, only loss scaling.
|
||||||
|
loss_scale = 8.
|
||||||
|
batch_size = 4
|
||||||
|
with strategy_fn().scope():
|
||||||
|
x = layers.Input(shape=(1,), batch_size=batch_size)
|
||||||
|
layer = AddLayer()
|
||||||
|
y = layer(x)
|
||||||
|
|
||||||
|
# The gradient of 'y' at this point is 1. With loss scaling, the gradient
|
||||||
|
# is 'loss_scale'. We divide by the batch size since the loss is averaged
|
||||||
|
# across batch elements.
|
||||||
|
expected_gradient = loss_scale / batch_size
|
||||||
|
identity_with_grad_check_fn = (
|
||||||
|
mp_test_util.create_identity_with_grad_check_fn([expected_gradient]))
|
||||||
|
y = core.Lambda(identity_with_grad_check_fn)(y)
|
||||||
|
model = models.Model(inputs=x, outputs=y)
|
||||||
|
|
||||||
|
def loss_fn(y_true, y_pred):
|
||||||
|
del y_true
|
||||||
|
return math_ops.reduce_mean(y_pred)
|
||||||
|
|
||||||
|
opt = gradient_descent.SGD(1.)
|
||||||
|
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||||
|
model.compile(opt, loss=loss_fn, cloning=cloning)
|
||||||
|
|
||||||
|
self.assertEqual(backend.eval(layer.v), 1)
|
||||||
|
x = np.ones((batch_size, 1))
|
||||||
|
y = np.ones((batch_size, 1))
|
||||||
|
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(batch_size)
|
||||||
|
model.fit(dataset)
|
||||||
|
# Variable starts at 1, and should have gradient of 1 subtracted from it.
|
||||||
|
expected = 0
|
||||||
|
self.assertEqual(backend.eval(layer.v), expected)
|
||||||
|
|
||||||
@parameterized.named_parameters({
|
@parameterized.named_parameters({
|
||||||
'testcase_name': 'base',
|
'testcase_name': 'base',
|
||||||
'strategy_fn': default_strategy_fn
|
'strategy_fn': default_strategy_fn
|
||||||
@ -413,9 +465,13 @@ class KerasModelTest(test.TestCase, parameterized.TestCase):
|
|||||||
}, {
|
}, {
|
||||||
'testcase_name': 'distribute',
|
'testcase_name': 'distribute',
|
||||||
'strategy_fn': create_mirrored_strategy,
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
}, {
|
||||||
|
'testcase_name': 'nocloning',
|
||||||
|
'strategy_fn': create_mirrored_strategy,
|
||||||
|
'cloning': False,
|
||||||
})
|
})
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def test_dynamic_loss_scaling(self, strategy_fn):
|
def test_dynamic_loss_scaling(self, strategy_fn, cloning=True):
|
||||||
strategy = strategy_fn()
|
strategy = strategy_fn()
|
||||||
initial_loss_scale = 2.
|
initial_loss_scale = 2.
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
@ -449,12 +505,12 @@ class KerasModelTest(test.TestCase, parameterized.TestCase):
|
|||||||
loss_scale = loss_scale_module.DynamicLossScale(
|
loss_scale = loss_scale_module.DynamicLossScale(
|
||||||
initial_loss_scale=initial_loss_scale, increment_period=2)
|
initial_loss_scale=initial_loss_scale, increment_period=2)
|
||||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale)
|
||||||
model.compile(opt, loss=loss_fn)
|
model.compile(opt, loss=loss_fn, cloning=cloning)
|
||||||
|
|
||||||
self.assertEqual(backend.eval(layer.v), 1)
|
self.assertEqual(backend.eval(layer.v), 1)
|
||||||
x = np.ones((2, 1))
|
x = np.ones((batch_size, 1))
|
||||||
y = np.ones((2, 1))
|
y = np.ones((batch_size, 1))
|
||||||
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(2)
|
dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).batch(batch_size)
|
||||||
model.fit(dataset)
|
model.fit(dataset)
|
||||||
# The variables starts with 1 and has a gradient of 1, so will go down by 1
|
# The variables starts with 1 and has a gradient of 1, so will go down by 1
|
||||||
# each step.
|
# each step.
|
||||||
|
@ -41,6 +41,20 @@ class _UnwrapPreventer(object):
|
|||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
|
def scale_loss(loss, loss_scale):
|
||||||
|
"""Scales the loss by the loss scale."""
|
||||||
|
if callable(loss):
|
||||||
|
return lambda: loss() * loss_scale
|
||||||
|
else:
|
||||||
|
return loss * loss_scale
|
||||||
|
|
||||||
|
|
||||||
|
def unscale_grads(grads, loss_scale):
|
||||||
|
"""Unscales the gradients by the loss scale."""
|
||||||
|
loss_scale_reciprocal = 1. / loss_scale
|
||||||
|
return [g * loss_scale_reciprocal if g is not None else None for g in grads]
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
|
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
|
||||||
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||||
"""An optimizer that applies loss scaling.
|
"""An optimizer that applies loss scaling.
|
||||||
@ -101,31 +115,18 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
|||||||
self._track_trackable(self._loss_scale, 'loss_scale')
|
self._track_trackable(self._loss_scale, 'loss_scale')
|
||||||
|
|
||||||
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
||||||
loss = self._scale_loss(loss)
|
loss = scale_loss(loss, self._loss_scale())
|
||||||
grads_and_vars = self._optimizer._compute_gradients(loss, var_list, # pylint: disable=protected-access
|
grads_and_vars = self._optimizer._compute_gradients(loss, var_list, # pylint: disable=protected-access
|
||||||
grad_loss)
|
grad_loss)
|
||||||
grads = [g for g, _ in grads_and_vars]
|
grads = [g for g, _ in grads_and_vars]
|
||||||
variables = [v for _, v in grads_and_vars]
|
variables = [v for _, v in grads_and_vars]
|
||||||
scaled_grads = self._scale_grads(grads)
|
unscaled_grads = unscale_grads(grads, self._loss_scale())
|
||||||
return list(zip(scaled_grads, variables))
|
return list(zip(unscaled_grads, variables))
|
||||||
|
|
||||||
def get_gradients(self, loss, params):
|
def get_gradients(self, loss, params):
|
||||||
loss = self._scale_loss(loss)
|
loss = scale_loss(loss, self._loss_scale())
|
||||||
grads = self._optimizer.get_gradients(loss, params)
|
grads = self._optimizer.get_gradients(loss, params)
|
||||||
return self._scale_grads(grads)
|
return unscale_grads(grads, self._loss_scale())
|
||||||
|
|
||||||
def _scale_loss(self, loss):
|
|
||||||
# The loss is callable for `_compute_gradients`, but not `get_gradients`.
|
|
||||||
loss_scale = self._loss_scale()
|
|
||||||
if callable(loss):
|
|
||||||
return lambda: loss() * loss_scale
|
|
||||||
else:
|
|
||||||
return loss * loss_scale
|
|
||||||
|
|
||||||
def _scale_grads(self, grads):
|
|
||||||
loss_scale = self._loss_scale()
|
|
||||||
loss_scale_reciprocal = 1 / loss_scale
|
|
||||||
return [None if g is None else g * loss_scale_reciprocal for g in grads]
|
|
||||||
|
|
||||||
def apply_gradients(self, grads_and_vars, name=None):
|
def apply_gradients(self, grads_and_vars, name=None):
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if distribution_strategy_context.in_cross_replica_context():
|
||||||
|
@ -119,8 +119,8 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
|
|||||||
|
|
||||||
grads = [g for g, _ in grads_and_vars]
|
grads = [g for g, _ in grads_and_vars]
|
||||||
variables = [v for _, v in grads_and_vars]
|
variables = [v for _, v in grads_and_vars]
|
||||||
scaled_grads = self._scale_grads(grads)
|
unscaled_grads = self._unscale_grads(grads)
|
||||||
return list(zip(scaled_grads, variables))
|
return list(zip(unscaled_grads, variables))
|
||||||
|
|
||||||
def _scale_loss(self, loss):
|
def _scale_loss(self, loss):
|
||||||
loss_scale = self._loss_scale()
|
loss_scale = self._loss_scale()
|
||||||
@ -128,7 +128,7 @@ class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
|
|||||||
return lambda: loss() * loss_scale
|
return lambda: loss() * loss_scale
|
||||||
return loss * loss_scale
|
return loss * loss_scale
|
||||||
|
|
||||||
def _scale_grads(self, grads):
|
def _unscale_grads(self, grads):
|
||||||
loss_scale = self._loss_scale()
|
loss_scale = self._loss_scale()
|
||||||
loss_scale_reciprical = 1 / loss_scale
|
loss_scale_reciprical = 1 / loss_scale
|
||||||
return [
|
return [
|
||||||
|
Loading…
Reference in New Issue
Block a user