Removes the special case for bfloat16 in gradient_checker_v2, which causes an infinite loop and doesn't help with anything (if the function returns bfloat16, the precision has already been lost).

PiperOrigin-RevId: 314217630
Change-Id: I24ab5decb81baab183099b93b1d8c9ed65bffed3
This commit is contained in:
Peng Wang 2020-06-01 15:40:04 -07:00 committed by TensorFlower Gardener
parent 625da41a12
commit e0b56ace77
2 changed files with 11 additions and 9 deletions

View File

@ -28,7 +28,6 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
@ -217,14 +216,8 @@ def _compute_numeric_jacobian(f, y_size, y_dtype, xs, param, delta):
and "x_size" columns where "x_size" is the number of elements in xs[param]
and "y_size" is the number of elements in the result.
"""
# bfloat16 doesn't have enough bits to represent high precision numbers such
# as delta. Convert to float32 here. Since numeric_jacobian is expected to
# be the groundtruth to compare against, it shouldn't lose any information.
x_shape = xs[param].shape
x_dtype = xs[param].dtype
if y_dtype == dtypes.bfloat16:
f = lambda *xs: math_ops.cast(f(*xs), dtypes.float32)
y_dtype = dtypes.float32
# To compute the jacobian, we treat x and y as one-dimensional vectors
x_size = _product(x_shape) * (2 if x_dtype.is_complex else 1)
@ -292,10 +285,10 @@ def _compute_gradient_list(f, xs, delta):
xs_shapes = [x.shape for x in xs]
f_temp = _prepare(f, xs_dtypes, xs_shapes)
y = f_temp(*xs)
return zip(*[
return tuple(zip(*[
_compute_gradient(f, y.shape, dtypes.as_dtype(y.dtype), xs, i, delta)
for i in range(len(xs))
])
]))
@tf_export("test.compute_gradient", v1=[])

View File

@ -97,6 +97,15 @@ class GradientCheckerTest(test.TestCase):
tf_logging.info("x1 error = %f", error)
self.assertLess(error, 1e-4)
def testBfloat16(self):
x1 = constant_op.constant(2.0, dtype="bfloat16")
x2 = constant_op.constant(3.0, dtype="bfloat16")
# bfloat16 is very imprecise, so we use very large delta and error bar here.
error = gradient_checker.max_error(*gradient_checker.compute_gradient(
lambda x1: math_ops.add(x1, x2), [x1], delta=0.1))
tf_logging.info("x1 error = %f", error)
self.assertLess(error, 0.07)
def testAddCustomized(self):
size = (2, 3)
x1 = constant_op.constant(2.0, shape=size, dtype=dtypes.float64, name="x1")