Update v1 only test with proper reason.
Also fix existing warning wrt deprecated assertion methods. PiperOrigin-RevId: 314196442 Change-Id: Ifab24cb9519b093bcf41c39726ed5a4fe6350576
This commit is contained in:
parent
bfc2553173
commit
15626c4e8d
@ -110,7 +110,7 @@ class SupervisorTest(test.TestCase):
|
||||
with ops.Graph().as_default():
|
||||
my_op = constant_op.constant(1.0)
|
||||
sv = supervisor.Supervisor(logdir=logdir)
|
||||
with sv.managed_session("") as sess:
|
||||
with sv.managed_session(""):
|
||||
for _ in xrange(10):
|
||||
self.evaluate(my_op)
|
||||
# Supervisor has been stopped.
|
||||
@ -170,7 +170,7 @@ class SupervisorTest(test.TestCase):
|
||||
"", close_summary_writer=True, start_standard_services=False) as sess:
|
||||
sv.summary_computed(sess, sess.run(summ))
|
||||
event_paths = sorted(glob.glob(os.path.join(logdir, "event*")))
|
||||
self.assertEquals(2, len(event_paths))
|
||||
self.assertEqual(2, len(event_paths))
|
||||
# The two event files should have the same contents.
|
||||
for path in event_paths:
|
||||
# The summary iterator should report the summary once as we closed the
|
||||
@ -178,7 +178,7 @@ class SupervisorTest(test.TestCase):
|
||||
rr = summary_iterator.summary_iterator(path)
|
||||
# The first event should list the file_version.
|
||||
ev = next(rr)
|
||||
self.assertEquals("brain.Event:2", ev.file_version)
|
||||
self.assertEqual("brain.Event:2", ev.file_version)
|
||||
|
||||
# The next one has the graph and metagraph.
|
||||
ev = next(rr)
|
||||
@ -198,7 +198,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
# The next one should be a stop message if we closed cleanly.
|
||||
ev = next(rr)
|
||||
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||
|
||||
# We should be done.
|
||||
with self.assertRaises(StopIteration):
|
||||
@ -227,7 +227,7 @@ class SupervisorTest(test.TestCase):
|
||||
rr = _summary_iterator(logdir)
|
||||
# The first event should list the file_version.
|
||||
ev = next(rr)
|
||||
self.assertEquals("brain.Event:2", ev.file_version)
|
||||
self.assertEqual("brain.Event:2", ev.file_version)
|
||||
|
||||
# The next one has the graph.
|
||||
ev = next(rr)
|
||||
@ -360,7 +360,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
# The first event should list the file_version.
|
||||
ev = next(rr)
|
||||
self.assertEquals("brain.Event:2", ev.file_version)
|
||||
self.assertEqual("brain.Event:2", ev.file_version)
|
||||
|
||||
# The next one has the graph.
|
||||
ev = next(rr)
|
||||
@ -385,7 +385,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
# The next one should be a stop message if we closed cleanly.
|
||||
ev = next(rr)
|
||||
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||
|
||||
# We should be done.
|
||||
self.assertRaises(StopIteration, lambda: next(rr))
|
||||
@ -421,7 +421,7 @@ class SupervisorTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"):
|
||||
sv.summary_computed(sess, sess.run(summ))
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testLogdirButExplicitlyNoSummaryWriter(self):
|
||||
logdir = self._test_dir("explicit_no_summary_writer")
|
||||
with ops.Graph().as_default():
|
||||
@ -460,7 +460,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
# The first event should list the file_version.
|
||||
ev = next(rr)
|
||||
self.assertEquals("brain.Event:2", ev.file_version)
|
||||
self.assertEqual("brain.Event:2", ev.file_version)
|
||||
|
||||
# The next one has the graph.
|
||||
ev = next(rr)
|
||||
@ -486,7 +486,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
# The next one should be a stop message if we closed cleanly.
|
||||
ev = next(rr)
|
||||
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||
|
||||
# We should be done.
|
||||
self.assertRaises(StopIteration, lambda: next(rr))
|
||||
@ -507,7 +507,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv = supervisor.Supervisor(logdir="", session_manager=sm)
|
||||
sv.prepare_or_wait_for_session("")
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testInitOp(self):
|
||||
logdir = self._test_dir("default_init_op")
|
||||
with ops.Graph().as_default():
|
||||
@ -517,7 +517,7 @@ class SupervisorTest(test.TestCase):
|
||||
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
||||
sv.stop()
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testInitFn(self):
|
||||
logdir = self._test_dir("default_init_op")
|
||||
with ops.Graph().as_default():
|
||||
@ -531,7 +531,7 @@ class SupervisorTest(test.TestCase):
|
||||
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
||||
sv.stop()
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testInitOpWithFeedDict(self):
|
||||
logdir = self._test_dir("feed_dict_init_op")
|
||||
with ops.Graph().as_default():
|
||||
@ -545,7 +545,7 @@ class SupervisorTest(test.TestCase):
|
||||
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
||||
sv.stop()
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testReadyForLocalInitOp(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
logdir = self._test_dir("default_ready_for_local_init_op")
|
||||
@ -588,7 +588,7 @@ class SupervisorTest(test.TestCase):
|
||||
sv0.stop()
|
||||
sv1.stop()
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testReadyForLocalInitOpRestoreFromCheckpoint(self):
|
||||
server = server_lib.Server.create_local_server()
|
||||
logdir = self._test_dir("ready_for_local_init_op_restore")
|
||||
@ -660,7 +660,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
# This shouldn't add a variable to the VARIABLES collection responsible
|
||||
# for variables that are saved/restored from checkpoints.
|
||||
self.assertEquals(len(variables.global_variables()), 0)
|
||||
self.assertEqual(len(variables.global_variables()), 0)
|
||||
|
||||
# Suppress normal variable inits to make sure the local one is
|
||||
# initialized via local_init_op.
|
||||
@ -681,7 +681,7 @@ class SupervisorTest(test.TestCase):
|
||||
collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
||||
# This shouldn't add a variable to the VARIABLES collection responsible
|
||||
# for variables that are saved/restored from checkpoints.
|
||||
self.assertEquals(len(variables.global_variables()), 0)
|
||||
self.assertEqual(len(variables.global_variables()), 0)
|
||||
|
||||
# Suppress normal variable inits to make sure the local one is
|
||||
# initialized via local_init_op.
|
||||
@ -720,7 +720,7 @@ class SupervisorTest(test.TestCase):
|
||||
"Variables not initialized: w"):
|
||||
sv.prepare_or_wait_for_session(server.target)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testSetupFail(self):
|
||||
logdir = self._test_dir("setup_fail")
|
||||
with ops.Graph().as_default():
|
||||
@ -731,17 +731,17 @@ class SupervisorTest(test.TestCase):
|
||||
variables.VariableV1([1.0, 2.0, 3.0], name="v")
|
||||
supervisor.Supervisor(logdir=logdir, is_chief=False)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testDefaultGlobalStep(self):
|
||||
logdir = self._test_dir("default_global_step")
|
||||
with ops.Graph().as_default():
|
||||
variables.VariableV1(287, name="global_step")
|
||||
sv = supervisor.Supervisor(logdir=logdir)
|
||||
sess = sv.prepare_or_wait_for_session("")
|
||||
self.assertEquals(287, sess.run(sv.global_step))
|
||||
self.assertEqual(287, sess.run(sv.global_step))
|
||||
sv.stop()
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testRestoreFromMetaGraph(self):
|
||||
logdir = self._test_dir("restore_from_meta_graph")
|
||||
with ops.Graph().as_default():
|
||||
@ -756,14 +756,14 @@ class SupervisorTest(test.TestCase):
|
||||
self.assertIsNotNone(new_saver)
|
||||
sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver)
|
||||
sess = sv2.prepare_or_wait_for_session("")
|
||||
self.assertEquals(1, sess.run("v0:0"))
|
||||
self.assertEqual(1, sess.run("v0:0"))
|
||||
sv2.saver.save(sess, sv2.save_path)
|
||||
sv2.stop()
|
||||
|
||||
# This test is based on the fact that the standard services start
|
||||
# right away and get to run once before sv.stop() returns.
|
||||
# We still sleep a bit to make the test robust.
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testStandardServicesWithoutGlobalStep(self):
|
||||
logdir = self._test_dir("standard_services_without_global_step")
|
||||
# Create a checkpoint.
|
||||
@ -784,7 +784,7 @@ class SupervisorTest(test.TestCase):
|
||||
# There should be an event file with a version number.
|
||||
rr = _summary_iterator(logdir)
|
||||
ev = next(rr)
|
||||
self.assertEquals("brain.Event:2", ev.file_version)
|
||||
self.assertEqual("brain.Event:2", ev.file_version)
|
||||
ev = next(rr)
|
||||
ev_graph = graph_pb2.GraphDef()
|
||||
ev_graph.ParseFromString(ev.graph_def)
|
||||
@ -802,7 +802,7 @@ class SupervisorTest(test.TestCase):
|
||||
self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }", ev.summary)
|
||||
|
||||
ev = next(rr)
|
||||
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||
|
||||
self.assertRaises(StopIteration, lambda: next(rr))
|
||||
# There should be a checkpoint file with the variable "foo"
|
||||
@ -814,7 +814,7 @@ class SupervisorTest(test.TestCase):
|
||||
|
||||
# Same as testStandardServicesNoGlobalStep but with a global step.
|
||||
# We should get a summary about the step time.
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||
def testStandardServicesWithGlobalStep(self):
|
||||
logdir = self._test_dir("standard_services_with_global_step")
|
||||
# Create a checkpoint.
|
||||
@ -835,7 +835,7 @@ class SupervisorTest(test.TestCase):
|
||||
# There should be an event file with a version number.
|
||||
rr = _summary_iterator(logdir)
|
||||
ev = next(rr)
|
||||
self.assertEquals("brain.Event:2", ev.file_version)
|
||||
self.assertEqual("brain.Event:2", ev.file_version)
|
||||
ev = next(rr)
|
||||
ev_graph = graph_pb2.GraphDef()
|
||||
ev_graph.ParseFromString(ev.graph_def)
|
||||
@ -849,8 +849,8 @@ class SupervisorTest(test.TestCase):
|
||||
ev = next(rr)
|
||||
# It is actually undeterministic whether SessionLog.START gets written
|
||||
# before the summary or the checkpoint, but this works when run 10000 times.
|
||||
self.assertEquals(123, ev.step)
|
||||
self.assertEquals(event_pb2.SessionLog.START, ev.session_log.status)
|
||||
self.assertEqual(123, ev.step)
|
||||
self.assertEqual(event_pb2.SessionLog.START, ev.session_log.status)
|
||||
first = next(rr)
|
||||
second = next(rr)
|
||||
# It is undeterministic whether the value gets written before the checkpoint
|
||||
@ -858,17 +858,17 @@ class SupervisorTest(test.TestCase):
|
||||
if first.HasField("summary"):
|
||||
self.assertProtoEquals("""value { tag: 'global_step/sec'
|
||||
simple_value: 0.0 }""", first.summary)
|
||||
self.assertEquals(123, second.step)
|
||||
self.assertEquals(event_pb2.SessionLog.CHECKPOINT,
|
||||
second.session_log.status)
|
||||
self.assertEqual(123, second.step)
|
||||
self.assertEqual(event_pb2.SessionLog.CHECKPOINT,
|
||||
second.session_log.status)
|
||||
else:
|
||||
self.assertEquals(123, first.step)
|
||||
self.assertEquals(event_pb2.SessionLog.CHECKPOINT,
|
||||
first.session_log.status)
|
||||
self.assertEqual(123, first.step)
|
||||
self.assertEqual(event_pb2.SessionLog.CHECKPOINT,
|
||||
first.session_log.status)
|
||||
self.assertProtoEquals("""value { tag: 'global_step/sec'
|
||||
simple_value: 0.0 }""", second.summary)
|
||||
ev = next(rr)
|
||||
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||
self.assertRaises(StopIteration, lambda: next(rr))
|
||||
# There should be a checkpoint file with the variable "foo"
|
||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||
|
Loading…
Reference in New Issue
Block a user