diff --git a/tensorflow/python/ops/array_ops_test.py b/tensorflow/python/ops/array_ops_test.py index d8e2dcd0fb3..87c05b47455 100644 --- a/tensorflow/python/ops/array_ops_test.py +++ b/tensorflow/python/ops/array_ops_test.py @@ -18,11 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import backprop from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gradients from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -30,7 +28,6 @@ from tensorflow.python.platform import test class ArrayOpTest(test.TestCase): - @test_util.deprecated_graph_mode_only def testGatherGradHasPartialStaticShape(self): # Create a tensor with an unknown dim 1. x = random_ops.random_normal([4, 10, 10]) @@ -38,19 +35,22 @@ class ArrayOpTest(test.TestCase): x, array_ops.reshape(array_ops.where_v2(x[0, :, 0] > 0.5), [-1]), axis=1) - self.assertAllEqual(x.shape.as_list(), [4, None, 10]) + x.shape.assert_is_compatible_with([4, None, 10]) - a = array_ops.gather(array_ops.gather(x, [0, 1]), [0, 1]) - b = array_ops.gather(array_ops.gather(x, [2, 3], axis=2), [0, 1]) - grad_a = ops.convert_to_tensor(gradients.gradients(a, x)[0]) - grad_b = ops.convert_to_tensor(gradients.gradients(b, x)[0]) + with backprop.GradientTape() as tape: + tape.watch(x) + a = array_ops.gather(array_ops.gather(x, [0, 1]), [0, 1]) + grad_a = tape.gradient(a, x) + with backprop.GradientTape() as tape: + tape.watch(x) + b = array_ops.gather(array_ops.gather(x, [2, 3], axis=2), [0, 1]) + grad_b = tape.gradient(b, x) # We make sure that the representation of the shapes are correct; the shape # equality check will always eval to false due to the shapes being partial. - self.assertAllEqual(grad_a.shape.as_list(), [None, None, 10]) - self.assertAllEqual(grad_b.shape.as_list(), [4, None, 10]) + grad_a.shape.assert_is_compatible_with([None, None, 10]) + grad_b.shape.assert_is_compatible_with([4, None, 10]) - @test_util.deprecated_graph_mode_only def testReshapeShapeInference(self): # Create a tensor with an unknown dim 1. x = random_ops.random_normal([4, 10, 10]) @@ -58,11 +58,11 @@ class ArrayOpTest(test.TestCase): x, array_ops.reshape(array_ops.where_v2(x[0, :, 0] > 0.5), [-1]), axis=1) - self.assertAllEqual(x.shape.as_list(), [4, None, 10]) + x.shape.assert_is_compatible_with([4, None, 10]) a = array_ops.reshape(x, array_ops.shape(x)) - self.assertAllEqual(a.shape.as_list(), [4, None, 10]) + a.shape.assert_is_compatible_with([4, None, 10]) b = array_ops.reshape(x, math_ops.cast(array_ops.shape(x), dtypes.int64)) - self.assertAllEqual(b.shape.as_list(), [4, None, 10]) + b.shape.assert_is_compatible_with([4, None, 10]) # We do not shape-infer across a tf.cast into anything that's not tf.int32 # or tf.int64, since they might end up mangling the shape. @@ -70,7 +70,7 @@ class ArrayOpTest(test.TestCase): x, math_ops.cast( math_ops.cast(array_ops.shape(x), dtypes.float32), dtypes.int32)) - self.assertAllEqual(c.shape.as_list(), [None, None, None]) + c.shape.assert_is_compatible_with([None, None, None]) def testEmptyMeshgrid(self): self.assertEqual(array_ops.meshgrid(), [])