Fix Windows cwise_ops_test.py bfloat16 failure.

Test was broken by 696a4a76ce

PiperOrigin-RevId: 325523032
Change-Id: Ieabe77e2cf53e4e4e81530ededbb55a4b3120c6a
This commit is contained in:
Reed Wanderman-Milne 2020-08-07 15:41:34 -07:00 committed by TensorFlower Gardener
parent 6848f64097
commit f6cfee3dff

View File

@ -840,14 +840,16 @@ class MathOpsOverloadTest(test.TestCase):
return self.evaluate(z)
def _compareBinary(self, x, y, dtype, np_func, tf_func):
np_ans = np_func(x, y).astype(dtype.as_numpy_dtype)
if dtype == dtypes_lib.bfloat16:
# assertAllClose does not properly handle bfloat16 values
np_ans = np_ans.astype(np.float32)
# astype and assertAllClose do not properly handle bfloat16 values
np_ans = np_func(x, y).astype(np.float32 if dtype == dtypes_lib.bfloat16
else dtype.as_numpy_dtype)
rtol = 1e-2 if dtype == dtypes_lib.bfloat16 else 1e-6
self.assertAllClose(np_ans,
self._computeTensorAndLiteral(x, y, dtype, tf_func))
self._computeTensorAndLiteral(x, y, dtype, tf_func),
rtol=rtol)
self.assertAllClose(np_ans,
self._computeLiteralAndTensor(x, y, dtype, tf_func))
self._computeLiteralAndTensor(x, y, dtype, tf_func),
rtol=rtol)
def _compareUnary(self, x, dtype, np_func, tf_func):
np_ans = np_func(x).astype(dtype.as_numpy_dtype)