Explicitly specify input array dtype
to TensorFlow's assertAllEqual()
test util function.
PiperOrigin-RevId: 313530667 Change-Id: Ia251459edaeb89318d012cf0ff54c9b43a56fe49
This commit is contained in:
parent
ad74ce7491
commit
3e7ba8944b
@ -247,8 +247,8 @@ class SliceTest(test.TestCase):
|
||||
slice_t = array_ops.slice(a, [0, 0], [2, 2])
|
||||
slice2_t = a[:2, :2]
|
||||
slice_val, slice2_val = self.evaluate([slice_t, slice2_t])
|
||||
self.assertAllEqual(slice_val, inp[:2, :2])
|
||||
self.assertAllEqual(slice2_val, inp[:2, :2])
|
||||
self.assertAllEqual(slice_val, np.array(inp[:2, :2], dtype=np.float32))
|
||||
self.assertAllEqual(slice2_val, np.array(inp[:2, :2], dtype=np.float32))
|
||||
self.assertEqual(slice_val.shape, slice_t.get_shape())
|
||||
self.assertEqual(slice2_val.shape, slice2_t.get_shape())
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user