251 lines
10 KiB
Python
251 lines
10 KiB
Python
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Contains LossScale classes."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from tensorflow.python.distribute import distribution_strategy_context
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import smart_cond
|
|
from tensorflow.python.ops import control_flow_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.training import optimizer
|
|
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
|
from tensorflow.python.util.tf_export import tf_export
|
|
|
|
|
|
@tf_export(v1=['train.experimental.MixedPrecisionLossScaleOptimizer'])
|
|
class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
|
|
"""An optimizer that applies loss scaling.
|
|
|
|
Loss scaling is a process that multiplies the loss by a multiplier called the
|
|
loss scale, and divides each gradient by the same multiplier. The pseudocode
|
|
for this process is:
|
|
|
|
```
|
|
loss = ...
|
|
loss *= loss_scale
|
|
grads = gradients(loss, vars)
|
|
grads /= loss_scale
|
|
```
|
|
|
|
Mathematically, loss scaling has no effect, but can help avoid numerical
|
|
underflow in intermediate gradients when float16 tensors are used for mixed
|
|
precision training. By multiplying the loss, each intermediate gradient will
|
|
have the same multiplier applied.
|
|
|
|
The loss scale can either be a fixed constant, chosen by the user, or be
|
|
dynamically determined. Dynamically determining the loss scale is convenient
|
|
as a loss scale does not have to be explicitly chosen. However it reduces
|
|
performance.
|
|
|
|
This optimizer wraps another optimizer and applies loss scaling to it via a
|
|
`LossScale`. Loss scaling is applied whenever gradients are
|
|
computed, such as through `minimize()`.
|
|
"""
|
|
|
|
def __init__(self, opt, loss_scale):
|
|
if not isinstance(opt, optimizer.Optimizer):
|
|
raise ValueError('"opt" must be an instance of Optimizer, but got: %s' %
|
|
type(opt))
|
|
self._optimizer = opt
|
|
|
|
use_locking = opt._use_locking # pylint: disable=protected-access
|
|
name = opt.get_name()
|
|
super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name)
|
|
|
|
self._loss_scale = loss_scale_module.get(loss_scale)
|
|
if self._loss_scale is None:
|
|
raise ValueError('loss_scale cannot be None')
|
|
self._track_trackable(self._optimizer, 'base_optimizer')
|
|
self._track_trackable(self._loss_scale, 'loss_scale')
|
|
|
|
def _doing_dynamic_loss_scaling(self):
|
|
"""Check if `_loss_scale` dynamically manages the loss scale."""
|
|
return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale)
|
|
|
|
def compute_gradients(self,
|
|
loss,
|
|
var_list=None,
|
|
gate_gradients=optimizer.Optimizer.GATE_OP,
|
|
aggregation_method=None,
|
|
colocate_gradients_with_ops=False,
|
|
grad_loss=None):
|
|
"""Compute gradients of `loss` for the variables in `var_list`.
|
|
|
|
This adjusts the dynamic range of the gradient evaluation by scaling up
|
|
the `loss` value. The gradient values are then scaled back down by the
|
|
reciprocal of the loss scale. This is useful in reduced precision training
|
|
where small gradient values would otherwise underflow the representable
|
|
range.
|
|
|
|
Args:
|
|
loss: A Tensor containing the value to minimize or a callable taking no
|
|
arguments which returns the value to minimize. When eager execution is
|
|
enabled it must be a callable.
|
|
var_list: Optional list or tuple of `tf.Variable` to update to minimize
|
|
`loss`. Defaults to the list of variables collected in the graph under
|
|
the key `GraphKeys.TRAINABLE_VARIABLES`.
|
|
gate_gradients: How to gate the computation of gradients. Can be
|
|
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
|
|
aggregation_method: Specifies the method used to combine gradient terms.
|
|
Valid values are defined in the class `AggregationMethod`.
|
|
colocate_gradients_with_ops: If True, try colocating gradients with the
|
|
corresponding op.
|
|
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
|
|
|
|
Returns:
|
|
A list of (gradient, variable) pairs. Variable is always present, but
|
|
gradient can be `None`.
|
|
"""
|
|
loss = self._scale_loss(loss)
|
|
grads_and_vars = self._optimizer.compute_gradients(
|
|
loss=loss,
|
|
var_list=var_list,
|
|
gate_gradients=gate_gradients,
|
|
aggregation_method=aggregation_method,
|
|
colocate_gradients_with_ops=colocate_gradients_with_ops,
|
|
grad_loss=grad_loss)
|
|
|
|
grads = [g for g, _ in grads_and_vars]
|
|
variables = [v for _, v in grads_and_vars]
|
|
unscaled_grads = self._unscale_grads(grads)
|
|
return list(zip(unscaled_grads, variables))
|
|
|
|
def _scale_loss(self, loss):
|
|
loss_scale = self._loss_scale()
|
|
if callable(loss):
|
|
def new_loss():
|
|
loss_val = loss()
|
|
return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
|
|
return new_loss
|
|
else:
|
|
return loss * math_ops.cast(loss_scale, loss.dtype)
|
|
|
|
def _unscale_grads(self, grads):
|
|
loss_scale = self._loss_scale()
|
|
loss_scale_reciprocal = 1 / loss_scale
|
|
return [
|
|
None if g is None else self._scale_grad(g, loss_scale_reciprocal)
|
|
for g in grads
|
|
]
|
|
|
|
def _scale_grad(self, grad, loss_scale_reciprocal):
|
|
if isinstance(grad, ops.IndexedSlices):
|
|
grad_vals = grad.values * loss_scale_reciprocal
|
|
return ops.IndexedSlices(grad_vals, grad.indices, grad.dense_shape)
|
|
return grad * loss_scale_reciprocal
|
|
|
|
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
|
|
"""Apply gradients to variables.
|
|
|
|
This is the second part of `minimize()`. It returns an `Operation` that
|
|
conditionally applies gradients if all gradient values are finite.
|
|
Otherwise no update is performed (nor is `global_step` incremented).
|
|
|
|
Args:
|
|
grads_and_vars: List of (gradient, variable) pairs as returned by
|
|
`compute_gradients()`.
|
|
global_step: Optional `Variable` to increment by one after the variables
|
|
have been updated.
|
|
name: Optional name for the returned operation. Default to the name
|
|
passed to the `Optimizer` constructor.
|
|
|
|
Returns:
|
|
An `Operation` that conditionally applies the specified gradients. If
|
|
`global_step` was not None, that operation also increments `global_step`.
|
|
|
|
Raises:
|
|
RuntimeError: If you should use `_distributed_apply()` instead.
|
|
"""
|
|
if distribution_strategy_context.in_cross_replica_context():
|
|
raise ValueError('apply_gradients() must be called in a replica context.')
|
|
|
|
if not self._doing_dynamic_loss_scaling():
|
|
return self._optimizer.apply_gradients(grads_and_vars, global_step, name)
|
|
|
|
replica_context = distribution_strategy_context.get_replica_context()
|
|
grads_and_vars = tuple(grads_and_vars)
|
|
|
|
# TODO(nluehr) cleanup GraphKeys.TRAIN_OP
|
|
return replica_context.merge_call(
|
|
self._distributed_apply, args=(grads_and_vars, global_step, name))
|
|
|
|
def _distributed_apply(self,
|
|
distribution,
|
|
grads_and_vars,
|
|
global_step=None,
|
|
name=None):
|
|
"""A version of `apply_gradients` for cross replica context.
|
|
|
|
When users are in a cross replica strategy, they must call this rather than
|
|
`apply_gradients()`.
|
|
|
|
Args:
|
|
distribution: a `DistributionStrategy` object.
|
|
grads_and_vars: List of (gradient, variable) pairs as returned by
|
|
`compute_gradients()` and then aggregated across replicas.
|
|
global_step: Optional (mirrored) `Variable` to increment by one after the
|
|
variables have been updated.
|
|
name: Optional name for the returned operation. Default to the name passed
|
|
to the `Optimizer` constructor.
|
|
|
|
Returns:
|
|
An `Operation` that applies the specified gradients across all
|
|
replicas. If `global_step` was not None, that operation also
|
|
increments `global_step`
|
|
"""
|
|
name = name if name is not None else self.get_name()
|
|
grads = [g for g, _ in grads_and_vars]
|
|
loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads))
|
|
|
|
def apply_fn():
|
|
return self._apply_gradients(distribution, grads_and_vars, global_step,
|
|
name + '-wrapped')
|
|
|
|
maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
|
|
control_flow_ops.no_op)
|
|
return control_flow_ops.group(
|
|
maybe_apply_op, loss_scale_update_op, name=name)
|
|
|
|
def _apply_gradients(self, distribution, grads_and_vars, global_step, name):
|
|
"""Unconditionally apply gradients in cross replica context."""
|
|
update_ops = distribution.extended.call_for_each_replica(
|
|
self._optimizer.apply_gradients,
|
|
args=(grads_and_vars, global_step, name))
|
|
return distribution.group(update_ops)
|
|
|
|
def _apply_sparse(self, grad, var):
|
|
"""This function should never be called."""
|
|
raise RuntimeError('This function should never be called')
|
|
|
|
def _apply_dense(self, grad, var):
|
|
"""This function should never be called."""
|
|
raise RuntimeError('This function should never be called')
|
|
|
|
def _resource_apply_sparse(self, grad, handle, indices):
|
|
"""This function should never be called."""
|
|
raise RuntimeError('This function should never be called')
|
|
|
|
def _resource_apply_dense(self, grad, handle):
|
|
"""This function should never be called."""
|
|
raise RuntimeError('This function should never be called')
|
|
|
|
def variables(self):
|
|
"""Returns the variables of the Optimizer."""
|
|
return (self._optimizer.variables() +
|
|
list(self._loss_scale._weights.values())) # pylint: disable=protected-access
|