Fix error with scalar values in assertAllInRange() input.
PiperOrigin-RevId: 334882384 Change-Id: If1bf4e3da2df22460786b849416735054f482f38
This commit is contained in:
parent
8991d372bf
commit
2e69a2dc42
@ -2904,6 +2904,8 @@ class TensorFlowTestCase(googletest.TestCase):
|
||||
lines = []
|
||||
subscripts = np.transpose(subscripts)
|
||||
prefix = " " * indent
|
||||
if np.ndim(value) == 0:
|
||||
return [prefix + "[0] : " + str(value)]
|
||||
for subscript in itertools.islice(subscripts, limit):
|
||||
lines.append(prefix + str(subscript) + " : " +
|
||||
str(value[tuple(subscript)]))
|
||||
|
@ -596,6 +596,19 @@ class TestUtilTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertAllInRange(
|
||||
x, 10, 15, open_lower_bound=True, open_upper_bound=True)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testAssertAllInRangeScalar(self):
|
||||
x = constant_op.constant(10.0, name="x")
|
||||
nan = constant_op.constant(np.nan, name="nan")
|
||||
self.assertAllInRange(x, 5, 15)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertAllInRange(nan, 5, 15)
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertAllInRange(x, 10, 15, open_lower_bound=True)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.assertAllInRange(x, 1, 2)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes
|
||||
def testAssertAllInRangeErrorMessageEllipses(self):
|
||||
x_init = np.array([[10.0, 15.0]] * 12)
|
||||
|
Loading…
x
Reference in New Issue
Block a user