Remove @test_util.run_deprecated_v1 in batch_scatter_ops_test.py

PiperOrigin-RevId: 324159787
Change-Id: Id932872b231639aa6726e51e9346b8321d2c3c7c
This commit is contained in:
Kibeom Kim 2020-07-30 23:38:45 -07:00 committed by TensorFlower Gardener
parent 2d21b21a91
commit f2ba032e65

View File

@ -23,7 +23,6 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -72,14 +71,14 @@ class ScatterTest(test.TestCase):
np_scatter(new, indices, updates)
# Scatter via tensorflow
ref = variables.Variable(old)
ref.initializer.run()
self.evaluate(variables.variables_initializer([ref]))
if method:
ref.batch_scatter_update(ops.IndexedSlices(indices, updates))
else:
tf_scatter(ref, indices, updates).eval()
self.evaluate(tf_scatter(ref, indices, updates))
self.assertAllClose(ref, new)
@test_util.run_deprecated_v1
def testVariableRankUpdate(self):
vtypes = [np.float32, np.float64]
for vtype in vtypes:
@ -87,42 +86,39 @@ class ScatterTest(test.TestCase):
self._VariableRankTest(
state_ops.batch_scatter_update, vtype, itype)
@test_util.run_deprecated_v1
def testBooleanScatterUpdate(self):
with self.session(use_gpu=False) as session:
var = variables.Variable([True, False])
update0 = state_ops.batch_scatter_update(var, [1], [True])
update1 = state_ops.batch_scatter_update(
var, constant_op.constant(
[0], dtype=dtypes.int64), [False])
var.initializer.run()
var = variables.Variable([True, False])
update0 = state_ops.batch_scatter_update(var, [1], [True])
update1 = state_ops.batch_scatter_update(
var, constant_op.constant(
[0], dtype=dtypes.int64), [False])
self.evaluate(variables.variables_initializer([var]))
session.run([update0, update1])
self.evaluate([update0, update1])
self.assertAllEqual([False, True], self.evaluate(var))
self.assertAllEqual([False, True], self.evaluate(var))
@test_util.run_deprecated_v1
def testScatterOutOfRange(self):
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
with self.session(use_gpu=False):
ref = variables.Variable(params)
ref.initializer.run()
# Indices all in range, no problem.
indices = np.array([2, 0, 5])
state_ops.batch_scatter_update(ref, indices, updates).eval()
ref = variables.Variable(params)
self.evaluate(variables.variables_initializer([ref]))
# Test some out of range errors.
indices = np.array([-1, 0, 5])
with self.assertRaisesOpError(
r'indices\[0\] = \[-1\] does not index into shape \[6\]'):
state_ops.batch_scatter_update(ref, indices, updates).eval()
# Indices all in range, no problem.
indices = np.array([2, 0, 5])
self.evaluate(state_ops.batch_scatter_update(ref, indices, updates))
indices = np.array([2, 0, 6])
with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into '
r'shape \[6\]'):
state_ops.batch_scatter_update(ref, indices, updates).eval()
# Test some out of range errors.
indices = np.array([-1, 0, 5])
with self.assertRaisesOpError(
r'indices\[0\] = \[-1\] does not index into shape \[6\]'):
self.evaluate(state_ops.batch_scatter_update(ref, indices, updates))
indices = np.array([2, 0, 6])
with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into '
r'shape \[6\]'):
self.evaluate(state_ops.batch_scatter_update(ref, indices, updates))
if __name__ == '__main__':
test.main()