Remove v1 only decorator

PiperOrigin-RevId: 323639834
Change-Id: Ie65dfb649898e138f5b2aad046fd9fc6d3f231c0
This commit is contained in:
Yanhua Sun 2020-07-28 13:20:19 -07:00 committed by TensorFlower Gardener
parent 68b5f1defd
commit 5198b44674

View File

@ -18,11 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.eager import backprop
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -30,7 +28,6 @@ from tensorflow.python.platform import test
class ArrayOpTest(test.TestCase): class ArrayOpTest(test.TestCase):
@test_util.deprecated_graph_mode_only
def testGatherGradHasPartialStaticShape(self): def testGatherGradHasPartialStaticShape(self):
# Create a tensor with an unknown dim 1. # Create a tensor with an unknown dim 1.
x = random_ops.random_normal([4, 10, 10]) x = random_ops.random_normal([4, 10, 10])
@ -38,19 +35,22 @@ class ArrayOpTest(test.TestCase):
x, x,
array_ops.reshape(array_ops.where_v2(x[0, :, 0] > 0.5), [-1]), array_ops.reshape(array_ops.where_v2(x[0, :, 0] > 0.5), [-1]),
axis=1) axis=1)
self.assertAllEqual(x.shape.as_list(), [4, None, 10]) x.shape.assert_is_compatible_with([4, None, 10])
a = array_ops.gather(array_ops.gather(x, [0, 1]), [0, 1]) with backprop.GradientTape() as tape:
b = array_ops.gather(array_ops.gather(x, [2, 3], axis=2), [0, 1]) tape.watch(x)
grad_a = ops.convert_to_tensor(gradients.gradients(a, x)[0]) a = array_ops.gather(array_ops.gather(x, [0, 1]), [0, 1])
grad_b = ops.convert_to_tensor(gradients.gradients(b, x)[0]) grad_a = tape.gradient(a, x)
with backprop.GradientTape() as tape:
tape.watch(x)
b = array_ops.gather(array_ops.gather(x, [2, 3], axis=2), [0, 1])
grad_b = tape.gradient(b, x)
# We make sure that the representation of the shapes are correct; the shape # We make sure that the representation of the shapes are correct; the shape
# equality check will always eval to false due to the shapes being partial. # equality check will always eval to false due to the shapes being partial.
self.assertAllEqual(grad_a.shape.as_list(), [None, None, 10]) grad_a.shape.assert_is_compatible_with([None, None, 10])
self.assertAllEqual(grad_b.shape.as_list(), [4, None, 10]) grad_b.shape.assert_is_compatible_with([4, None, 10])
@test_util.deprecated_graph_mode_only
def testReshapeShapeInference(self): def testReshapeShapeInference(self):
# Create a tensor with an unknown dim 1. # Create a tensor with an unknown dim 1.
x = random_ops.random_normal([4, 10, 10]) x = random_ops.random_normal([4, 10, 10])
@ -58,11 +58,11 @@ class ArrayOpTest(test.TestCase):
x, x,
array_ops.reshape(array_ops.where_v2(x[0, :, 0] > 0.5), [-1]), array_ops.reshape(array_ops.where_v2(x[0, :, 0] > 0.5), [-1]),
axis=1) axis=1)
self.assertAllEqual(x.shape.as_list(), [4, None, 10]) x.shape.assert_is_compatible_with([4, None, 10])
a = array_ops.reshape(x, array_ops.shape(x)) a = array_ops.reshape(x, array_ops.shape(x))
self.assertAllEqual(a.shape.as_list(), [4, None, 10]) a.shape.assert_is_compatible_with([4, None, 10])
b = array_ops.reshape(x, math_ops.cast(array_ops.shape(x), dtypes.int64)) b = array_ops.reshape(x, math_ops.cast(array_ops.shape(x), dtypes.int64))
self.assertAllEqual(b.shape.as_list(), [4, None, 10]) b.shape.assert_is_compatible_with([4, None, 10])
# We do not shape-infer across a tf.cast into anything that's not tf.int32 # We do not shape-infer across a tf.cast into anything that's not tf.int32
# or tf.int64, since they might end up mangling the shape. # or tf.int64, since they might end up mangling the shape.
@ -70,7 +70,7 @@ class ArrayOpTest(test.TestCase):
x, x,
math_ops.cast( math_ops.cast(
math_ops.cast(array_ops.shape(x), dtypes.float32), dtypes.int32)) math_ops.cast(array_ops.shape(x), dtypes.float32), dtypes.int32))
self.assertAllEqual(c.shape.as_list(), [None, None, None]) c.shape.assert_is_compatible_with([None, None, None])
def testEmptyMeshgrid(self): def testEmptyMeshgrid(self):
self.assertEqual(array_ops.meshgrid(), []) self.assertEqual(array_ops.meshgrid(), [])