Upstream loss scaling gradient tape

PiperOrigin-RevId: 264692056
This commit is contained in:
Loren Maggiore 2019-08-21 14:24:58 -07:00 committed by TensorFlower Gardener
parent ce6364764f
commit fb27b20844
10 changed files with 430 additions and 0 deletions

View File

@ -3200,6 +3200,36 @@ cuda_py_test(
xla_enable_strict_auto_jit = True,
)
py_library(
name = "loss_scaling_gradient_tape",
srcs = ["training/experimental/loss_scaling_gradient_tape.py"],
srcs_version = "PY2AND3",
deps = [
":array_ops",
":loss_scale",
":unconnected_gradients",
":util",
"//tensorflow/python/eager:backprop",
],
)
py_test(
name = "loss_scaling_gradient_tape_test",
size = "medium",
srcs = ["training/experimental/loss_scaling_gradient_tape_test.py"],
python_version = "PY3",
deps = [
":client_testlib",
":constant_op",
":loss_scale",
":loss_scaling_gradient_tape",
"//tensorflow/python/compat:v2_compat",
"//tensorflow/python/eager:def_function",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
py_library(
name = "math_grad",
srcs = ["ops/math_grad.py"],
@ -3762,6 +3792,7 @@ py_library(
":linalg_ops",
":logging_ops",
":lookup_ops",
":loss_scaling_gradient_tape",
":manip_grad",
":manip_ops",
":math_grad",

View File

@ -28,6 +28,9 @@ DoNotConvert = config_lib.DoNotConvert
# This list is evaluated in order and stops at the first rule that tests True
# for a definitely_convert of definitely_bypass call.
CONVERSION_RULES = (
# Known packages
Convert('tensorflow.python.training.experimental'),
# Builtin modules
DoNotConvert('collections'),
DoNotConvert('copy'),

View File

@ -23,6 +23,7 @@ from __future__ import print_function
import sys as _sys
from tensorflow.python import autograph
from tensorflow.python.training.experimental import loss_scaling_gradient_tape
# pylint: disable=g-bad-import-order
# Imports the following modules so that @RegisterGradient get executed.

View File

@ -37,6 +37,8 @@ TENSORFLOW_API_INIT_FILES = [
"lookup/__init__.py",
"lookup/experimental/__init__.py",
"math/__init__.py",
"mixed_precision/__init__.py",
"mixed_precision/experimental/__init__.py",
"nest/__init__.py",
"nn/__init__.py",
"quantization/__init__.py",

View File

@ -0,0 +1,154 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains Loss Scaling Gradient Tape."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@tf_export("mixed_precision.experimental.LossScalingGradientTape", v1=[])
class LossScalingGradientTape(backprop.GradientTape):
"""A gradient tape that scales losses and unscales resulting gradients.
Operates as a normal gradient tape, but takes in a
`tf.train.experimental.LossScale` object. Losses are scaled up by some amount
before the gradients are calculated and the resulting gradients are scaled
down by the same amount.
This has no net mathematical effect, but can be used to prevent vanishing
gradients, for example in the case of mixed precision training.
If a DynamicLossScale object is used and non-finite gradients are encountered,
the loss scale will be updated and the gradients recomputed until either
finite gradients are encountered or the loss scale becomes 1.
This class should *not* be used with a LossScaleOptimizer, as both classes
update the LossScale object. Use a non-loss scaling optimizer instead.
Usage:
```
opt = tf.keras.optimizers.SGD(1.0)
model_loss_scale = tf.train.experimental.DynamicLossScale()
for step in training_steps:
with LossScalingGradientTape(model_loss_scale) as tape:
logits = ... # Run model and get logits
loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
labels=labels)
loss = tf.reduce_mean(loss)
vars = tape.watched_variables()
grads = tape.gradient(loss, vars)
opt.apply_gradients(zip(grads, vars))
```
"""
def __init__(self,
loss_scale,
persistent=False,
watch_accessed_variables=True):
"""Creates a new LossScalingGradientTape.
Args:
loss_scale: `tf.train.experimental.LossScale` object that
manages what quantity to scale by. This is typically either a
FixedLossScale object with a constant scalar or a
`tf.train.experimental.DynamicLossScale` object that will
adjust the scalar appropriately if any non-finite gradients are
encountered.
persistent: Boolean controlling whether a persistent gradient tape is
created. False by default, which means at most one call can be made to
the gradient() method on this object.
watch_accessed_variables: Boolean controlling whether the tape will
automatically `watch` any (trainable) variables accessed while the tape
is active. Defaults to True meaning gradients can be requested from any
result computed in the tape derived from reading a trainable `Variable`.
If False users must explicitly `watch` any `Variable`s they want to
request gradients from.
"""
if not isinstance(loss_scale, loss_scale_module.LossScale):
raise ValueError("`loss_scale` must be an instance of LossScale.")
# always make a persistent tape to loop over loss scaling
super(LossScalingGradientTape, self).__init__(True,
watch_accessed_variables)
self._outer_persistent = persistent
self._loss_scale = loss_scale
def gradient(self,
target,
sources,
output_gradients=None,
unconnected_gradients=UnconnectedGradients.NONE):
"""Computes the gradient using operations recorded in context of this tape.
Uses the `LossScale` object provided in the constructor to scale `target`
and then to unscale the resulting gradients.
Args:
target: a list or nested structure of Tensors or Variables to be
differentiated.
sources: a list or nested structure of Tensors or Variables. `target` will
be differentiated against elements in `sources`.
output_gradients: a list of gradients, one for each element of target.
Defaults to None.
unconnected_gradients: a value which can either hold 'none' or 'zero' and
alters the value which will be returned if the target and sources are
unconnected. The possible values and effects are detailed in
'UnconnectedGradients' and it defaults to 'none'.
Returns:
a list or nested structure of Tensors (or IndexedSlices, or None),
one for each element in `sources`. Returned structure is the same as
the structure of `sources`. If non-finite gradients are encountered
after dynamic scaling, the loss scale will be updated and the gradients
recomputed until either finite gradients are encountered or the loss scale
becomes 1.
Raises:
RuntimeError: if called inside the context of the tape, or if called more
than once on a non-persistent tape.
ValueError: if the target is a variable or if unconnected gradients is
called with an unknown value.
"""
if self._tape is None: # pylint: disable=access-member-before-definition
raise RuntimeError("GradientTape.gradient can only be called once on "
"non-persistent tapes.")
ready_to_update = False
grads = nest.map_structure(array_ops.zeros_like, sources)
while not ready_to_update and self._loss_scale() > 1:
with self: # re-enter the gradient tape so it sees the loss scaling
loss_scale = self._loss_scale()
scaled_target = nest.map_structure(lambda t: t * loss_scale, target)
old_grads = super(LossScalingGradientTape, self).gradient(
scaled_target, sources, output_gradients, unconnected_gradients)
inv_loss_scale = 1.0 / self._loss_scale()
grads = nest.map_structure(lambda g: inv_loss_scale * g, old_grads)
# Check for non-finite gradients possibly resulting from scaling
_, ready_to_update = self._loss_scale.update(grads)
if not self._outer_persistent:
self._tape = None # free up resources if a persistent tape was not needed
return grads

View File

@ -0,0 +1,183 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for lsgt.LossScalingGradientTape."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.compat import v2_compat
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.platform import test
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
from tensorflow.python.training.experimental import loss_scaling_gradient_tape as lsgt
class LossScalingGradientTapeTest(test.TestCase, parameterized.TestCase):
@parameterized.parameters(loss_scale_module.FixedLossScale,
loss_scale_module.DynamicLossScale)
def test_basic_tapes_eager_mode(self, loss_scale):
x = constant_op.constant(3.0)
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
g.watch(x)
y = x * x
dy_dx = g.gradient(y, x)
self.assertEqual(self.evaluate(dy_dx), 6.0)
@parameterized.parameters(loss_scale_module.FixedLossScale,
loss_scale_module.DynamicLossScale)
def test_basic_tapes_graph_mode(self, loss_scale):
loss_scale = loss_scale(32)
@def_function.function
def _inner_test():
x = constant_op.constant(3.0)
with lsgt.LossScalingGradientTape(loss_scale) as g:
g.watch(x)
y = x * x
return g.gradient(y, x)
self.assertEqual(self.evaluate(_inner_test()), 6.0)
@parameterized.parameters(loss_scale_module.FixedLossScale,
loss_scale_module.DynamicLossScale)
def test_nested_tapes(self, loss_scale):
x = constant_op.constant(3.0)
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
g.watch(x)
with lsgt.LossScalingGradientTape(loss_scale(32)) as gg:
gg.watch(x)
y = x * x
dy_dx = gg.gradient(y, x)
self.assertEqual(self.evaluate(dy_dx), 6.0)
d2y_dx2 = g.gradient(dy_dx, x)
self.assertEqual(self.evaluate(d2y_dx2), 2.0)
@parameterized.parameters(loss_scale_module.FixedLossScale,
loss_scale_module.DynamicLossScale)
def test_non_persistent_tapes_error(self, loss_scale):
x = constant_op.constant(3.0)
with lsgt.LossScalingGradientTape(loss_scale(32), persistent=False) as g:
g.watch(x)
y = x * x
z = y * y
g.gradient(z, x)
with self.assertRaisesRegexp(RuntimeError, 'persistent'):
g.gradient(y, x)
@parameterized.parameters(loss_scale_module.FixedLossScale,
loss_scale_module.DynamicLossScale)
def test_persistent_tapes(self, loss_scale):
x = constant_op.constant(3.0)
with lsgt.LossScalingGradientTape(loss_scale(32), persistent=True) as g:
g.watch(x)
y = x * x
z = y * y
dz_dx = g.gradient(z, x)
self.assertEqual(self.evaluate(dz_dx), 108.0)
dy_dx = g.gradient(y, x)
self.assertEqual(self.evaluate(dy_dx), 6.0)
@parameterized.parameters(loss_scale_module.FixedLossScale,
loss_scale_module.DynamicLossScale)
def test_nested_sources(self, loss_scale):
x = (constant_op.constant(19.0), (constant_op.constant(8.),
constant_op.constant(9.)))
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
g.watch(x)
y = x * 13
dy_dx = g.gradient(y, x)
self.assertEqual(self.evaluate(dy_dx), (13., (13., 13.)))
@parameterized.parameters(loss_scale_module.FixedLossScale,
loss_scale_module.DynamicLossScale)
def test_nested_targets(self, loss_scale):
w = constant_op.constant(3.0)
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
g.watch(w)
x = w * 5
y = w * 7
z = w * 11
grad = g.gradient([x, (y, z)], w)
self.assertEqual(self.evaluate(grad), 23)
@parameterized.parameters(loss_scale_module.FixedLossScale,
loss_scale_module.DynamicLossScale)
def test_scaling_inf_gradient(self, loss_scale):
x = constant_op.constant(1.0)
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
g.watch(x)
y = x * np.inf
dy_dx = g.gradient(y, x)
self.assertEqual(self.evaluate(dy_dx), np.inf)
@parameterized.parameters(loss_scale_module.FixedLossScale,
loss_scale_module.DynamicLossScale)
def test_scaling_nan_gradient(self, loss_scale):
x = constant_op.constant(1.0)
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
g.watch(x)
y = x * np.nan
dy_dx = g.gradient(y, x)
self.assertTrue(np.isnan(self.evaluate(dy_dx)))
@parameterized.parameters(np.inf, np.nan)
def test_dynamic_scale_to_one_on_non_finite_gradient(self, non_finite_term):
loss_scale = loss_scale_module.DynamicLossScale(initial_loss_scale=32)
x = constant_op.constant(1.0)
with lsgt.LossScalingGradientTape(loss_scale) as g:
g.watch(x)
y = x * non_finite_term
g.gradient(y, x)
self.assertEqual(self.evaluate(loss_scale()), 1.0)
@parameterized.parameters([np.inf, np.isposinf], [np.nan, np.isnan])
def test_fixed_scaling_no_change_non_finite_gradient(self, non_finite_term,
is_non_finite):
loss_scale = loss_scale_module.FixedLossScale(32)
x = constant_op.constant(1.0)
with lsgt.LossScalingGradientTape(loss_scale) as g:
g.watch(x)
y = x * non_finite_term
dy_dx = g.gradient(y, x)
self.assertTrue(is_non_finite(self.evaluate(dy_dx)))
self.assertEqual(self.evaluate(loss_scale()), 32.0)
def test_dynamic_loss_scaling_down_loop(self):
loss_scale = loss_scale_module.DynamicLossScale(initial_loss_scale=32)
x = constant_op.constant(1.0)
with lsgt.LossScalingGradientTape(loss_scale) as g:
g.watch(x)
y = x * (3.0 * (10**37)) # grad will be inf after scaling
dy_dx = g.gradient(y, x)
self.assertEqual(self.evaluate(loss_scale()), 8.0)
self.assertAllClose(self.evaluate(dy_dx), (3.0 * (10**37)), atol=1e-06)
def test_dynamic_loss_scaling_inf_target_post_scale(self):
loss_scale = loss_scale_module.DynamicLossScale(initial_loss_scale=32.0)
x = constant_op.constant(3.0 * (10**37))
with lsgt.LossScalingGradientTape(loss_scale) as g:
g.watch(x)
y = x * 3.0 # target will be inf after scaling
dy_dx = g.gradient(y, x)
self.assertAllClose(self.evaluate(dy_dx), 3.0)
self.assertEqual(self.evaluate(loss_scale()), 32.0)
if __name__ == '__main__':
v2_compat.enable_v2_behavior()
test.main()

View File

@ -0,0 +1,38 @@
path: "tensorflow.mixed_precision.experimental.LossScalingGradientTape"
tf_class {
is_instance: "<class \'tensorflow.python.training.experimental.loss_scaling_gradient_tape.LossScalingGradientTape\'>"
is_instance: "<class \'tensorflow.python.eager.backprop.GradientTape\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'loss_scale\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
}
member_method {
name: "batch_jacobian"
argspec: "args=[\'self\', \'target\', \'source\', \'unconnected_gradients\', \'parallel_iterations\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'None\', \'True\'], "
}
member_method {
name: "gradient"
argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'UnconnectedGradients.NONE\'], "
}
member_method {
name: "jacobian"
argspec: "args=[\'self\', \'target\', \'sources\', \'unconnected_gradients\', \'parallel_iterations\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'None\', \'True\'], "
}
member_method {
name: "reset"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "stop_recording"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "watch"
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "watched_variables"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.mixed_precision.experimental"
tf_module {
member {
name: "LossScalingGradientTape"
mtype: "<type \'type\'>"
}
}

View File

@ -0,0 +1,7 @@
path: "tensorflow.mixed_precision"
tf_module {
member {
name: "experimental"
mtype: "<type \'module\'>"
}
}

View File

@ -256,6 +256,10 @@ tf_module {
name: "metrics"
mtype: "<type \'module\'>"
}
member {
name: "mixed_precision"
mtype: "<type \'module\'>"
}
member {
name: "name_scope"
mtype: "<type \'type\'>"