Add methods to LossScaleOptimizer to manually do loss scaling.
This can be used for users who use a GradientTape to compute gradients instead of LossScaleOptimizer.minimize. PiperOrigin-RevId: 252511806
This commit is contained in:
parent
9f3cbdf0bc
commit
9912d064b1
@ -240,13 +240,8 @@ def _process_single_batch(model,
|
|||||||
raise ValueError('The model cannot be run '
|
raise ValueError('The model cannot be run '
|
||||||
'because it has no loss to optimize.')
|
'because it has no loss to optimize.')
|
||||||
if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
|
if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
|
||||||
# TODO(reedwm): Make loss_scale public instead of accessing private
|
scaled_total_loss = model.optimizer.get_scaled_loss(total_loss)
|
||||||
# _loss_scale attribute.
|
|
||||||
loss_scale = model.optimizer._loss_scale()
|
|
||||||
scaled_total_loss = loss_scale_optimizer.scale_loss(total_loss,
|
|
||||||
loss_scale)
|
|
||||||
else:
|
else:
|
||||||
loss_scale = None
|
|
||||||
scaled_total_loss = total_loss
|
scaled_total_loss = total_loss
|
||||||
if training:
|
if training:
|
||||||
if not model.trainable_weights:
|
if not model.trainable_weights:
|
||||||
@ -255,8 +250,8 @@ def _process_single_batch(model,
|
|||||||
'compiling the model.')
|
'compiling the model.')
|
||||||
else:
|
else:
|
||||||
grads = tape.gradient(scaled_total_loss, model.trainable_weights)
|
grads = tape.gradient(scaled_total_loss, model.trainable_weights)
|
||||||
if loss_scale is not None:
|
if isinstance(model.optimizer, loss_scale_optimizer.LossScaleOptimizer):
|
||||||
grads = loss_scale_optimizer.unscale_grads(grads, loss_scale)
|
grads = model.optimizer.get_unscaled_gradients(grads)
|
||||||
model.optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
model.optimizer.apply_gradients(zip(grads, model.trainable_weights))
|
||||||
model._set_trainable_state(current_trainable_state)
|
model._set_trainable_state(current_trainable_state)
|
||||||
return outs, total_loss, output_losses, masks
|
return outs, total_loss, output_losses, masks
|
||||||
|
@ -42,20 +42,6 @@ class _UnwrapPreventer(object):
|
|||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
|
|
||||||
def scale_loss(loss, loss_scale):
|
|
||||||
"""Scales the loss by the loss scale."""
|
|
||||||
if callable(loss):
|
|
||||||
return lambda: loss() * loss_scale
|
|
||||||
else:
|
|
||||||
return loss * loss_scale
|
|
||||||
|
|
||||||
|
|
||||||
def unscale_grads(grads, loss_scale):
|
|
||||||
"""Unscales the gradients by the loss scale."""
|
|
||||||
loss_scale_reciprocal = 1. / loss_scale
|
|
||||||
return [g * loss_scale_reciprocal if g is not None else None for g in grads]
|
|
||||||
|
|
||||||
|
|
||||||
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
|
@keras_export('keras.mixed_precision.experimental.LossScaleOptimizer')
|
||||||
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
||||||
"""An optimizer that applies loss scaling.
|
"""An optimizer that applies loss scaling.
|
||||||
@ -83,7 +69,34 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
|||||||
|
|
||||||
This optimizer wraps another optimizer and applies loss scaling to it via a
|
This optimizer wraps another optimizer and applies loss scaling to it via a
|
||||||
`LossScale`. Loss scaling is applied whenever gradients are
|
`LossScale`. Loss scaling is applied whenever gradients are
|
||||||
computed, either through `minimize()` or `get_gradients()`.
|
computed, either through `minimize()` or `get_gradients()`. The loss scale is
|
||||||
|
updated via `LossScale.update()` whenever gradients are applied, either
|
||||||
|
through `minimize()` or `apply_gradients()`. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
opt = tf.keras.optimizers.SGD(0.1)
|
||||||
|
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(opt, "dynamic")
|
||||||
|
# 'minimize' applies loss scaling to the loss and updates the loss sale.
|
||||||
|
opt.minimize(loss_fn)
|
||||||
|
```
|
||||||
|
|
||||||
|
If a `tf.GradientTape` is used to compute gradients instead of
|
||||||
|
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, the loss
|
||||||
|
and gradients must be scaled manually. This can be done by calling
|
||||||
|
`LossScaleOptimizer.get_scaled_loss` before passing the loss to
|
||||||
|
`tf.GradientTape`, and `LossScaleOptimizer.get_unscaled_gradients` after
|
||||||
|
computing the gradients with `tf.GradientTape`. For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
opt = tf.keras.mixed_precision.experimental.LossScaleOptimizer(...)
|
||||||
|
vars = ...
|
||||||
|
with tf.GradientTape() as tape:
|
||||||
|
loss = ...
|
||||||
|
scaled_loss = opt.get_scaled_loss(loss)
|
||||||
|
scaled_grads = tape.gradient(scaled_loss, vars)
|
||||||
|
grads = opt.get_unscaled_gradients(scaled_grads)
|
||||||
|
opt.apply_gradients(zip(grads, vars)) # Loss scale will be updated here
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, opt, loss_scale):
|
def __init__(self, opt, loss_scale):
|
||||||
@ -123,19 +136,75 @@ class LossScaleOptimizer(optimizer_v2.OptimizerV2):
|
|||||||
self._track_trackable(self._optimizer, 'base_optimizer')
|
self._track_trackable(self._optimizer, 'base_optimizer')
|
||||||
self._track_trackable(self._loss_scale, 'loss_scale')
|
self._track_trackable(self._loss_scale, 'loss_scale')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loss_scale(self):
|
||||||
|
"""The `LossScale` instance associated with this optimizer."""
|
||||||
|
return self._loss_scale
|
||||||
|
|
||||||
|
def get_scaled_loss(self, loss):
|
||||||
|
"""Scales the loss by the loss scale.
|
||||||
|
|
||||||
|
This method is only needed if you compute gradients manually, e.g. with
|
||||||
|
`tf.GradientTape`. In that case, call this method to scale the loss before
|
||||||
|
passing the loss to `tf.GradientTape`. If you use
|
||||||
|
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
|
||||||
|
scaling is automatically applied and this method is unneeded.
|
||||||
|
|
||||||
|
If this method is called, `get_unscaled_gradients` should also be called.
|
||||||
|
See the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for
|
||||||
|
an example.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss: The loss, which will be multiplied by the loss scale. Can either be
|
||||||
|
a tensor or a callable returning a tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`loss` multiplied by `LossScaleOptimizer.loss_scale()`.
|
||||||
|
"""
|
||||||
|
loss_scale = self._loss_scale()
|
||||||
|
if callable(loss):
|
||||||
|
return lambda: loss() * loss_scale
|
||||||
|
else:
|
||||||
|
return loss * loss_scale
|
||||||
|
|
||||||
|
def get_unscaled_gradients(self, grads):
|
||||||
|
"""Unscales the gradients by the loss scale.
|
||||||
|
|
||||||
|
This method is only needed if you compute gradients manually, e.g. with
|
||||||
|
`tf.GradientTape`. In that case, call this method to unscale the gradients
|
||||||
|
after computing them with `tf.GradientTape`. If you use
|
||||||
|
`LossScaleOptimizer.minimize` or `LossScaleOptimizer.get_gradients`, loss
|
||||||
|
scaling is automatically applied and this method is unneeded.
|
||||||
|
|
||||||
|
If this method is called, `get_scaled_loss` should also be called. See
|
||||||
|
the `tf.keras.mixed_precision.experimental.LossScaleOptimizer` doc for an
|
||||||
|
example.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grads: A list of tensors, each which will be divided by the loss scale.
|
||||||
|
Can have None values, which are ignored.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new list the same size as `grads`, where every non-None value in `grads`
|
||||||
|
is divided by `LossScaleOptimizer.loss_scale()`.
|
||||||
|
"""
|
||||||
|
loss_scale = self._loss_scale()
|
||||||
|
loss_scale_reciprocal = 1. / loss_scale
|
||||||
|
return [g * loss_scale_reciprocal if g is not None else None for g in grads]
|
||||||
|
|
||||||
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
def _compute_gradients(self, loss, var_list, grad_loss=None):
|
||||||
loss = scale_loss(loss, self._loss_scale())
|
loss = self.get_scaled_loss(loss)
|
||||||
grads_and_vars = self._optimizer._compute_gradients(loss, var_list, # pylint: disable=protected-access
|
grads_and_vars = self._optimizer._compute_gradients(loss, var_list, # pylint: disable=protected-access
|
||||||
grad_loss)
|
grad_loss)
|
||||||
grads = [g for g, _ in grads_and_vars]
|
grads = [g for g, _ in grads_and_vars]
|
||||||
variables = [v for _, v in grads_and_vars]
|
variables = [v for _, v in grads_and_vars]
|
||||||
unscaled_grads = unscale_grads(grads, self._loss_scale())
|
unscaled_grads = self.get_unscaled_gradients(grads)
|
||||||
return list(zip(unscaled_grads, variables))
|
return list(zip(unscaled_grads, variables))
|
||||||
|
|
||||||
def get_gradients(self, loss, params):
|
def get_gradients(self, loss, params):
|
||||||
loss = scale_loss(loss, self._loss_scale())
|
loss = self.get_scaled_loss(loss)
|
||||||
grads = self._optimizer.get_gradients(loss, params)
|
grads = self._optimizer.get_gradients(loss, params)
|
||||||
return unscale_grads(grads, self._loss_scale())
|
return self.get_unscaled_gradients(grads)
|
||||||
|
|
||||||
def apply_gradients(self, grads_and_vars, name=None):
|
def apply_gradients(self, grads_and_vars, name=None):
|
||||||
if distribution_strategy_context.in_cross_replica_context():
|
if distribution_strategy_context.in_cross_replica_context():
|
||||||
|
@ -109,6 +109,20 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
# mp_test_util.create_identity_with_grad_check_fn added an assertion op.
|
# mp_test_util.create_identity_with_grad_check_fn added an assertion op.
|
||||||
self.evaluate(run_op)
|
self.evaluate(run_op)
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testGetScaledLoss(self):
|
||||||
|
opt = gradient_descent.SGD(2.0)
|
||||||
|
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2.)
|
||||||
|
self.assertEqual(10., self.evaluate(opt.get_scaled_loss(5.)))
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def testGetUnscaledGradients(self):
|
||||||
|
opt = gradient_descent.SGD(2.0)
|
||||||
|
opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale=2)
|
||||||
|
grads = opt.get_unscaled_gradients([3., None, -4.])
|
||||||
|
grads = [self.evaluate(g) if g is not None else g for g in grads]
|
||||||
|
self.assertEqual([1.5, None, -2.], grads)
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@parameterized.named_parameters(*TESTCASES)
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
def testDynamicLossScale(self, strategy_fn):
|
def testDynamicLossScale(self, strategy_fn):
|
||||||
@ -162,7 +176,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
# Gradient is 2, so variable will have 2 subtracted from it
|
# Gradient is 2, so variable will have 2 subtracted from it
|
||||||
self.assertAllClose([-1.0, 0.0], self.evaluate(var))
|
self.assertAllClose([-1.0, 0.0], self.evaluate(var))
|
||||||
# Loss scale has doubled from 2 to 4
|
# Loss scale has doubled from 2 to 4
|
||||||
self.assertEqual(4., self.evaluate(opt._loss_scale()))
|
self.assertEqual(4., self.evaluate(opt.loss_scale()))
|
||||||
|
|
||||||
# Test optimizer with NaN gradients
|
# Test optimizer with NaN gradients
|
||||||
loss = lambda: var * float('NaN')
|
loss = lambda: var * float('NaN')
|
||||||
@ -172,7 +186,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
# Variable should not change from before, due to NaN gradients.
|
# Variable should not change from before, due to NaN gradients.
|
||||||
self.assertAllClose(self.evaluate(var), [-1.0, 0.0])
|
self.assertAllClose(self.evaluate(var), [-1.0, 0.0])
|
||||||
# Loss scale should half due to NaN gradients.
|
# Loss scale should half due to NaN gradients.
|
||||||
self.assertEqual(2., self.evaluate(opt._loss_scale()))
|
self.assertEqual(2., self.evaluate(opt.loss_scale()))
|
||||||
|
|
||||||
@parameterized.named_parameters(*TESTCASES)
|
@parameterized.named_parameters(*TESTCASES)
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
@ -196,7 +210,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
# variable is subtracted by the accumulator, so the variable is subtracted
|
# variable is subtracted by the accumulator, so the variable is subtracted
|
||||||
# by 1.
|
# by 1.
|
||||||
self.assertAllClose([0.0, 1.0], self.evaluate(var))
|
self.assertAllClose([0.0, 1.0], self.evaluate(var))
|
||||||
self.assertEqual(self.evaluate(opt._loss_scale()), initial_loss_scale * 4)
|
self.assertEqual(self.evaluate(opt.loss_scale()), initial_loss_scale * 4)
|
||||||
|
|
||||||
run_op = strategy.experimental_run(run_fn)
|
run_op = strategy.experimental_run(run_fn)
|
||||||
self._run_if_in_graph_mode(run_op)
|
self._run_if_in_graph_mode(run_op)
|
||||||
@ -205,7 +219,7 @@ class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase):
|
|||||||
# variable is subtracted by the accumulator, so the variable is subtracted
|
# variable is subtracted by the accumulator, so the variable is subtracted
|
||||||
# by 2.
|
# by 2.
|
||||||
self.assertAllClose([-2., -1.], self.evaluate(var))
|
self.assertAllClose([-2., -1.], self.evaluate(var))
|
||||||
self.assertEqual(self.evaluate(opt._loss_scale()),
|
self.assertEqual(self.evaluate(opt.loss_scale()),
|
||||||
initial_loss_scale * 16)
|
initial_loss_scale * 16)
|
||||||
|
|
||||||
@test_util.run_in_graph_and_eager_modes
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
@ -12,6 +12,10 @@ tf_class {
|
|||||||
name: "learning_rate"
|
name: "learning_rate"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "loss_scale"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "lr"
|
name: "lr"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -48,6 +52,10 @@ tf_class {
|
|||||||
name: "get_gradients"
|
name: "get_gradients"
|
||||||
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_scaled_loss"
|
||||||
|
argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_slot"
|
name: "get_slot"
|
||||||
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
|
||||||
@ -56,6 +64,10 @@ tf_class {
|
|||||||
name: "get_slot_names"
|
name: "get_slot_names"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_unscaled_gradients"
|
||||||
|
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_updates"
|
name: "get_updates"
|
||||||
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -12,6 +12,10 @@ tf_class {
|
|||||||
name: "learning_rate"
|
name: "learning_rate"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
}
|
}
|
||||||
|
member {
|
||||||
|
name: "loss_scale"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
member {
|
member {
|
||||||
name: "lr"
|
name: "lr"
|
||||||
mtype: "<type \'property\'>"
|
mtype: "<type \'property\'>"
|
||||||
@ -48,6 +52,10 @@ tf_class {
|
|||||||
name: "get_gradients"
|
name: "get_gradients"
|
||||||
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_scaled_loss"
|
||||||
|
argspec: "args=[\'self\', \'loss\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_slot"
|
name: "get_slot"
|
||||||
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'var\', \'slot_name\'], varargs=None, keywords=None, defaults=None"
|
||||||
@ -56,6 +64,10 @@ tf_class {
|
|||||||
name: "get_slot_names"
|
name: "get_slot_names"
|
||||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "get_unscaled_gradients"
|
||||||
|
argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "get_updates"
|
name: "get_updates"
|
||||||
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'self\', \'loss\', \'params\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user