diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 8d628d448db..a7d8f841401 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -840,14 +840,16 @@ class MathOpsOverloadTest(test.TestCase): return self.evaluate(z) def _compareBinary(self, x, y, dtype, np_func, tf_func): - np_ans = np_func(x, y).astype(dtype.as_numpy_dtype) - if dtype == dtypes_lib.bfloat16: - # assertAllClose does not properly handle bfloat16 values - np_ans = np_ans.astype(np.float32) + # astype and assertAllClose do not properly handle bfloat16 values + np_ans = np_func(x, y).astype(np.float32 if dtype == dtypes_lib.bfloat16 + else dtype.as_numpy_dtype) + rtol = 1e-2 if dtype == dtypes_lib.bfloat16 else 1e-6 self.assertAllClose(np_ans, - self._computeTensorAndLiteral(x, y, dtype, tf_func)) + self._computeTensorAndLiteral(x, y, dtype, tf_func), + rtol=rtol) self.assertAllClose(np_ans, - self._computeLiteralAndTensor(x, y, dtype, tf_func)) + self._computeLiteralAndTensor(x, y, dtype, tf_func), + rtol=rtol) def _compareUnary(self, x, dtype, np_func, tf_func): np_ans = np_func(x).astype(dtype.as_numpy_dtype)