Properly print the msg attribute when assertNotEqual tests fail.

PiperOrigin-RevId: 298832152
Change-Id: I34650e05fbfc0252c3ed5591a8d471761b466819
This commit is contained in:
Zachary Garrett 2020-03-04 05:58:25 -08:00 committed by TensorFlower Gardener
parent b747c9142e
commit 34197e15ab
2 changed files with 17 additions and 2 deletions

View File

@ -2626,10 +2626,10 @@ class TensorFlowTestCase(googletest.TestCase):
msg: Optional message to report on failure.
"""
try:
self.assertAllEqual(a, b, msg)
self.assertAllEqual(a, b)
except AssertionError:
return
raise AssertionError("The two values are equal at all elements")
raise AssertionError("The two values are equal at all elements. %s" % msg)
@py_func_if_in_function
def assertAllGreater(self, a, comparison_target):

View File

@ -467,6 +467,21 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
with self.assertRaisesRegexp(AssertionError, r"not equal lhs"):
self.assertAllEqual([0] * 3, k)
@test_util.run_in_graph_and_eager_modes
def testAssertNotAllEqual(self):
i = variables.Variable([100], dtype=dtypes.int32, name="i")
j = constant_op.constant([20], dtype=dtypes.int32, name="j")
k = math_ops.add(i, j, name="k")
self.evaluate(variables.global_variables_initializer())
self.assertNotAllEqual([100] * 3, i)
self.assertNotAllEqual([120] * 3, k)
self.assertNotAllEqual([20] * 3, j)
with self.assertRaisesRegexp(
AssertionError, r"two values are equal at all elements.*extra message"):
self.assertNotAllEqual([120], k, msg="extra message")
@test_util.run_in_graph_and_eager_modes
def testAssertNotAllClose(self):
# Test with arrays