Update v1 only training/gradient_descent_test with graph scope.
PiperOrigin-RevId: 320497973 Change-Id: I222336a46900a79e0c224ebe3f9fd6585cdc4f89
This commit is contained in:
parent
ad35731ffd
commit
9937a3c6c9
@ -24,7 +24,6 @@ from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
@ -36,10 +35,10 @@ from tensorflow.python.training import gradient_descent
|
||||
|
||||
class GradientDescentOptimizerTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasic(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# train.GradientDescentOptimizer is V1 only API.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
@ -60,10 +59,10 @@ class GradientDescentOptimizerTest(test.TestCase):
|
||||
self.evaluate(var1))
|
||||
self.assertEqual(0, len(optimizer.variables()))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasicResourceVariable(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# train.GradientDescentOptimizer is V1 only API.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
@ -86,10 +85,10 @@ class GradientDescentOptimizerTest(test.TestCase):
|
||||
self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
|
||||
self.evaluate(var1))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testBasicCallableParams(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# train.GradientDescentOptimizer is V1 only API.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
@ -113,10 +112,10 @@ class GradientDescentOptimizerTest(test.TestCase):
|
||||
self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
|
||||
self.evaluate(var1))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMinimizeResourceVariable(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# train.GradientDescentOptimizer is V1 only API.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
|
||||
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
|
||||
@ -140,10 +139,10 @@ class GradientDescentOptimizerTest(test.TestCase):
|
||||
[[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType([3.0 - np_grad], self.evaluate(var1))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testMinimizeSparseResourceVariable(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# train.GradientDescentOptimizer is V1 only API.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
|
||||
var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
|
||||
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
|
||||
@ -168,10 +167,10 @@ class GradientDescentOptimizerTest(test.TestCase):
|
||||
[[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], self.evaluate(var0))
|
||||
self.assertAllCloseAccordingToType([3.0 - np_grad], self.evaluate(var1))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testTensorLearningRate(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# train.GradientDescentOptimizer is V1 only API.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
|
||||
@ -191,10 +190,10 @@ class GradientDescentOptimizerTest(test.TestCase):
|
||||
self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
|
||||
self.evaluate(var1))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testGradWrtRef(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# train.GradientDescentOptimizer is V1 only API.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
opt = gradient_descent.GradientDescentOptimizer(3.0)
|
||||
values = [1.0, 3.0]
|
||||
vars_ = [variables.Variable([v], dtype=dtype) for v in values]
|
||||
@ -203,10 +202,10 @@ class GradientDescentOptimizerTest(test.TestCase):
|
||||
for grad, _ in grads_and_vars:
|
||||
self.assertAllCloseAccordingToType([1.0], self.evaluate(grad))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testWithGlobalStep(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# train.GradientDescentOptimizer is V1 only API.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
global_step = variables.Variable(0, trainable=False)
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||
var1 = variables.Variable([3.0, 4.0], dtype=dtype)
|
||||
@ -227,10 +226,10 @@ class GradientDescentOptimizerTest(test.TestCase):
|
||||
self.evaluate(var1))
|
||||
self.assertAllCloseAccordingToType(1, self.evaluate(global_step))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testSparseBasic(self):
|
||||
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
# train.GradientDescentOptimizer is V1 only API.
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
|
||||
var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
|
||||
grads0 = ops.IndexedSlices(
|
||||
|
Loading…
x
Reference in New Issue
Block a user