fix skip test for adagrad_da_test
PiperOrigin-RevId: 321913659 Change-Id: I82c492522955911b9009bab006ea73cbffa1e404
This commit is contained in:
parent
bed7093667
commit
829256e314
@ -22,6 +22,7 @@ import numpy as np
|
|||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -35,7 +36,7 @@ class AdagradDAOptimizerTest(test.TestCase):
|
|||||||
|
|
||||||
def doTestAdagradDAwithoutRegularizationBasic1(self, use_resource=False):
|
def doTestAdagradDAwithoutRegularizationBasic1(self, use_resource=False):
|
||||||
for dtype in [dtypes.float64, dtypes.float32]:
|
for dtype in [dtypes.float64, dtypes.float32]:
|
||||||
with self.cached_session() as sess:
|
with ops.Graph().as_default(), self.cached_session():
|
||||||
global_step = variables.Variable(0, dtype=dtypes.int64)
|
global_step = variables.Variable(0, dtype=dtypes.int64)
|
||||||
if use_resource:
|
if use_resource:
|
||||||
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
|
var0 = resource_variable_ops.ResourceVariable([0.0, 0.0], dtype=dtype)
|
||||||
@ -74,15 +75,13 @@ class AdagradDAOptimizerTest(test.TestCase):
|
|||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.094821, -0.189358]), v1_val)
|
np.array([-0.094821, -0.189358]), v1_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testAdagradDAWithoutRegularizationBasic1(self):
|
def testAdagradDAWithoutRegularizationBasic1(self):
|
||||||
self.doTestAdagradDAwithoutRegularizationBasic1()
|
self.doTestAdagradDAwithoutRegularizationBasic1()
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testResourceAdagradDAWithoutRegularizationBasic1(self):
|
def testResourceAdagradDAWithoutRegularizationBasic1(self):
|
||||||
self.doTestAdagradDAwithoutRegularizationBasic1(use_resource=True)
|
self.doTestAdagradDAwithoutRegularizationBasic1(use_resource=True)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
@test_util.run_v1_only("loss needs to be callable in v2")
|
||||||
def testMinimizeSparseResourceVariable(self):
|
def testMinimizeSparseResourceVariable(self):
|
||||||
for dtype in [dtypes.float32, dtypes.float64]:
|
for dtype in [dtypes.float32, dtypes.float64]:
|
||||||
with self.cached_session():
|
with self.cached_session():
|
||||||
@ -104,10 +103,9 @@ class AdagradDAOptimizerTest(test.TestCase):
|
|||||||
self.evaluate(var0),
|
self.evaluate(var0),
|
||||||
rtol=0.01)
|
rtol=0.01)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testAdagradDAwithoutRegularizationBasic2(self):
|
def testAdagradDAwithoutRegularizationBasic2(self):
|
||||||
for dtype in [dtypes.float64, dtypes.float32]:
|
for dtype in [dtypes.float64, dtypes.float32]:
|
||||||
with self.cached_session() as sess:
|
with ops.Graph().as_default(), self.cached_session():
|
||||||
global_step = variables.Variable(0, dtype=dtypes.int64)
|
global_step = variables.Variable(0, dtype=dtypes.int64)
|
||||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
||||||
@ -137,10 +135,9 @@ class AdagradDAOptimizerTest(test.TestCase):
|
|||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.094821, -0.189358]), v1_val)
|
np.array([-0.094821, -0.189358]), v1_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testAdagradDAWithL1(self):
|
def testAdagradDAWithL1(self):
|
||||||
for dtype in [dtypes.float64, dtypes.float32]:
|
for dtype in [dtypes.float64, dtypes.float32]:
|
||||||
with self.cached_session() as sess:
|
with ops.Graph().as_default(), self.cached_session():
|
||||||
global_step = variables.Variable(0, dtype=dtypes.int64)
|
global_step = variables.Variable(0, dtype=dtypes.int64)
|
||||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
||||||
@ -170,10 +167,9 @@ class AdagradDAOptimizerTest(test.TestCase):
|
|||||||
self.assertAllCloseAccordingToType(
|
self.assertAllCloseAccordingToType(
|
||||||
np.array([-0.085339, -0.17989]), v1_val)
|
np.array([-0.085339, -0.17989]), v1_val)
|
||||||
|
|
||||||
@test_util.run_deprecated_v1
|
|
||||||
def testAdagradDAWithL1_L2(self):
|
def testAdagradDAWithL1_L2(self):
|
||||||
for dtype in [dtypes.float64, dtypes.float32]:
|
for dtype in [dtypes.float64, dtypes.float32]:
|
||||||
with self.cached_session() as sess:
|
with ops.Graph().as_default(), self.cached_session():
|
||||||
global_step = variables.Variable(0, dtype=dtypes.int64)
|
global_step = variables.Variable(0, dtype=dtypes.int64)
|
||||||
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
var0 = variables.Variable([1.0, 2.0], dtype=dtype)
|
||||||
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
var1 = variables.Variable([4.0, 3.0], dtype=dtype)
|
||||||
|
Loading…
Reference in New Issue
Block a user