fix skip test for adagrad_da_test
PiperOrigin-RevId: 321913659 Change-Id: I82c492522955911b9009bab006ea73cbffa1e404
This commit is contained in:
parent
bed7093667
commit
829256e314
@ -22,6 +22,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -35,7 +36,7 @@ class AdagradDAOptimizerTest(test.TestCase):
|
||||
|
||||
def doTestAdagradDAwithoutRegularizationBasic1(self, use_resource=False):
|
||||
for dtype in [dtypes.float64, dtypes.float32]:
|
||||
with self.cached_session() as sess:
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
global_step = variables.Variable(0, dtype=dtypes.int64)
|
||||
if use_resource:
|
||||
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
|
||||
@ -74,15 +75,13 @@ class AdagradDAOptimizerTest(test.TestCase):
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.094821, -0.189358]), v1_val)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAdagradDAWithoutRegularizationBasic1(self):
|
||||
self.doTestAdagradDAwithoutRegularizationBasic1()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testResourceAdagradDAWithoutRegularizationBasic1(self):
|
||||
self.doTestAdagradDAwithoutRegularizationBasic1(use_resource=True)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
@test_util.run_v1_only("loss needs to be callable in v2")
|
||||
def testMinimizeSparseResourceVariable(self):
|
||||
for dtype in [dtypes.float32, dtypes.float64]:
|
||||
with self.cached_session():
|
||||
@ -104,10 +103,9 @@ class AdagradDAOptimizerTest(test.TestCase):
|
||||
self.evaluate(var0),
|
||||
rtol=0.01)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAdagradDAwithoutRegularizationBasic2(self):
|
||||
for dtype in [dtypes.float64, dtypes.float32]:
|
||||
with self.cached_session() as sess:
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
global_step = variables.Variable(0, dtype=dtypes.int64)
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
||||
@ -137,10 +135,9 @@ class AdagradDAOptimizerTest(test.TestCase):
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.094821, -0.189358]), v1_val)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAdagradDAWithL1(self):
|
||||
for dtype in [dtypes.float64, dtypes.float32]:
|
||||
with self.cached_session() as sess:
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
global_step = variables.Variable(0, dtype=dtypes.int64)
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
||||
@ -170,10 +167,9 @@ class AdagradDAOptimizerTest(test.TestCase):
|
||||
self.assertAllCloseAccordingToType(
|
||||
np.array([-0.085339, -0.17989]), v1_val)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def testAdagradDAWithL1_L2(self):
|
||||
for dtype in [dtypes.float64, dtypes.float32]:
|
||||
with self.cached_session() as sess:
|
||||
with ops.Graph().as_default(), self.cached_session():
|
||||
global_step = variables.Variable(0, dtype=dtypes.int64)
|
||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
||||
|
Loading…
Reference in New Issue
Block a user