Properly print the msg attribute when assertNotEqual tests fail.
PiperOrigin-RevId: 298832152 Change-Id: I34650e05fbfc0252c3ed5591a8d471761b466819
This commit is contained in:
parent
b747c9142e
commit
34197e15ab
@ -2626,10 +2626,10 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
msg: Optional message to report on failure.
|
||||
"""
|
||||
try:
|
||||
self.assertAllEqual(a, b, msg)
|
||||
self.assertAllEqual(a, b)
|
||||
except AssertionError:
|
||||
return
|
||||
raise AssertionError("The two values are equal at all elements")
|
||||
raise AssertionError("The two values are equal at all elements. %s" % msg)
|
||||
|
||||
@py_func_if_in_function
|
||||
def assertAllGreater(self, a, comparison_target):
|
||||
|
||||
@ -467,6 +467,21 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
with self.assertRaisesRegexp(AssertionError, r"not equal lhs"):
|
||||
self.assertAllEqual([0] * 3, k)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testAssertNotAllEqual(self):
|
||||
i = variables.Variable([100], dtype=dtypes.int32, name="i")
|
||||
j = constant_op.constant([20], dtype=dtypes.int32, name="j")
|
||||
k = math_ops.add(i, j, name="k")
|
||||
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
self.assertNotAllEqual([100] * 3, i)
|
||||
self.assertNotAllEqual([120] * 3, k)
|
||||
self.assertNotAllEqual([20] * 3, j)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
AssertionError, r"two values are equal at all elements.*extra message"):
|
||||
self.assertNotAllEqual([120], k, msg="extra message")
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testAssertNotAllClose(self):
|
||||
# Test with arrays
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user