Merge pull request #27655 from MattConley:loss_scaling_optimizer

PiperOrigin-RevId: 243664104
This commit is contained in:
TensorFlower Gardener 2019-04-15 12:37:36 -07:00
commit 447e512d33
10 changed files with 1229 additions and 1 deletions

View File

@ -2858,6 +2858,53 @@ py_library(
],
)
py_library(
name = "loss_scale",
srcs = ["training/experimental/loss_scale.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework",
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "loss_scale_optimizer",
srcs = ["training/experimental/loss_scale_optimizer.py"],
srcs_version = "PY2AND3",
deps = [
":loss_scale",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "loss_scale_optimizer_test",
size = "small",
srcs = ["training/experimental/loss_scale_optimizer_test.py"],
deps = [
":loss_scale_optimizer",
"//tensorflow/python:client_testlib",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:one_device_strategy",
"//tensorflow/python/keras/mixed_precision/experimental:test_util",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "loss_scale_test",
size = "small",
srcs = ["training/experimental/loss_scale_test.py"],
deps = [
":loss_scale",
"//tensorflow/python:client_testlib",
"//tensorflow/python/distribute:mirrored_strategy",
"//tensorflow/python/distribute:one_device_strategy",
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "math_grad",
srcs = ["ops/math_grad.py"],
@ -3962,6 +4009,8 @@ py_library(
":io_ops",
":layers_util",
":lookup_ops",
":loss_scale",
":loss_scale_optimizer",
":math_ops",
":platform",
":pywrap_tensorflow",

View File

@ -0,0 +1,352 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains LossScale classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.ops import variable_scope
# TODO(reedwm): Merge this with tf.keras.mixed_precision.experimental.LossScale
@six.add_metaclass(abc.ABCMeta)
class LossScale(trackable.Trackable):
"""Loss scale base class.
Instances of this class represent a loss scale. Calling instances of this
class returns the loss scale as a scalar float32 tensor, while method
`update()` updates the loss scale depending on the values of the gradients.
Optimizers use instances of this class to scale loss and gradients.
Note: this LossScale class can only be used with a v1 optimizer wrapper,
tf.train.experimental.MixedPrecisionLossScaleOptimizer. For a v2
wrapper, tf.keras.mixed_precision.experimental.LossScaleOptimizer, a
tf.keras.mixed_precision.experimental.LossScale should be used instead.
"""
def __init__(self):
"""Initializes the loss scale class."""
self._weights = {}
@abc.abstractmethod
def __call__(self):
"""Returns the current loss scale as a scalar `float32` tensor."""
pass
@abc.abstractmethod
def update(self, grads):
"""Updates the value of the loss scale.
The loss scale will be potentially updated, based on the value of `grads`.
The tensor returned by calling this class is only updated when this function
is evaluated.
In eager mode, this directly updates the loss scale, so that calling
`__call__` will return the newly updated loss scale. In graph mode,
this returns an op that, when evaluated, updates the loss scale.
This function also returns a `should_apply_gradients` bool. If False,
gradients should not be applied to the variables that step, as nonfinite
gradients were found, and the loss scale has been be updated to reduce the
chance of finding nonfinite gradients in the next step. Some loss scale
classes will always return True, as they cannot adjust themselves in
response to nonfinite gradients.
When a DistributionStrategy is used, this function may only be called in a
cross-replica context.
Args:
grads: A list of unscaled gradients, each which is the gradient of the
loss with respect to a weight. The gradients should have already been
divided by the loss scale being before passed to this function.
Returns:
update_op: In eager mode, None. In graph mode, an op to update the loss
scale.
should_apply_gradients: Either a bool or a scalar boolean tensor. If
False, the caller should skip applying `grads` to the variables this
step.
"""
pass
def _add_weight(self, name, initial_value, dtype=None):
"""Adds a weight to this loss scale manager..
Args:
name: Variable name.
initial_value: The variable's initial value.
dtype: The type of the variable.
Returns:
A variable.
Raises:
RuntimeError: If a weight with `name` has already been added.
"""
variable = variable_scope.variable(
initial_value=initial_value,
name=name,
dtype=dtype,
trainable=False,
use_resource=True,
synchronization=variables.VariableSynchronization.AUTO,
# Set aggregation to NONE, as loss scaling variables should never be
# aggregated.
aggregation=variables.VariableAggregation.NONE)
if context.executing_eagerly():
graph_key = None
else:
graph = ops.get_default_graph()
graph_key = graph._graph_key # pylint: disable=protected-access
key = (name, graph_key)
if self._weights.get(key, None) is not None:
raise RuntimeError('Duplicate variables detected. {}'.format(key))
self._weights[key] = variable
self._handle_deferred_dependencies(name=name, trackable=variable)
return variable
@property
def _checkpoint_dependencies(self):
"""From Trackable. Gather graph-specific weights to save."""
if context.executing_eagerly():
graph_key = None
else:
graph = ops.get_default_graph()
graph_key = graph._graph_key # pylint: disable=protected-access
weights = []
for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]):
if g == graph_key:
weights.append(trackable.TrackableReference(name=name, ref=v))
return super(LossScale, self)._checkpoint_dependencies + weights
def _lookup_dependency(self, name):
"""From Trackable. Find a weight in the current graph."""
unconditional = super(LossScale, self)._lookup_dependency(name)
if unconditional is not None:
return unconditional
if context.executing_eagerly():
graph_key = None
else:
graph = ops.get_default_graph()
graph_key = graph._graph_key # pylint: disable=protected-access
return self._weights.get((name, graph_key), None)
class FixedLossScale(LossScale):
"""Loss scale class with a fixed value.
The loss scale is not updated for the lifetime of the class.
"""
def __init__(self, loss_scale_value):
"""Creates the fixed loss scale.
Args:
loss_scale_value: A Python float. Its ideal value varies depending on
models to run. Choosing a too small loss_scale might affect model
quality; a too big loss_scale might cause inf or nan. There is no single
right loss_scale to apply. There is no harm choosing a relatively big
number as long as no nan or inf is encountered in training.
Raises:
ValueError: If loss_scale is less than 1.
"""
super(FixedLossScale, self).__init__()
if not isinstance(loss_scale_value, six.integer_types + (float,)):
raise ValueError('loss_scale must be a Python int or float.')
if loss_scale_value < 1:
raise ValueError('loss scale must be at least 1.')
self._tensor_loss_scale = ops.convert_to_tensor(
loss_scale_value, dtype=dtypes.float32)
def __call__(self):
return self._tensor_loss_scale
def update(self, grads):
del grads
return control_flow_ops.no_op(), True
def _is_all_finite(grads):
"""Returns a scalar boolean tensor indicating if all gradients are finite."""
is_finite_per_grad = [
math_ops.reduce_all(math_ops.is_finite(g)) for g in grads
]
return math_ops.reduce_all(is_finite_per_grad)
def _op_in_graph_mode(tensor):
"""Returns the tensor's op in graph mode, or the tensor in eager mode.
This is useful because sometimes an op is needed in graph mode instead of a
tensor. In eager mode, there are no ops.
Args:
tensor: A tensor.
Returns:
The tensor's op in graph mode. The tensor in eager mode.
"""
if context.executing_eagerly():
return tensor
return tensor.op
def _assign_if_finite(var, value):
"""Assigns a value to a variable if the value is finite."""
return control_flow_ops.cond(
math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)),
control_flow_ops.no_op)
class DynamicLossScale(LossScale):
"""Loss scale class that dynamically adjusts the loss scale.
Dynamic loss scaling works by adjusting the loss scale as training progresses.
The goal is to keep the loss scale as high as possible without overflowing the
gradients. As long as the gradients do not overflow, raising the loss scale
never hurts.
The algorithm starts by setting the loss scale to an initial value. Every N
steps that the gradients are finite, the loss scale is increased by some
factor. However, if a NaN or Inf gradient is found, the gradients for that
step are not applied, and the loss scale is decreased by the factor. This
process tends to keep the loss scale as high as possible without gradients
overflowing.
"""
def __init__(self,
initial_loss_scale=2**15,
increment_period=2000,
multiplier=2.):
"""Constructor of exponential-update loss scale class.
Args:
initial_loss_scale: A Python float. The loss scale to use at the
beginning. It's better to start this at a very high number, because a
loss scale that is too high gets lowered far more quickly than a loss
scale that is to low gets raised. The default is 2 ** 15, which is
approximately half the maximum float16 value.
increment_period: Increases loss scale every `increment_period`
consecutive steps that finite gradients are encountered. If a nonfinite
gradient is encountered, the count is reset back to zero.
multiplier: The multiplier to use when increasing or decreasing the loss
scale.
"""
super(DynamicLossScale, self).__init__()
self._initial_loss_scale = float(initial_loss_scale)
self._increment_period = int(increment_period)
self._multiplier = float(multiplier)
self._current_loss_scale = self._add_weight(
name='loss_scale',
dtype=dtypes.float32,
initial_value=self._initial_loss_scale)
# The number of consecutive steps with finite gradients since the last
# nonfinite gradient or change in loss scale.
self._num_good_steps = self._add_weight(
name='good_steps', dtype=dtypes.int64, initial_value=0)
@property
def initial_loss_scale(self):
return self._initial_loss_scale
@property
def increment_period(self):
return self._increment_period
@property
def multiplier(self):
return self._multiplier
def __call__(self):
return self._current_loss_scale
def update(self, grads):
"""Updates loss scale based on if gradients are finite in current step."""
if distribution_strategy_context.has_strategy():
distribution = distribution_strategy_context.get_cross_replica_context()
def get_is_finite(grads):
is_finite = _is_all_finite(grads)
# We cast to float, because we cannot reduce booleans with
# DistributionStrategy.
return math_ops.cast(is_finite, dtypes.float32)
is_finite_float = distribution.extended.call_for_each_replica(
get_is_finite, args=(grads,))
reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM,
is_finite_float, axis=None)
is_finite = math_ops.equal(reduced_is_finite_float,
distribution.num_replicas_in_sync)
else:
is_finite = _is_all_finite(grads)
def update_if_finite_grads():
"""Update assuming the gradients are finite."""
def incr_loss_scale():
new_loss_scale = self._current_loss_scale * self._multiplier
return control_flow_ops.group(
_assign_if_finite(self._current_loss_scale, new_loss_scale),
self._num_good_steps.assign(0))
return control_flow_ops.cond(
self._num_good_steps + 1 >= self._increment_period,
incr_loss_scale, lambda: _op_in_graph_mode(
self._num_good_steps.assign_add(1)))
def update_if_not_finite_grads():
"""Update assuming the gradients are nonfinite."""
new_loss_scale = math_ops.maximum(
self._current_loss_scale / self._multiplier, 1)
return control_flow_ops.group(
self._num_good_steps.assign(0),
self._current_loss_scale.assign(new_loss_scale))
update_op = control_flow_ops.cond(is_finite, update_if_finite_grads,
update_if_not_finite_grads)
should_apply_gradients = is_finite
return update_op, should_apply_gradients
def get(identifier):
"""Get a loss scale object."""
if isinstance(identifier, six.integer_types + (float,)):
return FixedLossScale(identifier)
if identifier == 'dynamic':
return DynamicLossScale()
if isinstance(identifier, LossScale):
return identifier
elif identifier is None:
return None
else:
raise ValueError('Could not interpret loss scale identifier: %s' %
identifier)

View File

@ -0,0 +1,237 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains LossScale classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.training import optimizer
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.util.tf_export import tf_export
@tf_export(v1=['train.experimental.MixedPrecisionLossScaleOptimizer'])
class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer):
"""An optimizer that applies loss scaling.
Loss scaling is a process that multiplies the loss by a multiplier called the
loss scale, and divides each gradient by the same multiplier. The pseudocode
for this process is:
```
loss = ...
loss *= loss_scale
grads = gradients(loss, vars)
grads /= loss_scale
```
Mathematically, loss scaling has no effect, but can help avoid numerical
underflow in intermediate gradients when float16 tensors are used for mixed
precision training. By multiplying the loss, each intermediate gradient will
have the same multiplier applied.
The loss scale can either be a fixed constant, chosen by the user, or be
dynamically determined. Dynamically determining the loss scale is convenient
as a loss scale does not have to be explicitly chosen. However it reduces
performance.
This optimizer wraps another optimizer and applies loss scaling to it via a
`LossScale`. Loss scaling is applied whenever gradients are
computed, such as through `minimize()`.
"""
def __init__(self, opt, loss_scale):
if not isinstance(opt, optimizer.Optimizer):
raise ValueError('"opt" must be an instance of Optimizer, but got: %s' %
type(opt))
self._optimizer = opt
use_locking = opt._use_locking # pylint: disable=protected-access
name = opt.get_name()
super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name)
self._loss_scale = loss_scale_module.get(loss_scale)
self._track_trackable(self._optimizer, 'base_optimizer')
self._track_trackable(self._loss_scale, 'loss_scale')
def _doing_dynamic_loss_scaling(self):
"""Check if `_loss_scale` dynamically manages the loss scale."""
return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale)
def compute_gradients(self,
loss,
var_list=None,
gate_gradients=optimizer.Optimizer.GATE_OP,
aggregation_method=None,
colocate_gradients_with_ops=False,
grad_loss=None):
"""Compute gradients of `loss` for the variables in `var_list`.
This adjusts the dynamic range of the gradient evaluation by scaling up
the `loss` value. The gradient values are then scaled back down by the
recipricol of the loss scale. This is useful in reduced precision training
where small gradient values would otherwise underflow the representable
range.
Args:
loss: A Tensor containing the value to minimize or a callable taking no
arguments which returns the value to minimize. When eager execution is
enabled it must be a callable.
var_list: Optional list or tuple of `tf.Variable` to update to minimize
`loss`. Defaults to the list of variables collected in the graph under
the key `GraphKeys.TRAINABLE_VARIABLES`.
gate_gradients: How to gate the computation of gradients. Can be
`GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
aggregation_method: Specifies the method used to combine gradient terms.
Valid values are defined in the class `AggregationMethod`.
colocate_gradients_with_ops: If True, try colocating gradients with the
corresponding op.
grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
Returns:
A list of (gradient, variable) pairs. Variable is always present, but
gradient can be `None`.
"""
loss = self._scale_loss(loss)
grads_and_vars = self._optimizer.compute_gradients(
loss=loss,
var_list=var_list,
gate_gradients=gate_gradients,
aggregation_method=aggregation_method,
colocate_gradients_with_ops=colocate_gradients_with_ops,
grad_loss=grad_loss)
grads = [g for g, _ in grads_and_vars]
variables = [v for _, v in grads_and_vars]
scaled_grads = self._scale_grads(grads)
return list(zip(scaled_grads, variables))
def _scale_loss(self, loss):
loss_scale = self._loss_scale()
if callable(loss):
return lambda: loss() * loss_scale
return loss * loss_scale
def _scale_grads(self, grads):
loss_scale = self._loss_scale()
loss_scale_reciprical = 1 / loss_scale
return [
None if g is None else self._scale_grad(g, loss_scale_reciprical)
for g in grads
]
def _scale_grad(self, grad, loss_scale_reciprical):
if isinstance(grad, ops.IndexedSlices):
grad_vals = grad.values * loss_scale_reciprical
return ops.IndexedSlices(grad_vals, grad.indices, grad.dense_shape)
return grad * loss_scale_reciprical
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Apply gradients to variables.
This is the second part of `minimize()`. It returns an `Operation` that
conditionally applies gradients if all gradient values are finite.
Otherwise no update is performed (nor is `global_step` incremented).
Args:
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()`.
global_step: Optional `Variable` to increment by one after the variables
have been updated.
name: Optional name for the returned operation. Default to the name
passed to the `Optimizer` constructor.
Returns:
An `Operation` that conditionally applies the specified gradients. If
`global_step` was not None, that operation also increments `global_step`.
Raises:
RuntimeError: If you should use `_distributed_apply()` instead.
"""
if distribution_strategy_context.in_cross_replica_context():
raise ValueError('apply_gradients() must be called in a replica context.')
if not self._doing_dynamic_loss_scaling():
return self._optimizer.apply_gradients(grads_and_vars, global_step, name)
replica_context = distribution_strategy_context.get_replica_context()
# TODO(nluehr) cleanup GraphKeys.TRAIN_OP
return replica_context.merge_call(
self._distributed_apply, args=(grads_and_vars, global_step, name))
def _distributed_apply(self,
distribution,
grads_and_vars,
global_step=None,
name=None):
"""A version of `apply_gradients` for cross replica context.
When users are in a cross replica strategy, they must call this rather than
`apply_gradients()`.
Args:
distribution: a `DistributionStrategy` object.
grads_and_vars: List of (gradient, variable) pairs as returned by
`compute_gradients()` and then aggregated across replicas.
global_step: Optional (mirrored) `Variable` to increment by one after the
variables have been updated.
name: Optional name for the returned operation. Default to the name passed
to the `Optimizer` constructor.
Returns:
An `Operation` that applies the specified gradients across all
replicas. If `global_step` was not None, that operation also
increments `global_step`
"""
name = name if name is not None else self.get_name()
grads = [g for g, _ in grads_and_vars]
loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads))
def apply_fn():
return self._apply_gradients(distribution, grads_and_vars, global_step,
name + '-wrapped')
maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn,
control_flow_ops.no_op)
return control_flow_ops.group(
maybe_apply_op, loss_scale_update_op, name=name)
def _apply_gradients(self, distribution, grads_and_vars, global_step, name):
"""Unconditionally apply gradients in cross replica context."""
update_ops = distribution.extended.call_for_each_replica(
self._optimizer.apply_gradients,
args=(grads_and_vars, global_step, name))
return distribution.group(update_ops)
def _apply_sparse(self, grad, var):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def _apply_dense(self, grad, var):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def _resource_apply_sparse(self, grad, handle, indices):
"""This function should never be called."""
raise RuntimeError('This function should never be called')
def _resource_apply_dense(self, grad, handle):
"""This function should never be called."""
raise RuntimeError('This function should never be called')

View File

@ -0,0 +1,266 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MixedPrecisionLossScaleOptimizer."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from absl.testing import parameterized
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util
from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import momentum
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import loss_scale_optimizer
from tensorflow.python.training.tracking import util as trackable_utils
# If called outside any strategy.scope() calls, this will return the default
# strategy.
default_strategy_fn = distribution_strategy_context.get_strategy
def create_mirrored_strategy():
if context.num_gpus() >= 1:
return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0'])
else:
return mirrored_strategy.MirroredStrategy(['cpu:0'])
TESTCASES = ({
'testcase_name': 'Base',
'strategy_fn': default_strategy_fn
}, {
'testcase_name': 'Distribute',
'strategy_fn': create_mirrored_strategy
})
def get_gradients(opt, loss, params):
grads_and_vars = opt.compute_gradients(loss, params)
grads, _ = zip(*grads_and_vars)
return grads
class MixedPrecisionLossScaleOptimizerTest(test.TestCase,
parameterized.TestCase):
def _run_if_in_graph_mode(self, val):
# Running only in graph mode is useful, because optimizers sometimes return
# a value that, in Graph mode, is runnable with self.evaluate. But in Eager
# mode, the optimizer already does the computations and the return value
# cannot be run.
if not context.executing_eagerly():
self.evaluate(val)
def _run_fn_with_grad_check(self, strategy, var, opt, expected_grad):
grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(
expected_grad)
loss = lambda: grad_check_fn(var) / strategy.num_replicas_in_sync
return lambda: opt.minimize(loss, var_list=[var])
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def testFixedLossScaleAppliedToLossWithMinimize(self, strategy_fn):
with strategy_fn().scope() as strategy:
var = variables.Variable([5.0])
opt = gradient_descent.GradientDescentOptimizer(2.0)
loss_scale = 10.
opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(
opt, loss_scale)
# We need num_replicas_in_sync to divide loss_scale, otherwise loss_scale
# / strategy.num_replicas_in_sync will not be exact, which could lead to
# assertion failures due to rounding issues.
self.assertEqual(loss_scale % strategy.num_replicas_in_sync, 0)
run_fn = self._run_fn_with_grad_check(
strategy, var, opt, loss_scale / strategy.num_replicas_in_sync)
run_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
# The loss is the identity of the variable. Therefore the gradient is 1,
# and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
self.assertAllClose([3.], self.evaluate(var))
@test_util.deprecated_graph_mode_only
def testFixedLossScaleAppliedToLossWithGetGradients(self):
var = variables.Variable([2.0])
opt = gradient_descent.GradientDescentOptimizer(1.0)
loss_scale = 10.
opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(opt, loss_scale)
grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(loss_scale)
loss = grad_check_fn(var)
run_op = get_gradients(opt, loss, [var])
self.evaluate(variables.global_variables_initializer())
# This will cause an assertion to run, as
# mp_test_util.create_identity_with_grad_check_fn added an assertion op.
self.evaluate(run_op)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def testDynamicLossScale(self, strategy_fn):
strategy = strategy_fn()
learning_rate = 2.
expected_gradient = resource_variable_ops.ResourceVariable(
learning_rate / strategy.num_replicas_in_sync)
with strategy.scope():
var = variables.Variable([5.0])
opt = gradient_descent.GradientDescentOptimizer(learning_rate)
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=2, increment_period=1, multiplier=2)
opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(
opt, loss_scale)
self.assertEqual(
loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0)
run_fn = self._run_fn_with_grad_check(strategy, var, opt,
expected_gradient)
run_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
# The loss is the identity of the variable. Therefore the gradient is 1,
# and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3
self.assertAllClose([3.], self.evaluate(var))
# Loss scale will be double, so the expected gradient is also doubled.
self.evaluate(
expected_gradient.assign(2 * learning_rate /
strategy.num_replicas_in_sync))
run_op = strategy.experimental_run(run_fn)
self._run_if_in_graph_mode(run_op)
# As before, the 2 is subtracted from the variable, making it's new value
# 1.
self.assertAllClose([1.], self.evaluate(var))
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def testDynamicUpdate(self, strategy_fn):
with strategy_fn().scope() as strategy:
var = variables.Variable([1.0, 2.0])
opt = gradient_descent.GradientDescentOptimizer(1.0)
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=2, increment_period=1, multiplier=2)
opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(
opt, loss_scale)
# Test optimizer with finite gradients
loss = lambda: var * 2.0 / strategy.num_replicas_in_sync
run_fn = lambda: opt.minimize(loss, var_list=[var])
run_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
# Gradient is 2, so variable will have 2 subtracted from it
self.assertAllClose([-1.0, 0.0], self.evaluate(var))
# Loss scale has doubled from 2 to 4
self.assertEqual(4., self.evaluate(opt._loss_scale()))
# Test optimizer with NaN gradients
loss = lambda: var * float('NaN')
run_fn = lambda: opt.minimize(loss, var_list=[var])
run_op = strategy.experimental_run(run_fn)
self._run_if_in_graph_mode(run_op)
# Variable should not change from before, due to NaN gradients.
self.assertAllClose(self.evaluate(var), [-1.0, 0.0])
# Loss scale should half due to NaN gradients.
self.assertEqual(2., self.evaluate(opt._loss_scale()))
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def testDynamicLossScaleWithSlots(self, strategy_fn):
with strategy_fn().scope() as strategy:
var = variables.Variable([1.0, 2.0])
# An SGD optimizer with momentum has slot variables.
opt = momentum.MomentumOptimizer(1.0, momentum=1.)
initial_loss_scale = 2.
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=initial_loss_scale,
increment_period=1,
multiplier=4)
opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(
opt, loss_scale)
loss = lambda: var / strategy.num_replicas_in_sync
run_fn = lambda: opt.minimize(loss, var_list=[var])
run_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self._run_if_in_graph_mode(run_op)
# The momentum accumulator starts at 0 and the gradient is 1. The
# accumulator is incremented by the gradient, so it is now 1. Then the
# variable is subtracted by the accumulator, so the variable is subtracted
# by 1.
self.assertAllClose([0.0, 1.0], self.evaluate(var))
self.assertEqual(self.evaluate(opt._loss_scale()), initial_loss_scale * 4)
run_op = strategy.experimental_run(run_fn)
self._run_if_in_graph_mode(run_op)
# The momentum accumulator was 1 before this step and the gradient is 1.
# The accumulator is incremented by the gradient, so it is now 2. Then the
# variable is subtracted by the accumulator, so the variable is subtracted
# by 2.
self.assertAllClose([-2., -1.], self.evaluate(var))
self.assertEqual(
self.evaluate(opt._loss_scale()), initial_loss_scale * 16)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def testCheckpoint(self, strategy_fn):
strategy = strategy_fn()
if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and
not context.executing_eagerly()):
# TODO(b/121381184): Enable running the test in this case.
return
with self.test_session(), strategy.scope():
# Build and run a simple model.
var = variables.Variable([2.0])
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=1., increment_period=2., multiplier=2.)
opt = momentum.MomentumOptimizer(1.0, momentum=1.)
opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(
opt, loss_scale)
run_fn = lambda: opt.minimize(lambda: var + 1., var_list=[var])
opt_op = strategy.experimental_run(run_fn)
self.evaluate(variables.global_variables_initializer())
self.evaluate(opt_op)
self.assertEqual(self.evaluate(loss_scale()), 1.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
# Save a checkpoint.
checkpoint = trackable_utils.Checkpoint(optimizer=opt)
prefix = os.path.join(self.get_temp_dir(), 'ckpt')
save_path = checkpoint.save(prefix)
# Run model again.
self.evaluate(strategy.experimental_run(run_fn))
self.assertEqual(self.evaluate(loss_scale()), 2.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0)
# Load checkpoint and ensure loss scale is back to it's original value.
status = checkpoint.restore(save_path)
status.assert_consumed()
status.run_restore_ops()
self.assertEqual(self.evaluate(loss_scale()), 1.)
self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,265 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for LossScale classes.."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import mirrored_strategy
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
# TODO(reedwm): Create test case using multiple graphs
# If called outside any strategy.scope() calls, this will return the default
# strategy.
default_strategy_fn = distribution_strategy_context.get_strategy
def create_mirrored_strategy():
if context.num_gpus() >= 1:
return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0'])
else:
return mirrored_strategy.MirroredStrategy(['cpu:0'])
TESTCASES = ({
'testcase_name': 'base',
'strategy_fn': default_strategy_fn
}, {
'testcase_name': 'distribute',
'strategy_fn': create_mirrored_strategy
})
class FixedLossScaleTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
def test_basic(self):
loss_scale_value = 1000
loss_scale = loss_scale_module.FixedLossScale(loss_scale_value)
update_op, should_apply = loss_scale.update([constant_op.constant(0.)])
self.evaluate(update_op)
# should_apply should be a bool instead of a tensor, so that a tf.cond does
# not have to be built in the graph by the caller.
self.assertIsInstance(should_apply, bool)
self.assertTrue(should_apply)
self.assertEqual(loss_scale_value, self.evaluate(loss_scale()))
update_op, should_apply = loss_scale.update(
[constant_op.constant(float('NaN'))])
self.evaluate(update_op)
self.assertIsInstance(should_apply, bool)
self.assertTrue(should_apply)
self.assertEqual(loss_scale_value, self.evaluate(loss_scale()))
def _get_example_iter(inputs):
dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
return dataset_ops.make_one_shot_iterator(dataset)
class DynamicLossScaleTest(test.TestCase, parameterized.TestCase):
def _get_tensor(self, is_finite):
tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN'))
if not distribution_strategy_context.has_strategy():
return tensor
def get():
rep_id = (
distribution_strategy_context.get_replica_context()
.replica_id_in_sync_group)
return control_flow_ops.cond(
math_ops.equal(rep_id, 0), lambda: tensor, lambda: 1.)
distribution = distribution_strategy_context.get_strategy()
return distribution.extended.call_for_each_replica(get)
def _test_helper(self,
inputs,
expected_outputs,
initial_loss_scale=1.,
increment_period=2,
multiplier=2):
loss_scale = loss_scale_module.DynamicLossScale(
initial_loss_scale=initial_loss_scale,
increment_period=increment_period,
multiplier=multiplier)
itr = _get_example_iter(inputs)
def update():
is_finite = itr.get_next()
grad = self._get_tensor(is_finite)
update_op, should_apply_gradients = loss_scale.update([grad])
assert_op = check_ops.assert_equal(should_apply_gradients, is_finite)
if context.executing_eagerly():
return
with ops.control_dependencies([assert_op]):
return array_ops.identity(update_op)
actual_outputs = []
if not context.executing_eagerly():
update_op = update()
self.evaluate(variables.global_variables_initializer())
for _ in range(len(inputs)):
if context.executing_eagerly():
update()
else:
self.evaluate(update_op)
actual_outputs.append(self.evaluate(loss_scale()))
self.assertEqual(actual_outputs, expected_outputs)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_increase(self, strategy_fn):
with strategy_fn().scope():
inputs = [True] * 6
expected_outputs = [1, 2, 2, 4, 4, 8]
self._test_helper(inputs, expected_outputs)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_keep_increasing_until_capped(self, strategy_fn):
with strategy_fn().scope():
init_loss_scale = np.finfo(np.float32).max / 4
max_float = np.finfo(np.float32).max
inputs = [True] * 6
# Output is capped the 2nd time it doubles.
expected_outputs = [
init_loss_scale, init_loss_scale * 2, init_loss_scale * 2, max_float,
max_float, max_float
]
self._test_helper(inputs, expected_outputs, init_loss_scale)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_decrease_every_step(self, strategy_fn):
with strategy_fn().scope():
inputs = [False] * 6
init_loss_scale = 1024
expected_outputs = [512, 256, 128, 64, 32, 16]
self._test_helper(inputs, expected_outputs, init_loss_scale)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_keep_decreasing_until_one(self, strategy_fn):
with strategy_fn().scope():
inputs = [False] * 6
init_loss_scale = 16
expected_outputs = [8, 4, 2, 1, 1, 1]
self._test_helper(inputs, expected_outputs, init_loss_scale)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_nan_clear_good_step(self, strategy_fn):
with strategy_fn().scope():
inputs = [True, True, True, False, True]
expected_outputs = [1, 2, 2, 1, 1]
self._test_helper(inputs, expected_outputs)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_trigger_loss_scale_update_each_step(self, strategy_fn):
with strategy_fn().scope():
init_loss_scale = 1
increment_period = 1
inputs = [True] * 3 + [False, True, True]
expected_outputs = [2, 4, 8, 4, 8, 16]
self._test_helper(inputs, expected_outputs, init_loss_scale,
increment_period)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_alternating_good_and_bad_gradients_trigger_each_step(
self, strategy_fn):
with strategy_fn().scope():
init_loss_scale = 1
increment_period = 1
inputs = [True, False] * 4 + [True]
expected_outputs = [2, 1, 2, 1, 2, 1, 2, 1, 2]
self._test_helper(inputs, expected_outputs, init_loss_scale,
increment_period)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_alternating_good_and_bad_gradients_trigger_every_other_step(
self, strategy_fn):
with strategy_fn().scope():
init_loss_scale = 32
increment_period = 2
inputs = [True, False] * 3 + [True]
expected_outputs = [32, 16, 16, 8, 8, 4, 4]
self._test_helper(inputs, expected_outputs, init_loss_scale,
increment_period)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_nondefault_multiplier(self, strategy_fn):
with strategy_fn().scope():
init_loss_scale = 4
multiplier = 3
inputs = [True, True, False, True, True]
expected_outputs = [4, 12, 4, 4, 12]
self._test_helper(
inputs, expected_outputs, init_loss_scale, multiplier=multiplier)
@parameterized.named_parameters(*TESTCASES)
@test_util.run_in_graph_and_eager_modes
def test_random_mix_good_and_bad_gradients(self, strategy_fn):
with strategy_fn().scope():
init_loss_scale = 4
inputs = [
False, True, True, True, False, True, False, True, True, True, False
]
expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1]
self._test_helper(inputs, expected_outputs, init_loss_scale)
@test_util.run_in_graph_and_eager_modes
def test_get(self):
scalar = loss_scale_module.get('dynamic')
scalar2 = loss_scale_module.DynamicLossScale()
self.assertEqual(scalar.initial_loss_scale, scalar2.initial_loss_scale)
self.assertEqual(scalar.increment_period, scalar2.increment_period)
self.assertEqual(scalar.multiplier, scalar2.multiplier)
if __name__ == '__main__':
test.main()

View File

@ -33,6 +33,7 @@ from tensorflow.python.training.adagrad_da import AdagradDAOptimizer
from tensorflow.python.training.proximal_adagrad import ProximalAdagradOptimizer
from tensorflow.python.training.adam import AdamOptimizer
from tensorflow.python.training.ftrl import FtrlOptimizer
from tensorflow.python.training.experimental.loss_scale_optimizer import MixedPrecisionLossScaleOptimizer
from tensorflow.python.training.momentum import MomentumOptimizer
from tensorflow.python.training.moving_averages import ExponentialMovingAverage
from tensorflow.python.training.optimizer import Optimizer
@ -143,4 +144,3 @@ tf_export(v1=["train.SaverDef"])(SaverDef)
tf_export("train.SequenceExample")(SequenceExample)
tf_export("train.ServerDef")(ServerDef)
# pylint: enable=undefined-variable

View File

@ -0,0 +1,51 @@
path: "tensorflow.train.experimental.MixedPrecisionLossScaleOptimizer"
tf_class {
is_instance: "<class \'tensorflow.python.training.experimental.loss_scale_optimizer.MixedPrecisionLossScaleOptimizer\'>"
is_instance: "<class \'tensorflow.python.training.optimizer.Optimizer\'>"
is_instance: "<class \'tensorflow.python.training.tracking.base.Trackable\'>"
is_instance: "<type \'object\'>"
member {
name: "GATE_GRAPH"
mtype: "<type \'int\'>"
}
member {
name: "GATE_NONE"
mtype: "<type \'int\'>"
}
member {
name: "GATE_OP"
mtype: "<type \'int\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "apply_gradients"
argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "compute_gradients"
argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], "
}
member_method {
name: "get_name"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot"
argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "get_slot_names"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "minimize"
argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], "
}
member_method {
name: "variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,9 @@
path: "tensorflow.train.experimental"
tf_module {
member {
name: "MixedPrecisionLossScaleOptimizer"
mtype: "<type \'type\'>"
}
member {
name: "PythonState"
mtype: "<type \'type\'>"

View File

@ -1387,6 +1387,8 @@ renames = {
'tf.compat.v1.train.create_global_step',
'tf.train.do_quantize_training_on_graphdef':
'tf.compat.v1.train.do_quantize_training_on_graphdef',
'tf.train.experimental.MixedPrecisionLossScaleOptimizer':
'tf.compat.v1.train.experimental.MixedPrecisionLossScaleOptimizer',
'tf.train.exponential_decay':
'tf.compat.v1.train.exponential_decay',
'tf.train.export_meta_graph':

View File

@ -68,6 +68,8 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/compiler:compiler",
"//tensorflow/python:cond_v2",
"//tensorflow/python:distributed_framework_test_lib",
"//tensorflow/python:loss_scale",
"//tensorflow/python:loss_scale_optimizer",
"//tensorflow/python:meta_graph_testdata",
"//tensorflow/python:spectral_ops_test_util",
"//tensorflow/python:util_example_parser_configuration",