Fix unhelpful error message
For 99% of all usecases, if the expected shape differs from the actual shape, people will typically rerun with an additional print statement to see what the actual output was. PiperOrigin-RevId: 212303323
This commit is contained in:
parent
c5b14b334e
commit
a8b2dd9f72
@ -1327,9 +1327,17 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
|
||||
a = self._GetNdArray(a)
|
||||
b = self._GetNdArray(b)
|
||||
self.assertEqual(
|
||||
a.shape, b.shape,
|
||||
"Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
|
||||
# When the array rank is small, print its contents. Numpy array printing is
|
||||
# implemented using inefficient recursion so prints can cause tests to
|
||||
# time out.
|
||||
if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
|
||||
shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
|
||||
"%s.") % (a.shape, b.shape, b)
|
||||
else:
|
||||
shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
|
||||
b.shape)
|
||||
self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
|
||||
|
||||
if not np.allclose(a, b, rtol=rtol, atol=atol):
|
||||
# Prints more details than np.testing.assert_allclose.
|
||||
#
|
||||
|
Loading…
Reference in New Issue
Block a user