Add methods to LossScaleOptimizer to manually do loss scaling.
This can be used for users who use a GradientTape to compute gradients instead of LossScaleOptimizer.minimize. PiperOrigin-RevId: 252511806
This commit is contained in:
parent
9f3cbdf0bc
commit
9912d064b1
@ -240,13 +240,8 @@ def _process_single_batch(model,
|
||||
raise ValueError('The model cannot be run '
|
||||
'because it has no loss to optimize.')
|
||||
if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
|
||||
# TODO(reedwm): Make loss_scale public instead of accessing private
|
||||
# _loss_scale attribute.
|
||||
loss_scale = model.optimizer._loss_scale()
|
||||
scaled_total_loss = loss_scale_optimizer.scale_loss(total_loss,
|
||||
loss_scale)
|
||||
scaled_total_loss = model.optimizer.get_scaled_loss(total_loss)
|
||||
else:
|
||||
loss_scale = None
|
||||
scaled_total_loss = total_loss
|
||||
if training:
|
||||
if not model.trainable_weights:
|
||||
@ -255,8 +250,8 @@ def _process_single_batch(model,
|
||||
'compiling the model.')
|
||||
else:
|
||||
grads = tape.gradient(scaled_total_loss, model.trainable_weights)
|
||||
if loss_scale is not None:
|
||||
grads = loss_scale_optimizer.unscale_grads(grads, loss_scale)
|
||||
if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
|
||||
grads = model.optimizer.get_unscaled_gradients(grads)
|
||||
model.optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
||||
model._set_trainable_state(current_trainable_state)
|
||||
return outs, total_loss, output_losses, masks
|
||||
|
@ -42,20 +42,6 @@ class _UnwrapPreventer(object):
|
||||
self.value = value
|
||||
|
||||
|
||||
def scale_loss(loss, loss_scale):
|
||||
"""Scales the loss by the loss scale."""
|
||||
if callable(loss):
|
||||
return lambda: loss() * loss_scale
|
||||
else:
|
||||
return loss * loss_scale
|
||||
|
||||
|
||||
def unscale_grads(grads, loss_scale):
|
||||
"""Unscales the gradients by the loss scale."""
|
||||
loss_scale_reciprocal = 1. / loss_scale
|
||||
return [g * loss_scale_reciprocal if g is not None else None for g in grads]
|
||||
|
||||
|
||||
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
|
||||
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
"""An optimizer that applies loss scaling.
|
||||
@ -83,7 +69,34 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
|
||||
This optimizer wraps another optimizer and applies loss scaling to it via a
|
||||
`LossScale`. Loss scaling is applied whenever gradients are
|
||||
computed, either through `minimize()` or `get_gradients()`.
|
||||
computed, either through `minimize()` or `get_gradients()`. The loss scale is
|
||||
updated via `LossScale.update()` whenever gradients are applied, either
|
||||
through `minimize()` or `apply_gradients()`. For example:
|
||||
|
||||
```python
|
||||
opt = tf.keras.optimizers.SGD(0.1)
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
|
||||
# 'minimize' applies loss scaling to the loss and updates the loss sale.
|
||||
opt.minimize(loss_fn)
|
||||
```
|
||||
|
||||
If a `tf.GradientTape` is used to compute gradients instead of
|
||||
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, the loss
|
||||
and gradients must be scaled manually. This can be done by calling
|
||||
`LossScaleOptimizer.get_scaled_loss` before passing the loss to
|
||||
`tf.GradientTape`, and `LossScaleOptimizer.get_unscaled_gradients` after
|
||||
computing the gradients with `tf.GradientTape`. For example:
|
||||
|
||||
```python
|
||||
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...)
|
||||
vars = ...
|
||||
with tf.GradientTape() as tape:
|
||||
loss = ...
|
||||
scaled_loss = opt.get_scaled_loss(loss)
|
||||
scaled_grads = tape.gradient(scaled_loss, vars)
|
||||
grads = opt.get_unscaled_gradients(scaled_grads)
|
||||
opt.apply_gradients(zip(grads, vars)) # Loss scale will be updated here
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, opt, loss_scale):
|
||||
@ -123,19 +136,75 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||
self._track_trackable(self._optimizer, 'base_optimizer')
|
||||
self._track_trackable(self._loss_scale, 'loss_scale')
|
||||
|
||||
@property
|
||||
def loss_scale(self):
|
||||
"""The `LossScale` instance associated with this optimizer."""
|
||||
return self._loss_scale
|
||||
|
||||
def get_scaled_loss(self, loss):
|
||||
"""Scales the loss by the loss scale.
|
||||
|
||||
This method is only needed if you compute gradients manually, e.g. with
|
||||
`tf.GradientTape`. In that case, call this method to scale the loss before
|
||||
passing the loss to `tf.GradientTape`. If you use
|
||||
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
|
||||
scaling is automatically applied and this method is unneeded.
|
||||
|
||||
If this method is called, `get_unscaled_gradients` should also be called.
|
||||
See the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for
|
||||
an example.
|
||||
|
||||
Args:
|
||||
loss: The loss, which will be multiplied by the loss scale. Can either be
|
||||
a tensor or a callable returning a tensor.
|
||||
|
||||
Returns:
|
||||
`loss` multiplied by `LossScaleOptimizer.loss_scale()`.
|
||||
"""
|
||||
loss_scale = self._loss_scale()
|
||||
if callable(loss):
|
||||
return lambda: loss() * loss_scale
|
||||
else:
|
||||
return loss * loss_scale
|
||||
|
||||
def get_unscaled_gradients(self, grads):
|
||||
"""Unscales the gradients by the loss scale.
|
||||
|
||||
This method is only needed if you compute gradients manually, e.g. with
|
||||
`tf.GradientTape`. In that case, call this method to unscale the gradients
|
||||
after computing them with `tf.GradientTape`. If you use
|
||||
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
|
||||
scaling is automatically applied and this method is unneeded.
|
||||
|
||||
If this method is called, `get_scaled_loss` should also be called. See
|
||||
the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for an
|
||||
example.
|
||||
|
||||
Args:
|
||||
grads: A list of tensors, each which will be divided by the loss scale.
|
||||
Can have None values, which are ignored.
|
||||
|
||||
Returns:
|
||||
A new list the same size as `grads`, where every non-None value in `grads`
|
||||
is divided by `LossScaleOptimizer.loss_scale()`.
|
||||
"""
|
||||
loss_scale = self._loss_scale()
|
||||
loss_scale_reciprocal = 1. / loss_scale
|
||||
return [g * loss_scale_reciprocal if g is not None else None for g in grads]
|
||||
|
||||
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
||||
loss = scale_loss(loss, self._loss_scale())
|
||||
loss = self.get_scaled_loss(loss)
|
||||
grads_and_vars = self._optimizer._compute_gradients(loss, var_list, # pylint: disable=protected-access
|
||||
grad_loss)
|
||||
grads = [g for g, _ in grads_and_vars]
|
||||
variables = [v for _, v in grads_and_vars]
|
||||
unscaled_grads = unscale_grads(grads, self._loss_scale())
|
||||
unscaled_grads = self.get_unscaled_gradients(grads)
|
||||
return list(zip(unscaled_grads, variables))
|
||||
|
||||
def get_gradients(self, loss, params):
|
||||
loss = scale_loss(loss, self._loss_scale())
|
||||
loss = self.get_scaled_loss(loss)
|
||||
grads = self._optimizer.get_gradients(loss, params)
|
||||
return unscale_grads(grads, self._loss_scale())
|
||||
return self.get_unscaled_gradients(grads)
|
||||
|
||||
def apply_gradients(self, grads_and_vars, name=None):
|
||||
if distribution_strategy_context.in_cross_replica_context():
|
||||
|
@ -109,6 +109,20 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
# mp_test_util.create_identity_with_grad_check_fn added an assertion op.
|
||||
self.evaluate(run_op)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGetScaledLoss(self):
|
||||
opt = gradient_descent.SGD(2.0)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2.)
|
||||
self.assertEqual(10., self.evaluate(opt.get_scaled_loss(5.)))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testGetUnscaledGradients(self):
|
||||
opt = gradient_descent.SGD(2.0)
|
||||
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2)
|
||||
grads = opt.get_unscaled_gradients([3., None, -4.])
|
||||
grads = [self.evaluate(g) if g is not None else g for g in grads]
|
||||
self.assertEqual([1.5, None, -2.], grads)
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testDynamicLossScale(self, strategy_fn):
|
||||
@ -162,7 +176,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
# Gradient is 2, so variable will have 2 subtracted from it
|
||||
self.assertAllClose([-1.0, 0.0], self.evaluate(var))
|
||||
# Loss scale has doubled from 2 to 4
|
||||
self.assertEqual(4., self.evaluate(opt._loss_scale()))
|
||||
self.assertEqual(4., self.evaluate(opt.loss_scale()))
|
||||
|
||||
# Test optimizer with NaN gradients
|
||||
loss = lambda: var * float('NaN')
|
||||
@ -172,7 +186,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
# Variable should not change from before, due to NaN gradients.
|
||||
self.assertAllClose(self.evaluate(var), [-1.0, 0.0])
|
||||
# Loss scale should half due to NaN gradients.
|
||||
self.assertEqual(2., self.evaluate(opt._loss_scale()))
|
||||
self.assertEqual(2., self.evaluate(opt.loss_scale()))
|
||||
|
||||
@parameterized.named_parameters(*TESTCASES)
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
@ -196,7 +210,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
# variable is subtracted by the accumulator, so the variable is subtracted
|
||||
# by 1.
|
||||
self.assertAllClose([0.0, 1.0], self.evaluate(var))
|
||||
self.assertEqual(self.evaluate(opt._loss_scale()), initial_loss_scale * 4)
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()), initial_loss_scale * 4)
|
||||
|
||||
run_op = strategy.experimental_run(run_fn)
|
||||
self._run_if_in_graph_mode(run_op)
|
||||
@ -205,7 +219,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
||||
# variable is subtracted by the accumulator, so the variable is subtracted
|
||||
# by 2.
|
||||
self.assertAllClose([-2., -1.], self.evaluate(var))
|
||||
self.assertEqual(self.evaluate(opt._loss_scale()),
|
||||
self.assertEqual(self.evaluate(opt.loss_scale()),
|
||||
initial_loss_scale * 16)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
|
@ -12,6 +12,10 @@ tf_class {
|
||||
name: "learning_rate"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "loss_scale"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "lr"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -48,6 +52,10 @@ tf_class {
|
||||
name: "get_gradients"
|
||||
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_scaled_loss"
|
||||
argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_slot"
|
||||
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -56,6 +64,10 @@ tf_class {
|
||||
name: "get_slot_names"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_unscaled_gradients"
|
||||
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates"
|
||||
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
||||
|
@ -12,6 +12,10 @@ tf_class {
|
||||
name: "learning_rate"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "loss_scale"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "lr"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -48,6 +52,10 @@ tf_class {
|
||||
name: "get_gradients"
|
||||
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_scaled_loss"
|
||||
argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_slot"
|
||||
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -56,6 +64,10 @@ tf_class {
|
||||
name: "get_slot_names"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_unscaled_gradients"
|
||||
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "get_updates"
|
||||
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
||||
|
Loading…
x
Reference in New Issue
Block a user