Upstream loss scaling gradient tape
PiperOrigin-RevId: 264692056
This commit is contained in:
parent
ce6364764f
commit
fb27b20844
@ -3200,6 +3200,36 @@ cuda_py_test(
|
||||
xla_enable_strict_auto_jit = True,
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "loss_scaling_gradient_tape",
|
||||
srcs = ["training/experimental/loss_scaling_gradient_tape.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":array_ops",
|
||||
":loss_scale",
|
||||
":unconnected_gradients",
|
||||
":util",
|
||||
"//tensorflow/python/eager:backprop",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "loss_scaling_gradient_tape_test",
|
||||
size = "medium",
|
||||
srcs = ["training/experimental/loss_scaling_gradient_tape_test.py"],
|
||||
python_version = "PY3",
|
||||
deps = [
|
||||
":client_testlib",
|
||||
":constant_op",
|
||||
":loss_scale",
|
||||
":loss_scaling_gradient_tape",
|
||||
"//tensorflow/python/compat:v2_compat",
|
||||
"//tensorflow/python/eager:def_function",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "math_grad",
|
||||
srcs = ["ops/math_grad.py"],
|
||||
@ -3762,6 +3792,7 @@ py_library(
|
||||
":linalg_ops",
|
||||
":logging_ops",
|
||||
":lookup_ops",
|
||||
":loss_scaling_gradient_tape",
|
||||
":manip_grad",
|
||||
":manip_ops",
|
||||
":math_grad",
|
||||
|
@ -28,6 +28,9 @@ DoNotConvert = config_lib.DoNotConvert
|
||||
# This list is evaluated in order and stops at the first rule that tests True
|
||||
# for a definitely_convert of definitely_bypass call.
|
||||
CONVERSION_RULES = (
|
||||
# Known packages
|
||||
Convert('tensorflow.python.training.experimental'),
|
||||
|
||||
# Builtin modules
|
||||
DoNotConvert('collections'),
|
||||
DoNotConvert('copy'),
|
||||
|
@ -23,6 +23,7 @@ from __future__ import print_function
|
||||
import sys as _sys
|
||||
|
||||
from tensorflow.python import autograph
|
||||
from tensorflow.python.training.experimental import loss_scaling_gradient_tape
|
||||
|
||||
# pylint: disable=g-bad-import-order
|
||||
# Imports the following modules so that @RegisterGradient get executed.
|
||||
|
@ -37,6 +37,8 @@ TENSORFLOW_API_INIT_FILES = [
|
||||
"lookup/__init__.py",
|
||||
"lookup/experimental/__init__.py",
|
||||
"math/__init__.py",
|
||||
"mixed_precision/__init__.py",
|
||||
"mixed_precision/experimental/__init__.py",
|
||||
"nest/__init__.py",
|
||||
"nn/__init__.py",
|
||||
"quantization/__init__.py",
|
||||
|
@ -0,0 +1,154 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Contains Loss Scaling Gradient Tape."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("mixed_precision.experimental.LossScalingGradientTape", v1=[])
|
||||
class LossScalingGradientTape(backprop.GradientTape):
|
||||
"""A gradient tape that scales losses and unscales resulting gradients.
|
||||
|
||||
Operates as a normal gradient tape, but takes in a
|
||||
`tf.train.experimental.LossScale` object. Losses are scaled up by some amount
|
||||
before the gradients are calculated and the resulting gradients are scaled
|
||||
down by the same amount.
|
||||
|
||||
This has no net mathematical effect, but can be used to prevent vanishing
|
||||
gradients, for example in the case of mixed precision training.
|
||||
|
||||
If a DynamicLossScale object is used and non-finite gradients are encountered,
|
||||
the loss scale will be updated and the gradients recomputed until either
|
||||
finite gradients are encountered or the loss scale becomes 1.
|
||||
|
||||
This class should *not* be used with a LossScaleOptimizer, as both classes
|
||||
update the LossScale object. Use a non-loss scaling optimizer instead.
|
||||
|
||||
Usage:
|
||||
```
|
||||
opt = tf.keras.optimizers.SGD(1.0)
|
||||
model_loss_scale = tf.train.experimental.DynamicLossScale()
|
||||
|
||||
for step in training_steps:
|
||||
with LossScalingGradientTape(model_loss_scale) as tape:
|
||||
logits = ... # Run model and get logits
|
||||
loss = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
|
||||
labels=labels)
|
||||
loss = tf.reduce_mean(loss)
|
||||
vars = tape.watched_variables()
|
||||
grads = tape.gradient(loss, vars)
|
||||
opt.apply_gradients(zip(grads, vars))
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss_scale,
|
||||
persistent=False,
|
||||
watch_accessed_variables=True):
|
||||
"""Creates a new LossScalingGradientTape.
|
||||
|
||||
Args:
|
||||
loss_scale: `tf.train.experimental.LossScale` object that
|
||||
manages what quantity to scale by. This is typically either a
|
||||
FixedLossScale object with a constant scalar or a
|
||||
`tf.train.experimental.DynamicLossScale` object that will
|
||||
adjust the scalar appropriately if any non-finite gradients are
|
||||
encountered.
|
||||
persistent: Boolean controlling whether a persistent gradient tape is
|
||||
created. False by default, which means at most one call can be made to
|
||||
the gradient() method on this object.
|
||||
watch_accessed_variables: Boolean controlling whether the tape will
|
||||
automatically `watch` any (trainable) variables accessed while the tape
|
||||
is active. Defaults to True meaning gradients can be requested from any
|
||||
result computed in the tape derived from reading a trainable `Variable`.
|
||||
If False users must explicitly `watch` any `Variable`s they want to
|
||||
request gradients from.
|
||||
"""
|
||||
if not isinstance(loss_scale, loss_scale_module.LossScale):
|
||||
raise ValueError("`loss_scale` must be an instance of LossScale.")
|
||||
|
||||
# always make a persistent tape to loop over loss scaling
|
||||
super(LossScalingGradientTape, self).__init__(True,
|
||||
watch_accessed_variables)
|
||||
self._outer_persistent = persistent
|
||||
self._loss_scale = loss_scale
|
||||
|
||||
def gradient(self,
|
||||
target,
|
||||
sources,
|
||||
output_gradients=None,
|
||||
unconnected_gradients=UnconnectedGradients.NONE):
|
||||
"""Computes the gradient using operations recorded in context of this tape.
|
||||
|
||||
Uses the `LossScale` object provided in the constructor to scale `target`
|
||||
and then to unscale the resulting gradients.
|
||||
|
||||
Args:
|
||||
target: a list or nested structure of Tensors or Variables to be
|
||||
differentiated.
|
||||
sources: a list or nested structure of Tensors or Variables. `target` will
|
||||
be differentiated against elements in `sources`.
|
||||
output_gradients: a list of gradients, one for each element of target.
|
||||
Defaults to None.
|
||||
unconnected_gradients: a value which can either hold 'none' or 'zero' and
|
||||
alters the value which will be returned if the target and sources are
|
||||
unconnected. The possible values and effects are detailed in
|
||||
'UnconnectedGradients' and it defaults to 'none'.
|
||||
|
||||
Returns:
|
||||
a list or nested structure of Tensors (or IndexedSlices, or None),
|
||||
one for each element in `sources`. Returned structure is the same as
|
||||
the structure of `sources`. If non-finite gradients are encountered
|
||||
after dynamic scaling, the loss scale will be updated and the gradients
|
||||
recomputed until either finite gradients are encountered or the loss scale
|
||||
becomes 1.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if called inside the context of the tape, or if called more
|
||||
than once on a non-persistent tape.
|
||||
ValueError: if the target is a variable or if unconnected gradients is
|
||||
called with an unknown value.
|
||||
"""
|
||||
if self._tape is None: # pylint: disable=access-member-before-definition
|
||||
raise RuntimeError("GradientTape.gradient can only be called once on "
|
||||
"non-persistent tapes.")
|
||||
|
||||
ready_to_update = False
|
||||
grads = nest.map_structure(array_ops.zeros_like, sources)
|
||||
|
||||
while not ready_to_update and self._loss_scale() > 1:
|
||||
with self: # re-enter the gradient tape so it sees the loss scaling
|
||||
loss_scale = self._loss_scale()
|
||||
scaled_target = nest.map_structure(lambda t: t * loss_scale, target)
|
||||
|
||||
old_grads = super(LossScalingGradientTape, self).gradient(
|
||||
scaled_target, sources, output_gradients, unconnected_gradients)
|
||||
inv_loss_scale = 1.0 / self._loss_scale()
|
||||
grads = nest.map_structure(lambda g: inv_loss_scale * g, old_grads)
|
||||
# Check for non-finite gradients possibly resulting from scaling
|
||||
_, ready_to_update = self._loss_scale.update(grads)
|
||||
|
||||
if not self._outer_persistent:
|
||||
self._tape = None # free up resources if a persistent tape was not needed
|
||||
return grads
|
@ -0,0 +1,183 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for lsgt.LossScalingGradientTape."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
from tensorflow.python.compat import v2_compat
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training.experimental import loss_scale as loss_scale_module
|
||||
from tensorflow.python.training.experimental import loss_scaling_gradient_tape as lsgt
|
||||
|
||||
|
||||
class LossScalingGradientTapeTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.parameters(loss_scale_module.FixedLossScale,
|
||||
loss_scale_module.DynamicLossScale)
|
||||
def test_basic_tapes_eager_mode(self, loss_scale):
|
||||
x = constant_op.constant(3.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
|
||||
g.watch(x)
|
||||
y = x * x
|
||||
dy_dx = g.gradient(y, x)
|
||||
self.assertEqual(self.evaluate(dy_dx), 6.0)
|
||||
|
||||
@parameterized.parameters(loss_scale_module.FixedLossScale,
|
||||
loss_scale_module.DynamicLossScale)
|
||||
def test_basic_tapes_graph_mode(self, loss_scale):
|
||||
loss_scale = loss_scale(32)
|
||||
|
||||
@def_function.function
|
||||
def _inner_test():
|
||||
x = constant_op.constant(3.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale) as g:
|
||||
g.watch(x)
|
||||
y = x * x
|
||||
return g.gradient(y, x)
|
||||
self.assertEqual(self.evaluate(_inner_test()), 6.0)
|
||||
|
||||
@parameterized.parameters(loss_scale_module.FixedLossScale,
|
||||
loss_scale_module.DynamicLossScale)
|
||||
def test_nested_tapes(self, loss_scale):
|
||||
x = constant_op.constant(3.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
|
||||
g.watch(x)
|
||||
with lsgt.LossScalingGradientTape(loss_scale(32)) as gg:
|
||||
gg.watch(x)
|
||||
y = x * x
|
||||
dy_dx = gg.gradient(y, x)
|
||||
self.assertEqual(self.evaluate(dy_dx), 6.0)
|
||||
d2y_dx2 = g.gradient(dy_dx, x)
|
||||
self.assertEqual(self.evaluate(d2y_dx2), 2.0)
|
||||
|
||||
@parameterized.parameters(loss_scale_module.FixedLossScale,
|
||||
loss_scale_module.DynamicLossScale)
|
||||
def test_non_persistent_tapes_error(self, loss_scale):
|
||||
x = constant_op.constant(3.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale(32), persistent=False) as g:
|
||||
g.watch(x)
|
||||
y = x * x
|
||||
z = y * y
|
||||
g.gradient(z, x)
|
||||
with self.assertRaisesRegexp(RuntimeError, 'persistent'):
|
||||
g.gradient(y, x)
|
||||
|
||||
@parameterized.parameters(loss_scale_module.FixedLossScale,
|
||||
loss_scale_module.DynamicLossScale)
|
||||
def test_persistent_tapes(self, loss_scale):
|
||||
x = constant_op.constant(3.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale(32), persistent=True) as g:
|
||||
g.watch(x)
|
||||
y = x * x
|
||||
z = y * y
|
||||
dz_dx = g.gradient(z, x)
|
||||
self.assertEqual(self.evaluate(dz_dx), 108.0)
|
||||
dy_dx = g.gradient(y, x)
|
||||
self.assertEqual(self.evaluate(dy_dx), 6.0)
|
||||
|
||||
@parameterized.parameters(loss_scale_module.FixedLossScale,
|
||||
loss_scale_module.DynamicLossScale)
|
||||
def test_nested_sources(self, loss_scale):
|
||||
x = (constant_op.constant(19.0), (constant_op.constant(8.),
|
||||
constant_op.constant(9.)))
|
||||
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
|
||||
g.watch(x)
|
||||
y = x * 13
|
||||
dy_dx = g.gradient(y, x)
|
||||
self.assertEqual(self.evaluate(dy_dx), (13., (13., 13.)))
|
||||
|
||||
@parameterized.parameters(loss_scale_module.FixedLossScale,
|
||||
loss_scale_module.DynamicLossScale)
|
||||
def test_nested_targets(self, loss_scale):
|
||||
w = constant_op.constant(3.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
|
||||
g.watch(w)
|
||||
x = w * 5
|
||||
y = w * 7
|
||||
z = w * 11
|
||||
grad = g.gradient([x, (y, z)], w)
|
||||
self.assertEqual(self.evaluate(grad), 23)
|
||||
|
||||
@parameterized.parameters(loss_scale_module.FixedLossScale,
|
||||
loss_scale_module.DynamicLossScale)
|
||||
def test_scaling_inf_gradient(self, loss_scale):
|
||||
x = constant_op.constant(1.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
|
||||
g.watch(x)
|
||||
y = x * np.inf
|
||||
dy_dx = g.gradient(y, x)
|
||||
self.assertEqual(self.evaluate(dy_dx), np.inf)
|
||||
|
||||
@parameterized.parameters(loss_scale_module.FixedLossScale,
|
||||
loss_scale_module.DynamicLossScale)
|
||||
def test_scaling_nan_gradient(self, loss_scale):
|
||||
x = constant_op.constant(1.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale(32)) as g:
|
||||
g.watch(x)
|
||||
y = x * np.nan
|
||||
dy_dx = g.gradient(y, x)
|
||||
self.assertTrue(np.isnan(self.evaluate(dy_dx)))
|
||||
|
||||
@parameterized.parameters(np.inf, np.nan)
|
||||
def test_dynamic_scale_to_one_on_non_finite_gradient(self, non_finite_term):
|
||||
loss_scale = loss_scale_module.DynamicLossScale(initial_loss_scale=32)
|
||||
x = constant_op.constant(1.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale) as g:
|
||||
g.watch(x)
|
||||
y = x * non_finite_term
|
||||
g.gradient(y, x)
|
||||
self.assertEqual(self.evaluate(loss_scale()), 1.0)
|
||||
|
||||
@parameterized.parameters([np.inf, np.isposinf], [np.nan, np.isnan])
|
||||
def test_fixed_scaling_no_change_non_finite_gradient(self, non_finite_term,
|
||||
is_non_finite):
|
||||
loss_scale = loss_scale_module.FixedLossScale(32)
|
||||
x = constant_op.constant(1.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale) as g:
|
||||
g.watch(x)
|
||||
y = x * non_finite_term
|
||||
dy_dx = g.gradient(y, x)
|
||||
self.assertTrue(is_non_finite(self.evaluate(dy_dx)))
|
||||
self.assertEqual(self.evaluate(loss_scale()), 32.0)
|
||||
|
||||
def test_dynamic_loss_scaling_down_loop(self):
|
||||
loss_scale = loss_scale_module.DynamicLossScale(initial_loss_scale=32)
|
||||
x = constant_op.constant(1.0)
|
||||
with lsgt.LossScalingGradientTape(loss_scale) as g:
|
||||
g.watch(x)
|
||||
y = x * (3.0 * (10**37)) # grad will be inf after scaling
|
||||
dy_dx = g.gradient(y, x)
|
||||
self.assertEqual(self.evaluate(loss_scale()), 8.0)
|
||||
self.assertAllClose(self.evaluate(dy_dx), (3.0 * (10**37)), atol=1e-06)
|
||||
|
||||
def test_dynamic_loss_scaling_inf_target_post_scale(self):
|
||||
loss_scale = loss_scale_module.DynamicLossScale(initial_loss_scale=32.0)
|
||||
x = constant_op.constant(3.0 * (10**37))
|
||||
with lsgt.LossScalingGradientTape(loss_scale) as g:
|
||||
g.watch(x)
|
||||
y = x * 3.0 # target will be inf after scaling
|
||||
dy_dx = g.gradient(y, x)
|
||||
self.assertAllClose(self.evaluate(dy_dx), 3.0)
|
||||
self.assertEqual(self.evaluate(loss_scale()), 32.0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
v2_compat.enable_v2_behavior()
|
||||
test.main()
|
@ -0,0 +1,38 @@
|
||||
path: "tensorflow.mixed_precision.experimental.LossScalingGradientTape"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.training.experimental.loss_scaling_gradient_tape.LossScalingGradientTape\'>"
|
||||
is_instance: "<class \'tensorflow.python.eager.backprop.GradientTape\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'loss_scale\', \'persistent\', \'watch_accessed_variables\'], varargs=None, keywords=None, defaults=[\'False\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "batch_jacobian"
|
||||
argspec: "args=[\'self\', \'target\', \'source\', \'unconnected_gradients\', \'parallel_iterations\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "gradient"
|
||||
argspec: "args=[\'self\', \'target\', \'sources\', \'output_gradients\', \'unconnected_gradients\'], varargs=None, keywords=None, defaults=[\'None\', \'UnconnectedGradients.NONE\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "jacobian"
|
||||
argspec: "args=[\'self\', \'target\', \'sources\', \'unconnected_gradients\', \'parallel_iterations\', \'experimental_use_pfor\'], varargs=None, keywords=None, defaults=[\'UnconnectedGradients.NONE\', \'None\', \'True\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "reset"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "stop_recording"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "watch"
|
||||
argspec: "args=[\'self\', \'tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "watched_variables"
|
||||
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.mixed_precision.experimental"
|
||||
tf_module {
|
||||
member {
|
||||
name: "LossScalingGradientTape"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
path: "tensorflow.mixed_precision"
|
||||
tf_module {
|
||||
member {
|
||||
name: "experimental"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
}
|
@ -256,6 +256,10 @@ tf_module {
|
||||
name: "metrics"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "mixed_precision"
|
||||
mtype: "<type \'module\'>"
|
||||
}
|
||||
member {
|
||||
name: "name_scope"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user