Remove run_deprecated_v1 in proximal_gradient_descent_test.

All the test case has been updated to run with graph context, since the API expect to run in v1 graph context.

PiperOrigin-RevId: 321296244
Change-Id: I045584d2003febc0dd32b94abccc7382f07eb3d8
This commit is contained in:
Scott Zhu 2020-07-14 21:47:49 -07:00 committed by TensorFlower Gardener
parent eb384a3ff6
commit 41339588d9

View File

@ -23,7 +23,6 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
@ -37,7 +36,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
def doTestProximalGradientDescentwithoutRegularization(
self, use_resource=False):
with self.cached_session() as sess:
with ops.Graph().as_default(), self.cached_session():
if use_resource:
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
@ -63,17 +62,14 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
self.assertAllClose(np.array([-0.9, -1.8]), v0_val)
self.assertAllClose(np.array([-0.09, -0.18]), v1_val)
@test_util.run_deprecated_v1
def testProximalGradientDescentwithoutRegularization(self):
self.doTestProximalGradientDescentwithoutRegularization(use_resource=False)
@test_util.run_deprecated_v1
def testResourceProximalGradientDescentwithoutRegularization(self):
self.doTestProximalGradientDescentwithoutRegularization(use_resource=True)
@test_util.run_deprecated_v1
def testProximalGradientDescentwithoutRegularization2(self):
with self.cached_session() as sess:
with ops.Graph().as_default(), self.cached_session():
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@ -96,10 +92,9 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
self.assertAllClose(np.array([0.1, 0.2]), v0_val)
self.assertAllClose(np.array([3.91, 2.82]), v1_val)
@test_util.run_deprecated_v1
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.float32, dtypes.float64]:
with self.cached_session():
with ops.Graph().as_default(), self.cached_session():
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
@ -116,9 +111,8 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
self.evaluate(var0),
atol=0.01)
@test_util.run_deprecated_v1
def testProximalGradientDescentWithL1_L2(self):
with self.cached_session() as sess:
with ops.Graph().as_default(), self.cached_session():
var0 = variables.Variable([1.0, 2.0])
var1 = variables.Variable([4.0, 3.0])
grads0 = constant_op.constant([0.1, 0.2])
@ -164,7 +158,6 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
self.evaluate(variables.global_variables_initializer())
sess = ops.get_default_session()
v0_val, v1_val = self.evaluate([var0, var1])
if is_sparse:
self.assertAllClose([[1.0], [2.0]], v0_val)
@ -180,9 +173,8 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
v0_val, v1_val = self.evaluate([var0, var1])
return v0_val, v1_val
@test_util.run_deprecated_v1
def testEquivSparseGradientDescentwithoutRegularization(self):
with self.cached_session():
with ops.Graph().as_default(), self.cached_session():
val0, val1 = self.applyOptimizer(
proximal_gradient_descent.ProximalGradientDescentOptimizer(
3.0,
@ -190,23 +182,20 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
l2_regularization_strength=0.0),
is_sparse=True)
with self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0), is_sparse=True)
self.assertAllClose(val0, val2)
self.assertAllClose(val1, val3)
@test_util.run_deprecated_v1
def testEquivGradientDescentwithoutRegularization(self):
with self.cached_session():
with ops.Graph().as_default(), self.cached_session():
val0, val1 = self.applyOptimizer(
proximal_gradient_descent.ProximalGradientDescentOptimizer(
3.0,
l1_regularization_strength=0.0,
l2_regularization_strength=0.0))
with self.cached_session():
val2, val3 = self.applyOptimizer(
gradient_descent.GradientDescentOptimizer(3.0))