STT-tensorflow/tensorflow/python/training/supervisor_test.py
Gaurav Jain f618ab4955 Move away from deprecated asserts
- assertEquals -> assertEqual
- assertRaisesRegexp -> assertRegexpMatches
- assertRegexpMatches -> assertRegex

PiperOrigin-RevId: 319118081
Change-Id: Ieb457128522920ab55d6b69a7f244ab798a7d689
2020-06-30 16:10:22 -07:00

925 lines
35 KiB
Python

# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for supervisor.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
import shutil
import time
import uuid
from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.util import event_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.summary import summary_iterator
from tensorflow.python.summary.writer import writer
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import server_lib
from tensorflow.python.training import session_manager as session_manager_lib
from tensorflow.python.training import supervisor
def _summary_iterator(test_dir):
"""Reads events from test_dir/events.
Args:
test_dir: Name of the test directory.
Returns:
A summary_iterator
"""
event_paths = sorted(glob.glob(os.path.join(test_dir, "event*")))
return summary_iterator.summary_iterator(event_paths[-1])
class SupervisorTest(test.TestCase):
def _test_dir(self, test_name):
test_dir = os.path.join(self.get_temp_dir(), test_name)
if os.path.exists(test_dir):
shutil.rmtree(test_dir)
return test_dir
def _wait_for_glob(self, pattern, timeout_secs, for_checkpoint=True):
"""Wait for a checkpoint file to appear.
Args:
pattern: A string.
timeout_secs: How long to wait for in seconds.
for_checkpoint: whether we're globbing for checkpoints.
"""
end_time = time.time() + timeout_secs
while time.time() < end_time:
if for_checkpoint:
if checkpoint_management.checkpoint_exists(pattern):
return
else:
if len(gfile.Glob(pattern)) >= 1:
return
time.sleep(0.05)
self.assertFalse(True, "Glob never matched any file: %s" % pattern)
# This test does not test much.
def testBasics(self):
logdir = self._test_dir("basics")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
for _ in xrange(10):
self.evaluate(my_op)
sess.close()
sv.stop()
def testManagedSession(self):
logdir = self._test_dir("managed_session")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
with sv.managed_session(""):
for _ in xrange(10):
self.evaluate(my_op)
# Supervisor has been stopped.
self.assertTrue(sv.should_stop())
def testManagedSessionUserError(self):
logdir = self._test_dir("managed_user_error")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
last_step = None
with self.assertRaisesRegex(RuntimeError, "failing here"):
with sv.managed_session("") as sess:
for step in xrange(10):
last_step = step
if step == 1:
raise RuntimeError("failing here")
else:
self.evaluate(my_op)
# Supervisor has been stopped.
self.assertTrue(sv.should_stop())
self.assertEqual(1, last_step)
def testManagedSessionIgnoreOutOfRangeError(self):
logdir = self._test_dir("managed_out_of_range")
with ops.Graph().as_default():
my_op = constant_op.constant(1.0)
sv = supervisor.Supervisor(logdir=logdir)
last_step = None
with sv.managed_session("") as sess:
for step in xrange(10):
last_step = step
if step == 3:
raise errors_impl.OutOfRangeError(my_op.op.node_def, my_op.op,
"all done")
else:
self.evaluate(my_op)
# Supervisor has been stopped. OutOfRangeError was not thrown.
self.assertTrue(sv.should_stop())
self.assertEqual(3, last_step)
def testManagedSessionDoNotKeepSummaryWriter(self):
logdir = self._test_dir("managed_not_keep_summary_writer")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
summary.scalar("c3", constant_op.constant(3))
summ = summary.merge_all()
sv = supervisor.Supervisor(logdir=logdir, summary_op=None)
with sv.managed_session(
"", close_summary_writer=True, start_standard_services=False) as sess:
sv.summary_computed(sess, sess.run(summ))
# Sleep 1.2s to make sure that the next event file has a different name
# than the current one.
time.sleep(1.2)
with sv.managed_session(
"", close_summary_writer=True, start_standard_services=False) as sess:
sv.summary_computed(sess, sess.run(summ))
event_paths = sorted(glob.glob(os.path.join(logdir, "event*")))
self.assertEqual(2, len(event_paths))
# The two event files should have the same contents.
for path in event_paths:
# The summary iterator should report the summary once as we closed the
# summary writer across the 2 sessions.
rr = summary_iterator.summary_iterator(path)
# The first event should list the file_version.
ev = next(rr)
self.assertEqual("brain.Event:2", ev.file_version)
# The next one has the graph and metagraph.
ev = next(rr)
self.assertTrue(ev.graph_def)
ev = next(rr)
self.assertTrue(ev.meta_graph_def)
# The next one should have the values from the summary.
# But only once.
ev = next(rr)
self.assertProtoEquals("""
value { tag: 'c1' simple_value: 1.0 }
value { tag: 'c2' simple_value: 2.0 }
value { tag: 'c3' simple_value: 3.0 }
""", ev.summary)
# The next one should be a stop message if we closed cleanly.
ev = next(rr)
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
# We should be done.
with self.assertRaises(StopIteration):
next(rr)
def testManagedSessionKeepSummaryWriter(self):
logdir = self._test_dir("managed_keep_summary_writer")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
summary.scalar("c3", constant_op.constant(3))
summ = summary.merge_all()
sv = supervisor.Supervisor(logdir=logdir)
with sv.managed_session(
"", close_summary_writer=False,
start_standard_services=False) as sess:
sv.summary_computed(sess, sess.run(summ))
with sv.managed_session(
"", close_summary_writer=False,
start_standard_services=False) as sess:
sv.summary_computed(sess, sess.run(summ))
# Now close the summary writer to flush the events.
sv.summary_writer.close()
# The summary iterator should report the summary twice as we reused
# the same summary writer across the 2 sessions.
rr = _summary_iterator(logdir)
# The first event should list the file_version.
ev = next(rr)
self.assertEqual("brain.Event:2", ev.file_version)
# The next one has the graph.
ev = next(rr)
self.assertTrue(ev.graph_def)
ev = next(rr)
self.assertTrue(ev.meta_graph_def)
# The next one should have the values from the summary.
ev = next(rr)
self.assertProtoEquals("""
value { tag: 'c1' simple_value: 1.0 }
value { tag: 'c2' simple_value: 2.0 }
value { tag: 'c3' simple_value: 3.0 }
""", ev.summary)
# The next one should also have the values from the summary.
ev = next(rr)
self.assertProtoEquals("""
value { tag: 'c1' simple_value: 1.0 }
value { tag: 'c2' simple_value: 2.0 }
value { tag: 'c3' simple_value: 3.0 }
""", ev.summary)
# We should be done.
self.assertRaises(StopIteration, lambda: next(rr))
def _csv_data(self, logdir):
# Create a small data file with 3 CSV records.
data_path = os.path.join(logdir, "data.csv")
with open(data_path, "w") as f:
f.write("1,2,3\n")
f.write("4,5,6\n")
f.write("7,8,9\n")
return data_path
def testManagedEndOfInputOneQueue(self):
# Tests that the supervisor finishes without an error when using
# a fixed number of epochs, reading from a single queue.
logdir = self._test_dir("managed_end_of_input_one_queue")
os.makedirs(logdir)
data_path = self._csv_data(logdir)
with ops.Graph().as_default():
# Create an input pipeline that reads the file 3 times.
filename_queue = input_lib.string_input_producer(
[data_path], num_epochs=3)
reader = io_ops.TextLineReader()
_, csv = reader.read(filename_queue)
rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
sv = supervisor.Supervisor(logdir=logdir)
with sv.managed_session("") as sess:
while not sv.should_stop():
sess.run(rec)
def testManagedEndOfInputTwoQueues(self):
# Tests that the supervisor finishes without an error when using
# a fixed number of epochs, reading from two queues, the second
# one producing a batch from the first one.
logdir = self._test_dir("managed_end_of_input_two_queues")
os.makedirs(logdir)
data_path = self._csv_data(logdir)
with ops.Graph().as_default():
# Create an input pipeline that reads the file 3 times.
filename_queue = input_lib.string_input_producer(
[data_path], num_epochs=3)
reader = io_ops.TextLineReader()
_, csv = reader.read(filename_queue)
rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
sv = supervisor.Supervisor(logdir=logdir)
with sv.managed_session("") as sess:
while not sv.should_stop():
sess.run(shuff_rec)
def testManagedMainErrorTwoQueues(self):
# Tests that the supervisor correctly raises a main loop
# error even when using multiple queues for input.
logdir = self._test_dir("managed_main_error_two_queues")
os.makedirs(logdir)
data_path = self._csv_data(logdir)
with self.assertRaisesRegex(RuntimeError, "fail at step 3"):
with ops.Graph().as_default():
# Create an input pipeline that reads the file 3 times.
filename_queue = input_lib.string_input_producer(
[data_path], num_epochs=3)
reader = io_ops.TextLineReader()
_, csv = reader.read(filename_queue)
rec = parsing_ops.decode_csv(csv, record_defaults=[[1], [1], [1]])
shuff_rec = input_lib.shuffle_batch(rec, 1, 6, 4)
sv = supervisor.Supervisor(logdir=logdir)
with sv.managed_session("") as sess:
for step in range(9):
if sv.should_stop():
break
elif step == 3:
raise RuntimeError("fail at step 3")
else:
sess.run(shuff_rec)
def testSessionConfig(self):
logdir = self._test_dir("session_config")
with ops.Graph().as_default():
with ops.device("/cpu:1"):
my_op = constant_op.constant([1.0])
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session(
"", config=config_pb2.ConfigProto(device_count={"CPU": 2}))
for _ in xrange(10):
self.evaluate(my_op)
sess.close()
sv.stop()
def testChiefCanWriteEvents(self):
logdir = self._test_dir("can_write")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
summary.scalar("c3", constant_op.constant(3))
summ = summary.merge_all()
sv = supervisor.Supervisor(is_chief=True, logdir=logdir, summary_op=None)
meta_graph_def = meta_graph.create_meta_graph_def()
sess = sv.prepare_or_wait_for_session("")
sv.summary_computed(sess, sess.run(summ))
sess.close()
# Wait to make sure everything is written to file before stopping.
time.sleep(1)
sv.stop()
rr = _summary_iterator(logdir)
# The first event should list the file_version.
ev = next(rr)
self.assertEqual("brain.Event:2", ev.file_version)
# The next one has the graph.
ev = next(rr)
ev_graph = graph_pb2.GraphDef()
ev_graph.ParseFromString(ev.graph_def)
self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
# Stored MetaGraphDef
ev = next(rr)
ev_meta_graph = meta_graph_pb2.MetaGraphDef()
ev_meta_graph.ParseFromString(ev.meta_graph_def)
self.assertProtoEquals(meta_graph_def, ev_meta_graph)
self.assertProtoEquals(
sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
# The next one should have the values from the summary.
ev = next(rr)
self.assertProtoEquals("""
value { tag: 'c1' simple_value: 1.0 }
value { tag: 'c2' simple_value: 2.0 }
value { tag: 'c3' simple_value: 3.0 }
""", ev.summary)
# The next one should be a stop message if we closed cleanly.
ev = next(rr)
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
# We should be done.
self.assertRaises(StopIteration, lambda: next(rr))
def testNonChiefCannotWriteEvents(self):
def _summary_computed():
with ops.Graph().as_default():
sv = supervisor.Supervisor(is_chief=False)
sess = sv.prepare_or_wait_for_session("")
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
summ = summary.merge_all()
sv.summary_computed(sess, sess.run(summ))
def _start_standard_services():
with ops.Graph().as_default():
sv = supervisor.Supervisor(is_chief=False)
sess = sv.prepare_or_wait_for_session("")
sv.start_standard_services(sess)
self.assertRaises(RuntimeError, _summary_computed)
self.assertRaises(RuntimeError, _start_standard_services)
def testNoLogdirButWantSummary(self):
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
summary.scalar("c3", constant_op.constant(3))
summ = summary.merge_all()
sv = supervisor.Supervisor(logdir="", summary_op=None)
sess = sv.prepare_or_wait_for_session("")
with self.assertRaisesRegex(RuntimeError, "requires a summary writer"):
sv.summary_computed(sess, sess.run(summ))
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testLogdirButExplicitlyNoSummaryWriter(self):
logdir = self._test_dir("explicit_no_summary_writer")
with ops.Graph().as_default():
variables.VariableV1([1.0], name="foo")
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
summary.scalar("c3", constant_op.constant(3))
summ = summary.merge_all()
sv = supervisor.Supervisor(logdir=logdir, summary_writer=None)
sess = sv.prepare_or_wait_for_session("")
# Check that a checkpoint is still be generated.
self._wait_for_glob(sv.save_path, 3.0)
# Check that we cannot write a summary
with self.assertRaisesRegex(RuntimeError, "requires a summary writer"):
sv.summary_computed(sess, sess.run(summ))
def testNoLogdirButExplicitSummaryWriter(self):
logdir = self._test_dir("explicit_summary_writer")
with ops.Graph().as_default():
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
summary.scalar("c3", constant_op.constant(3))
summ = summary.merge_all()
sw = writer.FileWriter(logdir)
sv = supervisor.Supervisor(logdir="", summary_op=None, summary_writer=sw)
meta_graph_def = meta_graph.create_meta_graph_def()
sess = sv.prepare_or_wait_for_session("")
sv.summary_computed(sess, sess.run(summ))
sess.close()
# Wait to make sure everything is written to file before stopping.
time.sleep(1)
sv.stop()
# Check the summary was written to 'logdir'
rr = _summary_iterator(logdir)
# The first event should list the file_version.
ev = next(rr)
self.assertEqual("brain.Event:2", ev.file_version)
# The next one has the graph.
ev = next(rr)
ev_graph = graph_pb2.GraphDef()
ev_graph.ParseFromString(ev.graph_def)
self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
# Stored MetaGraphDef
ev = next(rr)
ev_meta_graph = meta_graph_pb2.MetaGraphDef()
ev_meta_graph.ParseFromString(ev.meta_graph_def)
self.assertProtoEquals(meta_graph_def, ev_meta_graph)
self.assertProtoEquals(
sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
# The next one should have the values from the summary.
ev = next(rr)
self.assertProtoEquals("""
value { tag: 'c1' simple_value: 1.0 }
value { tag: 'c2' simple_value: 2.0 }
value { tag: 'c3' simple_value: 3.0 }
""", ev.summary)
# The next one should be a stop message if we closed cleanly.
ev = next(rr)
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
# We should be done.
self.assertRaises(StopIteration, lambda: next(rr))
def testNoLogdirSucceeds(self):
with ops.Graph().as_default():
variables.VariableV1([1.0, 2.0, 3.0])
sv = supervisor.Supervisor(logdir="", summary_op=None)
sess = sv.prepare_or_wait_for_session("")
sess.close()
sv.stop()
def testUseSessionManager(self):
with ops.Graph().as_default():
variables.VariableV1([1.0, 2.0, 3.0])
sm = session_manager_lib.SessionManager()
# Pass in session_manager. The additional init_op is ignored.
sv = supervisor.Supervisor(logdir="", session_manager=sm)
sv.prepare_or_wait_for_session("")
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testInitOp(self):
logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
v = variables.VariableV1([1.0, 2.0, 3.0])
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
sv.stop()
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testInitFn(self):
logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
v = variables.VariableV1([1.0, 2.0, 3.0])
def _init_fn(sess):
sess.run(v.initializer)
sv = supervisor.Supervisor(logdir=logdir, init_op=None, init_fn=_init_fn)
sess = sv.prepare_or_wait_for_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
sv.stop()
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testInitOpWithFeedDict(self):
logdir = self._test_dir("feed_dict_init_op")
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
v = variables.VariableV1(p, name="v")
sv = supervisor.Supervisor(
logdir=logdir,
init_op=variables.global_variables_initializer(),
init_feed_dict={p: [1.0, 2.0, 3.0]})
sess = sv.prepare_or_wait_for_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
sv.stop()
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testReadyForLocalInitOp(self):
server = server_lib.Server.create_local_server()
logdir = self._test_dir("default_ready_for_local_init_op")
uid = uuid.uuid4().hex
def get_session(is_chief):
g = ops.Graph()
with g.as_default():
with ops.device("/job:localhost"):
v = variables.VariableV1(
1, name="default_ready_for_local_init_op_v_" + str(uid))
vadd = v.assign_add(1)
w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="default_ready_for_local_init_op_w_" + str(uid))
ready_for_local_init_op = variables.report_uninitialized_variables(
variables.global_variables())
sv = supervisor.Supervisor(
logdir=logdir,
is_chief=is_chief,
graph=g,
recovery_wait_secs=1,
init_op=v.initializer,
ready_for_local_init_op=ready_for_local_init_op)
sess = sv.prepare_or_wait_for_session(server.target)
return sv, sess, v, vadd, w
sv0, sess0, v0, _, w0 = get_session(True)
sv1, sess1, _, vadd1, w1 = get_session(False)
self.assertEqual(1, sess0.run(w0))
self.assertEqual(2, sess1.run(vadd1))
self.assertEqual(1, sess1.run(w1))
self.assertEqual(2, sess0.run(v0))
sv0.stop()
sv1.stop()
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testReadyForLocalInitOpRestoreFromCheckpoint(self):
server = server_lib.Server.create_local_server()
logdir = self._test_dir("ready_for_local_init_op_restore")
uid = uuid.uuid4().hex
# Create a checkpoint.
with ops.Graph().as_default():
v = variables.VariableV1(
10.0, name="ready_for_local_init_op_restore_v_" + str(uid))
summary.scalar("ready_for_local_init_op_restore_v_" + str(uid), v)
sv = supervisor.Supervisor(logdir=logdir)
sv.prepare_or_wait_for_session(server.target)
save_path = sv.save_path
self._wait_for_glob(save_path, 3.0)
self._wait_for_glob(
os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False)
# Wait to make sure everything is written to file before stopping.
time.sleep(1)
sv.stop()
def get_session(is_chief):
g = ops.Graph()
with g.as_default():
with ops.device("/job:localhost"):
v = variables.VariableV1(
1.0, name="ready_for_local_init_op_restore_v_" + str(uid))
vadd = v.assign_add(1)
w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="ready_for_local_init_op_restore_w_" + str(uid))
ready_for_local_init_op = variables.report_uninitialized_variables(
variables.global_variables())
sv = supervisor.Supervisor(
logdir=logdir,
is_chief=is_chief,
graph=g,
recovery_wait_secs=1,
ready_for_local_init_op=ready_for_local_init_op)
sess = sv.prepare_or_wait_for_session(server.target)
return sv, sess, v, vadd, w
sv0, sess0, v0, _, w0 = get_session(True)
sv1, sess1, _, vadd1, w1 = get_session(False)
self.assertEqual(10, sess0.run(w0))
self.assertEqual(11, sess1.run(vadd1))
self.assertEqual(10, sess1.run(w1))
self.assertEqual(11, sess0.run(v0))
sv0.stop()
sv1.stop()
def testLocalInitOp(self):
logdir = self._test_dir("default_local_init_op")
with ops.Graph().as_default():
# A local variable.
v = variables.VariableV1(
[1.0, 2.0, 3.0],
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
# An entity which is initialized through a TABLE_INITIALIZER.
w = variables.VariableV1([4, 5, 6], trainable=False, collections=[])
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, w.initializer)
# This shouldn't add a variable to the VARIABLES collection responsible
# for variables that are saved/restored from checkpoints.
self.assertEqual(len(variables.global_variables()), 0)
# Suppress normal variable inits to make sure the local one is
# initialized via local_init_op.
sv = supervisor.Supervisor(logdir=logdir, init_op=None)
sess = sv.prepare_or_wait_for_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
self.assertAllClose([4, 5, 6], sess.run(w))
sv.stop()
def testLocalInitOpForNonChief(self):
logdir = self._test_dir("default_local_init_op_non_chief")
with ops.Graph().as_default():
with ops.device("/job:localhost"):
# A local variable.
v = variables.VariableV1(
[1.0, 2.0, 3.0],
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
# This shouldn't add a variable to the VARIABLES collection responsible
# for variables that are saved/restored from checkpoints.
self.assertEqual(len(variables.global_variables()), 0)
# Suppress normal variable inits to make sure the local one is
# initialized via local_init_op.
sv = supervisor.Supervisor(logdir=logdir, init_op=None, is_chief=False)
sess = sv.prepare_or_wait_for_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
sv.stop()
def testInitOpFails(self):
server = server_lib.Server.create_local_server()
logdir = self._test_dir("default_init_op_fails")
with ops.Graph().as_default():
v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
variables.VariableV1([4.0, 5.0, 6.0], name="w")
# w will not be initialized.
sv = supervisor.Supervisor(logdir=logdir, init_op=v.initializer)
with self.assertRaisesRegex(RuntimeError, "Variables not initialized: w"):
sv.prepare_or_wait_for_session(server.target)
def testInitOpFailsForTransientVariable(self):
server = server_lib.Server.create_local_server()
logdir = self._test_dir("default_init_op_fails_for_local_variable")
with ops.Graph().as_default():
v = variables.VariableV1(
[1.0, 2.0, 3.0],
name="v",
collections=[ops.GraphKeys.LOCAL_VARIABLES])
variables.VariableV1(
[1.0, 2.0, 3.0],
name="w",
collections=[ops.GraphKeys.LOCAL_VARIABLES])
# w will not be initialized.
sv = supervisor.Supervisor(logdir=logdir, local_init_op=v.initializer)
with self.assertRaisesRegex(RuntimeError, "Variables not initialized: w"):
sv.prepare_or_wait_for_session(server.target)
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testSetupFail(self):
logdir = self._test_dir("setup_fail")
with ops.Graph().as_default():
variables.VariableV1([1.0, 2.0, 3.0], name="v")
with self.assertRaisesRegex(ValueError, "must have their device set"):
supervisor.Supervisor(logdir=logdir, is_chief=False)
with ops.Graph().as_default(), ops.device("/job:ps"):
variables.VariableV1([1.0, 2.0, 3.0], name="v")
supervisor.Supervisor(logdir=logdir, is_chief=False)
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testDefaultGlobalStep(self):
logdir = self._test_dir("default_global_step")
with ops.Graph().as_default():
variables.VariableV1(287, name="global_step")
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
self.assertEqual(287, sess.run(sv.global_step))
sv.stop()
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testRestoreFromMetaGraph(self):
logdir = self._test_dir("restore_from_meta_graph")
with ops.Graph().as_default():
variables.VariableV1(1, name="v0")
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
filename = sv.saver.save(sess, sv.save_path)
sv.stop()
# Create a new Graph and Supervisor and recover.
with ops.Graph().as_default():
new_saver = saver_lib.import_meta_graph(".".join([filename, "meta"]))
self.assertIsNotNone(new_saver)
sv2 = supervisor.Supervisor(logdir=logdir, saver=new_saver)
sess = sv2.prepare_or_wait_for_session("")
self.assertEqual(1, sess.run("v0:0"))
sv2.saver.save(sess, sv2.save_path)
sv2.stop()
# This test is based on the fact that the standard services start
# right away and get to run once before sv.stop() returns.
# We still sleep a bit to make the test robust.
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testStandardServicesWithoutGlobalStep(self):
logdir = self._test_dir("standard_services_without_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
v = variables.VariableV1([1.0], name="foo")
summary.scalar("v", v[0])
sv = supervisor.Supervisor(logdir=logdir)
meta_graph_def = meta_graph.create_meta_graph_def(
saver_def=sv.saver.saver_def)
sess = sv.prepare_or_wait_for_session("")
save_path = sv.save_path
self._wait_for_glob(save_path, 3.0)
self._wait_for_glob(
os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False)
# Wait to make sure everything is written to file before stopping.
time.sleep(1)
sv.stop()
# There should be an event file with a version number.
rr = _summary_iterator(logdir)
ev = next(rr)
self.assertEqual("brain.Event:2", ev.file_version)
ev = next(rr)
ev_graph = graph_pb2.GraphDef()
ev_graph.ParseFromString(ev.graph_def)
self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
# Stored MetaGraphDef
ev = next(rr)
ev_meta_graph = meta_graph_pb2.MetaGraphDef()
ev_meta_graph.ParseFromString(ev.meta_graph_def)
self.assertProtoEquals(meta_graph_def, ev_meta_graph)
self.assertProtoEquals(
sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
ev = next(rr)
self.assertProtoEquals("value { tag: 'v' simple_value: 1.0 }", ev.summary)
ev = next(rr)
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.VariableV1([10.10], name="foo")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(1.0, self.evaluate(v)[0])
# Same as testStandardServicesNoGlobalStep but with a global step.
# We should get a summary about the step time.
@test_util.run_v1_only("train.Supervisor is for v1 only")
def testStandardServicesWithGlobalStep(self):
logdir = self._test_dir("standard_services_with_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
v = variables.VariableV1([123], name="global_step")
sv = supervisor.Supervisor(logdir=logdir)
meta_graph_def = meta_graph.create_meta_graph_def(
saver_def=sv.saver.saver_def)
sess = sv.prepare_or_wait_for_session("")
# This is where the checkpoint will appear, with step number 123.
save_path = "%s-123" % sv.save_path
self._wait_for_glob(save_path, 3.0)
self._wait_for_glob(
os.path.join(logdir, "*events*"), 3.0, for_checkpoint=False)
# Wait to make sure everything is written to file before stopping.
time.sleep(1)
sv.stop()
# There should be an event file with a version number.
rr = _summary_iterator(logdir)
ev = next(rr)
self.assertEqual("brain.Event:2", ev.file_version)
ev = next(rr)
ev_graph = graph_pb2.GraphDef()
ev_graph.ParseFromString(ev.graph_def)
self.assertProtoEquals(sess.graph.as_graph_def(add_shapes=True), ev_graph)
ev = next(rr)
ev_meta_graph = meta_graph_pb2.MetaGraphDef()
ev_meta_graph.ParseFromString(ev.meta_graph_def)
self.assertProtoEquals(meta_graph_def, ev_meta_graph)
self.assertProtoEquals(
sess.graph.as_graph_def(add_shapes=True), ev_meta_graph.graph_def)
ev = next(rr)
# It is actually undeterministic whether SessionLog.START gets written
# before the summary or the checkpoint, but this works when run 10000 times.
self.assertEqual(123, ev.step)
self.assertEqual(event_pb2.SessionLog.START, ev.session_log.status)
first = next(rr)
second = next(rr)
# It is undeterministic whether the value gets written before the checkpoint
# since they are on separate threads, so we check for both conditions.
if first.HasField("summary"):
self.assertProtoEquals("""value { tag: 'global_step/sec'
simple_value: 0.0 }""", first.summary)
self.assertEqual(123, second.step)
self.assertEqual(event_pb2.SessionLog.CHECKPOINT,
second.session_log.status)
else:
self.assertEqual(123, first.step)
self.assertEqual(event_pb2.SessionLog.CHECKPOINT,
first.session_log.status)
self.assertProtoEquals("""value { tag: 'global_step/sec'
simple_value: 0.0 }""", second.summary)
ev = next(rr)
self.assertEqual(event_pb2.SessionLog.STOP, ev.session_log.status)
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
with ops.Graph().as_default(), self.cached_session() as sess:
v = variables.VariableV1([-12], name="global_step")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(123, self.evaluate(v)[0])
def testNoQueueRunners(self):
with ops.Graph().as_default(), self.cached_session() as sess:
sv = supervisor.Supervisor(logdir=self._test_dir("no_queue_runners"))
self.assertEqual(0, len(sv.start_queue_runners(sess)))
sv.stop()
def testPrepareSessionAfterStopForChief(self):
logdir = self._test_dir("prepare_after_stop_chief")
with ops.Graph().as_default():
sv = supervisor.Supervisor(logdir=logdir, is_chief=True)
# Create a first session and then stop.
sess = sv.prepare_or_wait_for_session("")
sv.stop()
sess.close()
self.assertTrue(sv.should_stop())
# Now create a second session and test that we don't stay stopped, until
# we ask to stop again.
sess2 = sv.prepare_or_wait_for_session("")
self.assertFalse(sv.should_stop())
sv.stop()
sess2.close()
self.assertTrue(sv.should_stop())
def testPrepareSessionAfterStopForNonChief(self):
logdir = self._test_dir("prepare_after_stop_nonchief")
with ops.Graph().as_default():
sv = supervisor.Supervisor(logdir=logdir, is_chief=False)
# Create a first session and then stop.
sess = sv.prepare_or_wait_for_session("")
sv.stop()
sess.close()
self.assertTrue(sv.should_stop())
# Now create a second session and test that we don't stay stopped, until
# we ask to stop again.
sess2 = sv.prepare_or_wait_for_session("")
self.assertFalse(sv.should_stop())
sv.stop()
sess2.close()
self.assertTrue(sv.should_stop())
if __name__ == "__main__":
test.main()