TensorFlower Gardener 6cbb2f3405 Merge pull request #37965 from zhuzilin:keras-amp-variables-dev
PiperOrigin-RevId: 305096445
Change-Id: I2001cc63b77e99ea6aff819e94e8c7f366282da2
2020-04-06 12:54:33 -07:00

251 lines
10 KiB
Python

# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains LossScale classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=['train.experimental.MixedPrecisionLossScaleOptimizer'])
class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
"""An optimizer that applies loss scaling.
Loss scaling is a process that multiplies the loss by a multiplier called the
loss scale, and divides each gradient by the same multiplier. The pseudocode
for this process is:
```
loss = ...
loss *= loss_scale
grads = gradients(loss, vars)
grads /= loss_scale
```
Mathematically, loss scaling has no effect, but can help avoid numerical
underflow in intermediate gradients when float16 tensors are used for mixed
precision training. By multiplying the loss, each intermediate gradient will
have the same multiplier applied.
The loss scale can either be a fixed constant, chosen by the user, or be
dynamically determined. Dynamically determining the loss scale is convenient
as a loss scale does not have to be explicitly chosen. However it reduces
performance.
This optimizer wraps another optimizer and applies loss scaling to it via a
`LossScale`. Loss scaling is applied whenever gradients are
computed, such as through `minimize()`.
"""
def __init__(self, opt, loss_scale):
if not isinstance(opt, optimizer.Optimizer):
raise ValueError('"opt" must be an instance of Optimizer, but got: %s' %
type(opt))
self._optimizer = opt
use_locking = opt._use_locking # pylint: disable=protected-access
name = opt.get_name()
super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name)
self._loss_scale = loss_scale_module.get(loss_scale)
if self._loss_scale is None:
raise ValueError('loss_scale cannot be None')
self._track_trackable(self._optimizer, 'base_optimizer')
self._track_trackable(self._loss_scale, 'loss_scale')
def _doing_dynamic_loss_scaling(self):
"""Check if `_loss_scale` dynamically manages the loss scale."""
return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale)
def compute_gradients(self,
loss,
var_list=None,
gate_gradients=optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None):
"""Compute gradients of `loss` for the variables in `var_list`.
This adjusts the dynamic range of the gradient evaluation by scaling up
the `loss` value. The gradient values are then scaled back down by the
reciprocal of the loss scale. This is useful in reduced precision training
where small gradient values would otherwise underflow the representable
range.
Args:
loss: A Tensor containing the value to minimize or a callable taking no
arguments which returns the value to minimize. When eager execution is
enabled it must be a callable.
var_list: Optional list or tuple of `tf.Variable` to update to minimize
`loss`. Defaults to the list of variables collected in the graph under
the key `GraphKeys.TRAINABLE_VARIABLES`.
gate_gradients: How to gate the computation of gradients. Can be
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
aggregation_method: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
colocate_gradients_with_ops: If True, try colocating gradients with the
corresponding op.
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
Returns:
A list of (gradient, variable) pairs. Variable is always present, but
gradient can be `None`.
"""
loss = self._scale_loss(loss)
grads_and_vars = self._optimizer.compute_gradients(
loss=loss,
var_list=var_list,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
grad_loss=grad_loss)
grads = [g for g, _ in grads_and_vars]
variables = [v for _, v in grads_and_vars]
unscaled_grads = self._unscale_grads(grads)
return list(zip(unscaled_grads, variables))
def _scale_loss(self, loss):
loss_scale = self._loss_scale()
if callable(loss):
def new_loss():
loss_val = loss()
return loss_val * math_ops.cast(loss_scale, loss_val.dtype)
return new_loss
else:
return loss * math_ops.cast(loss_scale, loss.dtype)
def _unscale_grads(self, grads):
loss_scale = self._loss_scale()
loss_scale_reciprocal = 1 / loss_scale
return [
None if g is None else self._scale_grad(g, loss_scale_reciprocal)
for g in grads
]
def _scale_grad(self, grad, loss_scale_reciprocal):
if isinstance(grad, ops.IndexedSlices):
grad_vals = grad.values * loss_scale_reciprocal
return ops.IndexedSlices(grad_vals, grad.indices, grad.dense_shape)
return grad * loss_scale_reciprocal
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Apply gradients to variables.
This is the second part of `minimize()`. It returns an `Operation` that
conditionally applies gradients if all gradient values are finite.
Otherwise no update is performed (nor is `global_step` incremented).
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the variables
have been updated.
name: Optional name for the returned operation. Default to the name
passed to the `Optimizer` constructor.
Returns:
An `Operation` that conditionally applies the specified gradients. If
`global_step` was not None, that operation also increments `global_step`.
Raises:
RuntimeError: If you should use `_distributed_apply()` instead.
"""
if distribution_strategy_context.in_cross_replica_context():
raise ValueError('apply_gradients() must be called in a replica context.')
if not self._doing_dynamic_loss_scaling():
return self._optimizer.apply_gradients(grads_and_vars, global_step, name)
replica_context = distribution_strategy_context.get_replica_context()
grads_and_vars = tuple(grads_and_vars)
# TODO(nluehr) cleanup GraphKeys.TRAIN_OP
return replica_context.merge_call(
self._distributed_apply, args=(grads_and_vars, global_step, name))
def _distributed_apply(self,
distribution,
grads_and_vars,
global_step=None,
name=None):
"""A version of `apply_gradients` for cross replica context.
When users are in a cross replica strategy, they must call this rather than
`apply_gradients()`.
Args:
distribution: a `DistributionStrategy` object.
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()` and then aggregated across replicas.
global_step: Optional (mirrored) `Variable` to increment by one after the
variables have been updated.
name: Optional name for the returned operation. Default to the name passed
to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients across all
replicas. If `global_step` was not None, that operation also
increments `global_step`
"""
name = name if name is not None else self.get_name()
grads = [g for g, _ in grads_and_vars]
loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads))
def apply_fn():
return self._apply_gradients(distribution, grads_and_vars, global_step,
name + '-wrapped')
maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
control_flow_ops.no_op)
return control_flow_ops.group(
maybe_apply_op, loss_scale_update_op, name=name)
def _apply_gradients(self, distribution, grads_and_vars, global_step, name):
"""Unconditionally apply gradients in cross replica context."""
update_ops = distribution.extended.call_for_each_replica(
self._optimizer.apply_gradients,
args=(grads_and_vars, global_step, name))
return distribution.group(update_ops)
def _apply_sparse(self, grad, var):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def _apply_dense(self, grad, var):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def _resource_apply_sparse(self, grad, handle, indices):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def _resource_apply_dense(self, grad, handle):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def variables(self):
"""Returns the variables of the Optimizer."""
return (self._optimizer.variables() +
list(self._loss_scale._weights.values())) # pylint: disable=protected-access