Fix Windows cwise_ops_test.py bfloat16 failure.
Test was broken by 696a4a76ce
PiperOrigin-RevId: 325523032
Change-Id: Ieabe77e2cf53e4e4e81530ededbb55a4b3120c6a
This commit is contained in:
parent
6848f64097
commit
f6cfee3dff
@ -840,14 +840,16 @@ class MathOpsOverloadTest(test.TestCase):
|
||||
return self.evaluate(z)
|
||||
|
||||
def _compareBinary(self, x, y, dtype, np_func, tf_func):
|
||||
np_ans = np_func(x, y).astype(dtype.as_numpy_dtype)
|
||||
if dtype == dtypes_lib.bfloat16:
|
||||
# assertAllClose does not properly handle bfloat16 values
|
||||
np_ans = np_ans.astype(np.float32)
|
||||
# astype and assertAllClose do not properly handle bfloat16 values
|
||||
np_ans = np_func(x, y).astype(np.float32 if dtype == dtypes_lib.bfloat16
|
||||
else dtype.as_numpy_dtype)
|
||||
rtol = 1e-2 if dtype == dtypes_lib.bfloat16 else 1e-6
|
||||
self.assertAllClose(np_ans,
|
||||
self._computeTensorAndLiteral(x, y, dtype, tf_func))
|
||||
self._computeTensorAndLiteral(x, y, dtype, tf_func),
|
||||
rtol=rtol)
|
||||
self.assertAllClose(np_ans,
|
||||
self._computeLiteralAndTensor(x, y, dtype, tf_func))
|
||||
self._computeLiteralAndTensor(x, y, dtype, tf_func),
|
||||
rtol=rtol)
|
||||
|
||||
def _compareUnary(self, x, dtype, np_func, tf_func):
|
||||
np_ans = np_func(x).astype(dtype.as_numpy_dtype)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user