Update v1 only test with proper reason.
Also fix existing warning wrt deprecated assertion methods. PiperOrigin-RevId: 314196442 Change-Id: Ifab24cb9519b093bcf41c39726ed5a4fe6350576
This commit is contained in:
parent
bfc2553173
commit
15626c4e8d
@ -110,7 +110,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
my_op = constant_op.constant(1.0)
|
my_op = constant_op.constant(1.0)
|
||||||
sv = supervisor.Supervisor(logdir=logdir)
|
sv = supervisor.Supervisor(logdir=logdir)
|
||||||
with sv.managed_session("") as sess:
|
with sv.managed_session(""):
|
||||||
for _ in xrange(10):
|
for _ in xrange(10):
|
||||||
self.evaluate(my_op)
|
self.evaluate(my_op)
|
||||||
# Supervisor has been stopped.
|
# Supervisor has been stopped.
|
||||||
@ -170,7 +170,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
"", close_summary_writer=True, start_standard_services=False) as sess:
|
"", close_summary_writer=True, start_standard_services=False) as sess:
|
||||||
sv.summary_computed(sess, sess.run(summ))
|
sv.summary_computed(sess, sess.run(summ))
|
||||||
event_paths = sorted(glob.glob(os.path.join(logdir, "event*")))
|
event_paths = sorted(glob.glob(os.path.join(logdir, "event*")))
|
||||||
self.assertEquals(2, len(event_paths))
|
self.assertEqual(2, len(event_paths))
|
||||||
# The two event files should have the same contents.
|
# The two event files should have the same contents.
|
||||||
for path in event_paths:
|
for path in event_paths:
|
||||||
# The summary iterator should report the summary once as we closed the
|
# The summary iterator should report the summary once as we closed the
|
||||||
@ -178,7 +178,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
rr = summary_iterator.summary_iterator(path)
|
rr = summary_iterator.summary_iterator(path)
|
||||||
# The first event should list the file_version.
|
# The first event should list the file_version.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals("brain.Event:2", ev.file_version)
|
self.assertEqual("brain.Event:2", ev.file_version)
|
||||||
|
|
||||||
# The next one has the graph and metagraph.
|
# The next one has the graph and metagraph.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
@ -198,7 +198,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
# The next one should be a stop message if we closed cleanly.
|
# The next one should be a stop message if we closed cleanly.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
|
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||||
|
|
||||||
# We should be done.
|
# We should be done.
|
||||||
with self.assertRaises(StopIteration):
|
with self.assertRaises(StopIteration):
|
||||||
@ -227,7 +227,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
rr = _summary_iterator(logdir)
|
rr = _summary_iterator(logdir)
|
||||||
# The first event should list the file_version.
|
# The first event should list the file_version.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals("brain.Event:2", ev.file_version)
|
self.assertEqual("brain.Event:2", ev.file_version)
|
||||||
|
|
||||||
# The next one has the graph.
|
# The next one has the graph.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
@ -360,7 +360,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
# The first event should list the file_version.
|
# The first event should list the file_version.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals("brain.Event:2", ev.file_version)
|
self.assertEqual("brain.Event:2", ev.file_version)
|
||||||
|
|
||||||
# The next one has the graph.
|
# The next one has the graph.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
@ -385,7 +385,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
# The next one should be a stop message if we closed cleanly.
|
# The next one should be a stop message if we closed cleanly.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
|
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||||
|
|
||||||
# We should be done.
|
# We should be done.
|
||||||
self.assertRaises(StopIteration, lambda: next(rr))
|
self.assertRaises(StopIteration, lambda: next(rr))
|
||||||
@ -421,7 +421,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"):
|
with self.assertRaisesRegexp(RuntimeError, "requires a summary writer"):
|
||||||
sv.summary_computed(sess, sess.run(summ))
|
sv.summary_computed(sess, sess.run(summ))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testLogdirButExplicitlyNoSummaryWriter(self):
|
def testLogdirButExplicitlyNoSummaryWriter(self):
|
||||||
logdir = self._test_dir("explicit_no_summary_writer")
|
logdir = self._test_dir("explicit_no_summary_writer")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -460,7 +460,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
# The first event should list the file_version.
|
# The first event should list the file_version.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals("brain.Event:2", ev.file_version)
|
self.assertEqual("brain.Event:2", ev.file_version)
|
||||||
|
|
||||||
# The next one has the graph.
|
# The next one has the graph.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
@ -486,7 +486,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
# The next one should be a stop message if we closed cleanly.
|
# The next one should be a stop message if we closed cleanly.
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
|
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||||
|
|
||||||
# We should be done.
|
# We should be done.
|
||||||
self.assertRaises(StopIteration, lambda: next(rr))
|
self.assertRaises(StopIteration, lambda: next(rr))
|
||||||
@ -507,7 +507,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv = supervisor.Supervisor(logdir="", session_manager=sm)
|
sv = supervisor.Supervisor(logdir="", session_manager=sm)
|
||||||
sv.prepare_or_wait_for_session("")
|
sv.prepare_or_wait_for_session("")
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testInitOp(self):
|
def testInitOp(self):
|
||||||
logdir = self._test_dir("default_init_op")
|
logdir = self._test_dir("default_init_op")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -517,7 +517,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testInitFn(self):
|
def testInitFn(self):
|
||||||
logdir = self._test_dir("default_init_op")
|
logdir = self._test_dir("default_init_op")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -531,7 +531,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testInitOpWithFeedDict(self):
|
def testInitOpWithFeedDict(self):
|
||||||
logdir = self._test_dir("feed_dict_init_op")
|
logdir = self._test_dir("feed_dict_init_op")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -545,7 +545,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
|
||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testReadyForLocalInitOp(self):
|
def testReadyForLocalInitOp(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
logdir = self._test_dir("default_ready_for_local_init_op")
|
logdir = self._test_dir("default_ready_for_local_init_op")
|
||||||
@ -588,7 +588,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
sv0.stop()
|
sv0.stop()
|
||||||
sv1.stop()
|
sv1.stop()
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testReadyForLocalInitOpRestoreFromCheckpoint(self):
|
def testReadyForLocalInitOpRestoreFromCheckpoint(self):
|
||||||
server = server_lib.Server.create_local_server()
|
server = server_lib.Server.create_local_server()
|
||||||
logdir = self._test_dir("ready_for_local_init_op_restore")
|
logdir = self._test_dir("ready_for_local_init_op_restore")
|
||||||
@ -660,7 +660,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
# This shouldn't add a variable to the VARIABLES collection responsible
|
# This shouldn't add a variable to the VARIABLES collection responsible
|
||||||
# for variables that are saved/restored from checkpoints.
|
# for variables that are saved/restored from checkpoints.
|
||||||
self.assertEquals(len(variables.global_variables()), 0)
|
self.assertEqual(len(variables.global_variables()), 0)
|
||||||
|
|
||||||
# Suppress normal variable inits to make sure the local one is
|
# Suppress normal variable inits to make sure the local one is
|
||||||
# initialized via local_init_op.
|
# initialized via local_init_op.
|
||||||
@ -681,7 +681,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
collections=[ops.GraphKeys.LOCAL_VARIABLES])
|
||||||
# This shouldn't add a variable to the VARIABLES collection responsible
|
# This shouldn't add a variable to the VARIABLES collection responsible
|
||||||
# for variables that are saved/restored from checkpoints.
|
# for variables that are saved/restored from checkpoints.
|
||||||
self.assertEquals(len(variables.global_variables()), 0)
|
self.assertEqual(len(variables.global_variables()), 0)
|
||||||
|
|
||||||
# Suppress normal variable inits to make sure the local one is
|
# Suppress normal variable inits to make sure the local one is
|
||||||
# initialized via local_init_op.
|
# initialized via local_init_op.
|
||||||
@ -720,7 +720,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
"Variables not initialized: w"):
|
"Variables not initialized: w"):
|
||||||
sv.prepare_or_wait_for_session(server.target)
|
sv.prepare_or_wait_for_session(server.target)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testSetupFail(self):
|
def testSetupFail(self):
|
||||||
logdir = self._test_dir("setup_fail")
|
logdir = self._test_dir("setup_fail")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -731,17 +731,17 @@ class SupervisorTest(test.TestCase):
|
|||||||
variables.VariableV1([1.0, 2.0, 3.0], name="v")
|
variables.VariableV1([1.0, 2.0, 3.0], name="v")
|
||||||
supervisor.Supervisor(logdir=logdir, is_chief=False)
|
supervisor.Supervisor(logdir=logdir, is_chief=False)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testDefaultGlobalStep(self):
|
def testDefaultGlobalStep(self):
|
||||||
logdir = self._test_dir("default_global_step")
|
logdir = self._test_dir("default_global_step")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
variables.VariableV1(287, name="global_step")
|
variables.VariableV1(287, name="global_step")
|
||||||
sv = supervisor.Supervisor(logdir=logdir)
|
sv = supervisor.Supervisor(logdir=logdir)
|
||||||
sess = sv.prepare_or_wait_for_session("")
|
sess = sv.prepare_or_wait_for_session("")
|
||||||
self.assertEquals(287, sess.run(sv.global_step))
|
self.assertEqual(287, sess.run(sv.global_step))
|
||||||
sv.stop()
|
sv.stop()
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testRestoreFromMetaGraph(self):
|
def testRestoreFromMetaGraph(self):
|
||||||
logdir = self._test_dir("restore_from_meta_graph")
|
logdir = self._test_dir("restore_from_meta_graph")
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
@ -756,14 +756,14 @@ class SupervisorTest(test.TestCase):
|
|||||||
self.assertIsNotNone(new_saver)
|
self.assertIsNotNone(new_saver)
|
||||||
sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver)
|
sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver)
|
||||||
sess = sv2.prepare_or_wait_for_session("")
|
sess = sv2.prepare_or_wait_for_session("")
|
||||||
self.assertEquals(1, sess.run("v0:0"))
|
self.assertEqual(1, sess.run("v0:0"))
|
||||||
sv2.saver.save(sess, sv2.save_path)
|
sv2.saver.save(sess, sv2.save_path)
|
||||||
sv2.stop()
|
sv2.stop()
|
||||||
|
|
||||||
# This test is based on the fact that the standard services start
|
# This test is based on the fact that the standard services start
|
||||||
# right away and get to run once before sv.stop() returns.
|
# right away and get to run once before sv.stop() returns.
|
||||||
# We still sleep a bit to make the test robust.
|
# We still sleep a bit to make the test robust.
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testStandardServicesWithoutGlobalStep(self):
|
def testStandardServicesWithoutGlobalStep(self):
|
||||||
logdir = self._test_dir("standard_services_without_global_step")
|
logdir = self._test_dir("standard_services_without_global_step")
|
||||||
# Create a checkpoint.
|
# Create a checkpoint.
|
||||||
@ -784,7 +784,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
# There should be an event file with a version number.
|
# There should be an event file with a version number.
|
||||||
rr = _summary_iterator(logdir)
|
rr = _summary_iterator(logdir)
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals("brain.Event:2", ev.file_version)
|
self.assertEqual("brain.Event:2", ev.file_version)
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
ev_graph = graph_pb2.GraphDef()
|
ev_graph = graph_pb2.GraphDef()
|
||||||
ev_graph.ParseFromString(ev.graph_def)
|
ev_graph.ParseFromString(ev.graph_def)
|
||||||
@ -802,7 +802,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }", ev.summary)
|
self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }", ev.summary)
|
||||||
|
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
|
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||||
|
|
||||||
self.assertRaises(StopIteration, lambda: next(rr))
|
self.assertRaises(StopIteration, lambda: next(rr))
|
||||||
# There should be a checkpoint file with the variable "foo"
|
# There should be a checkpoint file with the variable "foo"
|
||||||
@ -814,7 +814,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
|
|
||||||
# Same as testStandardServicesNoGlobalStep but with a global step.
|
# Same as testStandardServicesNoGlobalStep but with a global step.
|
||||||
# We should get a summary about the step time.
|
# We should get a summary about the step time.
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("train.Supervisor is for v1 only")
|
||||||
def testStandardServicesWithGlobalStep(self):
|
def testStandardServicesWithGlobalStep(self):
|
||||||
logdir = self._test_dir("standard_services_with_global_step")
|
logdir = self._test_dir("standard_services_with_global_step")
|
||||||
# Create a checkpoint.
|
# Create a checkpoint.
|
||||||
@ -835,7 +835,7 @@ class SupervisorTest(test.TestCase):
|
|||||||
# There should be an event file with a version number.
|
# There should be an event file with a version number.
|
||||||
rr = _summary_iterator(logdir)
|
rr = _summary_iterator(logdir)
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals("brain.Event:2", ev.file_version)
|
self.assertEqual("brain.Event:2", ev.file_version)
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
ev_graph = graph_pb2.GraphDef()
|
ev_graph = graph_pb2.GraphDef()
|
||||||
ev_graph.ParseFromString(ev.graph_def)
|
ev_graph.ParseFromString(ev.graph_def)
|
||||||
@ -849,8 +849,8 @@ class SupervisorTest(test.TestCase):
|
|||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
# It is actually undeterministic whether SessionLog.START gets written
|
# It is actually undeterministic whether SessionLog.START gets written
|
||||||
# before the summary or the checkpoint, but this works when run 10000 times.
|
# before the summary or the checkpoint, but this works when run 10000 times.
|
||||||
self.assertEquals(123, ev.step)
|
self.assertEqual(123, ev.step)
|
||||||
self.assertEquals(event_pb2.SessionLog.START, ev.session_log.status)
|
self.assertEqual(event_pb2.SessionLog.START, ev.session_log.status)
|
||||||
first = next(rr)
|
first = next(rr)
|
||||||
second = next(rr)
|
second = next(rr)
|
||||||
# It is undeterministic whether the value gets written before the checkpoint
|
# It is undeterministic whether the value gets written before the checkpoint
|
||||||
@ -858,17 +858,17 @@ class SupervisorTest(test.TestCase):
|
|||||||
if first.HasField("summary"):
|
if first.HasField("summary"):
|
||||||
self.assertProtoEquals("""value { tag: 'global_step/sec'
|
self.assertProtoEquals("""value { tag: 'global_step/sec'
|
||||||
simple_value: 0.0 }""", first.summary)
|
simple_value: 0.0 }""", first.summary)
|
||||||
self.assertEquals(123, second.step)
|
self.assertEqual(123, second.step)
|
||||||
self.assertEquals(event_pb2.SessionLog.CHECKPOINT,
|
self.assertEqual(event_pb2.SessionLog.CHECKPOINT,
|
||||||
second.session_log.status)
|
second.session_log.status)
|
||||||
else:
|
else:
|
||||||
self.assertEquals(123, first.step)
|
self.assertEqual(123, first.step)
|
||||||
self.assertEquals(event_pb2.SessionLog.CHECKPOINT,
|
self.assertEqual(event_pb2.SessionLog.CHECKPOINT,
|
||||||
first.session_log.status)
|
first.session_log.status)
|
||||||
self.assertProtoEquals("""value { tag: 'global_step/sec'
|
self.assertProtoEquals("""value { tag: 'global_step/sec'
|
||||||
simple_value: 0.0 }""", second.summary)
|
simple_value: 0.0 }""", second.summary)
|
||||||
ev = next(rr)
|
ev = next(rr)
|
||||||
self.assertEquals(event_pb2.SessionLog.STOP, ev.session_log.status)
|
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
|
||||||
self.assertRaises(StopIteration, lambda: next(rr))
|
self.assertRaises(StopIteration, lambda: next(rr))
|
||||||
# There should be a checkpoint file with the variable "foo"
|
# There should be a checkpoint file with the variable "foo"
|
||||||
with ops.Graph().as_default(), self.cached_session() as sess:
|
with ops.Graph().as_default(), self.cached_session() as sess:
|
||||||
|
Loading…
Reference in New Issue
Block a user