From 83bfd7e23cbcd16403bef3ebc0787fdb655d93f6 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Mon, 8 Apr 2019 13:09:36 -0700 Subject: [PATCH 01/10] Add loss scale optimizer for v1 optimizers Co-authored-by: Nathan Luehr Co-authored-by: Ben Barsdell --- tensorflow/python/BUILD | 48 ++ tensorflow/python/training/loss_scale.py | 412 ++++++++++++++++++ .../python/training/loss_scale_optimizer.py | 235 ++++++++++ .../training/loss_scale_optimizer_test.py | 258 +++++++++++ tensorflow/python/training/loss_scale_test.py | 261 +++++++++++ tensorflow/python/training/training.py | 2 + 6 files changed, 1216 insertions(+) create mode 100644 tensorflow/python/training/loss_scale.py create mode 100644 tensorflow/python/training/loss_scale_optimizer.py create mode 100644 tensorflow/python/training/loss_scale_optimizer_test.py create mode 100644 tensorflow/python/training/loss_scale_test.py diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 962f5f4c63e..84953c43825 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2780,6 +2780,54 @@ py_library( ], ) +py_library( + name = "loss_scale", + srcs = ["training/loss_scale.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework", + ":keras_lib", + "@absl_py//absl/testing:parameterized", + ], +) + +py_library( + name = "loss_scale_optimizer", + srcs = ["training/loss_scale_optimizer.py"], + srcs_version = "PY2AND3", + deps = [ + ":loss_scale", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "loss_scale_optimizer_test", + size = "small", + srcs = ["training/loss_scale_optimizer_test.py"], + deps = [ + ":loss_scale_optimizer", + "//tensorflow/python/keras/mixed_precision/experimental:test_util", + "//tensorflow/python:client_testlib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:one_device_strategy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "loss_scale_test", + size = "small", + srcs = ["training/loss_scale_test.py"], + deps = [ + ":loss_scale", + "//tensorflow/python:client_testlib", + "//tensorflow/python/distribute:mirrored_strategy", + "//tensorflow/python/distribute:one_device_strategy", + "@absl_py//absl/testing:parameterized", + ], +) + py_library( name = "math_grad", srcs = ["ops/math_grad.py"], diff --git a/tensorflow/python/training/loss_scale.py b/tensorflow/python/training/loss_scale.py new file mode 100644 index 00000000000..94265e4974a --- /dev/null +++ b/tensorflow/python/training/loss_scale.py @@ -0,0 +1,412 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains LossScaler classes.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import six + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.keras import initializers +from tensorflow.python.keras.engine import base_layer_utils +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.training.tracking import base as trackable +from tensorflow.python.ops import math_ops +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import reduce_util +from tensorflow.python.eager import context +from tensorflow.python.util.tf_export import tf_export + + +@six.add_metaclass(abc.ABCMeta) +@tf_export(v1=['train.LossScale']) +class LossScale(trackable.Trackable): + """Base class to compute the loss scale. + + Loss scaling is a process that: + + 1) Applies a multiplier on the loss before computing gradients, and + 2) Applies the reciprocal of the multiplier on the gradients before they are + applied on variables. + + Mathematically, loss scaling has no effect, but can help avoid numerical + underflow when float16 tensors are used. By multiplying the loss, each + gradient will have the same multiplier applied. + + Instances of this class compute the loss scale. Method `get_loss_scale()` + returns the current loss scale, while method `update()` updates the loss scale + depending on the values of the gradients. Optimizers use instances of this + class to scale loss and gradients. + """ + + def __init__(self): + """Initializes the loss scale class. + + Note subclasses should create variables in build() instead of in the + constructor. This is because callers might choose to place variables on + a certain device by calling build() under a tf.device() scope. + """ + self.built = False + self._weights = {} + + def build(self): + """Builds the weights of the loss scale class. + + If weights are needed, subclasses should build weights by calling + `self.add_weight(...)`, then call the super's build to set self.built = + True. + """ + self.built = True + + def __call__(self): + """Returns the current loss scale as a scalar `float32` tensor.""" + if not self.built: + self.build() + return self._get_loss_scale() + + @abc.abstractmethod + def _get_loss_scale(self): + """Returns the loss scale without calling build(). + + Subclasses must implement this. Subclasses should not override the public + `__call__` method, which calls this method. + """ + pass + + def update(self, grads): + """Updates the value of the loss scale. + + The loss scale tensor will be potentially updated, based on the value of + `grads`. The tensor returned by `get_loss_scale` is only + updated when this function is evaluated. + + In eager mode, this directly updates the loss scale, so that calling + `get_loss_scale` will return the newly updated loss scale. In graph mode, + this returns an op that, when evaluated, updates the loss scale. + + This function also returns a `should_apply_gradients` bool. If False, + gradients should not be applied to the variables that step, as nonfinite + gradients were found, and the loss scale controller can update the loss + scale to reduce the chance of finding nonfinite gradients in the next step. + Some loss scale controllers will always return True, as they cannot adjust + the loss scale in response to nonfinite gradients. + + When a DistributionStrategy is used, this function may only be called in a + cross-replica context. + + Args: + grads: A list of unscaled gradients, each which is the gradient of the + loss with respect to a weight. The gradients should have already been + divided by the loss scale being before passed to this function. + + Returns: + update_op: In eager mode, None. In graph mode, an op to update the loss + scale. + should_apply_gradients: Either a bool or a scalar boolean tensor. If + False, the caller should skip applying `grads` to the variables this + step. + """ + if not self.built: + self.build() + return self._update(grads) + + @abc.abstractmethod + def _update(self, grads): + """Updates the value of the loss scale without calling build(). + + Subclasses must implement this. Subclasses should not override the public + `update_loss_scale` method, which calls this method. + + Args: + grads: A list of unscaled gradients. See `update_loss_scale`. + + Returns: + In eager mode, None. In graph mode, an op to update the loss scale. + """ + pass + + + def add_weight(self, + name, + shape=(), + dtype=None, + initializer='zeros'): + """Adds a weight to this loss scale manager.. + + This should be called by subclasses in `build()` to build the weights of the + loss scale class. + + Args: + name: Variable name. + shape: Variable shape. + dtype: The type of the variable. + initializer: The initializer to use. + + Returns: + A variable. + """ + if isinstance(initializer, six.string_types) or callable(initializer): + initializer = initializers.get(initializer) + variable = self._add_variable_with_custom_getter( + name=name, + shape=shape, + getter=base_layer_utils.make_variable, + overwrite=True, + initializer=initializer, + dtype=dtype, + trainable=False, + use_resource=True, + synchronization=variables.VariableSynchronization.AUTO, + # Set aggregation to NONE, as loss scaling variables should never be + # aggregated. + aggregation=variables.VariableAggregation.NONE) + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + + key = (graph_key, name) + if self._weights.get(key, None) is not None: + raise RuntimeError('Duplicate variables detected. {}'.format(key)) + self._weights[key] = variable + self._handle_deferred_dependencies(name=name, trackable=variable) + return variable + + @property + def _checkpoint_dependencies(self): + """From Trackable. Gather graph-specific weights to save.""" + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + weights = [trackable.TrackableReference(name=name, ref=v) + for (g, name), v in sorted( + self._weights.items(), key=lambda i: i[0][1]) + if g == graph_key] + return super(LossScale, self)._checkpoint_dependencies + weights + + def _lookup_dependency(self, name): + """From Trackable. Find a weight in the current graph.""" + unconditional = super(LossScale, self)._lookup_dependency(name) + if unconditional is not None: + return unconditional + if context.executing_eagerly(): + graph_key = None + else: + graph = ops.get_default_graph() + graph_key = graph._graph_key # pylint: disable=protected-access + return self._weights.get((graph_key, name), None) + +@tf_export(v1=['train.FixedLossScale']) +class FixedLossScale(LossScale): + """Loss scale class with a fixed value. + + The loss scale is not updated for the lifetime of the class. + """ + + def __init__(self, loss_scale_value): + """Creates the fixed loss scale. + + Args: + loss_scale: A Python float. Its ideal value varies depending on models to + run. Choosing a too small loss_scale might affect model quality; a too + big loss_scale might cause inf or nan. There is no single right + loss_scale to apply. There is no harm choosing a relatively big number + as long as no nan or inf is encountered in training. + + Raises: + ValueError: If loss_scale is less than 1. + """ + super(FixedLossScale, self).__init__() + if not isinstance(loss_scale_value, six.integer_types + (float,)): + raise ValueError('loss_scale must be a Python int or float.') + if loss_scale_value < 1: + raise ValueError('loss scale must be at least 1.') + self._tensor_loss_scale = ops.convert_to_tensor(loss_scale_value, + dtype=dtypes.float32) + + def _get_loss_scale(self): + return self._tensor_loss_scale + + def _update(self, grads): + del grads + return control_flow_ops.no_op(), True + + +def _is_all_finite(grads): + """Returns a scalar boolean tensor indicating if all gradients are finite.""" + is_finite_per_grad = [math_ops.reduce_all(math_ops.is_finite(g)) + for g in grads] + return math_ops.reduce_all(is_finite_per_grad) + + +def _op_in_graph_mode(tensor): + """Returns the tensor's op in graph mode, or the tensor in eager mode. + + This is useful because sometimes an op is needed in graph mode instead of a + tensor. In eager mode, there are no ops. + + Args: + tensor: A tensor. + + Returns: + The tensor's op in graph mode. The tensor in eager mode. + """ + if context.executing_eagerly(): + return tensor + return tensor.op + + +def _assign_if_finite(var, value): + """Assigns a value to a variable if the value is finite.""" + return control_flow_ops.cond( + math_ops.is_finite(value), + lambda: _op_in_graph_mode(var.assign(value)), + control_flow_ops.no_op) + + +@tf_export(v1=['train.DynamicLossScale']) +class DynamicLossScale(LossScale): + """Loss scale class that dynamically adjusts the loss scale. + + Dynamic loss scaling works by adjusting the loss scale as training progresses. + The goal is to keep the loss scale as high as possible without overflowing the + gradients. As long as the gradients do not overflow, raising the loss scale + never hurts. + + The algorithm starts by setting the loss scale to an initial value. Every N + steps that the gradients are finite, the loss scale is increased by some + factor. However, if a NaN or Inf gradient is found, the gradients for that + step are not applied, and the loss scale is decreased by the factor. This + process tends to keep the loss scale as high as possible without gradients + overflowing. + """ + + def __init__(self, + initial_loss_scale=2 ** 15, + increment_period=2000, + multiplier=2.): + """Constructor of exponential-update loss scale class. + + Args: + initial_loss_scale: A Python float. The loss scale to use at the + beginning. It's better to start this at a very high number, because a + loss scale that is too high gets lowered far more quickly than a loss + scale that is to low gets raised. The default is 2 ** 15, which is + approximately half the maximum float16 value. + incr_every_n_steps: Increases loss scale every `incr_every_n_steps` + consecutive steps that finite gradients are encountered. If a nonfinite + gradient is encountered, the count is reset back to zero. + loss_scale_multiplier: The multiplier to use when increasing or decreasing + the loss scale. + """ + super(DynamicLossScale, self).__init__() + self._initial_loss_scale = float(initial_loss_scale) + self._increment_period = int(increment_period) + self._multiplier = float(multiplier) + + + @property + def initial_loss_scale(self): + return self._initial_loss_scale + + @property + def increment_period(self): + return self._increment_period + + @property + def multiplier(self): + return self._multiplier + + def build(self): + self._current_loss_scale = self.add_weight( + name='loss_scale', + dtype=dtypes.float32, + initializer=self._initial_loss_scale) + self._num_good_steps = self.add_weight( + name='good_steps', dtype=dtypes.int64, initializer='zeros') + self.built = True + + def _get_loss_scale(self): + return self._current_loss_scale + + def _update(self, grads): + """Updates loss scale based on if gradients are finite in current step.""" + if distribution_strategy_context.has_strategy(): + distribution = distribution_strategy_context.get_cross_replica_context() + def get_is_finite(grads): + is_finite = _is_all_finite(grads) + # We cast to float, because we cannot reduce booleans with + # DistributionStrategy. + return math_ops.cast(is_finite, dtypes.float32) + is_finite_float = distribution.extended.call_for_each_replica( + get_is_finite, args=(grads,)) + reduced_is_finite_float = distribution.reduce(reduce_util.ReduceOp.SUM, + is_finite_float) + is_finite = math_ops.equal(reduced_is_finite_float, + distribution.num_replicas_in_sync) + else: + is_finite = _is_all_finite(grads) + + def update_if_finite_grads(): + """Update assuming the gradients are finite.""" + + def incr_loss_scale(): + new_loss_scale = self._current_loss_scale * self._multiplier + return control_flow_ops.group( + _assign_if_finite(self._current_loss_scale, new_loss_scale), + self._num_good_steps.assign(0)) + + return control_flow_ops.cond( + self._num_good_steps + 1 >= self._increment_period, + incr_loss_scale, + lambda: _op_in_graph_mode(self._num_good_steps.assign_add(1))) + + def update_if_not_finite_grads(): + """Update assuming the gradients are nonfinite.""" + + new_loss_scale = math_ops.maximum( + self._current_loss_scale / self._multiplier, 1) + return control_flow_ops.group( + self._num_good_steps.assign(0), + self._current_loss_scale.assign(new_loss_scale) + ) + + + update_op = control_flow_ops.cond(is_finite, update_if_finite_grads, + update_if_not_finite_grads) + should_apply_gradients = is_finite + return update_op, should_apply_gradients + + +def get(identifier): + """Get a loss scale object.""" + if isinstance(identifier, six.integer_types + (float,)): + return FixedLossScale(identifier) + if identifier == 'dynamic': + return DynamicLossScale() + if isinstance(identifier, LossScale): + return identifier + elif identifier is None: + return None + else: + raise ValueError('Could not interpret loss scale identifier: %s' + % identifier) diff --git a/tensorflow/python/training/loss_scale_optimizer.py b/tensorflow/python/training/loss_scale_optimizer.py new file mode 100644 index 00000000000..d70e19d8b5b --- /dev/null +++ b/tensorflow/python/training/loss_scale_optimizer.py @@ -0,0 +1,235 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains LossScale classes.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.training import loss_scale as loss_scale_module +from tensorflow.python.training import optimizer +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.util.tf_export import tf_export + + +@tf_export(v1=['train.LossScaleOptimizer']) +class LossScaleOptimizer(optimizer.Optimizer): + """An optimizer that applies loss scaling. + + The loss scale can either be a fixed constant, chosen by the user, or be + dynamically determined. Dynamically determining the loss scale is convenient + as a loss scale does not have to be explicitly chosen. However it reduces + performance. + + This optimizer wraps another optimizer and applies loss scaling to it. Loss + scaling is applied whenever gradients are computed. + + Args: + opt: The Optimizer instance to wrap. + loss_scale: The loss scale or LossScale class to scale the loss and + gradients. This can either be an int/float to use a fixed loss scale, + the string "dynamic" to use dynamic loss scaling, or an instance of a + LossScale class. The string "dynamic" is equivalent to passing + `DynamicLossScale()`, and passing an int/float is equivalent + to passing a FixedLossScale instance with the given loss scale. + """ + def __init__(self, opt, loss_scale): + if not isinstance(opt, optimizer.Optimizer): + raise ValueError('"opt" must be an instance of Optimizer, but got: %s' % + type(opt)) + self._optimizer = opt + + use_locking = opt._use_locking # pylint: disable=protected-access + name = opt.get_name() + super(LossScaleOptimizer, self).__init__(use_locking, name) + + self._loss_scale = loss_scale_module.get(loss_scale) + self._track_trackable(self._optimizer, 'base_optimizer') + self._track_trackable(self._loss_scale, 'loss_scale') + + + def _doing_dynamic_loss_scaling(self): + """Check if `_loss_scale` dynamically manages the loss scale.""" + return isinstance(self._loss_scale, + loss_scale_module.DynamicLossScale) + + + def compute_gradients(self, loss, var_list=None, + gate_gradients=optimizer.Optimizer.GATE_OP, + aggregation_method=None, + colocate_gradients_with_ops=False, + grad_loss=None): + """Compute gradients of `loss` for the variables in `var_list`. + + This adjusts the dynamic range of the gradient evalutaion by scaling up + the `loss` value. The gradient values are then scaled back down by the + recipricol of the loss scale. This is useful in reduced precision training + where small gradient values would otherwise underflow the representable + range. + + Args: + loss: A Tensor containing the value to minimize or a callable taking + no arguments which returns the value to minimize. When eager execution + is enabled it must be a callable. + var_list: Optional list or tuple of `tf.Variable` to update to minimize + `loss`. Defaults to the list of variables collected in the graph + under the key `GraphKeys.TRAINABLE_VARIABLES`. + gate_gradients: How to gate the computation of gradients. Can be + `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`. + aggregation_method: Specifies the method used to combine gradient terms. + Valid values are defined in the class `AggregationMethod`. + colocate_gradients_with_ops: If True, try colocating gradients with + the corresponding op. + grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`. + + Returns: + A list of (gradient, variable) pairs. Variable is always present, but + gradient can be `None`. + """ + loss = self._scale_loss(loss) + grads_and_vars = self._optimizer.compute_gradients( + loss=loss, var_list=var_list, gate_gradients=gate_gradients, + aggregation_method=aggregation_method, + colocate_gradients_with_ops=colocate_gradients_with_ops, + grad_loss=grad_loss) + + grads = [g for g, _ in grads_and_vars] + variables = [v for _, v in grads_and_vars] + scaled_grads = self._scale_grads(grads) + return list(zip(scaled_grads, variables)) + + def _scale_loss(self, loss): + # The loss is callable for `_compute_gradients`, but not `get_gradients`. + loss_scale = self._loss_scale() + if callable(loss): + return lambda: loss() * loss_scale + return loss * loss_scale + + def _scale_grads(self, grads): + loss_scale = self._loss_scale() + loss_scale_reciprical = 1 / loss_scale + return [None if g is None else self._indexed_slices( + g, loss_scale_reciprical) for g in grads] + + def _indexed_slices(self, grad, loss_scale_reciprical): + if isinstance(grad, ops.IndexedSlices): + grad_vals = grad.values * loss_scale_reciprical + return ops.IndexedSlices(grad_vals, grad.indices, grad.dense_shape) + return grad * loss_scale_reciprical + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + """Apply gradients to variables. + + This is the second part of `minimize()`. It returns an `Operation` that + conditionally applies gradients if all gradient values are finite. + Otherwise no update is performed (nor is `global_step` incremented). + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()`. + global_step: Optional `Variable` to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the name + passed to the `Optimizer` constructor. + + Returns: + An `Operation` that conditionally applies the specified gradients. If + `global_step` was not None, that operation also increments `global_step`. + + Raises: + RuntimeError: If you should use `_distributed_apply()` instead. + """ + if distribution_strategy_context.in_cross_replica_context(): + raise ValueError('apply_gradients() must be called in a replica context.') + + if not self._doing_dynamic_loss_scaling(): + return self._optimizer.apply_gradients(grads_and_vars, global_step, name) + + replica_context = distribution_strategy_context.get_replica_context() + + # TODO(nluehr) cleanup GraphKeys.TRAIN_OP + return replica_context.merge_call( + self._maybe_apply_gradients_cross_replica, + args=(grads_and_vars, global_step, name)) + + def _distributed_apply(self, + distribution, + grads_and_vars, + global_step=None, + name=None): + """A version of `apply_gradients` for cross replica context. + + When users are in a cross replica strategy, they must call this rather than + `apply_gradients()`. + + Args: + distribution: a `DistributionStrategy` object. + grads_and_vars: List of (gradient, variable) pairs as returned by + `compute_gradients()` and then aggregated across replicas. + global_step: Optional (mirrored) `Variable` to increment by one + after the variables have been updated. + name: Optional name for the returned operation. Default to the name + passed to the `Optimizer` constructor. + + Returns: + An `Operation` that applies the specified gradients across all + replicas. If `global_step` was not None, that operation also + increments `global_step` + """ + self._maybe_apply_gradients_cross_replica(distribution, grads_and_vars, + global_step, name) + + def _maybe_apply_gradients_cross_replica(self, distribution, grads_and_vars, + global_step, name): + """Conditionally apply gradients in cross replica context.""" + name = name if name is not None else self.get_name() + grads = [g for g, _ in grads_and_vars] + loss_scale_update_op, should_apply_grads = ( + self._loss_scale.update(grads)) + maybe_apply_op = smart_cond.smart_cond( + should_apply_grads, + lambda: self._apply_gradients_cross_replica(distribution, + grads_and_vars, + global_step, + name+'-wrapped'), + control_flow_ops.no_op) + return control_flow_ops.group(maybe_apply_op, loss_scale_update_op, + name=name) + + def _apply_gradients_cross_replica(self, distribution, grads_and_vars, + global_step, name): + """Unconditionally apply gradients in cross replica context.""" + update_ops = distribution.extended.call_for_each_replica( + self._optimizer.apply_gradients, + args=(grads_and_vars, global_step, name)) + return distribution.group(update_ops) + + def _apply_sparse(self, grad, var): + """This function should never be called""" + raise RuntimeError("This function should never be called") + + def _apply_dense(self, grad, var): + """This function should never be called""" + raise RuntimeError("This function should never be called") + + def _resource_apply_sparse(self, grad, handle, indices): + """This function should never be called""" + raise RuntimeError("This function should never be called") + + def _resource_apply_dense(self, grad, handle): + """This function should never be called""" + raise RuntimeError("This function should never be called") diff --git a/tensorflow/python/training/loss_scale_optimizer_test.py b/tensorflow/python/training/loss_scale_optimizer_test.py new file mode 100644 index 00000000000..3ee38cb87c5 --- /dev/null +++ b/tensorflow/python/training/loss_scale_optimizer_test.py @@ -0,0 +1,258 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LossScaleOptimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl.testing import parameterized + +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util +from tensorflow.python.training import loss_scale as loss_scale_module +from tensorflow.python.training import loss_scale_optimizer +from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import momentum +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training.tracking import util as trackable_utils + + +# If called outside any strategy.scope() calls, this will return the default +# strategy. +default_strategy_fn = distribution_strategy_context.get_strategy + + +def create_mirrored_strategy(): + if context.num_gpus() >= 1: + return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0']) + else: + return mirrored_strategy.MirroredStrategy(['cpu:0']) + + +TESTCASES = ({ + 'testcase_name': 'Base', + 'strategy_fn': default_strategy_fn +}, { + 'testcase_name': 'Distribute', + 'strategy_fn': create_mirrored_strategy +}) + +def get_gradients(opt, loss, params): + grads_and_vars = opt.compute_gradients(loss, params) + grads, _ = zip(*grads_and_vars) + return grads + +class LossScaleOptimizerTest(test.TestCase, parameterized.TestCase): + + def _run_if_in_graph_mode(self, val): + # Running only in graph mode is useful, because optimizers sometimes return + # a value that, in Graph mode, is runnable with self.evaluate. But in Eager + # mode, the optimizer already does the computations and the return value + # cannot be run. + if not context.executing_eagerly(): + self.evaluate(val) + + def _run_fn_with_grad_check(self, strategy, var, opt, expected_grad): + grad_check_fn = mp_test_util.create_identity_with_grad_check_fn( + expected_grad) + loss = lambda: grad_check_fn(var) / strategy.num_replicas_in_sync + return lambda: opt.minimize(loss, var_list=[var]) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def testFixedLossScaleAppliedToLossWithMinimize(self, strategy_fn): + with strategy_fn().scope() as strategy: + var = variables.Variable([5.0]) + opt = gradient_descent.GradientDescentOptimizer(2.0) + loss_scale = 10. + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + # We need num_replicas_in_sync to divide loss_scale, otherwise loss_scale + # / strategy.num_replicas_in_sync will not be exact, which could lead to + # assertion failures due to rounding issues. + self.assertEqual(loss_scale % strategy.num_replicas_in_sync, 0) + run_fn = self._run_fn_with_grad_check( + strategy, var, opt, loss_scale / strategy.num_replicas_in_sync) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # The loss is the identity of the variable. Therefore the gradient is 1, + # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3 + self.assertAllClose([3.], self.evaluate(var)) + + @test_util.deprecated_graph_mode_only + def testFixedLossScaleAppliedToLossWithGetGradients(self): + var = variables.Variable([2.0]) + opt = gradient_descent.GradientDescentOptimizer(1.0) + loss_scale = 10. + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + grad_check_fn = mp_test_util.create_identity_with_grad_check_fn(loss_scale) + loss = grad_check_fn(var) + run_op = get_gradients(opt, loss, [var]) + self.evaluate(variables.global_variables_initializer()) + # This will cause an assertion to run, as + # mp_test_util.create_identity_with_grad_check_fn added an assertion op. + self.evaluate(run_op) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def testDynamicLossScale(self, strategy_fn): + strategy = strategy_fn() + learning_rate = 2. + expected_gradient = resource_variable_ops.ResourceVariable( + learning_rate / strategy.num_replicas_in_sync) + with strategy.scope(): + var = variables.Variable([5.0]) + opt = gradient_descent.GradientDescentOptimizer(learning_rate) + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=2, increment_period=1, multiplier=2) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + self.assertEqual( + loss_scale.initial_loss_scale % strategy.num_replicas_in_sync, 0) + + run_fn = self._run_fn_with_grad_check(strategy, var, opt, + expected_gradient) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # The loss is the identity of the variable. Therefore the gradient is 1, + # and so the variable will be init_val - grad * lr == 5 - 1 * 2 == 3 + self.assertAllClose([3.], self.evaluate(var)) + + # Loss scale will be double, so the expected gradient is also doubled. + self.evaluate(expected_gradient.assign( + 2 * learning_rate / strategy.num_replicas_in_sync)) + run_op = strategy.experimental_run(run_fn) + self._run_if_in_graph_mode(run_op) + # As before, the 2 is subtracted from the variable, making it's new value + # 1. + self.assertAllClose([1.], self.evaluate(var)) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def testDynamicUpdate(self, strategy_fn): + with strategy_fn().scope() as strategy: + var = variables.Variable([1.0, 2.0]) + opt = gradient_descent.GradientDescentOptimizer(1.0) + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=2, increment_period=1, multiplier=2) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + + # Test optimizer with finite gradients + loss = lambda: var * 2.0 / strategy.num_replicas_in_sync + run_fn = lambda: opt.minimize(loss, var_list=[var]) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # Gradient is 2, so variable will have 2 subtracted from it + self.assertAllClose([-1.0, 0.0], self.evaluate(var)) + # Loss scale has doubled from 2 to 4 + self.assertEqual(4., self.evaluate(opt._loss_scale())) + + # Test optimizer with NaN gradients + loss = lambda: var * float('NaN') + run_fn = lambda: opt.minimize(loss, var_list=[var]) + run_op = strategy.experimental_run(run_fn) + self._run_if_in_graph_mode(run_op) + # Variable should not change from before, due to NaN gradients. + self.assertAllClose(self.evaluate(var), [-1.0, 0.0]) + # Loss scale should half due to NaN gradients. + self.assertEqual(2., self.evaluate(opt._loss_scale())) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def testDynamicLossScaleWithSlots(self, strategy_fn): + with strategy_fn().scope() as strategy: + var = variables.Variable([1.0, 2.0]) + # An SGD optimizer with momentum has slot variables. + opt = momentum.MomentumOptimizer(1.0, momentum=1.) + initial_loss_scale = 2. + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=initial_loss_scale, increment_period=1, + multiplier=4) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + loss = lambda: var / strategy.num_replicas_in_sync + run_fn = lambda: opt.minimize(loss, var_list=[var]) + run_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self._run_if_in_graph_mode(run_op) + # The momentum accumulator starts at 0 and the gradient is 1. The + # accumulator is incremented by the gradient, so it is now 1. Then the + # variable is subtracted by the accumulator, so the variable is subtracted + # by 1. + self.assertAllClose([0.0, 1.0], self.evaluate(var)) + self.assertEqual(self.evaluate(opt._loss_scale()), initial_loss_scale * 4) + + run_op = strategy.experimental_run(run_fn) + self._run_if_in_graph_mode(run_op) + # The momentum accumulator was 1 before this step and the gradient is 1. + # The accumulator is incremented by the gradient, so it is now 2. Then the + # variable is subtracted by the accumulator, so the variable is subtracted + # by 2. + self.assertAllClose([-2., -1.], self.evaluate(var)) + self.assertEqual(self.evaluate(opt._loss_scale()), + initial_loss_scale * 16) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def testCheckpoint(self, strategy_fn): + strategy = strategy_fn() + if (isinstance(strategy, mirrored_strategy.MirroredStrategy) and + not context.executing_eagerly()): + # TODO(b/121381184): Enable running the test in this case. + return + + with self.test_session(), strategy.scope(): + # Build and run a simple model. + var = variables.Variable([2.0]) + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=1., increment_period=2., + multiplier=2.) + opt = momentum.MomentumOptimizer(1.0, momentum=1.) + opt = loss_scale_optimizer.LossScaleOptimizer(opt, loss_scale) + run_fn = lambda: opt.minimize(lambda: var + 1., var_list=[var]) + opt_op = strategy.experimental_run(run_fn) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(opt_op) + self.assertEqual(self.evaluate(loss_scale()), 1.) + self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) + + # Save a checkpoint. + checkpoint = trackable_utils.Checkpoint(optimizer=opt) + prefix = os.path.join(self.get_temp_dir(), 'ckpt') + save_path = checkpoint.save(prefix) + + # Run model again. + self.evaluate(strategy.experimental_run(run_fn)) + self.assertEqual(self.evaluate(loss_scale()), 2.) + self.assertEqual(self.evaluate(loss_scale._num_good_steps), 0) + + # Load checkpoint and ensure loss scale is back to it's original value. + status = checkpoint.restore(save_path) + status.assert_consumed() + status.run_restore_ops() + self.assertEqual(self.evaluate(loss_scale()), 1.) + self.assertEqual(self.evaluate(loss_scale._num_good_steps), 1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/training/loss_scale_test.py b/tensorflow/python/training/loss_scale_test.py new file mode 100644 index 00000000000..edf062cf762 --- /dev/null +++ b/tensorflow/python/training/loss_scale_test.py @@ -0,0 +1,261 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for LossScale classes..""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.distribute import distribution_strategy_context +from tensorflow.python.distribute import mirrored_strategy +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.training import loss_scale as loss_scale_module +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +# If called outside any strategy.scope() calls, this will return the default +# strategy. +default_strategy_fn = distribution_strategy_context.get_strategy + + +def create_mirrored_strategy(): + if context.num_gpus() >= 1: + return mirrored_strategy.MirroredStrategy(['cpu:0', 'gpu:0']) + else: + return mirrored_strategy.MirroredStrategy(['cpu:0']) + + +TESTCASES = ({ + 'testcase_name': 'base', + 'strategy_fn': default_strategy_fn +}, { + 'testcase_name': 'distribute', + 'strategy_fn': create_mirrored_strategy +}) + + +class FixedLossScaleTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes + def test_basic(self): + loss_scale_value = 1000 + loss_scale = loss_scale_module.FixedLossScale(loss_scale_value) + + update_op, should_apply = loss_scale.update([constant_op.constant(0.)]) + self.evaluate(update_op) + # should_apply should be a bool instead of a tensor, so that a tf.cond does + # not have to be built in the graph by the caller. + self.assertIsInstance(should_apply, bool) + self.assertTrue(should_apply) + self.assertEqual(loss_scale_value, self.evaluate(loss_scale())) + + update_op, should_apply = loss_scale.update( + [constant_op.constant(float('NaN'))]) + self.evaluate(update_op) + self.assertIsInstance(should_apply, bool) + self.assertTrue(should_apply) + self.assertEqual(loss_scale_value, self.evaluate(loss_scale())) + + +def _get_example_iter(inputs): + dataset = dataset_ops.Dataset.from_tensor_slices(inputs) + return dataset_ops.make_one_shot_iterator(dataset) + + +class DynamicLossScaleTest(test.TestCase, parameterized.TestCase): + + def _get_tensor(self, is_finite): + tensor = control_flow_ops.cond(is_finite, lambda: 1., lambda: float('NaN')) + + if not distribution_strategy_context.has_strategy(): + return tensor + def get(): + rep_id = (distribution_strategy_context.get_replica_context() + .replica_id_in_sync_group) + return control_flow_ops.cond(math_ops.equal(rep_id, 0), lambda: tensor, + lambda: 1.) + distribution = distribution_strategy_context.get_strategy() + return distribution.extended.call_for_each_replica(get) + + def _test_helper(self, + inputs, + expected_outputs, + initial_loss_scale=1., + increment_period=2, + multiplier=2): + loss_scale = loss_scale_module.DynamicLossScale( + initial_loss_scale=initial_loss_scale, + increment_period=increment_period, + multiplier=multiplier) + itr = _get_example_iter(inputs) + + def update(): + is_finite = itr.get_next() + grad = self._get_tensor(is_finite) + update_op, should_apply_gradients = loss_scale.update([grad]) + assert_op = check_ops.assert_equal(should_apply_gradients, is_finite) + if context.executing_eagerly(): + return + with ops.control_dependencies([assert_op]): + return array_ops.identity(update_op) + + actual_outputs = [] + + if not context.executing_eagerly(): + update_op = update() + self.evaluate(variables.global_variables_initializer()) + for _ in range(len(inputs)): + if context.executing_eagerly(): + update() + else: + self.evaluate(update_op) + actual_outputs.append(self.evaluate(loss_scale())) + self.assertEqual(actual_outputs, expected_outputs) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_increase(self, strategy_fn): + with strategy_fn().scope(): + inputs = [True] * 6 + expected_outputs = [1, 2, 2, 4, 4, 8] + self._test_helper(inputs, expected_outputs) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_keep_increasing_until_capped(self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = np.finfo(np.float32).max / 4 + max_float = np.finfo(np.float32).max + + inputs = [True] * 6 + # Output is capped the 2nd time it doubles. + expected_outputs = [ + init_loss_scale, init_loss_scale * 2, init_loss_scale * 2, max_float, + max_float, max_float + ] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_decrease_every_step(self, strategy_fn): + with strategy_fn().scope(): + inputs = [False] * 6 + init_loss_scale = 1024 + expected_outputs = [512, 256, 128, 64, 32, 16] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_keep_decreasing_until_one(self, strategy_fn): + with strategy_fn().scope(): + inputs = [False] * 6 + init_loss_scale = 16 + expected_outputs = [8, 4, 2, 1, 1, 1] + + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_nan_clear_good_step(self, strategy_fn): + with strategy_fn().scope(): + inputs = [True, True, True, False, True] + expected_outputs = [1, 2, 2, 1, 1] + self._test_helper(inputs, expected_outputs) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_trigger_loss_scale_update_each_step(self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = 1 + increment_period = 1 + + inputs = [True] * 3 + [False, True, True] + expected_outputs = [2, 4, 8, 4, 8, 16] + + self._test_helper(inputs, expected_outputs, init_loss_scale, + increment_period) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_alternating_good_and_bad_gradients_trigger_each_step(self, + strategy_fn): + with strategy_fn().scope(): + init_loss_scale = 1 + increment_period = 1 + + inputs = [True, False] * 4 + [True] + expected_outputs = [2, 1, 2, 1, 2, 1, 2, 1, 2] + self._test_helper(inputs, expected_outputs, init_loss_scale, + increment_period) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_alternating_good_and_bad_gradients_trigger_every_other_step( + self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = 32 + increment_period = 2 + + inputs = [True, False] * 3 + [True] + expected_outputs = [32, 16, 16, 8, 8, 4, 4] + self._test_helper(inputs, expected_outputs, init_loss_scale, + increment_period) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_nondefault_multiplier(self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = 4 + multiplier = 3 + inputs = [True, True, False, True, True] + expected_outputs = [4, 12, 4, 4, 12] + self._test_helper(inputs, expected_outputs, init_loss_scale, + multiplier=multiplier) + + @parameterized.named_parameters(*TESTCASES) + @test_util.run_in_graph_and_eager_modes + def test_random_mix_good_and_bad_gradients(self, strategy_fn): + with strategy_fn().scope(): + init_loss_scale = 4 + inputs = [ + False, True, True, True, False, True, False, True, True, True, False + ] + expected_outputs = [2, 2, 4, 4, 2, 2, 1, 1, 2, 2, 1] + self._test_helper(inputs, expected_outputs, init_loss_scale) + + @test_util.run_in_graph_and_eager_modes + def test_get(self): + scalar = loss_scale_module.get('dynamic') + scalar2 = loss_scale_module.DynamicLossScale() + self.assertEqual(scalar.initial_loss_scale, scalar2.initial_loss_scale) + self.assertEqual(scalar.increment_period, scalar2.increment_period) + self.assertEqual(scalar.multiplier, + scalar2.multiplier) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 9f509ae0a38..368dab300be 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -33,6 +33,8 @@ from tensorflow.python.training.adagrad_da import AdagradDAOptimizer from tensorflow.python.training.proximal_adagrad import ProximalAdagradOptimizer from tensorflow.python.training.adam import AdamOptimizer from tensorflow.python.training.ftrl import FtrlOptimizer +from tensorflow.python.training.loss_scale import FixedLossScale, DynamicLossScale +from tensorflow.python.training.loss_scale_optimizer import LossScaleOptimizer from tensorflow.python.training.momentum import MomentumOptimizer from tensorflow.python.training.moving_averages import ExponentialMovingAverage from tensorflow.python.training.optimizer import Optimizer From bc68bc0da746a048332a5b38a1a6d0edf5ad7221 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Tue, 9 Apr 2019 10:09:56 -0700 Subject: [PATCH 02/10] Update docstring and comments --- .../python/training/loss_scale_optimizer.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/tensorflow/python/training/loss_scale_optimizer.py b/tensorflow/python/training/loss_scale_optimizer.py index d70e19d8b5b..9e85cc61cb5 100644 --- a/tensorflow/python/training/loss_scale_optimizer.py +++ b/tensorflow/python/training/loss_scale_optimizer.py @@ -30,22 +30,30 @@ from tensorflow.python.util.tf_export import tf_export class LossScaleOptimizer(optimizer.Optimizer): """An optimizer that applies loss scaling. + Loss scaling is a process that multiplies the loss by a multiplier called the + loss scale, and divides each gradient by the same multiplier. The pseudocode + for this process is: + + ``` + loss = ... + loss *= loss_scale + grads = gradients(loss, vars) + grads /= loss_scale + ``` + + Mathematically, loss scaling has no effect, but can help avoid numerical + underflow in intermediate gradients when float16 tensors are used. By + multiplying the loss, each intermediate gradient will have the same multiplier + applied. + The loss scale can either be a fixed constant, chosen by the user, or be dynamically determined. Dynamically determining the loss scale is convenient as a loss scale does not have to be explicitly chosen. However it reduces performance. - This optimizer wraps another optimizer and applies loss scaling to it. Loss - scaling is applied whenever gradients are computed. - - Args: - opt: The Optimizer instance to wrap. - loss_scale: The loss scale or LossScale class to scale the loss and - gradients. This can either be an int/float to use a fixed loss scale, - the string "dynamic" to use dynamic loss scaling, or an instance of a - LossScale class. The string "dynamic" is equivalent to passing - `DynamicLossScale()`, and passing an int/float is equivalent - to passing a FixedLossScale instance with the given loss scale. + This optimizer wraps another optimizer and applies loss scaling to it via a + `LossScale`. Loss scaling is applied whenever gradients are + computed, such as through `minimize()`. """ def __init__(self, opt, loss_scale): if not isinstance(opt, optimizer.Optimizer): @@ -113,7 +121,6 @@ class LossScaleOptimizer(optimizer.Optimizer): return list(zip(scaled_grads, variables)) def _scale_loss(self, loss): - # The loss is callable for `_compute_gradients`, but not `get_gradients`. loss_scale = self._loss_scale() if callable(loss): return lambda: loss() * loss_scale From cd4bd9d1deb7bbf274c0ea7b116cc04c25743003 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Tue, 9 Apr 2019 10:13:47 -0700 Subject: [PATCH 03/10] Update function names for V2 consistency --- tensorflow/python/training/loss_scale_optimizer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/training/loss_scale_optimizer.py b/tensorflow/python/training/loss_scale_optimizer.py index 9e85cc61cb5..bf62633bc7a 100644 --- a/tensorflow/python/training/loss_scale_optimizer.py +++ b/tensorflow/python/training/loss_scale_optimizer.py @@ -129,10 +129,10 @@ class LossScaleOptimizer(optimizer.Optimizer): def _scale_grads(self, grads): loss_scale = self._loss_scale() loss_scale_reciprical = 1 / loss_scale - return [None if g is None else self._indexed_slices( + return [None if g is None else self._scale_grad( g, loss_scale_reciprical) for g in grads] - def _indexed_slices(self, grad, loss_scale_reciprical): + def _scale_grad(self, grad, loss_scale_reciprical): if isinstance(grad, ops.IndexedSlices): grad_vals = grad.values * loss_scale_reciprical return ops.IndexedSlices(grad_vals, grad.indices, grad.dense_shape) @@ -170,7 +170,7 @@ class LossScaleOptimizer(optimizer.Optimizer): # TODO(nluehr) cleanup GraphKeys.TRAIN_OP return replica_context.merge_call( - self._maybe_apply_gradients_cross_replica, + self._apply_gradients_cross_replica, args=(grads_and_vars, global_step, name)) def _distributed_apply(self, @@ -197,10 +197,10 @@ class LossScaleOptimizer(optimizer.Optimizer): replicas. If `global_step` was not None, that operation also increments `global_step` """ - self._maybe_apply_gradients_cross_replica(distribution, grads_and_vars, + self._apply_gradients_cross_replica(distribution, grads_and_vars, global_step, name) - def _maybe_apply_gradients_cross_replica(self, distribution, grads_and_vars, + def _apply_gradients_cross_replica(self, distribution, grads_and_vars, global_step, name): """Conditionally apply gradients in cross replica context.""" name = name if name is not None else self.get_name() From c76cef535562d9b2c6d790ddfc20eb0df06a031b Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Tue, 9 Apr 2019 17:45:12 -0700 Subject: [PATCH 04/10] Change path to loss scaling code And revert _maybe_apply_gradients_cross_replica function rename --- tensorflow/python/BUILD | 8 ++++---- .../python/tools/api/generator/api_init_files_v1.bzl | 2 ++ .../{ => mixed_precision/experimental}/loss_scale.py | 6 +++--- .../experimental}/loss_scale_optimizer.py | 10 +++++----- .../experimental}/loss_scale_optimizer_test.py | 4 ++-- .../experimental}/loss_scale_test.py | 2 +- tensorflow/python/training/training.py | 2 -- 7 files changed, 17 insertions(+), 17 deletions(-) rename tensorflow/python/training/{ => mixed_precision/experimental}/loss_scale.py (98%) rename tensorflow/python/training/{ => mixed_precision/experimental}/loss_scale_optimizer.py (96%) rename tensorflow/python/training/{ => mixed_precision/experimental}/loss_scale_optimizer_test.py (98%) rename tensorflow/python/training/{ => mixed_precision/experimental}/loss_scale_test.py (98%) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 84953c43825..56fcf82fe80 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2782,7 +2782,7 @@ py_library( py_library( name = "loss_scale", - srcs = ["training/loss_scale.py"], + srcs = ["training/mixed_precision/experimental/loss_scale.py"], srcs_version = "PY2AND3", deps = [ "//tensorflow/python:framework", @@ -2793,7 +2793,7 @@ py_library( py_library( name = "loss_scale_optimizer", - srcs = ["training/loss_scale_optimizer.py"], + srcs = ["training/mixed_precision/experimental/loss_scale_optimizer.py"], srcs_version = "PY2AND3", deps = [ ":loss_scale", @@ -2804,7 +2804,7 @@ py_library( py_test( name = "loss_scale_optimizer_test", size = "small", - srcs = ["training/loss_scale_optimizer_test.py"], + srcs = ["training/mixed_precision/experimental/loss_scale_optimizer_test.py"], deps = [ ":loss_scale_optimizer", "//tensorflow/python/keras/mixed_precision/experimental:test_util", @@ -2818,7 +2818,7 @@ py_test( py_test( name = "loss_scale_test", size = "small", - srcs = ["training/loss_scale_test.py"], + srcs = ["training/mixed_precision/experimental/loss_scale_test.py"], deps = [ ":loss_scale", "//tensorflow/python:client_testlib", diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl index 9be2b2daf97..b711a908e1b 100644 --- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl +++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl @@ -80,6 +80,8 @@ TENSORFLOW_API_INIT_FILES_V1 = [ "tpu/__init__.py", "train/__init__.py", "train/experimental/__init__.py", + "train/mixed_precision/__init__.py", + "train/mixed_precision/experimental/__init__.py", "train/queue_runner/__init__.py", "user_ops/__init__.py", "version/__init__.py", diff --git a/tensorflow/python/training/loss_scale.py b/tensorflow/python/training/mixed_precision/experimental/loss_scale.py similarity index 98% rename from tensorflow/python/training/loss_scale.py rename to tensorflow/python/training/mixed_precision/experimental/loss_scale.py index 94265e4974a..5ee6f9232b0 100644 --- a/tensorflow/python/training/loss_scale.py +++ b/tensorflow/python/training/mixed_precision/experimental/loss_scale.py @@ -35,7 +35,7 @@ from tensorflow.python.util.tf_export import tf_export @six.add_metaclass(abc.ABCMeta) -@tf_export(v1=['train.LossScale']) +@tf_export(v1=['train.mixed_precision.experimental.LossScale']) class LossScale(trackable.Trackable): """Base class to compute the loss scale. @@ -215,7 +215,7 @@ class LossScale(trackable.Trackable): graph_key = graph._graph_key # pylint: disable=protected-access return self._weights.get((graph_key, name), None) -@tf_export(v1=['train.FixedLossScale']) +@tf_export(v1=['train.mixed_precision.experimental.FixedLossScale']) class FixedLossScale(LossScale): """Loss scale class with a fixed value. @@ -283,7 +283,7 @@ def _assign_if_finite(var, value): control_flow_ops.no_op) -@tf_export(v1=['train.DynamicLossScale']) +@tf_export(v1=['train.mixed_precision.experimental.DynamicLossScale']) class DynamicLossScale(LossScale): """Loss scale class that dynamically adjusts the loss scale. diff --git a/tensorflow/python/training/loss_scale_optimizer.py b/tensorflow/python/training/mixed_precision/experimental/loss_scale_optimizer.py similarity index 96% rename from tensorflow/python/training/loss_scale_optimizer.py rename to tensorflow/python/training/mixed_precision/experimental/loss_scale_optimizer.py index bf62633bc7a..80c2087eb7f 100644 --- a/tensorflow/python/training/loss_scale_optimizer.py +++ b/tensorflow/python/training/mixed_precision/experimental/loss_scale_optimizer.py @@ -20,13 +20,13 @@ from __future__ import print_function from tensorflow.python.framework import ops from tensorflow.python.framework import smart_cond from tensorflow.python.ops import control_flow_ops -from tensorflow.python.training import loss_scale as loss_scale_module +from tensorflow.python.training.mixed_precision.experimental import loss_scale as loss_scale_module from tensorflow.python.training import optimizer from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.util.tf_export import tf_export -@tf_export(v1=['train.LossScaleOptimizer']) +@tf_export(v1=['train.mixed_precision.experimental.LossScaleOptimizer']) class LossScaleOptimizer(optimizer.Optimizer): """An optimizer that applies loss scaling. @@ -170,7 +170,7 @@ class LossScaleOptimizer(optimizer.Optimizer): # TODO(nluehr) cleanup GraphKeys.TRAIN_OP return replica_context.merge_call( - self._apply_gradients_cross_replica, + self._maybe_apply_gradients_cross_replica, args=(grads_and_vars, global_step, name)) def _distributed_apply(self, @@ -197,10 +197,10 @@ class LossScaleOptimizer(optimizer.Optimizer): replicas. If `global_step` was not None, that operation also increments `global_step` """ - self._apply_gradients_cross_replica(distribution, grads_and_vars, + self._maybe_apply_gradients_cross_replica(distribution, grads_and_vars, global_step, name) - def _apply_gradients_cross_replica(self, distribution, grads_and_vars, + def _maybe_apply_gradients_cross_replica(self, distribution, grads_and_vars, global_step, name): """Conditionally apply gradients in cross replica context.""" name = name if name is not None else self.get_name() diff --git a/tensorflow/python/training/loss_scale_optimizer_test.py b/tensorflow/python/training/mixed_precision/experimental/loss_scale_optimizer_test.py similarity index 98% rename from tensorflow/python/training/loss_scale_optimizer_test.py rename to tensorflow/python/training/mixed_precision/experimental/loss_scale_optimizer_test.py index 3ee38cb87c5..d208011c756 100644 --- a/tensorflow/python/training/loss_scale_optimizer_test.py +++ b/tensorflow/python/training/mixed_precision/experimental/loss_scale_optimizer_test.py @@ -26,8 +26,8 @@ from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import mirrored_strategy from tensorflow.python.eager import context from tensorflow.python.framework import test_util -from tensorflow.python.training import loss_scale as loss_scale_module -from tensorflow.python.training import loss_scale_optimizer +from tensorflow.python.training.mixed_precision.experimental import loss_scale as loss_scale_module +from tensorflow.python.training.mixed_precision.experimental import loss_scale_optimizer from tensorflow.python.keras.mixed_precision.experimental import test_util as mp_test_util from tensorflow.python.training import gradient_descent from tensorflow.python.training import momentum diff --git a/tensorflow/python/training/loss_scale_test.py b/tensorflow/python/training/mixed_precision/experimental/loss_scale_test.py similarity index 98% rename from tensorflow/python/training/loss_scale_test.py rename to tensorflow/python/training/mixed_precision/experimental/loss_scale_test.py index edf062cf762..9e3acb5817a 100644 --- a/tensorflow/python/training/loss_scale_test.py +++ b/tensorflow/python/training/mixed_precision/experimental/loss_scale_test.py @@ -27,7 +27,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.training import loss_scale as loss_scale_module +from tensorflow.python.training.mixed_precision.experimental import loss_scale as loss_scale_module from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 368dab300be..9f509ae0a38 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -33,8 +33,6 @@ from tensorflow.python.training.adagrad_da import AdagradDAOptimizer from tensorflow.python.training.proximal_adagrad import ProximalAdagradOptimizer from tensorflow.python.training.adam import AdamOptimizer from tensorflow.python.training.ftrl import FtrlOptimizer -from tensorflow.python.training.loss_scale import FixedLossScale, DynamicLossScale -from tensorflow.python.training.loss_scale_optimizer import LossScaleOptimizer from tensorflow.python.training.momentum import MomentumOptimizer from tensorflow.python.training.moving_averages import ExponentialMovingAverage from tensorflow.python.training.optimizer import Optimizer From add240c14d5bc08d2d79c8bf4dfad98cc0da3bd3 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Tue, 9 Apr 2019 22:09:09 -0700 Subject: [PATCH 05/10] Properly rename _apply_gradients functions --- .../experimental/loss_scale_optimizer.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tensorflow/python/training/mixed_precision/experimental/loss_scale_optimizer.py b/tensorflow/python/training/mixed_precision/experimental/loss_scale_optimizer.py index 80c2087eb7f..795b85d4caa 100644 --- a/tensorflow/python/training/mixed_precision/experimental/loss_scale_optimizer.py +++ b/tensorflow/python/training/mixed_precision/experimental/loss_scale_optimizer.py @@ -170,7 +170,7 @@ class LossScaleOptimizer(optimizer.Optimizer): # TODO(nluehr) cleanup GraphKeys.TRAIN_OP return replica_context.merge_call( - self._maybe_apply_gradients_cross_replica, + self._distributed_apply, args=(grads_and_vars, global_step, name)) def _distributed_apply(self, @@ -197,28 +197,19 @@ class LossScaleOptimizer(optimizer.Optimizer): replicas. If `global_step` was not None, that operation also increments `global_step` """ - self._maybe_apply_gradients_cross_replica(distribution, grads_and_vars, - global_step, name) - - def _maybe_apply_gradients_cross_replica(self, distribution, grads_and_vars, - global_step, name): - """Conditionally apply gradients in cross replica context.""" name = name if name is not None else self.get_name() grads = [g for g, _ in grads_and_vars] loss_scale_update_op, should_apply_grads = ( self._loss_scale.update(grads)) maybe_apply_op = smart_cond.smart_cond( should_apply_grads, - lambda: self._apply_gradients_cross_replica(distribution, - grads_and_vars, - global_step, - name+'-wrapped'), + lambda: self._apply_gradients(distribution, grads_and_vars, + global_step, name+'-wrapped'), control_flow_ops.no_op) return control_flow_ops.group(maybe_apply_op, loss_scale_update_op, name=name) - def _apply_gradients_cross_replica(self, distribution, grads_and_vars, - global_step, name): + def _apply_gradients(self, distribution, grads_and_vars, global_step, name): """Unconditionally apply gradients in cross replica context.""" update_ops = distribution.extended.call_for_each_replica( self._optimizer.apply_gradients, From b866bf9c93230a37dec71de06b7eacffcaf8fc92 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Wed, 10 Apr 2019 17:19:08 -0700 Subject: [PATCH 06/10] Add golden API files --- tensorflow/python/training/training.py | 3 ++ ...ion.experimental.-dynamic-loss-scale.pbtxt | 35 +++++++++++++ ...ision.experimental.-fixed-loss-scale.pbtxt | 23 +++++++++ ...n.experimental.-loss-scale-optimizer.pbtxt | 51 +++++++++++++++++++ ...d_precision.experimental.-loss-scale.pbtxt | 22 ++++++++ ...w.train.mixed_precision.experimental.pbtxt | 19 +++++++ .../v1/tensorflow.train.mixed_precision.pbtxt | 7 +++ .../api/golden/v1/tensorflow.train.pbtxt | 4 ++ 8 files changed, 164 insertions(+) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-dynamic-loss-scale.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-fixed-loss-scale.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale-optimizer.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.pbtxt create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.pbtxt diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index 9f509ae0a38..6a963b85ed1 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -33,6 +33,9 @@ from tensorflow.python.training.adagrad_da import AdagradDAOptimizer from tensorflow.python.training.proximal_adagrad import ProximalAdagradOptimizer from tensorflow.python.training.adam import AdamOptimizer from tensorflow.python.training.ftrl import FtrlOptimizer +from tensorflow.python.training.mixed_precision.experimental.loss_scale import DynamicLossScale +from tensorflow.python.training.mixed_precision.experimental.loss_scale import FixedLossScale +from tensorflow.python.training.mixed_precision.experimental.loss_scale_optimizer import LossScaleOptimizer from tensorflow.python.training.momentum import MomentumOptimizer from tensorflow.python.training.moving_averages import ExponentialMovingAverage from tensorflow.python.training.optimizer import Optimizer diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-dynamic-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-dynamic-loss-scale.pbtxt new file mode 100644 index 00000000000..ab55e3543a5 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-dynamic-loss-scale.pbtxt @@ -0,0 +1,35 @@ +path: "tensorflow.train.mixed_precision.experimental.DynamicLossScale" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "increment_period" + mtype: "" + } + member { + name: "initial_loss_scale" + mtype: "" + } + member { + name: "multiplier" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'initial_loss_scale\', \'increment_period\', \'multiplier\'], varargs=None, keywords=None, defaults=[\'32768\', \'2000\', \'2.0\'], " + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'zeros\'], " + } + member_method { + name: "build" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-fixed-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-fixed-loss-scale.pbtxt new file mode 100644 index 00000000000..d2670e4018b --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-fixed-loss-scale.pbtxt @@ -0,0 +1,23 @@ +path: "tensorflow.train.mixed_precision.experimental.FixedLossScale" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'loss_scale_value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'zeros\'], " + } + member_method { + name: "build" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale-optimizer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale-optimizer.pbtxt new file mode 100644 index 00000000000..1fffa498a72 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale-optimizer.pbtxt @@ -0,0 +1,51 @@ +path: "tensorflow.train.mixed_precision.experimental.LossScaleOptimizer" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "GATE_GRAPH" + mtype: "" + } + member { + name: "GATE_NONE" + mtype: "" + } + member { + name: "GATE_OP" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'opt\', \'loss_scale\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "apply_gradients" + argspec: "args=[\'self\', \'grads_and_vars\', \'global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "compute_gradients" + argspec: "args=[\'self\', \'loss\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'None\', \'False\', \'None\'], " + } + member_method { + name: "get_name" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot" + argspec: "args=[\'self\', \'var\', \'name\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "get_slot_names" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "minimize" + argspec: "args=[\'self\', \'loss\', \'global_step\', \'var_list\', \'gate_gradients\', \'aggregation_method\', \'colocate_gradients_with_ops\', \'name\', \'grad_loss\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'1\', \'None\', \'False\', \'None\', \'None\'], " + } + member_method { + name: "variables" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale.pbtxt new file mode 100644 index 00000000000..4fee7027064 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale.pbtxt @@ -0,0 +1,22 @@ +path: "tensorflow.train.mixed_precision.experimental.LossScale" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "add_weight" + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'zeros\'], " + } + member_method { + name: "build" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "update" + argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.pbtxt new file mode 100644 index 00000000000..aa9cdbfa9c1 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.train.mixed_precision.experimental" +tf_module { + member { + name: "DynamicLossScale" + mtype: "" + } + member { + name: "FixedLossScale" + mtype: "" + } + member { + name: "LossScale" + mtype: "" + } + member { + name: "LossScaleOptimizer" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.pbtxt new file mode 100644 index 00000000000..52f2fd2f2b4 --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.pbtxt @@ -0,0 +1,7 @@ +path: "tensorflow.train.mixed_precision" +tf_module { + member { + name: "experimental" + mtype: "" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt index 551fda2eacd..8c9f71e4746 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt @@ -244,6 +244,10 @@ tf_module { name: "experimental" mtype: "" } + member { + name: "mixed_precision" + mtype: "" + } member { name: "queue_runner" mtype: "" From 56291a96ab97aeb7667fa575e68c0441df3579fe Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Thu, 11 Apr 2019 13:58:25 -0700 Subject: [PATCH 07/10] Address requested changes to v1 LossScale class --- .../experimental/loss_scale.py | 138 ++++++------------ .../experimental/loss_scale_test.py | 1 + ...ion.experimental.-dynamic-loss-scale.pbtxt | 8 - ...ision.experimental.-fixed-loss-scale.pbtxt | 8 - ...d_precision.experimental.-loss-scale.pbtxt | 8 - 5 files changed, 47 insertions(+), 116 deletions(-) diff --git a/tensorflow/python/training/mixed_precision/experimental/loss_scale.py b/tensorflow/python/training/mixed_precision/experimental/loss_scale.py index 5ee6f9232b0..fee6cbc2b55 100644 --- a/tensorflow/python/training/mixed_precision/experimental/loss_scale.py +++ b/tensorflow/python/training/mixed_precision/experimental/loss_scale.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Contains LossScaler classes.""" +"""Contains LossScale classes.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -34,78 +34,50 @@ from tensorflow.python.eager import context from tensorflow.python.util.tf_export import tf_export +# TODO(reedwm): Merge this with tf.keras.mixed_precision.experimental.LossScale @six.add_metaclass(abc.ABCMeta) @tf_export(v1=['train.mixed_precision.experimental.LossScale']) class LossScale(trackable.Trackable): - """Base class to compute the loss scale. + """Loss scale base class. - Loss scaling is a process that: + Instances of this class represent a loss scale. Calling instances of this + class returns the loss scale as a scalar float32 tensor, while method + `update()` updates the loss scale depending on the values of the gradients. + Optimizers use instances of this class to scale loss and gradients. - 1) Applies a multiplier on the loss before computing gradients, and - 2) Applies the reciprocal of the multiplier on the gradients before they are - applied on variables. - - Mathematically, loss scaling has no effect, but can help avoid numerical - underflow when float16 tensors are used. By multiplying the loss, each - gradient will have the same multiplier applied. - - Instances of this class compute the loss scale. Method `get_loss_scale()` - returns the current loss scale, while method `update()` updates the loss scale - depending on the values of the gradients. Optimizers use instances of this - class to scale loss and gradients. + Note: this LossScale class can only be used with a v1 optimizer wrapper, + tf.train.mixed_precision.experimental.LossScaleOptimizer. For a v2 + wrapper, tf.keras.mixed_precision.experimental.LossScaleOptimizer, a + tf.keras.mixed_precision.experimental.LossScale should be used instead. """ def __init__(self): - """Initializes the loss scale class. - - Note subclasses should create variables in build() instead of in the - constructor. This is because callers might choose to place variables on - a certain device by calling build() under a tf.device() scope. - """ - self.built = False + """Initializes the loss scale class.""" self._weights = {} - def build(self): - """Builds the weights of the loss scale class. - - If weights are needed, subclasses should build weights by calling - `self.add_weight(...)`, then call the super's build to set self.built = - True. - """ - self.built = True - + @abc.abstractmethod def __call__(self): """Returns the current loss scale as a scalar `float32` tensor.""" - if not self.built: - self.build() - return self._get_loss_scale() - - @abc.abstractmethod - def _get_loss_scale(self): - """Returns the loss scale without calling build(). - - Subclasses must implement this. Subclasses should not override the public - `__call__` method, which calls this method. - """ pass + @abc.abstractmethod def update(self, grads): """Updates the value of the loss scale. - The loss scale tensor will be potentially updated, based on the value of - `grads`. The tensor returned by `get_loss_scale` is only - updated when this function is evaluated. + The loss scale will be potentially updated, based on the value of `grads`. + The tensor returned by calling this class is only updated when this function + is evaluated. In eager mode, this directly updates the loss scale, so that calling - `get_loss_scale` will return the newly updated loss scale. In graph mode, + `__call__` will return the newly updated loss scale. In graph mode, this returns an op that, when evaluated, updates the loss scale. This function also returns a `should_apply_gradients` bool. If False, gradients should not be applied to the variables that step, as nonfinite - gradients were found, and the loss scale controller can update the loss - scale to reduce the chance of finding nonfinite gradients in the next step. - Some loss scale controllers will always return True, as they cannot adjust - the loss scale in response to nonfinite gradients. + gradients were found, and the loss scale has been be updated to reduce the + chance of finding nonfinite gradients in the next step. Some loss scale + classes will always return True, as they cannot adjust themselves in + response to nonfinite gradients. When a DistributionStrategy is used, this function may only be called in a cross-replica context. @@ -122,31 +94,13 @@ class LossScale(trackable.Trackable): False, the caller should skip applying `grads` to the variables this step. """ - if not self.built: - self.build() - return self._update(grads) - - @abc.abstractmethod - def _update(self, grads): - """Updates the value of the loss scale without calling build(). - - Subclasses must implement this. Subclasses should not override the public - `update_loss_scale` method, which calls this method. - - Args: - grads: A list of unscaled gradients. See `update_loss_scale`. - - Returns: - In eager mode, None. In graph mode, an op to update the loss scale. - """ pass - - def add_weight(self, - name, - shape=(), - dtype=None, - initializer='zeros'): + def _add_weight(self, + name, + shape=(), + dtype=None, + initializer='zeros'): """Adds a weight to this loss scale manager.. This should be called by subclasses in `build()` to build the weights of the @@ -182,7 +136,7 @@ class LossScale(trackable.Trackable): graph = ops.get_default_graph() graph_key = graph._graph_key # pylint: disable=protected-access - key = (graph_key, name) + key = (name, graph_key) if self._weights.get(key, None) is not None: raise RuntimeError('Duplicate variables detected. {}'.format(key)) self._weights[key] = variable @@ -198,8 +152,8 @@ class LossScale(trackable.Trackable): graph = ops.get_default_graph() graph_key = graph._graph_key # pylint: disable=protected-access weights = [trackable.TrackableReference(name=name, ref=v) - for (g, name), v in sorted( - self._weights.items(), key=lambda i: i[0][1]) + for (name, g), v in sorted( + self._weights.items(), key=lambda i: i[0][0]) if g == graph_key] return super(LossScale, self)._checkpoint_dependencies + weights @@ -213,7 +167,7 @@ class LossScale(trackable.Trackable): else: graph = ops.get_default_graph() graph_key = graph._graph_key # pylint: disable=protected-access - return self._weights.get((graph_key, name), None) + return self._weights.get((name, graph_key), None) @tf_export(v1=['train.mixed_precision.experimental.FixedLossScale']) class FixedLossScale(LossScale): @@ -243,10 +197,10 @@ class FixedLossScale(LossScale): self._tensor_loss_scale = ops.convert_to_tensor(loss_scale_value, dtype=dtypes.float32) - def _get_loss_scale(self): + def __call__(self): return self._tensor_loss_scale - def _update(self, grads): + def update(self, grads): del grads return control_flow_ops.no_op(), True @@ -312,10 +266,10 @@ class DynamicLossScale(LossScale): loss scale that is too high gets lowered far more quickly than a loss scale that is to low gets raised. The default is 2 ** 15, which is approximately half the maximum float16 value. - incr_every_n_steps: Increases loss scale every `incr_every_n_steps` + increment_period: Increases loss scale every `increment_period` consecutive steps that finite gradients are encountered. If a nonfinite gradient is encountered, the count is reset back to zero. - loss_scale_multiplier: The multiplier to use when increasing or decreasing + multiplier: The multiplier to use when increasing or decreasing the loss scale. """ super(DynamicLossScale, self).__init__() @@ -323,6 +277,15 @@ class DynamicLossScale(LossScale): self._increment_period = int(increment_period) self._multiplier = float(multiplier) + self._current_loss_scale = self._add_weight( + name='loss_scale', + dtype=dtypes.float32, + initializer=self._initial_loss_scale) + # The number of consecutive steps with finite gradients since the last + # nonfinite gradient or change in loss scale. + self._num_good_steps = self._add_weight( + name='good_steps', dtype=dtypes.int64, initializer='zeros') + @property def initial_loss_scale(self): @@ -336,19 +299,10 @@ class DynamicLossScale(LossScale): def multiplier(self): return self._multiplier - def build(self): - self._current_loss_scale = self.add_weight( - name='loss_scale', - dtype=dtypes.float32, - initializer=self._initial_loss_scale) - self._num_good_steps = self.add_weight( - name='good_steps', dtype=dtypes.int64, initializer='zeros') - self.built = True - - def _get_loss_scale(self): + def __call__(self): return self._current_loss_scale - def _update(self, grads): + def update(self, grads): """Updates loss scale based on if gradients are finite in current step.""" if distribution_strategy_context.has_strategy(): distribution = distribution_strategy_context.get_cross_replica_context() diff --git a/tensorflow/python/training/mixed_precision/experimental/loss_scale_test.py b/tensorflow/python/training/mixed_precision/experimental/loss_scale_test.py index 9e3acb5817a..68479ee4b2c 100644 --- a/tensorflow/python/training/mixed_precision/experimental/loss_scale_test.py +++ b/tensorflow/python/training/mixed_precision/experimental/loss_scale_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test +# TODO: Create test case using multiple graphs # If called outside any strategy.scope() calls, this will return the default # strategy. diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-dynamic-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-dynamic-loss-scale.pbtxt index ab55e3543a5..763bcca764b 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-dynamic-loss-scale.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-dynamic-loss-scale.pbtxt @@ -20,14 +20,6 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'initial_loss_scale\', \'increment_period\', \'multiplier\'], varargs=None, keywords=None, defaults=[\'32768\', \'2000\', \'2.0\'], " } - member_method { - name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'zeros\'], " - } - member_method { - name: "build" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "update" argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-fixed-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-fixed-loss-scale.pbtxt index d2670e4018b..565406fcd93 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-fixed-loss-scale.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-fixed-loss-scale.pbtxt @@ -8,14 +8,6 @@ tf_class { name: "__init__" argspec: "args=[\'self\', \'loss_scale_value\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'zeros\'], " - } - member_method { - name: "build" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "update" argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale.pbtxt index 4fee7027064..5140463ec16 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.train.mixed_precision.experimental.-loss-scale.pbtxt @@ -7,14 +7,6 @@ tf_class { name: "__init__" argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "add_weight" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\'], varargs=None, keywords=None, defaults=[\'()\', \'None\', \'zeros\'], " - } - member_method { - name: "build" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" - } member_method { name: "update" argspec: "args=[\'self\', \'grads\'], varargs=None, keywords=None, defaults=None" From 26fa178b02827f27084938adb3af87fc1b74e4c2 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Thu, 11 Apr 2019 14:02:07 -0700 Subject: [PATCH 08/10] Remove old comment --- .../python/training/mixed_precision/experimental/loss_scale.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorflow/python/training/mixed_precision/experimental/loss_scale.py b/tensorflow/python/training/mixed_precision/experimental/loss_scale.py index fee6cbc2b55..5d1adafee80 100644 --- a/tensorflow/python/training/mixed_precision/experimental/loss_scale.py +++ b/tensorflow/python/training/mixed_precision/experimental/loss_scale.py @@ -103,9 +103,6 @@ class LossScale(trackable.Trackable): initializer='zeros'): """Adds a weight to this loss scale manager.. - This should be called by subclasses in `build()` to build the weights of the - loss scale class. - Args: name: Variable name. shape: Variable shape. From 490dfd87e70a61d3a3f8d31fd30f7275c54c4822 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Thu, 11 Apr 2019 17:44:11 -0700 Subject: [PATCH 09/10] Add pip package dependencies --- tensorflow/tools/pip_package/BUILD | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 19c09ce0c1e..2ab804bc876 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -68,6 +68,8 @@ COMMON_PIP_DEPS = [ "//tensorflow/python/compiler:compiler", "//tensorflow/python:cond_v2", "//tensorflow/python:distributed_framework_test_lib", + "//tensorflow/python:loss_scale", + "//tensorflow/python:loss_scale_optimizer", "//tensorflow/python:meta_graph_testdata", "//tensorflow/python:spectral_ops_test_util", "//tensorflow/python:util_example_parser_configuration", From 0f3d3df0a7c0787d2f8586a5f7cb446aa5c47775 Mon Sep 17 00:00:00 2001 From: Matt Conley Date: Thu, 11 Apr 2019 18:02:44 -0700 Subject: [PATCH 10/10] Correct BUILD dependency order --- tensorflow/python/BUILD | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 56fcf82fe80..89a266c1b72 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2785,8 +2785,8 @@ py_library( srcs = ["training/mixed_precision/experimental/loss_scale.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:framework", ":keras_lib", + "//tensorflow/python:framework", "@absl_py//absl/testing:parameterized", ], ) @@ -2807,10 +2807,10 @@ py_test( srcs = ["training/mixed_precision/experimental/loss_scale_optimizer_test.py"], deps = [ ":loss_scale_optimizer", - "//tensorflow/python/keras/mixed_precision/experimental:test_util", "//tensorflow/python:client_testlib", "//tensorflow/python/distribute:mirrored_strategy", "//tensorflow/python/distribute:one_device_strategy", + "//tensorflow/python/keras/mixed_precision/experimental:test_util", "@absl_py//absl/testing:parameterized", ], )