Remove run_deprecated_v1 in proximal_gradient_descent_test.
All the test case has been updated to run with graph context, since the API expect to run in v1 graph context. PiperOrigin-RevId: 321296244 Change-Id: I045584d2003febc0dd32b94abccc7382f07eb3d8
This commit is contained in:
parent
eb384a3ff6
commit
41339588d9
@ -23,7 +23,6 @@ import numpy as np
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -37,7 +36,7 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
|
||||
|
||||
def doTestProximalGradientDescentwithoutRegularization(
|
||||
self, use_resource=False):
|
||||
with self.cached_session() as sess:
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
if use_resource:
|
||||
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0])
|
||||
var1 = resource_variable_ops.ResourceVariable([0.0, 0.0])
|
||||
@ -63,17 +62,14 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
|
||||
self.assertAllClose(np.array([-0.9, -1.8]), v0_val)
|
||||
self.assertAllClose(np.array([-0.09, -0.18]), v1_val)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testProximalGradientDescentwithoutRegularization(self):
|
||||
self.doTestProximalGradientDescentwithoutRegularization(use_resource=False)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testResourceProximalGradientDescentwithoutRegularization(self):
|
||||
self.doTestProximalGradientDescentwithoutRegularization(use_resource=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testProximalGradientDescentwithoutRegularization2(self):
|
||||
with self.cached_session() as sess:
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
var0 = variables.Variable([1.0, 2.0])
|
||||
var1 = variables.Variable([4.0, 3.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
@ -96,10 +92,9 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
|
||||
self.assertAllClose(np.array([0.1, 0.2]), v0_val)
|
||||
self.assertAllClose(np.array([3.91, 2.82]), v1_val)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMinimizeSparseResourceVariable(self):
|
||||
for dtype in [dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
|
||||
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
|
||||
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
|
||||
@ -116,9 +111,8 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
|
||||
self.evaluate(var0),
|
||||
atol=0.01)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testProximalGradientDescentWithL1_L2(self):
|
||||
with self.cached_session() as sess:
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
var0 = variables.Variable([1.0, 2.0])
|
||||
var1 = variables.Variable([4.0, 3.0])
|
||||
grads0 = constant_op.constant([0.1, 0.2])
|
||||
@ -164,7 +158,6 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
|
||||
update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
|
||||
self.evaluate(variables.global_variables_initializer())
|
||||
|
||||
sess = ops.get_default_session()
|
||||
v0_val, v1_val = self.evaluate([var0, var1])
|
||||
if is_sparse:
|
||||
self.assertAllClose([[1.0], [2.0]], v0_val)
|
||||
@ -180,9 +173,8 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
|
||||
v0_val, v1_val = self.evaluate([var0, var1])
|
||||
return v0_val, v1_val
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEquivSparseGradientDescentwithoutRegularization(self):
|
||||
with self.cached_session():
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
val0, val1 = self.applyOptimizer(
|
||||
proximal_gradient_descent.ProximalGradientDescentOptimizer(
|
||||
3.0,
|
||||
@ -190,23 +182,20 @@ class ProximalGradientDescentOptimizerTest(test.TestCase):
|
||||
l2_regularization_strength=0.0),
|
||||
is_sparse=True)
|
||||
|
||||
with self.cached_session():
|
||||
val2, val3 = self.applyOptimizer(
|
||||
gradient_descent.GradientDescentOptimizer(3.0), is_sparse=True)
|
||||
|
||||
self.assertAllClose(val0, val2)
|
||||
self.assertAllClose(val1, val3)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testEquivGradientDescentwithoutRegularization(self):
|
||||
with self.cached_session():
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
val0, val1 = self.applyOptimizer(
|
||||
proximal_gradient_descent.ProximalGradientDescentOptimizer(
|
||||
3.0,
|
||||
l1_regularization_strength=0.0,
|
||||
l2_regularization_strength=0.0))
|
||||
|
||||
with self.cached_session():
|
||||
val2, val3 = self.applyOptimizer(
|
||||
gradient_descent.GradientDescentOptimizer(3.0))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user