From 34197e15ab2bef074583441c0c57797bd30ba184 Mon Sep 17 00:00:00 2001 From: Zachary Garrett Date: Wed, 4 Mar 2020 05:58:25 -0800 Subject: [PATCH] Properly print the `msg` attribute when assertNotEqual tests fail. PiperOrigin-RevId: 298832152 Change-Id: I34650e05fbfc0252c3ed5591a8d471761b466819 --- tensorflow/python/framework/test_util.py | 4 ++-- tensorflow/python/framework/test_util_test.py | 15 +++++++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 01c455f7102..84651dec152 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -2626,10 +2626,10 @@ class TensorFlowTestCase(googletest.TestCase): msg: Optional message to report on failure. """ try: - self.assertAllEqual(a, b, msg) + self.assertAllEqual(a, b) except AssertionError: return - raise AssertionError("The two values are equal at all elements") + raise AssertionError("The two values are equal at all elements. %s" % msg) @py_func_if_in_function def assertAllGreater(self, a, comparison_target): diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py index 96f7d600713..b5cb903c666 100644 --- a/tensorflow/python/framework/test_util_test.py +++ b/tensorflow/python/framework/test_util_test.py @@ -467,6 +467,21 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase): with self.assertRaisesRegexp(AssertionError, r"not equal lhs"): self.assertAllEqual([0] * 3, k) + @test_util.run_in_graph_and_eager_modes + def testAssertNotAllEqual(self): + i = variables.Variable([100], dtype=dtypes.int32, name="i") + j = constant_op.constant([20], dtype=dtypes.int32, name="j") + k = math_ops.add(i, j, name="k") + + self.evaluate(variables.global_variables_initializer()) + self.assertNotAllEqual([100] * 3, i) + self.assertNotAllEqual([120] * 3, k) + self.assertNotAllEqual([20] * 3, j) + + with self.assertRaisesRegexp( + AssertionError, r"two values are equal at all elements.*extra message"): + self.assertNotAllEqual([120], k, msg="extra message") + @test_util.run_in_graph_and_eager_modes def testAssertNotAllClose(self): # Test with arrays