Update gradient_checker_v2 to use a step size in the finite difference approximation that is exactly representable as a binary floating point number. This is an old trick that in some cases avoids polluting the finite difference approximation with rounding errors that cause false negatives in gradient tests.

PiperOrigin-RevId: 343348502
Change-Id: I3539ae7de7105177c5a1b9144b491f36369344f4
This commit is contained in:
A. Unique TensorFlower 2020-11-19 12:41:29 -08:00 committed by TensorFlower Gardener
parent 5f9f2d21d2
commit 1427bfc12e
4 changed files with 21 additions and 43 deletions
RELEASE.md
tensorflow
python
tools/api/golden/v2

View File

@ -54,6 +54,7 @@
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with
`tf.GradientTape` inside a `tf.function`.
* Changed the default step size in `gradient_checker_v2.compute_gradients` to be exactly representable as a binary floating point numbers. This avoids poluting gradient approximations needlessly, which is some cases leads to false negatives in op gradient tests.
* `tf.summary`:
* New `tf.summary.graph` allows manual write of TensorFlow graph

View File

@ -19,9 +19,7 @@ from __future__ import division
from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.python import tf2
from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -29,7 +27,6 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
@ -117,45 +114,19 @@ class ReluTest(test.TestCase):
order="F")
err = gradient_checker_v2.max_error(*gradient_checker_v2.compute_gradient(
nn_ops.relu, [x], delta=1.0 / 1024))
self.assertLess(err, 1e-4)
self.assertLess(err, 1e-6)
# The gradient for fp16 is inaccurate due to the low-precision.
# We compare the fp16 analytical gradient against their fp32 counterpart.
# The gradient test for ReLU is a bit tricky as the derivative is not well
# defined at around zero and we want to avoid that in terms of input values.
def testGradientFloat16(self):
def grad(x):
with backprop.GradientTape() as tape:
tape.watch(x)
y = nn_ops.l2_loss(nn_ops.relu(x))
return tape.gradient(y, x)
def f():
with test_util.use_gpu():
# Randomly construct a 1D shape from [1, 40)
shape = random_ops.random_uniform([1],
minval=1,
maxval=40,
dtype=dtypes.int32)
x32 = random_ops.random_uniform(shape, minval=-1, maxval=1)
x16 = math_ops.cast(x32, dtype=dtypes.float16)
return grad(x32), grad(x16)
# We're going to ensure that the fp16 and fp32 gradients
# are "close" to each other for ~100 random values.
#
# In TensorFlow 1.x, invoking f() (without eager execution enabled)
# would construct a graph. Instead of construct a graph with O(100) nodes,
# we construct a single graph to be executed ~100 times in a Session.
if not tf2.enabled():
d32_tensor, d16_tensor = f()
with self.cached_session() as sess:
f = lambda: sess.run([d32_tensor, d16_tensor])
# Repeat the experiment for 100 times. All tensor shapes and its tensor
# values are randomly generated for each run.
for _ in xrange(100):
d32, d16 = f()
self.assertAllClose(d32, d16, atol=3e-4)
with self.cached_session():
x = np.asarray(
[[-0.9, -0.7, -0.5, -0.3, -0.1], [0.1, 0.3, 0.5, 0.7, 0.9]],
dtype=np.float16,
order="F")
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(nn_ops.relu, [x]))
self.assertLess(err, 1e-6)
def testGradientFloat64(self):
with self.cached_session():
@ -165,7 +136,7 @@ class ReluTest(test.TestCase):
order="F")
err = gradient_checker_v2.max_error(*gradient_checker_v2.compute_gradient(
nn_ops.relu, [x], delta=1.0 / 1024))
self.assertLess(err, 1e-10)
self.assertLess(err, 1e-15)
def testGradGradFloat32(self):
with self.cached_session():

View File

@ -292,7 +292,7 @@ def _compute_gradient_list(f, xs, delta):
@tf_export("test.compute_gradient", v1=[])
def compute_gradient(f, x, delta=1e-3):
def compute_gradient(f, x, delta=None):
"""Computes the theoretical and numeric Jacobian of `f`.
With y = f(x), computes the theoretical and numeric Jacobian dy/dx.
@ -329,6 +329,12 @@ def compute_gradient(f, x, delta=1e-3):
raise ValueError(
"`x` must be a list or tuple of values convertible to a Tensor "
"(arguments to `f`), not a %s" % type(x))
if delta is None:
# By default, we use a step size for the central finite difference
# approximation that is exactly representable as a binary floating
# point number, since this reduces the amount of noise due to rounding
# in the approximation of some functions.
delta = 1.0 / 1024
return _compute_gradient_list(f, x, delta)

View File

@ -18,7 +18,7 @@ tf_module {
}
member_method {
name: "compute_gradient"
argspec: "args=[\'f\', \'x\', \'delta\'], varargs=None, keywords=None, defaults=[\'0.001\'], "
argspec: "args=[\'f\', \'x\', \'delta\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "create_local_cluster"