Remove @test_util.run_deprecated_v1 in bcast_ops_test.py

PiperOrigin-RevId: 324083460
Change-Id: I654fe4379a1f94e5ff23a81480394a818b54357a
This commit is contained in:
Kibeom Kim 2020-07-30 14:28:52 -07:00 committed by TensorFlower Gardener
parent a9740221ea
commit 161867c062

View File

@ -20,7 +20,6 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops.gen_array_ops import broadcast_args
from tensorflow.python.ops.gen_array_ops import broadcast_gradient_args
from tensorflow.python.platform import test
@ -29,14 +28,11 @@ from tensorflow.python.platform import test
class BcastOpsTest(test.TestCase):
def _GetBroadcastShape(self, xs, ys):
with self.cached_session() as sess:
return sess.run(broadcast_args(xs, ys))
return self.evaluate(broadcast_args(xs, ys))
def _GetGradientArgs(self, xs, ys):
with self.cached_session() as sess:
return sess.run(broadcast_gradient_args(xs, ys))
return self.evaluate(broadcast_gradient_args(xs, ys))
@test_util.run_deprecated_v1
def testBasic(self):
r = self._GetBroadcastShape([2, 3, 5], [1])
self.assertAllEqual(r, [2, 3, 5])
@ -68,7 +64,6 @@ class BcastOpsTest(test.TestCase):
r = self._GetBroadcastShape([3, 1], [2, 1, 5])
self.assertAllEqual(r, [2, 3, 5])
@test_util.run_deprecated_v1
def testBasicGradient(self):
r0, r1 = self._GetGradientArgs([2, 3, 5], [1])
self.assertAllEqual(r0, [])
@ -110,7 +105,6 @@ class BcastOpsTest(test.TestCase):
self.assertAllEqual(r0, [0, 2])
self.assertAllEqual(r1, [1])
@test_util.run_deprecated_v1
def testZeroDims(self):
r = self._GetBroadcastShape([2, 0, 3, 0, 5], [3, 0, 5])
self.assertAllEqual(r, [2, 0, 3, 0, 5])
@ -124,7 +118,6 @@ class BcastOpsTest(test.TestCase):
r = self._GetBroadcastShape([3, 1, 5], [2, 0, 3, 0, 5])
self.assertAllEqual(r, [2, 0, 3, 0, 5])
@test_util.run_deprecated_v1
def testZeroDimsGradient(self):
r0, r1 = self._GetGradientArgs([2, 0, 3, 0, 5], [3, 0, 5])
self.assertAllEqual(r0, [])
@ -142,7 +135,6 @@ class BcastOpsTest(test.TestCase):
self.assertAllEqual(r0, [0, 1, 3])
self.assertAllEqual(r1, [])
@test_util.run_deprecated_v1
def testDataTypes(self):
for dtype in [dtypes.int32, dtypes.int64]:
r = self._GetBroadcastShape(