From 1b2ac91b25aefeb445d1baa20ebcd9ba5c42f0e2 Mon Sep 17 00:00:00 2001 From: Kibeom Kim Date: Fri, 31 Jul 2020 16:55:05 -0700 Subject: [PATCH] Remove @test_util.deprecated_graph_mode_only in dense_update_ops_test.py PiperOrigin-RevId: 324311282 Change-Id: I7f8c4fc89e4982dc885a8a07c155010d3f6097fd --- .../python/kernel_tests/dense_update_ops_test.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py index 2d7eac10a12..b73f04b25d0 100644 --- a/tensorflow/python/kernel_tests/dense_update_ops_test.py +++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py @@ -30,28 +30,28 @@ from tensorflow.python.platform import test class AssignOpTest(test.TestCase): - def _initAssignFetch(self, x, y, use_gpu=False): + def _initAssignFetch(self, x, y, use_gpu): """Initialize a param to init and update it with y.""" super(AssignOpTest, self).setUp() - with self.cached_session(use_gpu=use_gpu): + with test_util.device(use_gpu=use_gpu): p = variables.Variable(x) assign = state_ops.assign(p, y) self.evaluate(p.initializer) new_value = self.evaluate(assign) return self.evaluate(p), new_value - def _initAssignAddFetch(self, x, y, use_gpu=False): + def _initAssignAddFetch(self, x, y, use_gpu): """Initialize a param to init, and compute param += y.""" - with self.cached_session(use_gpu=use_gpu): + with test_util.device(use_gpu=use_gpu): p = variables.Variable(x) add = state_ops.assign_add(p, y) self.evaluate(p.initializer) new_value = self.evaluate(add) return self.evaluate(p), new_value - def _initAssignSubFetch(self, x, y, use_gpu=False): + def _initAssignSubFetch(self, x, y, use_gpu): """Initialize a param to init, and compute param -= y.""" - with self.cached_session(use_gpu=use_gpu): + with test_util.device(use_gpu=use_gpu): p = variables.Variable(x) sub = state_ops.assign_sub(p, y) self.evaluate(p.initializer) @@ -78,11 +78,10 @@ class AssignOpTest(test.TestCase): var_value, op_value = self._initAssignAddFetch(x, y, use_gpu=True) self.assertAllEqual(x + y, var_value) self.assertAllEqual(x + y, op_value) - var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=False) + var_value, op_value = self._initAssignSubFetch(x, y, use_gpu=True) self.assertAllEqual(x - y, var_value) self.assertAllEqual(x - y, op_value) - @test_util.run_deprecated_v1 def testBasic(self): self._testTypes(np.arange(0, 20).reshape([4, 5]))