Remove @test_util.run_deprecated_v1 in bcast_ops_test.py
PiperOrigin-RevId: 324083460 Change-Id: I654fe4379a1f94e5ff23a81480394a818b54357a
This commit is contained in:
parent
a9740221ea
commit
161867c062
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.gen_array_ops import broadcast_args
|
||||
from tensorflow.python.ops.gen_array_ops import broadcast_gradient_args
|
||||
from tensorflow.python.platform import test
|
||||
@ -29,14 +28,11 @@ from tensorflow.python.platform import test
|
||||
class BcastOpsTest(test.TestCase):
|
||||
|
||||
def _GetBroadcastShape(self, xs, ys):
|
||||
with self.cached_session() as sess:
|
||||
return sess.run(broadcast_args(xs, ys))
|
||||
return self.evaluate(broadcast_args(xs, ys))
|
||||
|
||||
def _GetGradientArgs(self, xs, ys):
|
||||
with self.cached_session() as sess:
|
||||
return sess.run(broadcast_gradient_args(xs, ys))
|
||||
return self.evaluate(broadcast_gradient_args(xs, ys))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasic(self):
|
||||
r = self._GetBroadcastShape([2, 3, 5], [1])
|
||||
self.assertAllEqual(r, [2, 3, 5])
|
||||
@ -68,7 +64,6 @@ class BcastOpsTest(test.TestCase):
|
||||
r = self._GetBroadcastShape([3, 1], [2, 1, 5])
|
||||
self.assertAllEqual(r, [2, 3, 5])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasicGradient(self):
|
||||
r0, r1 = self._GetGradientArgs([2, 3, 5], [1])
|
||||
self.assertAllEqual(r0, [])
|
||||
@ -110,7 +105,6 @@ class BcastOpsTest(test.TestCase):
|
||||
self.assertAllEqual(r0, [0, 2])
|
||||
self.assertAllEqual(r1, [1])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZeroDims(self):
|
||||
r = self._GetBroadcastShape([2, 0, 3, 0, 5], [3, 0, 5])
|
||||
self.assertAllEqual(r, [2, 0, 3, 0, 5])
|
||||
@ -124,7 +118,6 @@ class BcastOpsTest(test.TestCase):
|
||||
r = self._GetBroadcastShape([3, 1, 5], [2, 0, 3, 0, 5])
|
||||
self.assertAllEqual(r, [2, 0, 3, 0, 5])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testZeroDimsGradient(self):
|
||||
r0, r1 = self._GetGradientArgs([2, 0, 3, 0, 5], [3, 0, 5])
|
||||
self.assertAllEqual(r0, [])
|
||||
@ -142,7 +135,6 @@ class BcastOpsTest(test.TestCase):
|
||||
self.assertAllEqual(r0, [0, 1, 3])
|
||||
self.assertAllEqual(r1, [])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testDataTypes(self):
|
||||
for dtype in [dtypes.int32, dtypes.int64]:
|
||||
r = self._GetBroadcastShape(
|
||||
|
Loading…
Reference in New Issue
Block a user