Fix unhelpful error message

For 99% of all usecases, if the expected shape differs from the actual shape, people will typically rerun with an additional print statement to see what the actual output was.

PiperOrigin-RevId: 212303323
This commit is contained in:
A. Unique TensorFlower 2018-09-10 11:35:24 -07:00 committed by TensorFlower Gardener
parent c5b14b334e
commit a8b2dd9f72

View File

@ -1327,9 +1327,17 @@ class TensorFlowTestCase(googletest.TestCase):
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
a = self._GetNdArray(a)
b = self._GetNdArray(b)
self.assertEqual(
a.shape, b.shape,
"Shape mismatch: expected %s, got %s." % (a.shape, b.shape))
# When the array rank is small, print its contents. Numpy array printing is
# implemented using inefficient recursion so prints can cause tests to
# time out.
if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
"%s.") % (a.shape, b.shape, b)
else:
shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
b.shape)
self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
if not np.allclose(a, b, rtol=rtol, atol=atol):
# Prints more details than np.testing.assert_allclose.
#