Update session_manager_test wrt run_v1_only annotation.
PiperOrigin-RevId: 321304049 Change-Id: I0dd2b297bac8b0257a5a1dcac7ff1124714ae531
This commit is contained in:
parent
000ee10461
commit
f6f6b66d50
@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
@ -39,6 +40,11 @@ from tensorflow.python.training import session_manager
|
||||
|
||||
class SessionManagerTest(test.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(SessionManagerTest, cls).setUpClass()
|
||||
variable_scope.disable_resource_variables()
|
||||
|
||||
def testPrepareSessionSucceeds(self):
|
||||
with ops.Graph().as_default():
|
||||
v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
|
||||
@ -81,7 +87,6 @@ class SessionManagerTest(test.TestCase):
|
||||
sess = sm.prepare_session("")
|
||||
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testPrepareSessionFails(self):
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
|
||||
checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
|
||||
@ -166,7 +171,6 @@ class SessionManagerTest(test.TestCase):
|
||||
sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
|
||||
self.assertEqual(1, sess.run(v))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testRecoverSession(self):
|
||||
# Create a checkpoint.
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
|
||||
@ -199,7 +203,6 @@ class SessionManagerTest(test.TestCase):
|
||||
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
|
||||
checkpoint_dir))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testWaitForSessionReturnsNoneAfterTimeout(self):
|
||||
with ops.Graph().as_default():
|
||||
variables.VariableV1(1, name="v")
|
||||
@ -222,7 +225,6 @@ class SessionManagerTest(test.TestCase):
|
||||
variables.global_variables()),
|
||||
local_init_op=None)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testRecoverSessionWithReadyForLocalInitOp(self):
|
||||
# Create a checkpoint.
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(),
|
||||
@ -276,7 +278,6 @@ class SessionManagerTest(test.TestCase):
|
||||
self.assertEqual(1, sess.run(v))
|
||||
self.assertEqual(1, sess.run(w))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
|
||||
# We use ready_for_local_init_op=report_uninitialized_variables(),
|
||||
# which causes recover_session to not run local_init_op, and to return
|
||||
@ -333,7 +334,6 @@ class SessionManagerTest(test.TestCase):
|
||||
sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
|
||||
self.assertEqual(1, sess.run(v))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testRecoverSessionNoChkptStillRunsLocalInitOp(self):
|
||||
# This test checks for backwards compatibility.
|
||||
# In particular, we continue to ensure that recover_session will execute
|
||||
@ -362,7 +362,6 @@ class SessionManagerTest(test.TestCase):
|
||||
sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
|
||||
self.assertEqual(1, sess.run(w))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testRecoverSessionFailsStillRunsLocalInitOp(self):
|
||||
# Create a checkpoint.
|
||||
checkpoint_dir = os.path.join(
|
||||
@ -406,7 +405,6 @@ class SessionManagerTest(test.TestCase):
|
||||
sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
|
||||
self.assertEqual(1, sess.run(w))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testWaitForSessionLocalInit(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
with ops.Graph().as_default() as graph:
|
||||
@ -458,7 +456,7 @@ class SessionManagerTest(test.TestCase):
|
||||
# because of overly restrictive ready_for_local_init_op
|
||||
sm.wait_for_session("", max_wait_secs=3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("Requires TF V1 variable behavior.")
|
||||
def testWaitForSessionInsufficientReadyForLocalInitCheck(self):
|
||||
with ops.Graph().as_default() as graph:
|
||||
v = variables.VariableV1(1, name="v")
|
||||
@ -476,7 +474,6 @@ class SessionManagerTest(test.TestCase):
|
||||
"Session was not ready after waiting.*"):
|
||||
sm.wait_for_session("", max_wait_secs=3)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testPrepareSessionWithReadyForLocalInitOp(self):
|
||||
with ops.Graph().as_default():
|
||||
v = variables.VariableV1(1, name="v")
|
||||
@ -516,7 +513,7 @@ class SessionManagerTest(test.TestCase):
|
||||
self.assertEqual(1, sess.run(w))
|
||||
self.assertEqual(3, sess.run(x))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("Requires TF V1 variable behavior.")
|
||||
def testPrepareSessionWithPartialInitOp(self):
|
||||
with ops.Graph().as_default():
|
||||
v = variables.VariableV1(1, name="v")
|
||||
@ -583,7 +580,6 @@ class SessionManagerTest(test.TestCase):
|
||||
self.assertEqual(1, sess.run(w_res))
|
||||
self.assertEqual(3, sess.run(x_res))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testPrepareSessionWithCyclicInitializer(self):
|
||||
# Regression test. Previously Variable._build_initializer_expr would enter
|
||||
# into an infinite recursion when the variable's initial_value involved
|
||||
@ -657,7 +653,7 @@ class SessionManagerTest(test.TestCase):
|
||||
"Init operations did not make model ready for local_init"):
|
||||
sm2.prepare_session("", init_op=None)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("Requires TF V1 variable behavior.")
|
||||
def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
|
||||
with ops.Graph().as_default():
|
||||
v = variables.VariableV1(1, name="v")
|
||||
@ -680,6 +676,11 @@ class SessionManagerTest(test.TestCase):
|
||||
|
||||
class ObsoleteSessionManagerTest(test.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super(ObsoleteSessionManagerTest, cls).setUpClass()
|
||||
variable_scope.disable_resource_variables()
|
||||
|
||||
def testPrepareSessionSucceeds(self):
|
||||
with ops.Graph().as_default():
|
||||
v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
|
||||
@ -710,7 +711,6 @@ class ObsoleteSessionManagerTest(test.TestCase):
|
||||
"", init_fn=lambda sess: sess.run(v.initializer))
|
||||
self.assertAllClose([125], sess.run(v))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testPrepareSessionFails(self):
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
|
||||
checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
|
||||
@ -772,7 +772,6 @@ class ObsoleteSessionManagerTest(test.TestCase):
|
||||
variables.is_variable_initialized(
|
||||
sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testRecoverSession(self):
|
||||
# Create a checkpoint.
|
||||
checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
|
||||
@ -811,7 +810,6 @@ class ObsoleteSessionManagerTest(test.TestCase):
|
||||
sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
|
||||
self.assertEqual(1, sess.run(v))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
def testWaitForSessionReturnsNoneAfterTimeout(self):
|
||||
with ops.Graph().as_default():
|
||||
variables.VariableV1(1, name="v")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user