Add missing gradient_checker_v2.max_error() checks in tests.

PiperOrigin-RevId: 324095184
Change-Id: I9e3f1e5fc4f5f22b3b3cf4e8196a9b2e28840ed8
This commit is contained in:
Kibeom Kim 2020-07-30 15:23:56 -07:00 committed by TensorFlower Gardener
parent 26c01423d9
commit 431bb29171
3 changed files with 13 additions and 6 deletions

View File

@ -189,10 +189,13 @@ class BroadcastSimpleTest(test.TestCase):
return math_ops.truediv(x1, x2) * math_ops.cast(1.1, dtype=x2.dtype)
with self.cached_session():
gradient_checker_v2.compute_gradient(
div_x1, [x1], self._GRAD_TOL[dtypes.as_dtype(x1.dtype)])
gradient_checker_v2.compute_gradient(
div_x2, [x2], self._GRAD_TOL[dtypes.as_dtype(x2.dtype)])
err = gradient_checker_v2.max_error(*gradient_checker_v2.compute_gradient(
div_x1, [x1]))
self.assertLess(err, self._GRAD_TOL[dtypes.as_dtype(x1.dtype)])
err = gradient_checker_v2.max_error(*gradient_checker_v2.compute_gradient(
div_x2, [x2]))
self.assertLess(err, self._GRAD_TOL[dtypes.as_dtype(x2.dtype)])
self._compareGpu(x1, x2, np.true_divide, math_ops.truediv)
self._compareGpu(x1, x2 + 0.1, np.floor_divide, math_ops.floordiv)

View File

@ -180,7 +180,9 @@ class CastOpTest(test.TestCase):
x = math_ops.cast(x, dst_t)
return x
gradient_checker_v2.compute_gradient(cast, [x])
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(cast, [x]))
self.assertLess(err, 1e-3)
class SparseTensorCastTest(test.TestCase):

View File

@ -126,7 +126,9 @@ class ReshapeTest(test.TestCase):
return array_ops.reshape(x, [1, 8, 3])
with self.cached_session():
gradient_checker_v2.compute_gradient(reshape, [input_tensor], 1e-3)
err = gradient_checker_v2.max_error(
*gradient_checker_v2.compute_gradient(reshape, [input_tensor]))
self.assertLess(err, 1e-3)
def testFloatEmpty(self):
x = np.empty((0, 0, 0, 0), dtype=np.float32)