Add methods to LossScaleOptimizer to manually do loss scaling.

This can be used for users who use a GradientTape to compute gradients instead of LossScaleOptimizer.minimize.

PiperOrigin-RevId: 252511806
This commit is contained in:
Reed Wanderman-Milne 2019-06-10 16:39:29 -07:00 committed by TensorFlower Gardener
parent 9f3cbdf0bc
commit 9912d064b1
5 changed files with 133 additions and 31 deletions

View File

@ -240,13 +240,8 @@ def _process_single_batch(model,
raise ValueError('The model cannot be run '
'because it has no loss to optimize.')
if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
# TODO(reedwm): Make loss_scale public instead of accessing private
# _loss_scale attribute.
loss_scale = model.optimizer._loss_scale()
scaled_total_loss = loss_scale_optimizer.scale_loss(total_loss,
loss_scale)
scaled_total_loss = model.optimizer.get_scaled_loss(total_loss)
else:
loss_scale = None
scaled_total_loss = total_loss
if training:
if not model.trainable_weights:
@ -255,8 +250,8 @@ def _process_single_batch(model,
'compiling the model.')
else:
grads = tape.gradient(scaled_total_loss, model.trainable_weights)
if loss_scale is not None:
grads = loss_scale_optimizer.unscale_grads(grads, loss_scale)
if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
grads = model.optimizer.get_unscaled_gradients(grads)
model.optimizer.apply_gradients(zip(grads, model.trainable_weights))
model._set_trainable_state(current_trainable_state)
return outs, total_loss, output_losses, masks

View File

@ -42,20 +42,6 @@ class _UnwrapPreventer(object):
self.value = value
def scale_loss(loss, loss_scale):
"""Scales the loss by the loss scale."""
if callable(loss):
return lambda: loss() * loss_scale
else:
return loss * loss_scale
def unscale_grads(grads, loss_scale):
"""Unscales the gradients by the loss scale."""
loss_scale_reciprocal = 1. / loss_scale
return [g * loss_scale_reciprocal if g is not None else None for g in grads]
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
"""An optimizer that applies loss scaling.
@ -83,7 +69,34 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
This optimizer wraps another optimizer and applies loss scaling to it via a
`LossScale`. Loss scaling is applied whenever gradients are
computed, either through `minimize()` or `get_gradients()`.
computed, either through `minimize()` or `get_gradients()`. The loss scale is
updated via `LossScale.update()` whenever gradients are applied, either
through `minimize()` or `apply_gradients()`. For example:
```python
opt = tf.keras.optimizers.SGD(0.1)
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
# 'minimize' applies loss scaling to the loss and updates the loss sale.
opt.minimize(loss_fn)
```
If a `tf.GradientTape` is used to compute gradients instead of
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, the loss
and gradients must be scaled manually. This can be done by calling
`LossScaleOptimizer.get_scaled_loss` before passing the loss to
`tf.GradientTape`, and `LossScaleOptimizer.get_unscaled_gradients` after
computing the gradients with `tf.GradientTape`. For example:
```python
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...)
vars = ...
with tf.GradientTape() as tape:
loss = ...
scaled_loss = opt.get_scaled_loss(loss)
scaled_grads = tape.gradient(scaled_loss, vars)
grads = opt.get_unscaled_gradients(scaled_grads)
opt.apply_gradients(zip(grads, vars)) # Loss scale will be updated here
```
"""
def __init__(self, opt, loss_scale):
@ -123,19 +136,75 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
self._track_trackable(self._optimizer, 'base_optimizer')
self._track_trackable(self._loss_scale, 'loss_scale')
@property
def loss_scale(self):
"""The `LossScale` instance associated with this optimizer."""
return self._loss_scale
def get_scaled_loss(self, loss):
"""Scales the loss by the loss scale.
This method is only needed if you compute gradients manually, e.g. with
`tf.GradientTape`. In that case, call this method to scale the loss before
passing the loss to `tf.GradientTape`. If you use
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
scaling is automatically applied and this method is unneeded.
If this method is called, `get_unscaled_gradients` should also be called.
See the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for
an example.
Args:
loss: The loss, which will be multiplied by the loss scale. Can either be
a tensor or a callable returning a tensor.
Returns:
`loss` multiplied by `LossScaleOptimizer.loss_scale()`.
"""
loss_scale = self._loss_scale()
if callable(loss):
return lambda: loss() * loss_scale
else:
return loss * loss_scale
def get_unscaled_gradients(self, grads):
"""Unscales the gradients by the loss scale.
This method is only needed if you compute gradients manually, e.g. with
`tf.GradientTape`. In that case, call this method to unscale the gradients
after computing them with `tf.GradientTape`. If you use
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
scaling is automatically applied and this method is unneeded.
If this method is called, `get_scaled_loss` should also be called. See
the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for an
example.
Args:
grads: A list of tensors, each which will be divided by the loss scale.
Can have None values, which are ignored.
Returns:
A new list the same size as `grads`, where every non-None value in `grads`
is divided by `LossScaleOptimizer.loss_scale()`.
"""
loss_scale = self._loss_scale()
loss_scale_reciprocal = 1. / loss_scale
return [g * loss_scale_reciprocal if g is not None else None for g in grads]
def _compute_gradients(self, loss, var_list, grad_loss=None):
loss = scale_loss(loss, self._loss_scale())
loss = self.get_scaled_loss(loss)
grads_and_vars = self._optimizer._compute_gradients(loss, var_list, # pylint: disable=protected-access
grad_loss)
grads = [g for g, _ in grads_and_vars]
variables = [v for _, v in grads_and_vars]
unscaled_grads = unscale_grads(grads, self._loss_scale())
unscaled_grads = self.get_unscaled_gradients(grads)
return list(zip(unscaled_grads, variables))
def get_gradients(self, loss, params):
loss = scale_loss(loss, self._loss_scale())
loss = self.get_scaled_loss(loss)
grads = self._optimizer.get_gradients(loss, params)
return unscale_grads(grads, self._loss_scale())
return self.get_unscaled_gradients(grads)
def apply_gradients(self, grads_and_vars, name=None):
if distribution_strategy_context.in_cross_replica_context():

View File

@ -109,6 +109,20 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# mp_test_util.create_identity_with_grad_check_fn added an assertion op.
self.evaluate(run_op)
@test_util.run_in_graph_and_eager_modes
def testGetScaledLoss(self):
opt = gradient_descent.SGD(2.0)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2.)
self.assertEqual(10., self.evaluate(opt.get_scaled_loss(5.)))
@test_util.run_in_graph_and_eager_modes
def testGetUnscaledGradients(self):
opt = gradient_descent.SGD(2.0)
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2)
grads = opt.get_unscaled_gradients([3., None, -4.])
grads = [self.evaluate(g) if g is not None else g for g in grads]
self.assertEqual([1.5, None, -2.], grads)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def testDynamicLossScale(self, strategy_fn):
@ -162,7 +176,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# Gradient is 2, so variable will have 2 subtracted from it
self.assertAllClose([-1.0, 0.0], self.evaluate(var))
# Loss scale has doubled from 2 to 4
self.assertEqual(4., self.evaluate(opt._loss_scale()))
self.assertEqual(4., self.evaluate(opt.loss_scale()))
# Test optimizer with NaN gradients
loss = lambda: var * float('NaN')
@ -172,7 +186,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# Variable should not change from before, due to NaN gradients.
self.assertAllClose(self.evaluate(var), [-1.0, 0.0])
# Loss scale should half due to NaN gradients.
self.assertEqual(2., self.evaluate(opt._loss_scale()))
self.assertEqual(2., self.evaluate(opt.loss_scale()))
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
@ -196,7 +210,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# variable is subtracted by the accumulator, so the variable is subtracted
# by 1.
self.assertAllClose([0.0, 1.0], self.evaluate(var))
self.assertEqual(self.evaluate(opt._loss_scale()), initial_loss_scale * 4)
self.assertEqual(self.evaluate(opt.loss_scale()), initial_loss_scale * 4)
run_op = strategy.experimental_run(run_fn)
self._run_if_in_graph_mode(run_op)
@ -205,7 +219,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
# variable is subtracted by the accumulator, so the variable is subtracted
# by 2.
self.assertAllClose([-2., -1.], self.evaluate(var))
self.assertEqual(self.evaluate(opt._loss_scale()),
self.assertEqual(self.evaluate(opt.loss_scale()),
initial_loss_scale * 16)
@test_util.run_in_graph_and_eager_modes

View File

@ -12,6 +12,10 @@ tf_class {
name: "learning_rate"
mtype: "<type \'property\'>"
}
member {
name: "loss_scale"
mtype: "<type \'property\'>"
}
member {
name: "lr"
mtype: "<type \'property\'>"
@ -48,6 +52,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_scaled_loss"
argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
@ -56,6 +64,10 @@ tf_class {
name: "get_slot_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_unscaled_gradients"
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"

View File

@ -12,6 +12,10 @@ tf_class {
name: "learning_rate"
mtype: "<type \'property\'>"
}
member {
name: "loss_scale"
mtype: "<type \'property\'>"
}
member {
name: "lr"
mtype: "<type \'property\'>"
@ -48,6 +52,10 @@ tf_class {
name: "get_gradients"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_scaled_loss"
argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
@ -56,6 +64,10 @@ tf_class {
name: "get_slot_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_unscaled_gradients"
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_updates"
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"