Update session_manager_test wrt run_v1_only annotation.

PiperOrigin-RevId: 321304049
Change-Id: I0dd2b297bac8b0257a5a1dcac7ff1124714ae531
This commit is contained in:
Scott Zhu 2020-07-14 23:10:32 -07:00 committed by TensorFlower Gardener
parent 000ee10461
commit f6f6b66d50

View File

@ -28,6 +28,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@ -39,6 +40,11 @@ from tensorflow.python.training import session_manager
class SessionManagerTest(test.TestCase):
@classmethod
def setUpClass(cls):
super(SessionManagerTest, cls).setUpClass()
variable_scope.disable_resource_variables()
def testPrepareSessionSucceeds(self):
with ops.Graph().as_default():
v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
@ -81,7 +87,6 @@ class SessionManagerTest(test.TestCase):
sess = sm.prepare_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
@test_util.run_v1_only("b/120545219")
def testPrepareSessionFails(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
@ -166,7 +171,6 @@ class SessionManagerTest(test.TestCase):
sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
self.assertEqual(1, sess.run(v))
@test_util.run_v1_only("b/120545219")
def testRecoverSession(self):
# Create a checkpoint.
checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
@ -199,7 +203,6 @@ class SessionManagerTest(test.TestCase):
checkpoint_filename_with_path=checkpoint_management.latest_checkpoint(
checkpoint_dir))
@test_util.run_v1_only("b/120545219")
def testWaitForSessionReturnsNoneAfterTimeout(self):
with ops.Graph().as_default():
variables.VariableV1(1, name="v")
@ -222,7 +225,6 @@ class SessionManagerTest(test.TestCase):
variables.global_variables()),
local_init_op=None)
@test_util.run_v1_only("b/120545219")
def testRecoverSessionWithReadyForLocalInitOp(self):
# Create a checkpoint.
checkpoint_dir = os.path.join(self.get_temp_dir(),
@ -276,7 +278,6 @@ class SessionManagerTest(test.TestCase):
self.assertEqual(1, sess.run(v))
self.assertEqual(1, sess.run(w))
@test_util.run_v1_only("b/120545219")
def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
# We use ready_for_local_init_op=report_uninitialized_variables(),
# which causes recover_session to not run local_init_op, and to return
@ -333,7 +334,6 @@ class SessionManagerTest(test.TestCase):
sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
self.assertEqual(1, sess.run(v))
@test_util.run_v1_only("b/120545219")
def testRecoverSessionNoChkptStillRunsLocalInitOp(self):
# This test checks for backwards compatibility.
# In particular, we continue to ensure that recover_session will execute
@ -362,7 +362,6 @@ class SessionManagerTest(test.TestCase):
sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
self.assertEqual(1, sess.run(w))
@test_util.run_v1_only("b/120545219")
def testRecoverSessionFailsStillRunsLocalInitOp(self):
# Create a checkpoint.
checkpoint_dir = os.path.join(
@ -406,7 +405,6 @@ class SessionManagerTest(test.TestCase):
sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
self.assertEqual(1, sess.run(w))
@test_util.run_v1_only("b/120545219")
def testWaitForSessionLocalInit(self):
server = server_lib.Server.create_local_server()
with ops.Graph().as_default() as graph:
@ -458,7 +456,7 @@ class SessionManagerTest(test.TestCase):
# because of overly restrictive ready_for_local_init_op
sm.wait_for_session("", max_wait_secs=3)
@test_util.run_v1_only("b/120545219")
@test_util.run_v1_only("Requires TF V1 variable behavior.")
def testWaitForSessionInsufficientReadyForLocalInitCheck(self):
with ops.Graph().as_default() as graph:
v = variables.VariableV1(1, name="v")
@ -476,7 +474,6 @@ class SessionManagerTest(test.TestCase):
"Session was not ready after waiting.*"):
sm.wait_for_session("", max_wait_secs=3)
@test_util.run_v1_only("b/120545219")
def testPrepareSessionWithReadyForLocalInitOp(self):
with ops.Graph().as_default():
v = variables.VariableV1(1, name="v")
@ -516,7 +513,7 @@ class SessionManagerTest(test.TestCase):
self.assertEqual(1, sess.run(w))
self.assertEqual(3, sess.run(x))
@test_util.run_v1_only("b/120545219")
@test_util.run_v1_only("Requires TF V1 variable behavior.")
def testPrepareSessionWithPartialInitOp(self):
with ops.Graph().as_default():
v = variables.VariableV1(1, name="v")
@ -583,7 +580,6 @@ class SessionManagerTest(test.TestCase):
self.assertEqual(1, sess.run(w_res))
self.assertEqual(3, sess.run(x_res))
@test_util.run_v1_only("b/120545219")
def testPrepareSessionWithCyclicInitializer(self):
# Regression test. Previously Variable._build_initializer_expr would enter
# into an infinite recursion when the variable's initial_value involved
@ -657,7 +653,7 @@ class SessionManagerTest(test.TestCase):
"Init operations did not make model ready for local_init"):
sm2.prepare_session("", init_op=None)
@test_util.run_v1_only("b/120545219")
@test_util.run_v1_only("Requires TF V1 variable behavior.")
def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
with ops.Graph().as_default():
v = variables.VariableV1(1, name="v")
@ -680,6 +676,11 @@ class SessionManagerTest(test.TestCase):
class ObsoleteSessionManagerTest(test.TestCase):
@classmethod
def setUpClass(cls):
super(ObsoleteSessionManagerTest, cls).setUpClass()
variable_scope.disable_resource_variables()
def testPrepareSessionSucceeds(self):
with ops.Graph().as_default():
v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
@ -710,7 +711,6 @@ class ObsoleteSessionManagerTest(test.TestCase):
"", init_fn=lambda sess: sess.run(v.initializer))
self.assertAllClose([125], sess.run(v))
@test_util.run_v1_only("b/120545219")
def testPrepareSessionFails(self):
checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
@ -772,7 +772,6 @@ class ObsoleteSessionManagerTest(test.TestCase):
variables.is_variable_initialized(
sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
@test_util.run_v1_only("b/120545219")
def testRecoverSession(self):
# Create a checkpoint.
checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
@ -811,7 +810,6 @@ class ObsoleteSessionManagerTest(test.TestCase):
sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
self.assertEqual(1, sess.run(v))
@test_util.run_v1_only("b/120545219")
def testWaitForSessionReturnsNoneAfterTimeout(self):
with ops.Graph().as_default():
variables.VariableV1(1, name="v")