Removes the special case for bfloat16 in gradient_checker_v2, which causes an infinite loop and doesn't help with anything (if the function returns bfloat16, the precision has already been lost).
PiperOrigin-RevId: 314217630 Change-Id: I24ab5decb81baab183099b93b1d8c9ed65bffed3
This commit is contained in:
parent
625da41a12
commit
e0b56ace77
tensorflow/python/ops
@ -28,7 +28,6 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -217,14 +216,8 @@ def _compute_numeric_jacobian(f, y_size, y_dtype, xs, param, delta):
|
||||
and "x_size" columns where "x_size" is the number of elements in xs[param]
|
||||
and "y_size" is the number of elements in the result.
|
||||
"""
|
||||
# bfloat16 doesn't have enough bits to represent high precision numbers such
|
||||
# as delta. Convert to float32 here. Since numeric_jacobian is expected to
|
||||
# be the groundtruth to compare against, it shouldn't lose any information.
|
||||
x_shape = xs[param].shape
|
||||
x_dtype = xs[param].dtype
|
||||
if y_dtype == dtypes.bfloat16:
|
||||
f = lambda *xs: math_ops.cast(f(*xs), dtypes.float32)
|
||||
y_dtype = dtypes.float32
|
||||
|
||||
# To compute the jacobian, we treat x and y as one-dimensional vectors
|
||||
x_size = _product(x_shape) * (2 if x_dtype.is_complex else 1)
|
||||
@ -292,10 +285,10 @@ def _compute_gradient_list(f, xs, delta):
|
||||
xs_shapes = [x.shape for x in xs]
|
||||
f_temp = _prepare(f, xs_dtypes, xs_shapes)
|
||||
y = f_temp(*xs)
|
||||
return zip(*[
|
||||
return tuple(zip(*[
|
||||
_compute_gradient(f, y.shape, dtypes.as_dtype(y.dtype), xs, i, delta)
|
||||
for i in range(len(xs))
|
||||
])
|
||||
]))
|
||||
|
||||
|
||||
@tf_export("test.compute_gradient", v1=[])
|
||||
|
@ -97,6 +97,15 @@ class GradientCheckerTest(test.TestCase):
|
||||
tf_logging.info("x1 error = %f", error)
|
||||
self.assertLess(error, 1e-4)
|
||||
|
||||
def testBfloat16(self):
|
||||
x1 = constant_op.constant(2.0, dtype="bfloat16")
|
||||
x2 = constant_op.constant(3.0, dtype="bfloat16")
|
||||
# bfloat16 is very imprecise, so we use very large delta and error bar here.
|
||||
error = gradient_checker.max_error(*gradient_checker.compute_gradient(
|
||||
lambda x1: math_ops.add(x1, x2), [x1], delta=0.1))
|
||||
tf_logging.info("x1 error = %f", error)
|
||||
self.assertLess(error, 0.07)
|
||||
|
||||
def testAddCustomized(self):
|
||||
size = (2, 3)
|
||||
x1 = constant_op.constant(2.0, shape=size, dtype=dtypes.float64, name="x1")
|
||||
|
Loading…
Reference in New Issue
Block a user