diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 5ef53050a5d..ea081c3ea4e 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2858,6 +2858,53 @@ py_library( ], ) +py_library( + name = "loss_scale", + srcs = ["training/experimental/loss_scale.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "loss_scale_optimizer", + srcs = ["training/experimental/loss_scale_optimizer.py"], + srcs_version = "PY2AND3", + deps = [ + ":loss_scale", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "loss_scale_optimizer_test", + size = "small", + srcs = ["training/experimental/loss_scale_optimizer_test.py"], + deps = [ + ":loss_scale_optimizer", + "//tensorflow/python:client_testlib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:one_device_strategy", + "//tensorflow/python/keras/mixed_precision/experimental:test_util", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "loss_scale_test", + size = "small", + srcs = ["training/experimental/loss_scale_test.py"], + deps = [ + ":loss_scale", + "//tensorflow/python:client_testlib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:one_device_strategy", + "@absl_py//absl/testing:parameterized", + ], +) + py_library( name = "math_grad", srcs = ["ops/math_grad.py"], @@ -3962,6 +4009,8 @@ py_library( ":io_ops", ":layers_util", ":lookup_ops", + ":loss_scale", + ":loss_scale_optimizer", ":math_ops", ":platform", ":pywrap_tensorflow", diff --git a/tensorflow/python/training/experimental/loss_scale.py b/tensorflow/python/training/experimental/loss_scale.py new file mode 100644 index 00000000000..c833241e26a --- /dev/null +++ b/tensorflow/python/training/experimental/loss_scale.py @@ -0,0 +1,352 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains LossScale classes.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import reduce_util +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.ops import variable_scope + + +# TODO(reedwm): Merge this with tf.keras.mixed_precision.experimental.LossScale +@six.add_metaclass(abc.ABCMeta) +class LossScale(trackable.Trackable): + """Loss scale base class. + + Instances of this class represent a loss scale. Calling instances of this + class returns the loss scale as a scalar float32 tensor, while method + `update()` updates the loss scale depending on the values of the gradients. + Optimizers use instances of this class to scale loss and gradients. + + Note: this LossScale class can only be used with a v1 optimizer wrapper, + tf.train.experimental.MixedPrecisionLossScaleOptimizer. For a v2 + wrapper, tf.keras.mixed_precision.experimental.LossScaleOptimizer, a + tf.keras.mixed_precision.experimental.LossScale should be used instead. + """ + + def __init__(self): + """Initializes the loss scale class.""" + self._weights = {} + + @abc.abstractmethod + def __call__(self): + """Returns the current loss scale as a scalar `float32` tensor.""" + pass + + @abc.abstractmethod + def update(self, grads): + """Updates the value of the loss scale. + + The loss scale will be potentially updated, based on the value of `grads`. + The tensor returned by calling this class is only updated when this function + is evaluated. + + In eager mode, this directly updates the loss scale, so that calling + `__call__` will return the newly updated loss scale. In graph mode, + this returns an op that, when evaluated, updates the loss scale. + + This function also returns a `should_apply_gradients` bool. If False, + gradients should not be applied to the variables that step, as nonfinite + gradients were found, and the loss scale has been be updated to reduce the + chance of finding nonfinite gradients in the next step. Some loss scale + classes will always return True, as they cannot adjust themselves in + response to nonfinite gradients. + + When a DistributionStrategy is used, this function may only be called in a + cross-replica context. + + Args: + grads: A list of unscaled gradients, each which is the gradient of the + loss with respect to a weight. The gradients should have already been + divided by the loss scale being before passed to this function. + + Returns: + update_op: In eager mode, None. In graph mode, an op to update the loss + scale. + should_apply_gradients: Either a bool or a scalar boolean tensor. If + False, the caller should skip applying `grads` to the variables this + step. + """ + pass + + def _add_weight(self, name, initial_value, dtype=None): + """Adds a weight to this loss scale manager.. + + Args: + name: Variable name. + initial_value: The variable's initial value. + dtype: The type of the variable. + + Returns: + A variable. + + Raises: + RuntimeError: If a weight with `name` has already been added. + """ + variable = variable_scope.variable( + initial_value=initial_value, + name=name, + dtype=dtype, + trainable=False, + use_resource=True, + synchronization=variables.VariableSynchronization.AUTO, + # Set aggregation to NONE, as loss scaling variables should never be + # aggregated. + aggregation=variables.VariableAggregation.NONE) + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + + key = (name, graph_key) + if self._weights.get(key, None) is not None: + raise RuntimeError('Duplicate variables detected. {}'.format(key)) + self._weights[key] = variable + self._handle_deferred_dependencies(name=name, trackable=variable) + return variable + + @property + def _checkpoint_dependencies(self): + """From Trackable. Gather graph-specific weights to save.""" + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + weights = [] + for (name, g), v in sorted(self._weights.items(), key=lambda i: i[0][0]): + if g == graph_key: + weights.append(trackable.TrackableReference(name=name, ref=v)) + return super(LossScale, self)._checkpoint_dependencies + weights + + def _lookup_dependency(self, name): + """From Trackable. Find a weight in the current graph.""" + unconditional = super(LossScale, self)._lookup_dependency(name) + if unconditional is not None: + return unconditional + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + return self._weights.get((name, graph_key), None) + + +class FixedLossScale(LossScale): + """Loss scale class with a fixed value. + + The loss scale is not updated for the lifetime of the class. + """ + + def __init__(self, loss_scale_value): + """Creates the fixed loss scale. + + Args: + loss_scale_value: A Python float. Its ideal value varies depending on + models to run. Choosing a too small loss_scale might affect model + quality; a too big loss_scale might cause inf or nan. There is no single + right loss_scale to apply. There is no harm choosing a relatively big + number as long as no nan or inf is encountered in training. + + Raises: + ValueError: If loss_scale is less than 1. + """ + super(FixedLossScale, self).__init__() + if not isinstance(loss_scale_value, six.integer_types + (float,)): + raise ValueError('loss_scale must be a Python int or float.') + if loss_scale_value < 1: + raise ValueError('loss scale must be at least 1.') + self._tensor_loss_scale = ops.convert_to_tensor( + loss_scale_value, dtype=dtypes.float32) + + def __call__(self): + return self._tensor_loss_scale + + def update(self, grads): + del grads + return control_flow_ops.no_op(), True + + +def _is_all_finite(grads): + """Returns a scalar boolean tensor indicating if all gradients are finite.""" + is_finite_per_grad = [ + math_ops.reduce_all(math_ops.is_finite(g)) for g in grads + ] + return math_ops.reduce_all(is_finite_per_grad) + + +def _op_in_graph_mode(tensor): + """Returns the tensor's op in graph mode, or the tensor in eager mode. + + This is useful because sometimes an op is needed in graph mode instead of a + tensor. In eager mode, there are no ops. + + Args: + tensor: A tensor. + + Returns: + The tensor's op in graph mode. The tensor in eager mode. + """ + if context.executing_eagerly(): + return tensor + return tensor.op + + +def _assign_if_finite(var, value): + """Assigns a value to a variable if the value is finite.""" + return control_flow_ops.cond( + math_ops.is_finite(value), lambda: _op_in_graph_mode(var.assign(value)), + control_flow_ops.no_op) + + +class DynamicLossScale(LossScale): + """Loss scale class that dynamically adjusts the loss scale. + + Dynamic loss scaling works by adjusting the loss scale as training progresses. + The goal is to keep the loss scale as high as possible without overflowing the + gradients. As long as the gradients do not overflow, raising the loss scale + never hurts. + + The algorithm starts by setting the loss scale to an initial value. Every N + steps that the gradients are finite, the loss scale is increased by some + factor. However, if a NaN or Inf gradient is found, the gradients for that + step are not applied, and the loss scale is decreased by the factor. This + process tends to keep the loss scale as high as possible without gradients + overflowing. + """ + + def __init__(self, + initial_loss_scale=2**15, + increment_period=2000, + multiplier=2.): + """Constructor of exponential-update loss scale class. + + Args: + initial_loss_scale: A Python float. The loss scale to use at the + beginning. It's better to start this at a very high number, because a + loss scale that is too high gets lowered far more quickly than a loss + scale that is to low gets raised. The default is 2 ** 15, which is + approximately half the maximum float16 value. + increment_period: Increases loss scale every `increment_period` + consecutive steps that finite gradients are encountered. If a nonfinite + gradient is encountered, the count is reset back to zero. + multiplier: The multiplier to use when increasing or decreasing the loss + scale. + """ + super(DynamicLossScale, self).__init__() + self._initial_loss_scale = float(initial_loss_scale) + self._increment_period = int(increment_period) + self._multiplier = float(multiplier) + + self._current_loss_scale = self._add_weight( + name='loss_scale', + dtype=dtypes.float32, + initial_value=self._initial_loss_scale) + # The number of consecutive steps with finite gradients since the last + # nonfinite gradient or change in loss scale. + self._num_good_steps = self._add_weight( + name='good_steps', dtype=dtypes.int64, initial_value=0) + + @property + def initial_loss_scale(self): + return self._initial_loss_scale + + @property + def increment_period(self): + return self._increment_period + + @property + def multiplier(self): + return self._multiplier + + def __call__(self): + return self._current_loss_scale + + def update(self, grads): + """Updates loss scale based on if gradients are finite in current step.""" + if distribution_strategy_context.has_strategy(): + distribution = distribution_strategy_context.get_cross_replica_context() + + def get_is_finite(grads): + is_finite = _is_all_finite(grads) + # We cast to float, because we cannot reduce booleans with + # DistributionStrategy. + return math_ops.cast(is_finite, dtypes.float32) + + is_finite_float = distribution.extended.call_for_each_replica( + get_is_finite, args=(grads,)) + reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, + is_finite_float, axis=None) + is_finite = math_ops.equal(reduced_is_finite_float, + distribution.num_replicas_in_sync) + else: + is_finite = _is_all_finite(grads) + + def update_if_finite_grads(): + """Update assuming the gradients are finite.""" + + def incr_loss_scale(): + new_loss_scale = self._current_loss_scale * self._multiplier + return control_flow_ops.group( + _assign_if_finite(self._current_loss_scale, new_loss_scale), + self._num_good_steps.assign(0)) + + return control_flow_ops.cond( + self._num_good_steps + 1 >= self._increment_period, + incr_loss_scale, lambda: _op_in_graph_mode( + self._num_good_steps.assign_add(1))) + + def update_if_not_finite_grads(): + """Update assuming the gradients are nonfinite.""" + + new_loss_scale = math_ops.maximum( + self._current_loss_scale / self._multiplier, 1) + return control_flow_ops.group( + self._num_good_steps.assign(0), + self._current_loss_scale.assign(new_loss_scale)) + + update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, + update_if_not_finite_grads) + should_apply_gradients = is_finite + return update_op, should_apply_gradients + + +def get(identifier): + """Get a loss scale object.""" + if isinstance(identifier, six.integer_types + (float,)): + return FixedLossScale(identifier) + if identifier == 'dynamic': + return DynamicLossScale() + if isinstance(identifier, LossScale): + return identifier + elif identifier is None: + return None + else: + raise ValueError('Could not interpret loss scale identifier: %s' % + identifier) diff --git a/tensorflow/python/training/experimental/loss_scale_optimizer.py b/tensorflow/python/training/experimental/loss_scale_optimizer.py new file mode 100644 index 00000000000..b0d101fd6d5 --- /dev/null +++ b/tensorflow/python/training/experimental/loss_scale_optimizer.py @@ -0,0 +1,237 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains LossScale classes.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training.experimental import loss_scale as loss_scale_module +from tensorflow.python.util.tf_export import tf_export + + +@tf_export(v1=['train.experimental.MixedPrecisionLossScaleOptimizer']) +class MixedPrecisionLossScaleOptimizer(optimizer.Optimizer): + """An optimizer that applies loss scaling. + + Loss scaling is a process that multiplies the loss by a multiplier called the + loss scale, and divides each gradient by the same multiplier. The pseudocode + for this process is: + + ``` + loss = ... + loss *= loss_scale + grads = gradients(loss, vars) + grads /= loss_scale + ``` + + Mathematically, loss scaling has no effect, but can help avoid numerical + underflow in intermediate gradients when float16 tensors are used for mixed + precision training. By multiplying the loss, each intermediate gradient will + have the same multiplier applied. + + The loss scale can either be a fixed constant, chosen by the user, or be + dynamically determined. Dynamically determining the loss scale is convenient + as a loss scale does not have to be explicitly chosen. However it reduces + performance. + + This optimizer wraps another optimizer and applies loss scaling to it via a + `LossScale`. Loss scaling is applied whenever gradients are + computed, such as through `minimize()`. + """ + + def __init__(self, opt, loss_scale): + if not isinstance(opt, optimizer.Optimizer): + raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % + type(opt)) + self._optimizer = opt + + use_locking = opt._use_locking # pylint: disable=protected-access + name = opt.get_name() + super(MixedPrecisionLossScaleOptimizer, self).__init__(use_locking, name) + + self._loss_scale = loss_scale_module.get(loss_scale) + self._track_trackable(self._optimizer, 'base_optimizer') + self._track_trackable(self._loss_scale, 'loss_scale') + + def _doing_dynamic_loss_scaling(self): + """Check if `_loss_scale` dynamically manages the loss scale.""" + return isinstance(self._loss_scale, loss_scale_module.DynamicLossScale) + + def compute_gradients(self, + loss, + var_list=None, + gate_gradients=optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None): + """Compute gradients of `loss` for the variables in `var_list`. + + This adjusts the dynamic range of the gradient evaluation by scaling up + the `loss` value. The gradient values are then scaled back down by the + recipricol of the loss scale. This is useful in reduced precision training + where small gradient values would otherwise underflow the representable + range. + + Args: + loss: A Tensor containing the value to minimize or a callable taking no + arguments which returns the value to minimize. When eager execution is + enabled it must be a callable. + var_list: Optional list or tuple of `tf.Variable` to update to minimize + `loss`. Defaults to the list of variables collected in the graph under + the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with the + corresponding op. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + + Returns: + A list of (gradient, variable) pairs. Variable is always present, but + gradient can be `None`. + """ + loss = self._scale_loss(loss) + grads_and_vars = self._optimizer.compute_gradients( + loss=loss, + var_list=var_list, + gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + + grads = [g for g, _ in grads_and_vars] + variables = [v for _, v in grads_and_vars] + scaled_grads = self._scale_grads(grads) + return list(zip(scaled_grads, variables)) + + def _scale_loss(self, loss): + loss_scale = self._loss_scale() + if callable(loss): + return lambda: loss() * loss_scale + return loss * loss_scale + + def _scale_grads(self, grads): + loss_scale = self._loss_scale() + loss_scale_reciprical = 1 / loss_scale + return [ + None if g is None else self._scale_grad(g, loss_scale_reciprical) + for g in grads + ] + + def _scale_grad(self, grad, loss_scale_reciprical): + if isinstance(grad, ops.IndexedSlices): + grad_vals = grad.values * loss_scale_reciprical + return ops.IndexedSlices(grad_vals, grad.indices, grad.dense_shape) + return grad * loss_scale_reciprical + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + """Apply gradients to variables. + + This is the second part of `minimize()`. It returns an `Operation` that + conditionally applies gradients if all gradient values are finite. + Otherwise no update is performed (nor is `global_step` incremented). + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the variables + have been updated. + name: Optional name for the returned operation. Default to the name + passed to the `Optimizer` constructor. + + Returns: + An `Operation` that conditionally applies the specified gradients. If + `global_step` was not None, that operation also increments `global_step`. + + Raises: + RuntimeError: If you should use `_distributed_apply()` instead. + """ + if distribution_strategy_context.in_cross_replica_context(): + raise ValueError('apply_gradients() must be called in a replica context.') + + if not self._doing_dynamic_loss_scaling(): + return self._optimizer.apply_gradients(grads_and_vars, global_step, name) + + replica_context = distribution_strategy_context.get_replica_context() + + # TODO(nluehr) cleanup GraphKeys.TRAIN_OP + return replica_context.merge_call( + self._distributed_apply, args=(grads_and_vars, global_step, name)) + + def _distributed_apply(self, + distribution, + grads_and_vars, + global_step=None, + name=None): + """A version of `apply_gradients` for cross replica context. + + When users are in a cross replica strategy, they must call this rather than + `apply_gradients()`. + + Args: + distribution: a `DistributionStrategy` object. + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()` and then aggregated across replicas. + global_step: Optional (mirrored) `Variable` to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the name passed + to the `Optimizer` constructor. + + Returns: + An `Operation` that applies the specified gradients across all + replicas. If `global_step` was not None, that operation also + increments `global_step` + """ + name = name if name is not None else self.get_name() + grads = [g for g, _ in grads_and_vars] + loss_scale_update_op, should_apply_grads = (self._loss_scale.update(grads)) + + def apply_fn(): + return self._apply_gradients(distribution, grads_and_vars, global_step, + name + '-wrapped') + + maybe_apply_op = smart_cond.smart_cond(should_apply_grads, apply_fn, + control_flow_ops.no_op) + return control_flow_ops.group( + maybe_apply_op, loss_scale_update_op, name=name) + + def _apply_gradients(self, distribution, grads_and_vars, global_step, name): + """Unconditionally apply gradients in cross replica context.""" + update_ops = distribution.extended.call_for_each_replica( + self._optimizer.apply_gradients, + args=(grads_and_vars, global_step, name)) + return distribution.group(update_ops) + + def _apply_sparse(self, grad, var): + """This function should never be called.""" + raise RuntimeError('This function should never be called') + + def _apply_dense(self, grad, var): + """This function should never be called.""" + raise RuntimeError('This function should never be called') + + def _resource_apply_sparse(self, grad, handle, indices): + """This function should never be called.""" + raise RuntimeError('This function should never be called') + + def _resource_apply_dense(self, grad, handle): + """This function should never be called.""" + raise RuntimeError('This function should never be called') diff --git a/tensorflow/python/training/experimental/loss_scale_optimizer_test.py b/tensorflow/python/training/experimental/loss_scale_optimizer_test.py new file mode 100644 index 00000000000..c2259cd7ed2 --- /dev/null +++ b/tensorflow/python/training/experimental/loss_scale_optimizer_test.py @@ -0,0 +1,266 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MixedPrecisionLossScaleOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl.testing import parameterized + +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util +from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import momentum +from tensorflow.python.training.experimental import loss_scale as loss_scale_module +from tensorflow.python.training.experimental import loss_scale_optimizer +from tensorflow.python.training.tracking import util as trackable_utils + +# If called outside any strategy.scope() calls, this will return the default +# strategy. +default_strategy_fn = distribution_strategy_context.get_strategy + + +def create_mirrored_strategy(): + if context.num_gpus() >= 1: + return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0']) + else: + return mirrored_strategy.MirroredStrategy(['cpu:0']) + + +TESTCASES = ({ + 'testcase_name': 'Base', + 'strategy_fn': default_strategy_fn +}, { + 'testcase_name': 'Distribute', + 'strategy_fn': create_mirrored_strategy +}) + + +def get_gradients(opt, loss, params): + grads_and_vars = opt.compute_gradients(loss, params) + grads, _ = zip(*grads_and_vars) + return grads + + +class MixedPrecisionLossScaleOptimizerTest(test.TestCase, + parameterized.TestCase): + + def _run_if_in_graph_mode(self, val): + # Running only in graph mode is useful, because optimizers sometimes return + # a value that, in Graph mode, is runnable with self.evaluate. But in Eager + # mode, the optimizer already does the computations and the return value + # cannot be run. + if not context.executing_eagerly(): + self.evaluate(val) + + def _run_fn_with_grad_check(self, strategy, var, opt, expected_grad): + grad_check_fn = mp_test_util.create_identity_with_grad_check_fn( + expected_grad) + loss = lambda: grad_check_fn(var) / strategy.num_replicas_in_sync + return lambda: opt.minimize(loss, var_list=[var]) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def testFixedLossScaleAppliedToLossWithMinimize(self, strategy_fn): + with strategy_fn().scope() as strategy: + var = variables.Variable([5.0]) + opt = gradient_descent.GradientDescentOptimizer(2.0) + loss_scale = 10. + opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer( + opt, loss_scale) + # We need num_replicas_in_sync to divide loss_scale, otherwise loss_scale + # / strategy.num_replicas_in_sync will not be exact, which could lead to + # assertion failures due to rounding issues. + self.assertEqual(loss_scale % strategy.num_replicas_in_sync, 0) + run_fn = self._run_fn_with_grad_check( + strategy, var, opt, loss_scale / strategy.num_replicas_in_sync) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # The loss is the identity of the variable. Therefore the gradient is 1, + # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3 + self.assertAllClose([3.], self.evaluate(var)) + + @test_util.deprecated_graph_mode_only + def testFixedLossScaleAppliedToLossWithGetGradients(self): + var = variables.Variable([2.0]) + opt = gradient_descent.GradientDescentOptimizer(1.0) + loss_scale = 10. + opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer(opt, loss_scale) + grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(loss_scale) + loss = grad_check_fn(var) + run_op = get_gradients(opt, loss, [var]) + self.evaluate(variables.global_variables_initializer()) + # This will cause an assertion to run, as + # mp_test_util.create_identity_with_grad_check_fn added an assertion op. + self.evaluate(run_op) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def testDynamicLossScale(self, strategy_fn): + strategy = strategy_fn() + learning_rate = 2. + expected_gradient = resource_variable_ops.ResourceVariable( + learning_rate / strategy.num_replicas_in_sync) + with strategy.scope(): + var = variables.Variable([5.0]) + opt = gradient_descent.GradientDescentOptimizer(learning_rate) + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=2, increment_period=1, multiplier=2) + opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer( + opt, loss_scale) + self.assertEqual( + loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0) + + run_fn = self._run_fn_with_grad_check(strategy, var, opt, + expected_gradient) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # The loss is the identity of the variable. Therefore the gradient is 1, + # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3 + self.assertAllClose([3.], self.evaluate(var)) + + # Loss scale will be double, so the expected gradient is also doubled. + self.evaluate( + expected_gradient.assign(2 * learning_rate / + strategy.num_replicas_in_sync)) + run_op = strategy.experimental_run(run_fn) + self._run_if_in_graph_mode(run_op) + # As before, the 2 is subtracted from the variable, making it's new value + # 1. + self.assertAllClose([1.], self.evaluate(var)) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def testDynamicUpdate(self, strategy_fn): + with strategy_fn().scope() as strategy: + var = variables.Variable([1.0, 2.0]) + opt = gradient_descent.GradientDescentOptimizer(1.0) + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=2, increment_period=1, multiplier=2) + opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer( + opt, loss_scale) + + # Test optimizer with finite gradients + loss = lambda: var * 2.0 / strategy.num_replicas_in_sync + run_fn = lambda: opt.minimize(loss, var_list=[var]) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # Gradient is 2, so variable will have 2 subtracted from it + self.assertAllClose([-1.0, 0.0], self.evaluate(var)) + # Loss scale has doubled from 2 to 4 + self.assertEqual(4., self.evaluate(opt._loss_scale())) + + # Test optimizer with NaN gradients + loss = lambda: var * float('NaN') + run_fn = lambda: opt.minimize(loss, var_list=[var]) + run_op = strategy.experimental_run(run_fn) + self._run_if_in_graph_mode(run_op) + # Variable should not change from before, due to NaN gradients. + self.assertAllClose(self.evaluate(var), [-1.0, 0.0]) + # Loss scale should half due to NaN gradients. + self.assertEqual(2., self.evaluate(opt._loss_scale())) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def testDynamicLossScaleWithSlots(self, strategy_fn): + with strategy_fn().scope() as strategy: + var = variables.Variable([1.0, 2.0]) + # An SGD optimizer with momentum has slot variables. + opt = momentum.MomentumOptimizer(1.0, momentum=1.) + initial_loss_scale = 2. + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=initial_loss_scale, + increment_period=1, + multiplier=4) + opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer( + opt, loss_scale) + loss = lambda: var / strategy.num_replicas_in_sync + run_fn = lambda: opt.minimize(loss, var_list=[var]) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # The momentum accumulator starts at 0 and the gradient is 1. The + # accumulator is incremented by the gradient, so it is now 1. Then the + # variable is subtracted by the accumulator, so the variable is subtracted + # by 1. + self.assertAllClose([0.0, 1.0], self.evaluate(var)) + self.assertEqual(self.evaluate(opt._loss_scale()), initial_loss_scale * 4) + + run_op = strategy.experimental_run(run_fn) + self._run_if_in_graph_mode(run_op) + # The momentum accumulator was 1 before this step and the gradient is 1. + # The accumulator is incremented by the gradient, so it is now 2. Then the + # variable is subtracted by the accumulator, so the variable is subtracted + # by 2. + self.assertAllClose([-2., -1.], self.evaluate(var)) + self.assertEqual( + self.evaluate(opt._loss_scale()), initial_loss_scale * 16) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def testCheckpoint(self, strategy_fn): + strategy = strategy_fn() + if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and + not context.executing_eagerly()): + # TODO(b/121381184): Enable running the test in this case. + return + + with self.test_session(), strategy.scope(): + # Build and run a simple model. + var = variables.Variable([2.0]) + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=1., increment_period=2., multiplier=2.) + opt = momentum.MomentumOptimizer(1.0, momentum=1.) + opt = loss_scale_optimizer.MixedPrecisionLossScaleOptimizer( + opt, loss_scale) + run_fn = lambda: opt.minimize(lambda: var + 1., var_list=[var]) + opt_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(opt_op) + self.assertEqual(self.evaluate(loss_scale()), 1.) + self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) + + # Save a checkpoint. + checkpoint = trackable_utils.Checkpoint(optimizer=opt) + prefix = os.path.join(self.get_temp_dir(), 'ckpt') + save_path = checkpoint.save(prefix) + + # Run model again. + self.evaluate(strategy.experimental_run(run_fn)) + self.assertEqual(self.evaluate(loss_scale()), 2.) + self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0) + + # Load checkpoint and ensure loss scale is back to it's original value. + status = checkpoint.restore(save_path) + status.assert_consumed() + status.run_restore_ops() + self.assertEqual(self.evaluate(loss_scale()), 1.) + self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/training/experimental/loss_scale_test.py b/tensorflow/python/training/experimental/loss_scale_test.py new file mode 100644 index 00000000000..f135161de07 --- /dev/null +++ b/tensorflow/python/training/experimental/loss_scale_test.py @@ -0,0 +1,265 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LossScale classes..""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training.experimental import loss_scale as loss_scale_module + +# TODO(reedwm): Create test case using multiple graphs + +# If called outside any strategy.scope() calls, this will return the default +# strategy. +default_strategy_fn = distribution_strategy_context.get_strategy + + +def create_mirrored_strategy(): + if context.num_gpus() >= 1: + return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0']) + else: + return mirrored_strategy.MirroredStrategy(['cpu:0']) + + +TESTCASES = ({ + 'testcase_name': 'base', + 'strategy_fn': default_strategy_fn +}, { + 'testcase_name': 'distribute', + 'strategy_fn': create_mirrored_strategy +}) + + +class FixedLossScaleTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def test_basic(self): + loss_scale_value = 1000 + loss_scale = loss_scale_module.FixedLossScale(loss_scale_value) + + update_op, should_apply = loss_scale.update([constant_op.constant(0.)]) + self.evaluate(update_op) + # should_apply should be a bool instead of a tensor, so that a tf.cond does + # not have to be built in the graph by the caller. + self.assertIsInstance(should_apply, bool) + self.assertTrue(should_apply) + self.assertEqual(loss_scale_value, self.evaluate(loss_scale())) + + update_op, should_apply = loss_scale.update( + [constant_op.constant(float('NaN'))]) + self.evaluate(update_op) + self.assertIsInstance(should_apply, bool) + self.assertTrue(should_apply) + self.assertEqual(loss_scale_value, self.evaluate(loss_scale())) + + +def _get_example_iter(inputs): + dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + return dataset_ops.make_one_shot_iterator(dataset) + + +class DynamicLossScaleTest(test.TestCase, parameterized.TestCase): + + def _get_tensor(self, is_finite): + tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN')) + + if not distribution_strategy_context.has_strategy(): + return tensor + + def get(): + rep_id = ( + distribution_strategy_context.get_replica_context() + .replica_id_in_sync_group) + return control_flow_ops.cond( + math_ops.equal(rep_id, 0), lambda: tensor, lambda: 1.) + + distribution = distribution_strategy_context.get_strategy() + return distribution.extended.call_for_each_replica(get) + + def _test_helper(self, + inputs, + expected_outputs, + initial_loss_scale=1., + increment_period=2, + multiplier=2): + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=initial_loss_scale, + increment_period=increment_period, + multiplier=multiplier) + itr = _get_example_iter(inputs) + + def update(): + is_finite = itr.get_next() + grad = self._get_tensor(is_finite) + update_op, should_apply_gradients = loss_scale.update([grad]) + assert_op = check_ops.assert_equal(should_apply_gradients, is_finite) + if context.executing_eagerly(): + return + with ops.control_dependencies([assert_op]): + return array_ops.identity(update_op) + + actual_outputs = [] + + if not context.executing_eagerly(): + update_op = update() + self.evaluate(variables.global_variables_initializer()) + for _ in range(len(inputs)): + if context.executing_eagerly(): + update() + else: + self.evaluate(update_op) + actual_outputs.append(self.evaluate(loss_scale())) + self.assertEqual(actual_outputs, expected_outputs) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_increase(self, strategy_fn): + with strategy_fn().scope(): + inputs = [True] * 6 + expected_outputs = [1, 2, 2, 4, 4, 8] + self._test_helper(inputs, expected_outputs) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_keep_increasing_until_capped(self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = np.finfo(np.float32).max / 4 + max_float = np.finfo(np.float32).max + + inputs = [True] * 6 + # Output is capped the 2nd time it doubles. + expected_outputs = [ + init_loss_scale, init_loss_scale * 2, init_loss_scale * 2, max_float, + max_float, max_float + ] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_decrease_every_step(self, strategy_fn): + with strategy_fn().scope(): + inputs = [False] * 6 + init_loss_scale = 1024 + expected_outputs = [512, 256, 128, 64, 32, 16] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_keep_decreasing_until_one(self, strategy_fn): + with strategy_fn().scope(): + inputs = [False] * 6 + init_loss_scale = 16 + expected_outputs = [8, 4, 2, 1, 1, 1] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_nan_clear_good_step(self, strategy_fn): + with strategy_fn().scope(): + inputs = [True, True, True, False, True] + expected_outputs = [1, 2, 2, 1, 1] + self._test_helper(inputs, expected_outputs) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_trigger_loss_scale_update_each_step(self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = 1 + increment_period = 1 + + inputs = [True] * 3 + [False, True, True] + expected_outputs = [2, 4, 8, 4, 8, 16] + + self._test_helper(inputs, expected_outputs, init_loss_scale, + increment_period) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_alternating_good_and_bad_gradients_trigger_each_step( + self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = 1 + increment_period = 1 + + inputs = [True, False] * 4 + [True] + expected_outputs = [2, 1, 2, 1, 2, 1, 2, 1, 2] + self._test_helper(inputs, expected_outputs, init_loss_scale, + increment_period) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_alternating_good_and_bad_gradients_trigger_every_other_step( + self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = 32 + increment_period = 2 + + inputs = [True, False] * 3 + [True] + expected_outputs = [32, 16, 16, 8, 8, 4, 4] + self._test_helper(inputs, expected_outputs, init_loss_scale, + increment_period) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_nondefault_multiplier(self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = 4 + multiplier = 3 + inputs = [True, True, False, True, True] + expected_outputs = [4, 12, 4, 4, 12] + self._test_helper( + inputs, expected_outputs, init_loss_scale, multiplier=multiplier) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_random_mix_good_and_bad_gradients(self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = 4 + inputs = [ + False, True, True, True, False, True, False, True, True, True, False + ] + expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1] + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @test_util.run_in_graph_and_eager_modes + def test_get(self): + scalar = loss_scale_module.get('dynamic') + scalar2 = loss_scale_module.DynamicLossScale() + self.assertEqual(scalar.initial_loss_scale, scalar2.initial_loss_scale) + self.assertEqual(scalar.increment_period, scalar2.increment_period) + self.assertEqual(scalar.multiplier, scalar2.multiplier) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 9f509ae0a38..f67df24f36f 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -33,6 +33,7 @@ from tensorflow.python.training.adagrad_da import AdagradDAOptimizer from tensorflow.python.training.proximal_adagrad import ProximalAdagradOptimizer from tensorflow.python.training.adam import AdamOptimizer from tensorflow.python.training.ftrl import FtrlOptimizer +from tensorflow.python.training.experimental.loss_scale_optimizer import MixedPrecisionLossScaleOptimizer from tensorflow.python.training.momentum import MomentumOptimizer from tensorflow.python.training.moving_averages import ExponentialMovingAverage from tensorflow.python.training.optimizer import Optimizer @@ -143,4 +144,3 @@ tf_export(v1=["train.SaverDef"])(SaverDef) tf_export("train.SequenceExample")(SequenceExample) tf_export("train.ServerDef")(ServerDef) # pylint: enable=undefined-variable - diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-mixed-precision-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-mixed-precision-loss-scale-optimizer.pbtxt new file mode 100644 index 00000000000..4b1700fceb1 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.-mixed-precision-loss-scale-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.experimental.MixedPrecisionLossScaleOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt index 2761b489b96..6c4cd668bed 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.experimental.pbtxt @@ -1,5 +1,9 @@ path: "tensorflow.train.experimental" tf_module { + member { + name: "MixedPrecisionLossScaleOptimizer" + mtype: "" + } member { name: "PythonState" mtype: "" diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py index dd6db74baf6..2be49df93cc 100644 --- a/tensorflow/tools/compatibility/renames_v2.py +++ b/tensorflow/tools/compatibility/renames_v2.py @@ -1387,6 +1387,8 @@ renames = { 'tf.compat.v1.train.create_global_step', 'tf.train.do_quantize_training_on_graphdef': 'tf.compat.v1.train.do_quantize_training_on_graphdef', + 'tf.train.experimental.MixedPrecisionLossScaleOptimizer': + 'tf.compat.v1.train.experimental.MixedPrecisionLossScaleOptimizer', 'tf.train.exponential_decay': 'tf.compat.v1.train.exponential_decay', 'tf.train.export_meta_graph': diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 19c09ce0c1e..2ab804bc876 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -68,6 +68,8 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/compiler:compiler", "//tensorflow/python:cond_v2", "//tensorflow/python:distributed_framework_test_lib", + "//tensorflow/python:loss_scale", + "//tensorflow/python:loss_scale_optimizer", "//tensorflow/python:meta_graph_testdata", "//tensorflow/python:spectral_ops_test_util", "//tensorflow/python:util_example_parser_configuration",