Explicitly specify input array dtype to TensorFlow's assertAllEqual() test util function.

PiperOrigin-RevId: 313530667
Change-Id: Ia251459edaeb89318d012cf0ff54c9b43a56fe49
This commit is contained in:
Hye Soo Yang 2020-05-27 22:55:00 -07:00 committed by TensorFlower Gardener
parent ad74ce7491
commit 3e7ba8944b

View File

@ -247,8 +247,8 @@ class SliceTest(test.TestCase):
slice_t = array_ops.slice(a, [0, 0], [2, 2])
slice2_t = a[:2, :2]
slice_val, slice2_val = self.evaluate([slice_t, slice2_t])
self.assertAllEqual(slice_val, inp[:2, :2])
self.assertAllEqual(slice2_val, inp[:2, :2])
self.assertAllEqual(slice_val, np.array(inp[:2, :2], dtype=np.float32))
self.assertAllEqual(slice2_val, np.array(inp[:2, :2], dtype=np.float32))
self.assertEqual(slice_val.shape, slice_t.get_shape())
self.assertEqual(slice2_val.shape, slice2_t.get_shape())