Remove @test_util.deprecated_graph_mode_only in dense_update_ops_test.py
PiperOrigin-RevId: 324311282 Change-Id: I7f8c4fc89e4982dc885a8a07c155010d3f6097fd
This commit is contained in:
parent
68bedc248d
commit
1b2ac91b25
@ -30,28 +30,28 @@ from tensorflow.python.platform import test
|
||||
|
||||
class AssignOpTest(test.TestCase):
|
||||
|
||||
def _initAssignFetch(self, x, y, use_gpu=False):
|
||||
def _initAssignFetch(self, x, y, use_gpu):
|
||||
"""Initialize a param to init and update it with y."""
|
||||
super(AssignOpTest, self).setUp()
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
with test_util.device(use_gpu=use_gpu):
|
||||
p = variables.Variable(x)
|
||||
assign = state_ops.assign(p, y)
|
||||
self.evaluate(p.initializer)
|
||||
new_value = self.evaluate(assign)
|
||||
return self.evaluate(p), new_value
|
||||
|
||||
def _initAssignAddFetch(self, x, y, use_gpu=False):
|
||||
def _initAssignAddFetch(self, x, y, use_gpu):
|
||||
"""Initialize a param to init, and compute param += y."""
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
with test_util.device(use_gpu=use_gpu):
|
||||
p = variables.Variable(x)
|
||||
add = state_ops.assign_add(p, y)
|
||||
self.evaluate(p.initializer)
|
||||
new_value = self.evaluate(add)
|
||||
return self.evaluate(p), new_value
|
||||
|
||||
def _initAssignSubFetch(self, x, y, use_gpu=False):
|
||||
def _initAssignSubFetch(self, x, y, use_gpu):
|
||||
"""Initialize a param to init, and compute param -= y."""
|
||||
with self.cached_session(use_gpu=use_gpu):
|
||||
with test_util.device(use_gpu=use_gpu):
|
||||
p = variables.Variable(x)
|
||||
sub = state_ops.assign_sub(p, y)
|
||||
self.evaluate(p.initializer)
|
||||
@ -78,11 +78,10 @@ class AssignOpTest(test.TestCase):
|
||||
var_value, op_value = self._initAssignAddFetch(x, y, use_gpu=True)
|
||||
self.assertAllEqual(x + y, var_value)
|
||||
self.assertAllEqual(x + y, op_value)
|
||||
var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False)
|
||||
var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=True)
|
||||
self.assertAllEqual(x - y, var_value)
|
||||
self.assertAllEqual(x - y, op_value)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasic(self):
|
||||
self._testTypes(np.arange(0, 20).reshape([4, 5]))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user