STT-tensorflow/tensorflow/python/training/monitored_session_test.py
2020-02-11 15:09:21 +09:00

2317 lines
93 KiB
Python

# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for monitored_session."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import glob
import os
import sys
import threading
import time
import traceback
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import debug_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_coordinator
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.summary import summary
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import coordinator
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import summary_io
from tensorflow.python.training import training_util
def latest_summaries(base_dir):
"""Parse summary events from latest event file in base_dir."""
file_paths = glob.glob(os.path.join(base_dir, 'events.*'))
file_path = sorted(file_paths)[-1] if file_paths else None
latest_events = summary_io.summary_iterator(file_path) if file_path else []
return [e for e in latest_events if e.HasField('summary')]
class ScaffoldTest(test.TestCase):
"""Scaffold tests."""
def test_nothing_created_before_finalize(self):
with ops.Graph().as_default():
scaffold = monitored_session.Scaffold()
self.assertEqual(None, scaffold.init_op)
self.assertEqual(None, scaffold.init_feed_dict)
self.assertEqual(None, scaffold.init_fn)
self.assertEqual(None, scaffold.ready_op)
self.assertEqual(None, scaffold.ready_for_local_init_op)
self.assertEqual(None, scaffold.local_init_op)
self.assertEqual(None, scaffold.saver)
def test_defaults_empty_graph(self):
with ops.Graph().as_default():
scaffold = monitored_session.Scaffold()
variables.VariableV1(1, name='my_var')
variables.VariableV1(
2, name='my_local_var', collections=[ops.GraphKeys.LOCAL_VARIABLES])
scaffold.finalize()
self.assertTrue(isinstance(scaffold.init_op, ops.Operation))
self.assertEqual(None, scaffold.init_feed_dict)
self.assertEqual(None, scaffold.init_fn)
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
self.assertEqual(None, scaffold.local_init_feed_dict)
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
with self.cached_session() as sess:
self.assertItemsEqual([b'my_var', b'my_local_var'],
sess.run(scaffold.ready_op))
self.assertItemsEqual([b'my_var'],
sess.run(scaffold.ready_for_local_init_op))
sess.run(scaffold.init_op)
self.assertEqual(0, len(sess.run(scaffold.ready_for_local_init_op)))
sess.run(scaffold.local_init_op)
self.assertEqual(0, len(sess.run(scaffold.ready_op)))
def test_defaults_no_variables(self):
with ops.Graph().as_default():
scaffold = monitored_session.Scaffold()
constant_op.constant(1, name='my_const')
scaffold.finalize()
self.assertTrue(isinstance(scaffold.init_op, ops.Operation))
self.assertEqual(None, scaffold.init_feed_dict)
self.assertEqual(None, scaffold.init_fn)
self.assertTrue(isinstance(scaffold.ready_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.ready_for_local_init_op, ops.Tensor))
self.assertTrue(isinstance(scaffold.local_init_op, ops.Operation))
self.assertEqual(None, scaffold.local_init_feed_dict)
self.assertTrue(isinstance(scaffold.saver, saver_lib.Saver))
def test_caches_values(self):
with ops.Graph().as_default():
variables.VariableV1([1])
scaffold1 = monitored_session.Scaffold()
scaffold1.finalize()
scaffold2 = monitored_session.Scaffold()
scaffold2.finalize()
self.assertEqual(scaffold1.init_op, scaffold2.init_op)
self.assertEqual(scaffold1.ready_op, scaffold2.ready_op)
self.assertEqual(scaffold1.ready_for_local_init_op,
scaffold2.ready_for_local_init_op)
self.assertEqual(scaffold1.local_init_op, scaffold2.local_init_op)
self.assertEqual(scaffold1.saver, scaffold2.saver)
def test_raise_error_if_more_than_one_cached_item(self):
with ops.Graph().as_default():
variables.VariableV1([1])
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
with self.assertRaisesRegexp(RuntimeError, 'More than one item'):
monitored_session.Scaffold().finalize()
def test_uses_passed_values(self):
with ops.Graph().as_default():
variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold = monitored_session.Scaffold(
init_op=2,
init_feed_dict=3,
init_fn=lambda scaffold, sess: 4,
ready_op=5,
ready_for_local_init_op=6,
local_init_op=7,
local_init_feed_dict=8,
saver=saver)
scaffold.finalize()
self.assertEqual(2, scaffold.init_op)
self.assertEqual(3, scaffold.init_feed_dict)
self.assertTrue(callable(scaffold.init_fn))
self.assertEqual(5, scaffold.ready_op)
self.assertEqual(6, scaffold.ready_for_local_init_op)
self.assertEqual(7, scaffold.local_init_op)
self.assertEqual(8, scaffold.local_init_feed_dict)
self.assertEqual(saver, scaffold.saver)
def test_graph_is_finalized(self):
with ops.Graph().as_default():
variables.VariableV1([1])
monitored_session.Scaffold().finalize()
with self.assertRaisesRegexp(RuntimeError,
'Graph is finalized and cannot be modified'):
constant_op.constant([0])
def test_new_scaffold_from_default_scaffold(self):
scaffold1 = monitored_session.Scaffold()
with ops.Graph().as_default():
variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold2 = monitored_session.Scaffold(
init_op=2,
init_feed_dict=3,
init_fn=lambda scaffold, sess: 4,
ready_op=5,
ready_for_local_init_op=6,
local_init_op=7,
local_init_feed_dict=8,
saver=saver,
copy_from_scaffold=scaffold1)
scaffold2.finalize()
self.assertEqual(2, scaffold2.init_op)
self.assertEqual(3, scaffold2.init_feed_dict)
self.assertTrue(callable(scaffold2.init_fn))
self.assertEqual(5, scaffold2.ready_op)
self.assertEqual(6, scaffold2.ready_for_local_init_op)
self.assertEqual(7, scaffold2.local_init_op)
self.assertEqual(8, scaffold2.local_init_feed_dict)
self.assertEqual(saver, scaffold2.saver)
def test_new_scaffold_from_existing_scaffold(self):
with ops.Graph().as_default():
variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold1 = monitored_session.Scaffold(
init_op=2,
init_feed_dict=3,
init_fn=lambda scaffold, sess: 4,
ready_op=5,
ready_for_local_init_op=6,
local_init_op=7,
local_init_feed_dict=8,
saver=saver)
scaffold2 = monitored_session.Scaffold(
init_op=4,
init_feed_dict=6,
init_fn=lambda scaffold, sess: 8,
ready_op=10,
ready_for_local_init_op=12,
local_init_op=14,
local_init_feed_dict=15,
saver=saver,
copy_from_scaffold=scaffold1)
scaffold2.finalize()
self.assertEqual(4, scaffold2.init_op)
self.assertEqual(6, scaffold2.init_feed_dict)
self.assertTrue(callable(scaffold2.init_fn))
self.assertEqual(10, scaffold2.ready_op)
self.assertEqual(12, scaffold2.ready_for_local_init_op)
self.assertEqual(14, scaffold2.local_init_op)
self.assertEqual(15, scaffold2.local_init_feed_dict)
self.assertEqual(saver, scaffold2.saver)
def test_copy_from_scaffold_is_scaffold(self):
with ops.Graph().as_default():
with self.assertRaisesRegexp(
TypeError, 'copy_from_scaffold is not a Scaffold instance'):
monitored_session.Scaffold(copy_from_scaffold=1)
def _test_dir(temp_dir, test_name):
"""Create an empty dir to use for tests.
Args:
temp_dir: Tmp directory path.
test_name: Name of the test.
Returns:
Absolute path to the test directory.
"""
test_dir = os.path.join(temp_dir, test_name)
if os.path.isdir(test_dir):
for f in glob.glob('%s/*' % test_dir):
os.remove(f)
else:
os.makedirs(test_dir)
return test_dir
class FakeHook(session_run_hook.SessionRunHook):
def __init__(self):
self.should_stop = False
self.request = None
self.call_counter = collections.Counter()
self.last_run_context = None
self.last_run_values = None
def begin(self):
self.call_counter['begin'] += 1
def after_create_session(self, session, coord): # pylint: disable=unused-argument
self.call_counter['after_create_session'] += 1
def before_run(self, run_context):
self.call_counter['before_run'] += 1
self.last_run_context = run_context
return self.request
def after_run(self, run_context, run_values):
self.call_counter['after_run'] += 1
self.last_run_values = run_values
if self.should_stop:
run_context.request_stop()
def end(self, session):
self.call_counter['end'] += 1
class MonitoredTrainingSessionTest(test.TestCase):
"""Tests MonitoredTrainingSession."""
def test_saving_restoring_checkpoint(self):
logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession(
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredTrainingSession(
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(2, session.run(gstep))
def test_save_checkpoint_steps(self):
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_steps')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession(
is_chief=True,
checkpoint_dir=logdir,
save_checkpoint_steps=100,
log_step_count_steps=10) as session:
for _ in range(100):
session.run(new_gstep)
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredTrainingSession(
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(100, session.run(gstep))
def test_save_checkpoint_secs(self):
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_secs')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession(
is_chief=True,
checkpoint_dir=logdir,
save_checkpoint_secs=0.1,
log_step_count_steps=10) as session:
session.run(new_gstep)
time.sleep(0.2)
for _ in range(10):
session.run(new_gstep)
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredTrainingSession(
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(11, session.run(gstep))
def test_summaries_steps(self):
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_steps')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2)
with monitored_session.MonitoredTrainingSession(
is_chief=True,
checkpoint_dir=logdir,
save_summaries_steps=100,
log_step_count_steps=10) as session:
for _ in range(101):
session.run(new_gstep)
summaries = latest_summaries(logdir)
tags = [s.summary.value[0].tag for s in summaries]
self.assertIn('my_summary_tag', tags)
self.assertIn('global_step/sec', tags)
def test_summaries_secs(self):
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_secs')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2)
with monitored_session.MonitoredTrainingSession(
is_chief=True,
checkpoint_dir=logdir,
save_summaries_steps=None,
save_summaries_secs=0.1,
log_step_count_steps=10) as session:
session.run(new_gstep)
time.sleep(0.2)
for _ in range(101):
session.run(new_gstep)
summaries = latest_summaries(logdir)
tags = [s.summary.value[0].tag for s in summaries]
self.assertIn('my_summary_tag', tags)
self.assertIn('global_step/sec', tags)
def test_custom_saving(self):
logdir = _test_dir(self.get_temp_dir(), 'test_saving_restoring_checkpoint')
fake_hook = FakeHook()
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession(
is_chief=True,
checkpoint_dir=logdir,
chief_only_hooks=[fake_hook],
save_checkpoint_secs=0) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
# Check whether custom hook called or not
self.assertEqual(1, fake_hook.call_counter['begin'])
# A restart will not find the checkpoint, since we didn't save.
with monitored_session.MonitoredTrainingSession(
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(0, session.run(gstep))
def test_save_graph_def(self):
logdir = _test_dir(self.get_temp_dir(), 'test_save_graph_def')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession(
is_chief=True,
checkpoint_dir=logdir,
save_checkpoint_steps=1,
save_graph_def=True) as session:
self.assertIn('graph.pbtxt', os.listdir(logdir))
self.assertLen(glob.glob(os.path.join(logdir, '*.meta')), 1)
session.run(new_gstep)
self.assertLen(glob.glob(os.path.join(logdir, '*.meta')), 2)
def test_save_graph_def_false(self):
logdir = _test_dir(self.get_temp_dir(), 'test_save_graph_def')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with monitored_session.MonitoredTrainingSession(
is_chief=True,
checkpoint_dir=logdir,
save_checkpoint_steps=1,
save_graph_def=False) as session:
self.assertNotIn('graph.pbtxt', os.listdir(logdir))
self.assertEmpty(glob.glob(os.path.join(logdir, '*.meta')))
session.run(new_gstep)
self.assertEmpty(glob.glob(os.path.join(logdir, '*.meta')))
class MockExtended(object):
def __init__(self, between_graph, should_init, should_checkpoint,
should_save_summary):
self.experimental_between_graph = between_graph
self.experimental_should_init = should_init
self.should_checkpoint = should_checkpoint
self.should_save_summary = should_save_summary
class MockStrategy(object):
def __init__(self,
between_graph=False,
should_init=True,
should_checkpoint=None,
should_save_summary=None):
self.extended = MockExtended(between_graph, should_init, should_checkpoint,
should_save_summary)
class MonitoredTrainingSessionWithDistributeCoordinatorTest(test.TestCase):
"""Test distribute coordinator controls summary saving and checkpointing."""
def test_summary_hook_enabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_save_summary=True), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_enabled')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_summaries_steps=100,
log_step_count_steps=10) as session:
for _ in range(101):
session.run(new_gstep)
summaries = latest_summaries(logdir)
tags = [s.summary.value[0].tag for s in summaries]
self.assertIn('my_summary_tag', tags)
self.assertIn('global_step/sec', tags)
def test_summary_hook_disabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_save_summary=False), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_summaries_disabled')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
summary.scalar('my_summary_tag', new_gstep * 2)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_summaries_steps=100,
log_step_count_steps=10) as session:
for _ in range(101):
session.run(new_gstep)
# No summary is saved.
summaries = latest_summaries(logdir)
self.assertEqual(len(summaries), 0)
def test_checkpoint_hook_enabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_checkpoint=True), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_enabled')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_checkpoint_steps=100,
log_step_count_steps=10) as session:
for _ in range(100):
session.run(new_gstep)
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredTrainingSession(
is_chief=True, checkpoint_dir=logdir) as session:
self.assertEqual(100, session.run(gstep))
def test_checkpoint_hook_disabled(self):
context = distribute_coordinator._WorkerContext(
MockStrategy(should_checkpoint=False), None, None, None)
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_checkpoint_steps=100,
log_step_count_steps=10) as session:
for _ in range(100):
session.run(new_gstep)
# No checkpoint is saved.
checkpoint = checkpoint_management.latest_checkpoint(logdir)
self.assertIsNone(checkpoint)
def test_checkpoint_hook_enable_on_non_chief_with_collective_ops(self):
strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy()
strategy.extended._is_chief = False
context = distribute_coordinator._WorkerContext(strategy, None, 'worker', 1)
logdir = _test_dir(self.get_temp_dir(), 'test_save_checkpoint_disabled')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
new_gstep = state_ops.assign_add(gstep, 1)
with context, monitored_session.MonitoredTrainingSession(
checkpoint_dir=logdir,
save_checkpoint_steps=100,
log_step_count_steps=10) as session:
for _ in range(100):
session.run(new_gstep)
# No checkpoint is saved.
checkpoint = checkpoint_management.latest_checkpoint(logdir)
self.assertIsNone(checkpoint)
# But saved to a temporary directory.
checkpoint = checkpoint_management.latest_checkpoint(
os.path.join(logdir, 'tmp_worker_1'))
self.assertIsNotNone(checkpoint)
class StopAtNSession(monitored_session._WrappedSession):
"""A wrapped session that stops at the N-th call to _check_stop."""
def __init__(self, sess, n):
super(StopAtNSession, self).__init__(sess)
self._count = n
def _check_stop(self):
if self._count == 0:
return True
self._count -= 1
return False
class WrappedSessionTest(test.TestCase):
"""_WrappedSession tests."""
@test_util.run_deprecated_v1
def test_properties(self):
with self.cached_session() as sess:
constant_op.constant(0.0)
wrapped_sess = monitored_session._WrappedSession(sess)
self.assertEquals(sess.graph, wrapped_sess.graph)
self.assertEquals(sess.sess_str, wrapped_sess.sess_str)
@test_util.run_deprecated_v1
def test_should_stop_on_close(self):
with self.cached_session() as sess:
wrapped_sess = monitored_session._WrappedSession(sess)
self.assertFalse(wrapped_sess.should_stop())
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
@test_util.run_deprecated_v1
def test_should_stop_uses_check_stop(self):
with self.cached_session() as sess:
wrapped_sess = StopAtNSession(sess, 3)
self.assertFalse(wrapped_sess.should_stop())
self.assertFalse(wrapped_sess.should_stop())
self.assertFalse(wrapped_sess.should_stop())
self.assertTrue(wrapped_sess.should_stop())
@test_util.run_deprecated_v1
def test_should_stop_delegates_to_wrapped_session(self):
with self.cached_session() as sess:
wrapped_sess0 = StopAtNSession(sess, 4)
wrapped_sess1 = monitored_session._WrappedSession(wrapped_sess0)
self.assertFalse(wrapped_sess1.should_stop())
self.assertFalse(wrapped_sess1.should_stop())
self.assertFalse(wrapped_sess1.should_stop())
self.assertFalse(wrapped_sess1.should_stop())
self.assertTrue(wrapped_sess1.should_stop())
@test_util.run_deprecated_v1
def test_close_twice(self):
with self.cached_session() as sess:
wrapped_sess = monitored_session._WrappedSession(sess)
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
@test_util.run_deprecated_v1
def test_run(self):
with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
self.assertEqual(42, sess.run(v, feed_dict={c: 42}))
wrapped_sess = monitored_session._WrappedSession(sess)
self.assertEqual(51, wrapped_sess.run(v, feed_dict={c: 51}))
def busy_wait_for_coord_stop(coord):
while not coord.should_stop():
time.sleep(0.001)
class CoordinatedSessionTest(test.TestCase):
"""_CoordinatedSession tests."""
@test_util.run_deprecated_v1
def test_properties(self):
with self.cached_session() as sess:
constant_op.constant(0.0)
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertEquals(sess.graph, coord_sess.graph)
self.assertEquals(sess.sess_str, coord_sess.sess_str)
@test_util.run_deprecated_v1
def test_run(self):
with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
@test_util.run_deprecated_v1
def test_should_stop_on_close(self):
with self.cached_session() as sess:
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertFalse(coord_sess.should_stop())
coord_sess.close()
self.assertTrue(coord_sess.should_stop())
@test_util.run_deprecated_v1
def test_should_stop_on_coord_stop(self):
with self.cached_session() as sess:
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertFalse(coord_sess.should_stop())
coord.request_stop()
self.assertTrue(coord_sess.should_stop())
@test_util.run_deprecated_v1
def test_dont_request_stop_on_exception_in_main_thread(self):
with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertFalse(coord_sess.should_stop())
self.assertEqual(0, coord_sess.run(c))
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
with self.assertRaisesRegexp(TypeError, 'None has invalid type'):
coord_sess.run([None], feed_dict={c: 2})
self.assertFalse(coord.should_stop())
self.assertFalse(coord_sess.should_stop())
@test_util.run_deprecated_v1
def test_stop_threads_on_close_after_exception(self):
with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
coord = coordinator.Coordinator()
threads = [
threading.Thread(
target=busy_wait_for_coord_stop, args=(coord,)) for _ in range(3)
]
for t in threads:
coord.register_thread(t)
t.start()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
self.assertFalse(coord_sess.should_stop())
for t in threads:
self.assertTrue(t.is_alive())
self.assertEqual(0, coord_sess.run(c))
for t in threads:
self.assertTrue(t.is_alive())
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
for t in threads:
self.assertTrue(t.is_alive())
with self.assertRaisesRegexp(TypeError, 'None has invalid type'):
coord_sess.run([None], feed_dict={c: 2})
coord_sess.close()
for t in threads:
self.assertFalse(t.is_alive())
self.assertTrue(coord.should_stop())
self.assertTrue(coord_sess.should_stop())
def test_stop_threads_on_close(self):
with self.cached_session() as sess:
coord = coordinator.Coordinator()
threads = [
threading.Thread(
target=busy_wait_for_coord_stop, args=(coord,)) for _ in range(3)
]
for t in threads:
coord.register_thread(t)
t.start()
coord_sess = monitored_session._CoordinatedSession(sess, coord)
coord_sess.close()
for t in threads:
self.assertFalse(t.is_alive())
self.assertTrue(coord.should_stop())
self.assertTrue(coord_sess.should_stop())
@test_util.run_deprecated_v1
def test_propagates_exception_trace(self):
assertion = control_flow_ops.Assert(False, ['This should fail.'])
with self.cached_session() as sess:
coord = coordinator.Coordinator(clean_stop_exception_types=())
coord_sess = monitored_session._CoordinatedSession(sess, coord)
try:
coord_sess.run([assertion])
self.fail('No exception was raised by assertion.')
except errors_impl.InvalidArgumentError:
# Extract the name of the file where the exception was first raised.
_, _, exc_traceback = sys.exc_info()
tb = traceback.extract_tb(exc_traceback)
exc_source_file = tb[-1][0]
exc_source_basename = os.path.basename(exc_source_file)
# If it's monitored_session.py then the original stack trace was not
# correctly propagated.
self.assertIn(
exc_source_basename, ['session.py', 'monitored_session.py'],
'The exception was raised from an unrecognized file. This unit '
'test probably needs to be updated. Traceback:\n%s\n' % tb)
self.assertEqual(
exc_source_basename, 'session.py',
'Original stack trace was not propagated by MonitoredSession. '
'Traceback:\n%s' % tb)
class AbortAtNSession(object):
"""A mock session that aborts at the N-th run call."""
def __init__(self, sess, n):
self._sess = sess
self._count = n
def close(self):
pass
def run(self, *args, **kwargs):
if self._count == 0:
raise errors_impl.AbortedError('Aborted at N', None, None)
self._count -= 1
return self._sess.run(*args, **kwargs)
class StopCoordinatorWithException(session_run_hook.SessionRunHook):
"""With this hook Coordinator throws an exception after N-runs."""
def __init__(self, calls_before_stopping, exception_to_raise=None):
self._started_the_side_thread_already = False
self._lock = threading.Lock()
self._stored_exception_event = threading.Event()
self._calls_before_stopping = calls_before_stopping
self._exception_to_raise = (exception_to_raise or errors_impl.AbortedError(
None, None, 'Aborted at N'))
def _maybe_stop_with_exception(self, coord):
while True:
with self._lock:
if self._calls_before_stopping == 0:
try:
raise self._exception_to_raise
except Exception as e: # pylint: disable=broad-except
coord.request_stop(e)
self._stored_exception_event.set()
break
def after_create_session(self, session, coord):
if self._started_the_side_thread_already:
return
separate_thread = threading.Thread(
target=self._maybe_stop_with_exception, args=(coord,))
coord.register_thread(separate_thread)
separate_thread.start()
self._started_the_side_thread_already = True
# Coordinator will take care of joining `separate_thread`.
def after_run(self, run_context, run_values):
stopping_now = False
with self._lock:
self._calls_before_stopping -= 1
if self._calls_before_stopping == 0:
stopping_now = True
if stopping_now:
self._stored_exception_event.wait()
class FailTrainingAfterCoordinatorStopped(StopCoordinatorWithException):
"""With this hook training encounters an exception after N-runs."""
def __init__(self, calls_before_stopping):
StopCoordinatorWithException.__init__(self, calls_before_stopping)
self._coord = None
def after_create_session(self, session, coord):
self._coord = coord
return StopCoordinatorWithException.after_create_session(
self, session, coord)
def after_run(self, run_context, run_values):
StopCoordinatorWithException.after_run(self, run_context, run_values)
try:
# After a `run`, an exception could have been stored inside the
# coordinator.
self._coord.raise_requested_exception()
except errors_impl.AbortedError:
# In real world, the main thread may or may not know about the exception
# that stopped the coordinator. Because the coordinator has stopped, the
# main thread could have gotten stuck as well (for example, the
# coordinator was supposed to execute `FIFOQueue.enqueue` while the main
# thread is executing a blocking `FIFOQueue.dequeue`). After it got stuck,
# the session is going to get garbage collected after some time with:
raise errors_impl.CancelledError(None, None,
'Session got garbage-collected.')
class CountingSessionCreator(object):
"""A creator that counts the number of created sessions."""
def __init__(self, session):
self._initial_session = session
# We only have one session per test case. We can't re-create it, thus
# it shouldn't be closed.
self._initial_session.close = lambda *args: None
self._create_session_calls = 0
@property
def number_of_sessions_created(self):
return self._create_session_calls
def create_session(self):
self._create_session_calls += 1
return self._initial_session
class RecoverableSessionTest(test.TestCase):
"""_RecoverableSession tests."""
class _SessionReturner(object):
def __init__(self, sess):
self._sess = sess
def create_session(self):
return self._sess
@test_util.run_deprecated_v1
def test_properties(self):
with self.cached_session() as sess:
constant_op.constant(0.0)
recoverable_sess = monitored_session._RecoverableSession(
self._SessionReturner(sess))
self.assertEquals(sess.graph, recoverable_sess.graph)
self.assertEquals(sess.sess_str, recoverable_sess.sess_str)
@test_util.run_deprecated_v1
def test_run(self):
with self.cached_session() as sess:
c = constant_op.constant(0)
v = array_ops.identity(c)
recoverable_sess = monitored_session._RecoverableSession(
self._SessionReturner(sess))
self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
@test_util.run_deprecated_v1
def test_recovery(self):
with self.cached_session() as sess:
class StackSessionCreator(object):
def __init__(self, sess):
self.sessions_to_use = [
AbortAtNSession(sess, x + 1) for x in range(3)
]
def create_session(self):
return self.sessions_to_use.pop(0)
c = constant_op.constant(0)
v = array_ops.identity(c)
session_creator = StackSessionCreator(sess)
# List of 3 sessions to use for recovery. The first one aborts
# after 1 run() call, the second after 2 run calls, the third
# after 3 run calls.
self.assertEqual(3, len(session_creator.sessions_to_use))
# Make the recoverable session uses these 3 sessions in sequence by
# passing a factory that pops from the session_to_use list.
recoverable_sess = monitored_session._RecoverableSession(session_creator)
self.assertEqual(
2, len(session_creator.sessions_to_use)) # One session popped.
# Using first session.
self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
self.assertEqual(
2, len(session_creator.sessions_to_use)) # Still 2 sessions available
# This will fail and recover by picking up the second session.
self.assertEqual(42, recoverable_sess.run(v, feed_dict={c: 42}))
self.assertEqual(
1, len(session_creator.sessions_to_use)) # Still 1 session available
self.assertEqual(33, recoverable_sess.run(v, feed_dict={c: 33}))
self.assertEqual(
1, len(session_creator.sessions_to_use)) # Still 1 session available
# This will fail and recover by picking up the last session.
self.assertEqual(24, recoverable_sess.run(v, feed_dict={c: 24}))
self.assertEqual(
0, len(session_creator.sessions_to_use)) # All sessions used.
self.assertEqual(11, recoverable_sess.run(v, feed_dict={c: 11}))
self.assertEqual(0, recoverable_sess.run(v, feed_dict={c: 0}))
# This will fail and throw a real error as the pop() will fail.
with self.assertRaisesRegexp(IndexError, 'pop from empty list'):
recoverable_sess.run(v, feed_dict={c: -12})
@test_util.run_deprecated_v1
def test_recovery_from_coordinator_exception(self):
with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
[StopCoordinatorWithException(calls_before_stopping=2)])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run(v, feed_dict={c: 51}))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run(v, feed_dict={c: 42}))
# Even though the coordinator was asked to stop, the underlying session is
# recreated and is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
@test_util.run_deprecated_v1
def test_recovery_from_non_preemption_in_coordinator(self):
with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
hook = StopCoordinatorWithException(
calls_before_stopping=2,
exception_to_raise=errors_impl.UnknownError(
None, None, 'Some fatal exception inside the coordinator.'))
session = monitored_session.MonitoredSession(session_creator, [hook])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run(v, feed_dict={c: 51}))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run(v, feed_dict={c: 42}))
# The coordinator was asked to stop due to non-redeemable error. Training
# should stop and the session should not be recreated.
self.assertTrue(session.should_stop())
self.assertEqual(1, session_creator.number_of_sessions_created)
with self.assertRaises(errors_impl.UnknownError):
session.close()
@test_util.run_deprecated_v1
def test_recovery_from_session_getting_stuck(self):
with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
[FailTrainingAfterCoordinatorStopped(calls_before_stopping=2)])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
# Training will not fail, since it's the call number 0.
self.assertEqual(51, session.run(v, feed_dict={c: 51}))
self.assertFalse(session.should_stop())
# Training will fail during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run(v, feed_dict={c: 42}))
# Even though the coordinator stopped which and training failed, the
# underlying session is recreated and training is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
@test_util.run_deprecated_v1
def test_step_fn_recovery_from_coordinator_exception_when_run_hooks(self):
with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
[StopCoordinatorWithException(calls_before_stopping=2)])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
return step_fn
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# Even though the coordinator was asked to stop, the underlying session is
# recreated and is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
@test_util.run_deprecated_v1
def test_recovery_from_non_preemption_in_coordinator_when_run_hooks(self):
with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
hook = StopCoordinatorWithException(
calls_before_stopping=2,
exception_to_raise=errors_impl.UnknownError(
None, None, 'Some fatal exception inside the coordinator.'))
session = monitored_session.MonitoredSession(session_creator, [hook])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
return step_fn
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# The coordinator was asked to stop due to non-redeemable error. Training
# should stop and the session should not be recreated.
self.assertTrue(session.should_stop())
self.assertEqual(1, session_creator.number_of_sessions_created)
with self.assertRaises(errors_impl.UnknownError):
session.close()
@test_util.run_deprecated_v1
def test_recovery_from_session_getting_stuck_when_run_hooks(self):
with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = monitored_session.MonitoredSession(
session_creator,
[FailTrainingAfterCoordinatorStopped(calls_before_stopping=2)])
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
return step_fn
# Training will not fail, since it's the call number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# Training will fail during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# Even though the coordinator stopped which and training failed, the
# underlying session is recreated and training is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
def create_raw_session_with_failing_coordinator(self, session_creator, hook):
"""Return MonitoredSession that triggers coordinator failures."""
session = monitored_session.MonitoredSession(session_creator, [hook])
# We would like to test a situation where during fetches through the
# raw session, the coordinator fails with an exception. To do that, we
# are going to use (raw_session + StopCoordinatorWithException) hook
# combination that is stored in
# `MonitoredSession._RecoverableSession._CoordinatedSession._sess`
# at this point:
session._tf_sess = lambda: session._sess._sess._sess
# `run()` on such a session is equivalent to `run()` on the raw session
# with separate coordinator threads independently stopping with an
# exception.
return session
@test_util.run_deprecated_v1
def test_step_fn_recovery_from_coordinator_exception_with_raw_session(self):
with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
StopCoordinatorWithException(calls_before_stopping=2))
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.session.run(fetches=v, feed_dict={c: value})
return step_fn
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# Even though the coordinator was asked to stop, the underlying session is
# recreated and is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
@test_util.run_deprecated_v1
def test_recovery_from_non_preemption_in_coordinator_with_raw_session(self):
with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
StopCoordinatorWithException(
calls_before_stopping=2,
exception_to_raise=errors_impl.UnknownError(
None, None, 'Some fatal exception inside the coordinator.')))
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
return step_fn
# The coordinator will not abort during this call, since it's the call
# number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# The coordinator will abort during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# The coordinator was asked to stop due to non-redeemable error. Training
# should stop and the session should not be recreated.
self.assertTrue(session.should_stop())
self.assertEqual(1, session_creator.number_of_sessions_created)
with self.assertRaises(errors_impl.UnknownError):
session.close()
@test_util.run_deprecated_v1
def test_recovery_from_session_getting_stuck_with_raw_session(self):
with self.cached_session() as test_session:
session_creator = CountingSessionCreator(test_session)
session = self.create_raw_session_with_failing_coordinator(
session_creator,
FailTrainingAfterCoordinatorStopped(calls_before_stopping=2))
self.assertEqual(1, session_creator.number_of_sessions_created)
self.assertFalse(session.should_stop())
c = constant_op.constant(0)
v = array_ops.identity(c)
def feed_step_fn(value):
def step_fn(step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: value})
return step_fn
# Training will not fail, since it's the call number 0.
self.assertEqual(51, session.run_step_fn(feed_step_fn(51)))
self.assertFalse(session.should_stop())
# Training will fail during the next call, since it's the call
# number 1.
self.assertEqual(42, session.run_step_fn(feed_step_fn(42)))
# Even though the coordinator stopped which and training failed, the
# underlying session is recreated and training is to be continued.
self.assertFalse(session.should_stop())
self.assertEqual(2, session_creator.number_of_sessions_created)
class FakeSession(monitored_session._WrappedSession):
def __init__(self, sess):
monitored_session._WrappedSession.__init__(self, sess)
self.args_called = {}
def run(self, fetches, **kwargs):
self.args_called = dict(kwargs)
# Call run only with fetches since we directly pass other arguments.
return monitored_session._WrappedSession.run(self, fetches)
class HookedSessionTest(test.TestCase):
"""Tests of _HookedSession."""
def testRunPassesAllArguments(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
mock_run = FakeSession(sess)
mon_sess = monitored_session._HookedSession(sess=mock_run, hooks=[])
a_tensor = constant_op.constant([0], name='a_tensor')
self.evaluate(variables.global_variables_initializer())
output = mon_sess.run(fetches=a_tensor,
feed_dict='a_feed',
options='an_option',
run_metadata='a_metadata')
self.assertEqual(output, [0])
self.assertEqual(mock_run.args_called, {
'feed_dict': 'a_feed',
'options': 'an_option',
'run_metadata': 'a_metadata'
})
def testCallsHooksBeginEnd(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = constant_op.constant([0], name='a_tensor')
self.evaluate(variables.global_variables_initializer())
mon_sess.run(a_tensor)
for hook in [mock_hook, mock_hook2]:
self.assertEqual(
hook.last_run_values,
session_run_hook.SessionRunValues(
results=None,
options=config_pb2.RunOptions(),
run_metadata=config_pb2.RunMetadata()))
self.assertEqual(hook.last_run_context.original_args,
session_run_hook.SessionRunArgs(a_tensor))
self.assertEqual(hook.last_run_context.session, sess)
self.assertEqual(hook.call_counter['begin'], 0)
self.assertEqual(hook.call_counter['after_create_session'], 0)
self.assertEqual(hook.call_counter['before_run'], 1)
self.assertEqual(hook.call_counter['after_run'], 1)
def testShouldStop(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
constant_op.constant([0], name='a_tensor')
self.evaluate(variables.global_variables_initializer())
mon_sess.run(fetches='a_tensor')
self.assertFalse(mon_sess.should_stop())
mock_hook.should_stop = True
mon_sess.run(fetches='a_tensor')
self.assertTrue(mon_sess.should_stop())
def testFetchesHookRequests(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = constant_op.constant([0], name='a_tensor')
another_tensor = constant_op.constant([5], name='another_tensor')
third_tensor = constant_op.constant([10], name='third_tensor')
mock_hook.request = session_run_hook.SessionRunArgs([another_tensor])
mock_hook2.request = session_run_hook.SessionRunArgs([third_tensor])
self.evaluate(variables.global_variables_initializer())
output = mon_sess.run(fetches=a_tensor)
self.assertEqual(output, [0])
self.assertEqual(mock_hook.last_run_values.results, [5])
self.assertEqual(mock_hook2.last_run_values.results, [10])
def testOnlyHooksHaveFeeds(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = constant_op.constant([0], name='a_tensor')
b_tensor = constant_op.constant([0], name='b_tensor')
add_tensor = a_tensor + b_tensor
mock_hook.request = session_run_hook.SessionRunArgs(
None, feed_dict={a_tensor: [5]})
mock_hook2.request = session_run_hook.SessionRunArgs(
None, feed_dict={b_tensor: [10]})
self.evaluate(variables.global_variables_initializer())
self.assertEqual(mon_sess.run(fetches=add_tensor), [15])
def testBothHooksAndUserHaveFeeds(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = constant_op.constant([0], name='a_tensor')
b_tensor = constant_op.constant([0], name='b_tensor')
c_tensor = constant_op.constant([0], name='c_tensor')
add_tensor = a_tensor + b_tensor + c_tensor
mock_hook.request = session_run_hook.SessionRunArgs(
None, feed_dict={a_tensor: [5]})
mock_hook2.request = session_run_hook.SessionRunArgs(
None, feed_dict={b_tensor: [10]})
self.evaluate(variables.global_variables_initializer())
feed_dict = {c_tensor: [20]}
self.assertEqual(
mon_sess.run(fetches=add_tensor, feed_dict=feed_dict), [35])
# User feed_dict should not be changed
self.assertEqual(len(feed_dict), 1)
def testHooksFeedConflicts(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = constant_op.constant([0], name='a_tensor')
b_tensor = constant_op.constant([0], name='b_tensor')
add_tensor = a_tensor + b_tensor
mock_hook.request = session_run_hook.SessionRunArgs(
None, feed_dict={a_tensor: [5]})
mock_hook2.request = session_run_hook.SessionRunArgs(
None, feed_dict={a_tensor: [10]})
self.evaluate(variables.global_variables_initializer())
with self.assertRaisesRegexp(RuntimeError, 'Same tensor is fed'):
mon_sess.run(fetches=add_tensor)
def testHooksAndUserFeedConflicts(self):
with ops.Graph().as_default(), session_lib.Session() as sess:
mock_hook = FakeHook()
mock_hook2 = FakeHook()
mon_sess = monitored_session._HookedSession(
sess=sess, hooks=[mock_hook, mock_hook2])
a_tensor = constant_op.constant([0], name='a_tensor')
b_tensor = constant_op.constant([0], name='b_tensor')
add_tensor = a_tensor + b_tensor
mock_hook.request = session_run_hook.SessionRunArgs(
None, feed_dict={a_tensor: [5]})
mock_hook2.request = session_run_hook.SessionRunArgs(
None, feed_dict={b_tensor: [10]})
self.evaluate(variables.global_variables_initializer())
with self.assertRaisesRegexp(RuntimeError, 'Same tensor is fed'):
mon_sess.run(fetches=add_tensor, feed_dict={b_tensor: [10]})
class RaiseOnceAtCountN(session_run_hook.SessionRunHook):
"""Hook that raises an Exception at step N."""
def __init__(self, n, ex):
self.n = n
self.ex = ex
self.raised = False
def before_run(self, run_context):
# Raise the first time we reach step N.
self.n -= 1
if 0 == self.n and not self.raised:
self.raised = True
raise self.ex
return None
class RunOptionsMetadataHook(session_run_hook.SessionRunHook):
"""A hook that observes & optionally modifies RunOptions and RunMetadata."""
def __init__(self, trace_level, timeout_in_ms, output_partition_graphs,
debug_tensor_watch, report_tensor_allocations_upon_oom):
self._trace_level = trace_level
self._timeout_in_ms = timeout_in_ms
self._output_partition_graphs = output_partition_graphs
self._debug_tensor_watch = debug_tensor_watch
self._report_tensor_allocations_upon_oom = (
report_tensor_allocations_upon_oom)
self.run_options_list = []
self.run_metadata_list = []
def before_run(self, run_context):
options = config_pb2.RunOptions(
trace_level=self._trace_level,
timeout_in_ms=self._timeout_in_ms,
output_partition_graphs=self._output_partition_graphs,
report_tensor_allocations_upon_oom=self
._report_tensor_allocations_upon_oom)
options.debug_options.debug_tensor_watch_opts.extend(
[self._debug_tensor_watch])
return session_run_hook.SessionRunArgs(None, None, options=options)
def after_run(self, run_context, run_values):
self.run_options_list.append(run_values.options)
self.run_metadata_list.append(run_values.run_metadata)
class MonitoredSessionTest(test.TestCase):
"""MonitoredSession tests."""
def test_defaults(self):
with ops.Graph().as_default():
a_var = variables.VariableV1(0)
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
def test_last_step(self):
logdir = _test_dir(self.get_temp_dir(), 'test_last_step')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
# Run till step 3 and save.
hooks = [basic_session_run_hooks.StopAtStepHook(last_step=3)]
with monitored_session.MonitoredSession(hooks=hooks) as session:
self.assertEqual(0, session.run(gstep))
self.assertFalse(session.should_stop())
self.assertEqual(1, session.run(do_step))
self.assertFalse(session.should_stop())
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
self.assertEqual(3, session.run(do_step))
self.assertTrue(session.should_stop())
save_path = saver_lib._get_saver_or_default().save(
session._coordinated_creator.tf_sess,
os.path.join(logdir, 'step-3'))
# Run till step 5 and save.
def load_ckpt(scaffold, sess):
scaffold.saver.restore(sess, save_path)
session_creator = monitored_session.ChiefSessionCreator(
monitored_session.Scaffold(init_fn=load_ckpt))
hooks = [basic_session_run_hooks.StopAtStepHook(last_step=5)]
with monitored_session.MonitoredSession(
hooks=hooks, session_creator=session_creator) as session:
self.assertEqual(3, session.run(gstep))
self.assertFalse(session.should_stop())
self.assertEqual(4, session.run(do_step))
self.assertFalse(session.should_stop())
self.assertEqual(5, session.run(do_step))
self.assertTrue(session.should_stop())
def test_num_steps(self):
logdir = _test_dir(self.get_temp_dir(), 'test_num_steps')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
# Do 3 steps and save.
hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=3)]
with monitored_session.MonitoredSession(hooks=hooks) as session:
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertTrue(session.should_stop())
save_path = saver_lib._get_saver_or_default().save(
session._coordinated_creator.tf_sess,
os.path.join(logdir, 'step-3'))
# Restore and do 4 steps.
def load_ckpt(scaffold, sess):
scaffold.saver.restore(sess, save_path)
session_creator = monitored_session.ChiefSessionCreator(
scaffold=monitored_session.Scaffold(init_fn=load_ckpt))
hooks = [basic_session_run_hooks.StopAtStepHook(num_steps=4)]
with monitored_session.MonitoredSession(
hooks=hooks, session_creator=session_creator) as session:
self.assertEqual(4, session.run(do_step))
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertTrue(session.should_stop())
# This set of tests, verifies the supervised session behavior when exceptions
# are raised next to the innermost session run() call.
@test_util.run_deprecated_v1
def test_recovery(self):
logdir = _test_dir(self.get_temp_dir(), 'test_recovery')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
scaffold = monitored_session.Scaffold()
# Use a hook to save the model every 100 steps. It also saves it at
# the end.
hooks = [
basic_session_run_hooks.CheckpointSaverHook(
logdir, save_steps=1, scaffold=scaffold)
]
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir),
hooks=hooks) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir)) as session:
self.assertEqual(2, session.run(gstep))
# A restart will find the checkpoint and recover automatically.
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold,
checkpoint_filename_with_path=checkpoint_management.
latest_checkpoint(logdir))) as session:
self.assertEqual(2, session.run(gstep))
def test_retry_initialization_on_aborted_error(self):
# Tests that we silently retry on abort during initialization.
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
self.init_raised_aborted_error = False
def _init_fn(scaffold, session):
_, _ = scaffold, session
if not self.init_raised_aborted_error:
self.init_raised_aborted_error = True
raise errors_impl.AbortedError(None, None, 'Abort')
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold=monitored_session.Scaffold(
init_fn=_init_fn))) as session:
self.assertFalse(session.should_stop())
self.assertEqual(0, session.run(gstep))
self.assertTrue(self.init_raised_aborted_error)
def _retry_test(self, ex):
# Tests that we silently retry on error. Note that this does not test
# recovery as we do not use a CheckpointSaver in this test.
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(4, ex)
with monitored_session.MonitoredSession(hooks=[hook]) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Here at step 3, the hook triggers and raises AbortedError. The
# MonitoredSession automatically retries and restart from a freshly
# initialized session, so the step is back to 0 and running do_step
# moves it to 1.
self.assertEqual(1, session.run(do_step))
self.assertFalse(session.should_stop())
self.assertTrue(hook.raised)
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
def test_retry_on_aborted_error(self):
self._retry_test(errors_impl.AbortedError(None, None, 'Abort'))
def test_retry_on_unavailable_error(self):
self._retry_test(errors_impl.UnavailableError(None, None, 'Unavailable'))
def test_recover_and_retry_on_aborted_error(self):
# Tests that we silently retry and recover on abort. This test uses
# a CheckpointSaver to have something to recover from.
logdir = _test_dir(self.get_temp_dir(),
'test_recover_and_retry_on_aborted_error')
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
scaffold = monitored_session.Scaffold()
abort_hook = RaiseOnceAtCountN(
4, errors_impl.AbortedError(None, None, 'Abort'))
# Save after each step.
ckpt_hook = basic_session_run_hooks.CheckpointSaverHook(
logdir, save_steps=1, scaffold=scaffold)
hooks = [abort_hook, ckpt_hook]
with monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
scaffold, checkpoint_dir=logdir),
hooks=hooks) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Here at step 3, the hook triggers and raises AbortedError. The
# MonitoredSession automatically restores and retries.
self.assertEqual(3, session.run(do_step))
self.assertTrue(abort_hook.raised)
self.assertFalse(session.should_stop())
self.assertEqual(4, session.run(do_step))
self.assertFalse(session.should_stop())
def test_exit_cleanly_on_out_of_range_exception(self):
# Tests that we stop cleanly when OutOfRange is raised.
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None,
'EOI'))
session = monitored_session.MonitoredSession(hooks=[hook])
# session should cleanly exit from the context.
with session:
self.assertEqual(0, session.run(gstep))
self.assertFalse(session.should_stop())
# Here at step 1, the hook triggers and raises OutOfRange. The
# session should go into should_stop() mode. It should raise the
# exception. So next step should not be executed.
session.run(do_step)
self.assertTrue(False)
self.assertTrue(session.should_stop())
def test_exit_cleanly_on_stop_iteration_exception(self):
# Tests that we stop cleanly when OutOfRange is raised.
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(2, StopIteration)
session = monitored_session.MonitoredSession(hooks=[hook])
# session should cleanly exit from the context.
with session:
self.assertEqual(0, session.run(gstep))
self.assertFalse(session.should_stop())
# Here at step 1, the hook triggers and raises StopIteration. The
# session should go into should_stop() mode. It should raise the
# exception. So next step should not be executed.
session.run(do_step)
self.assertTrue(False)
self.assertTrue(session.should_stop())
def test_regular_exception_pass_through_run(self):
# Tests that regular exceptions just pass through a "with
# MonitoredSession" block and set the session in stop mode.
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(4, RuntimeError('regular exception'))
session = monitored_session.MonitoredSession(hooks=[hook])
with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
with session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# This triggers the hook and raises the exception
session.run(do_step)
# We should not hit this
self.assertFalse(True)
self.assertTrue(hook.raised)
self.assertTrue(session.should_stop())
def test_regular_exception_reported_to_coord_pass_through_run(self):
# Tests that regular exceptions reported to the coordinator from a thread
# passes through a "run()" call within a "with MonitoredSession" block and
# set the session in stop mode.
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
session = monitored_session.MonitoredSession()
run_performed_without_error = False
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
with session:
self.assertEqual(0, session.run(gstep))
# Report an exception through the coordinator.
try:
raise RuntimeError('a thread wants to stop')
except RuntimeError as e:
session._coordinated_creator.coord.request_stop(e)
# Call run() which should perform normally.
self.assertEqual(0, session.run(gstep))
run_performed_without_error = True
self.assertTrue(run_performed_without_error)
def test_regular_exception_reported_to_coord_pass_through_return(self):
# Tests that regular exceptions reported to the coordinator from a thread
# passes through returning from a "with MonitoredSession" block and
# set the session in stop mode.
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
session = monitored_session.MonitoredSession()
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
with session:
self.assertEqual(0, session.run(gstep))
# Report an exception through the coordinator.
try:
raise RuntimeError('a thread wants to stop')
except RuntimeError as e:
session._coordinated_creator.coord.request_stop(e)
self.assertTrue(session.should_stop())
# This set of tests, verifies the session behavior when exceptions are raised
# from code inside a "with MonitoredSession:" context.
def test_stop_cleanly_when_no_exception_in_with_body(self):
# Tests that regular exceptions pass through
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
session = monitored_session.MonitoredSession()
with session:
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Should have closed.
self.assertTrue(session.should_stop())
self.assertTrue(session._is_closed())
def test_raises_regular_exceptions_in_with_body(self):
# Tests that regular exceptions in "with body" are seen outside.
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
session = monitored_session.MonitoredSession()
# We should see that exception.
with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
with session:
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Will be visible outside the "with body".
raise RuntimeError('regular exception')
# Should have closed.
self.assertTrue(session.should_stop())
self.assertTrue(session._is_closed())
def test_graph(self):
with ops.Graph().as_default() as g:
with monitored_session.MonitoredSession() as session:
self.assertEqual(g, session.graph)
def test_graph_finalized_during_run_unfinalized_after_exit(self):
with ops.Graph().as_default() as g:
a_var = variables.VariableV1(0)
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
self.assertTrue(g.finalized)
self.assertFalse(g.finalized)
def test_keep_finalized_graph_as_finalized(self):
with ops.Graph().as_default() as g:
a_var = variables.VariableV1(0)
monitored_session.Scaffold().finalize()
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
self.assertTrue(g.finalized)
self.assertTrue(g.finalized)
def test_merge_run_options_from_hooks(self):
"""Test for rewriting RunOptions and observing RunMetadata with hooks."""
with ops.Graph().as_default():
my_const = constant_op.constant(42, name='my_const')
_ = constant_op.constant(24, name='my_const_2')
watch_a = debug_pb2.DebugTensorWatch(
node_name='my_const',
output_slot=0,
debug_ops=['DebugIdentity'],
debug_urls=[])
hook_a = RunOptionsMetadataHook(2, 30000, False, watch_a, False)
watch_b = debug_pb2.DebugTensorWatch(
node_name='my_const_2',
output_slot=0,
debug_ops=['DebugIdentity'],
debug_urls=[])
hook_b = RunOptionsMetadataHook(3, 60000, True, watch_b, True)
with monitored_session.MonitoredSession(
hooks=[hook_a, hook_b]) as session:
self.assertEqual(42, session.run(my_const))
# trace_level=3 should have overridden trace_level=2;
# timeout_in_ms=60000 should have overridden 30000;
# output_partition_graphs=True should have overridden False.
# The two debug tensor watches should have been merged.
self.assertEqual([
config_pb2.RunOptions(
trace_level=3,
timeout_in_ms=60000,
output_partition_graphs=True,
debug_options=debug_pb2.DebugOptions(
debug_tensor_watch_opts=[watch_a, watch_b]),
report_tensor_allocations_upon_oom=True),
], hook_b.run_options_list)
self.assertEqual(1, len(hook_b.run_metadata_list))
self.assertTrue(
isinstance(hook_b.run_metadata_list[0], config_pb2.RunMetadata))
self.assertGreater(len(hook_b.run_metadata_list[0].partition_graphs), 0)
def test_merge_caller_and_hook_run_options(self):
"""Test that RunOptions from caller and hooks can be merged properly."""
with ops.Graph().as_default():
my_const = constant_op.constant(42, name='my_const')
_ = constant_op.constant(24, name='my_const_2')
hook_watch = debug_pb2.DebugTensorWatch(
node_name='my_const_2',
output_slot=0,
debug_ops=['DebugIdentity'],
debug_urls=[])
hook = RunOptionsMetadataHook(2, 60000, False, hook_watch, False)
with monitored_session.MonitoredSession(hooks=[hook]) as session:
caller_watch = debug_pb2.DebugTensorWatch(
node_name='my_const',
output_slot=0,
debug_ops=['DebugIdentity'],
debug_urls=[])
caller_options = config_pb2.RunOptions(
trace_level=3,
timeout_in_ms=30000,
output_partition_graphs=True,
report_tensor_allocations_upon_oom=True)
caller_options.debug_options.debug_tensor_watch_opts.extend(
[caller_watch])
self.assertEqual(42, session.run(my_const, options=caller_options))
# trace_level=3 from the caller should override 2 from the hook.
# timeout_in_ms=60000 from the hook should override from the caller.
# output_partition_graph=True from the caller should override False
# from the hook.
# The two debug watches from the caller and the hook should be merged,
# in that order.
self.assertEqual([
config_pb2.RunOptions(
trace_level=3,
timeout_in_ms=60000,
output_partition_graphs=True,
debug_options=debug_pb2.DebugOptions(
debug_tensor_watch_opts=[caller_watch, hook_watch]),
report_tensor_allocations_upon_oom=True),
], hook.run_options_list)
self.assertEqual(1, len(hook.run_metadata_list))
self.assertTrue(
isinstance(hook.run_metadata_list[0], config_pb2.RunMetadata))
self.assertGreater(len(hook.run_metadata_list[0].partition_graphs), 0)
@test_util.run_deprecated_v1
def test_with_statement_and_close(self):
# Test case for https://github.com/tensorflow/tensorflow/issues/12224
# where close() inside the with should have a better error message.
with self.assertRaisesRegexp(RuntimeError, 'Session is already closed'):
with monitored_session.MonitoredSession() as session:
session.close()
def test_step_fn_example(self):
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
def step_fn(step_context):
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
return value
with monitored_session.MonitoredSession() as session:
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
def test_step_function_stops(self):
with ops.Graph().as_default():
def step_fn(step_context):
step_context.request_stop()
with monitored_session.MonitoredSession() as session:
self.assertEqual(None, session.run_step_fn(step_fn))
self.assertTrue(session.should_stop())
def test_step_request_stop_without_a_with_block(self):
with ops.Graph().as_default():
was_stop_iteration_raised = False
def step_fn(step_context):
step_context.request_stop()
session = monitored_session.MonitoredSession()
try:
self.assertEqual(None, session.run_step_fn(step_fn))
except StopIteration:
was_stop_iteration_raised = True
self.assertTrue(was_stop_iteration_raised)
self.assertFalse(session.should_stop())
def test_step_request_stop_in_a_loop(self):
with ops.Graph().as_default():
def step_fn(step_context):
step_context.request_stop()
with monitored_session.MonitoredSession() as session:
while not session.should_stop():
_ = session.run_step_fn(step_fn)
self.fail('An exception should be raised on the line above.')
def test_step_request_stop_with_returning_a_type(self):
with ops.Graph().as_default():
def step_fn(step_context):
del step_context
return 'a type'
with monitored_session.MonitoredSession() as session:
self.assertEqual('a type', session.run_step_fn(step_fn))
def test_step_with_extra_arguments(self):
with ops.Graph().as_default():
def step_fn(step_context, extra_foo):
del step_context, extra_foo
with monitored_session.MonitoredSession() as session:
with self.assertRaisesRegexp(
ValueError,
'`step_fn` may either have one `step_context` argument'):
self.assertEqual(None, session.run_step_fn(step_fn))
def test_step_fn_belongs_to_a_class(self):
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
class Model(object):
def step_fn(self, step_context):
return step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
with monitored_session.MonitoredSession() as session:
model = Model()
self.assertNear(3.2, session.run_step_fn(model.step_fn), 0.1)
def test_step_fn_belongs_to_a_class_and_has_extra_methods(self):
with ops.Graph().as_default():
class Model(object):
def step_fn(self, step_context, extra_foo):
del step_context, extra_foo
with monitored_session.MonitoredSession() as session:
with self.assertRaisesRegexp(
ValueError,
'`step_fn` may either have one `step_context` argument'):
model = Model()
self.assertEqual(None, session.run_step_fn(model.step_fn))
def test_step_fn_with_hooks(self):
with ops.Graph().as_default():
var = resource_variable_ops.ResourceVariable(0.0)
# This test highlights the interaction of hooks with
# `Monitoredsession.run_step_fn`. The order of execution of operations
# below is:
# 0. stage_0
# 1. stage_1_0 or stage_1_1 in an undefined order
# 2. stage_2
stage_0 = state_ops.assign_add(var, 0.3)
stage_1_0 = state_ops.assign_add(var, 0.7)
# The order of `stage_1_0` and `stage_1_1` is undefined by
# `MonitoredSession`, but we should be able to assert when both of them
# are complete. To obtain a consistent result of adding two different
# constants to `var`, we rely on a control dependency and
# `ResourceVariable`. Otherwise, it is possible that one of the
# additions overwrites the result of the other addition.
with ops.control_dependencies([stage_1_0]):
stage_1_1 = state_ops.assign_add(var, 0.5)
stage_2 = state_ops.assign_add(var, 1.1)
class Hook(session_run_hook.SessionRunHook):
def __init__(self, testing):
self._testing = testing
def before_run(self, run_context):
return session_run_hook.SessionRunArgs(fetches=stage_1_0)
def after_run(self, run_context, run_values):
self._testing.assertNear(0.3 + 0.5 + 0.7,
run_context.session.run(var), 0.1)
self._testing.assertNear(0.3 + 0.5 + 0.7 + 1.1,
run_context.session.run(stage_2), 0.1)
def step_fn(step_context):
self.assertNear(0.3, step_context.session.run(stage_0), 0.1)
return step_context.run_with_hooks(fetches=stage_1_1)
with monitored_session.MonitoredSession(hooks=[Hook(self)]) as session:
self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))
def test_step_fn_has_the_same_hooks_behavior_without_recovery(self):
with ops.Graph().as_default():
var = resource_variable_ops.ResourceVariable(0.0)
stage_0 = state_ops.assign_add(var, 0.3)
stage_1_0 = state_ops.assign_add(var, 0.7)
with ops.control_dependencies([stage_1_0]):
stage_1_1 = state_ops.assign_add(var, 0.5)
stage_2 = state_ops.assign_add(var, 1.1)
class Hook(session_run_hook.SessionRunHook):
def __init__(self, testing):
self._testing = testing
def before_run(self, run_context):
return session_run_hook.SessionRunArgs(fetches=stage_1_0)
def after_run(self, run_context, run_values):
self._testing.assertNear(0.3 + 0.5 + 0.7,
run_context.session.run(var), 0.1)
self._testing.assertNear(0.3 + 0.5 + 0.7 + 1.1,
run_context.session.run(stage_2), 0.1)
def step_fn(step_context):
self.assertNear(0.3, step_context.session.run(stage_0), 0.1)
return step_context.run_with_hooks(fetches=stage_1_1)
with monitored_session.SingularMonitoredSession(
hooks=[Hook(self)]) as session:
self.assertEqual(0.3 + 0.5 + 0.7, session.run_step_fn(step_fn))
def test_step_fn_with_hooks_and_request_stop(self):
with ops.Graph().as_default():
trace_the_hook = {'before_run': False, 'after_run': False}
class Hook(session_run_hook.SessionRunHook):
def before_run(self, run_context):
trace_the_hook['before_run'] = True
def after_run(self, run_context, run_values):
trace_the_hook['after_run'] = True
def step_fn(step_context):
step_context.request_stop()
with monitored_session.MonitoredSession(hooks=[Hook()]) as session:
self.assertEqual(None, session.run_step_fn(step_fn))
self.assertTrue(session.should_stop())
# `step_context.request_stop()` in a step_fn interrupts the flow of
# running the hooks.
self.assertFalse(trace_the_hook['before_run'])
self.assertFalse(trace_the_hook['after_run'])
def test_recovers_from_an_exception_in_step_fn(self):
trace_the_exception = {'run_already': False}
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
def step_fn(step_context):
if not trace_the_exception['run_already']:
trace_the_exception['run_already'] = True
raise errors_impl.AbortedError(None, None, 'Abort')
return step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
with monitored_session.MonitoredSession() as session:
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
self.assertTrue(trace_the_exception['run_already'])
def test_recovers_from_an_exception_in_step_fn_after_hooks(self):
trace_the_exception = {'run_already': False, 'side_effect_counter': 0}
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
graph_state = variables.VariableV1(0.0)
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
def step_fn(step_context):
trace_the_exception['side_effect_counter'] += 1
step_context.session.run(graph_side_effect)
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
if not trace_the_exception['run_already']:
trace_the_exception['run_already'] = True
raise errors_impl.AbortedError(None, None, 'Abort')
return value
with self.cached_session() as test_session:
with monitored_session.MonitoredSession(
CountingSessionCreator(test_session)) as session:
session.run(variables.global_variables_initializer())
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
self.assertTrue(trace_the_exception['run_already'])
# Make sure the rest of the body of the step_fn is re-executed upon
# AbortedError:
self.assertEqual(2, trace_the_exception['side_effect_counter'])
self.assertNear(0.62, session.run(graph_state), 0.1)
def test_step_fn_doesnt_recover_when_it_wasnt_asked_to(self):
trace_the_exception = {'run_already': False}
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
def step_fn(step_context):
if not trace_the_exception['run_already']:
trace_the_exception['run_already'] = True
raise errors_impl.AbortedError(None, None, 'Abort')
value = step_context.run_with_hooks(fetches=v, feed_dict={c: 3.2})
return value
with monitored_session.SingularMonitoredSession() as session:
with self.assertRaisesRegexp(errors_impl.AbortedError, 'Abort'):
self.assertNear(3.2, session.run_step_fn(step_fn), 0.1)
self.fail()
self.assertTrue(trace_the_exception['run_already'])
def test_step_fn_exception_from_before_run(self):
trace_the_exception = {'run_already': False, 'side_effect_counter': 0}
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
vv = constant_op.constant(3.2)
graph_state = variables.VariableV1(0.0)
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
class Hook(session_run_hook.SessionRunHook):
def __init__(self, testing):
self._testing = testing
def before_run(self, run_context):
if not trace_the_exception['run_already']:
trace_the_exception['run_already'] = True
raise errors_impl.AbortedError(None, None, 'Abort')
return session_run_hook.SessionRunArgs(fetches=vv)
def after_run(self, run_context, run_values):
self._testing.assertNear(3.2, run_values.results, 0.1)
def step_fn(step_context):
trace_the_exception['side_effect_counter'] += 1
step_context.session.run(graph_side_effect)
return step_context.run_with_hooks(fetches=v, feed_dict={c: 1.3})
with self.cached_session() as test_session:
with monitored_session.MonitoredSession(
CountingSessionCreator(test_session),
hooks=[Hook(self)]) as session:
test_session.run(variables.global_variables_initializer())
self.assertNear(1.3, session.run_step_fn(step_fn), 0.1)
self.assertEqual(2, trace_the_exception['side_effect_counter'])
self.assertNear(0.62, session.run(graph_state), 0.1)
class SingularMonitoredSessionTest(test.TestCase):
"""Tests SingularMonitoredSession."""
def test_handles_initialization(self):
with ops.Graph().as_default():
a_var = variables.VariableV1(0)
with monitored_session.SingularMonitoredSession() as session:
# If it's not initialized, following statement raises an error.
self.assertEqual(0, session.run(a_var))
def test_do_not_handle_aborted_error(self):
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
class _RaiseAbortedHook(session_run_hook.SessionRunHook):
def before_run(self, run_context):
raise errors_impl.AbortedError(None, None, 'Abort')
with monitored_session.SingularMonitoredSession(
hooks=[_RaiseAbortedHook()]) as session:
with self.assertRaises(errors_impl.AbortedError):
self.assertEqual(0, session.run(gstep))
with self.assertRaises(errors_impl.AbortedError):
with monitored_session.SingularMonitoredSession(
hooks=[_RaiseAbortedHook()]) as session:
self.assertEqual(0, session.run(gstep))
def test_exit_cleanly_on_out_of_range_exception(self):
# Tests that we stop cleanly when OutOfRange is raised.
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
hook = RaiseOnceAtCountN(2, errors_impl.OutOfRangeError(None, None,
'EOI'))
session = monitored_session.SingularMonitoredSession(hooks=[hook])
# session should cleanly exit from the context.
with session:
self.assertEqual(0, session.run(gstep))
self.assertFalse(session.should_stop())
# Here at step 1, the hook triggers and raises OutOfRange. The
# session should go into should_stop() mode. It should raise the
# exception. So next step should not be executed.
session.run(do_step)
self.assertTrue(False)
self.assertTrue(session.should_stop())
def test_regular_exception_reported_to_coord_pass_through_run(self):
# Tests that regular exceptions reported to the coordinator from a thread
# passes through a "run()" call within a "with MonitoredSession" block and
# set the session in stop mode.
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
session = monitored_session.SingularMonitoredSession()
run_performed_without_error = False
with self.assertRaisesRegexp(RuntimeError, 'a thread wants to stop'):
with session:
self.assertEqual(0, session.run(gstep))
# Report an exception through the coordinator.
try:
raise RuntimeError('a thread wants to stop')
except RuntimeError as e:
session._coordinated_creator.coord.request_stop(e)
# Call run() which should perform normally.
self.assertEqual(0, session.run(gstep))
run_performed_without_error = True
self.assertTrue(run_performed_without_error)
def test_stop_cleanly_when_no_exception_in_with_body(self):
# Tests that regular exceptions pass through
with ops.Graph().as_default():
gstep = training_util.get_or_create_global_step()
do_step = state_ops.assign_add(gstep, 1)
session = monitored_session.SingularMonitoredSession()
with session:
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Should have closed.
self.assertTrue(session.should_stop())
self.assertEqual(None, session.raw_session())
def test_graph(self):
with ops.Graph().as_default() as g:
with monitored_session.SingularMonitoredSession() as session:
self.assertEqual(g, session.graph)
def test_raw_session(self):
with ops.Graph().as_default():
with monitored_session.SingularMonitoredSession() as session:
self.assertTrue(isinstance(session.raw_session(), session_lib.Session))
if __name__ == '__main__':
test.main()