From a8b2dd9f72fe78cca59d525230f5358430fec45c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Sep 2018 11:35:24 -0700 Subject: [PATCH] Fix unhelpful error message For 99% of all usecases, if the expected shape differs from the actual shape, people will typically rerun with an additional print statement to see what the actual output was. PiperOrigin-RevId: 212303323 --- tensorflow/python/framework/test_util.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 4bece9e25e8..d63abd7f018 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -1327,9 +1327,17 @@ class TensorFlowTestCase(googletest.TestCase): def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None): a = self._GetNdArray(a) b = self._GetNdArray(b) - self.assertEqual( - a.shape, b.shape, - "Shape mismatch: expected %s, got %s." % (a.shape, b.shape)) + # When the array rank is small, print its contents. Numpy array printing is + # implemented using inefficient recursion so prints can cause tests to + # time out. + if a.shape != b.shape and (b.ndim <= 3 or b.size < 500): + shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents " + "%s.") % (a.shape, b.shape, b) + else: + shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape, + b.shape) + self.assertEqual(a.shape, b.shape, shape_mismatch_msg) + if not np.allclose(a, b, rtol=rtol, atol=atol): # Prints more details than np.testing.assert_allclose. #