Remove @test_util.run_deprecated_v1 in batch_scatter_ops_test.py
PiperOrigin-RevId: 324159787 Change-Id: Id932872b231639aa6726e51e9346b8321d2c3c7c
This commit is contained in:
parent
2d21b21a91
commit
f2ba032e65
@ -23,7 +23,6 @@ import numpy as np
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import state_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -72,14 +71,14 @@ class ScatterTest(test.TestCase):
|
||||
np_scatter(new, indices, updates)
|
||||
# Scatter via tensorflow
|
||||
ref = variables.Variable(old)
|
||||
ref.initializer.run()
|
||||
self.evaluate(variables.variables_initializer([ref]))
|
||||
|
||||
if method:
|
||||
ref.batch_scatter_update(ops.IndexedSlices(indices, updates))
|
||||
else:
|
||||
tf_scatter(ref, indices, updates).eval()
|
||||
self.evaluate(tf_scatter(ref, indices, updates))
|
||||
self.assertAllClose(ref, new)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testVariableRankUpdate(self):
|
||||
vtypes = [np.float32, np.float64]
|
||||
for vtype in vtypes:
|
||||
@ -87,42 +86,39 @@ class ScatterTest(test.TestCase):
|
||||
self._VariableRankTest(
|
||||
state_ops.batch_scatter_update, vtype, itype)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBooleanScatterUpdate(self):
|
||||
with self.session(use_gpu=False) as session:
|
||||
var = variables.Variable([True, False])
|
||||
update0 = state_ops.batch_scatter_update(var, [1], [True])
|
||||
update1 = state_ops.batch_scatter_update(
|
||||
var, constant_op.constant(
|
||||
[0], dtype=dtypes.int64), [False])
|
||||
var.initializer.run()
|
||||
var = variables.Variable([True, False])
|
||||
update0 = state_ops.batch_scatter_update(var, [1], [True])
|
||||
update1 = state_ops.batch_scatter_update(
|
||||
var, constant_op.constant(
|
||||
[0], dtype=dtypes.int64), [False])
|
||||
self.evaluate(variables.variables_initializer([var]))
|
||||
|
||||
session.run([update0, update1])
|
||||
self.evaluate([update0, update1])
|
||||
|
||||
self.assertAllEqual([False, True], self.evaluate(var))
|
||||
self.assertAllEqual([False, True], self.evaluate(var))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testScatterOutOfRange(self):
|
||||
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
|
||||
updates = np.array([-3, -4, -5]).astype(np.float32)
|
||||
with self.session(use_gpu=False):
|
||||
ref = variables.Variable(params)
|
||||
ref.initializer.run()
|
||||
|
||||
# Indices all in range, no problem.
|
||||
indices = np.array([2, 0, 5])
|
||||
state_ops.batch_scatter_update(ref, indices, updates).eval()
|
||||
ref = variables.Variable(params)
|
||||
self.evaluate(variables.variables_initializer([ref]))
|
||||
|
||||
# Test some out of range errors.
|
||||
indices = np.array([-1, 0, 5])
|
||||
with self.assertRaisesOpError(
|
||||
r'indices\[0\] = \[-1\] does not index into shape \[6\]'):
|
||||
state_ops.batch_scatter_update(ref, indices, updates).eval()
|
||||
# Indices all in range, no problem.
|
||||
indices = np.array([2, 0, 5])
|
||||
self.evaluate(state_ops.batch_scatter_update(ref, indices, updates))
|
||||
|
||||
indices = np.array([2, 0, 6])
|
||||
with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into '
|
||||
r'shape \[6\]'):
|
||||
state_ops.batch_scatter_update(ref, indices, updates).eval()
|
||||
# Test some out of range errors.
|
||||
indices = np.array([-1, 0, 5])
|
||||
with self.assertRaisesOpError(
|
||||
r'indices\[0\] = \[-1\] does not index into shape \[6\]'):
|
||||
self.evaluate(state_ops.batch_scatter_update(ref, indices, updates))
|
||||
|
||||
indices = np.array([2, 0, 6])
|
||||
with self.assertRaisesOpError(r'indices\[2\] = \[6\] does not index into '
|
||||
r'shape \[6\]'):
|
||||
self.evaluate(state_ops.batch_scatter_update(ref, indices, updates))
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user