Add SupervisedSession class.

Extend the Monitor interface to support a post_step() callback which is passed the Session.

Add a Scaffold object to hold onto the commonly needed graph pieces to run training.

Based on a design by @ilblackdragon .
Change: 127323605
This commit is contained in:
A. Unique TensorFlower 2016-07-13 08:10:45 -08:00 committed by TensorFlower Gardener
parent 3219a1e939
commit 4f6e9efb40
15 changed files with 1177 additions and 74 deletions

View File

@ -684,6 +684,26 @@ py_test(
],
)
py_test(
name = "supervised_session_test",
size = "small",
srcs = ["python/learn/tests/supervised_session_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "summary_writer_cache_test",
size = "small",
srcs = ["python/learn/tests/summary_writer_cache_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow:tensorflow_py",
],
)
py_binary(
name = "inspect_checkpoint",
srcs = [

View File

@ -31,7 +31,6 @@ from tensorflow.contrib.learn.python.learn import monitors
from tensorflow.contrib.learn.python.learn import ops
from tensorflow.contrib.learn.python.learn import preprocessing
from tensorflow.contrib.learn.python.learn import utils
from tensorflow.contrib.learn.python.learn.coordinated_session import *
from tensorflow.contrib.learn.python.learn.dataframe import *
from tensorflow.contrib.learn.python.learn.estimators import *
from tensorflow.contrib.learn.python.learn.experiment import Experiment
@ -42,6 +41,4 @@ from tensorflow.contrib.learn.python.learn.graph_actions import run_feeds
from tensorflow.contrib.learn.python.learn.graph_actions import run_n
from tensorflow.contrib.learn.python.learn.graph_actions import train
from tensorflow.contrib.learn.python.learn.io import *
from tensorflow.contrib.learn.python.learn.recoverable_session import *
from tensorflow.contrib.learn.python.learn.wrapped_session import *
# pylint: enable=wildcard-import

View File

@ -57,3 +57,8 @@ class CoordinatedSession(WrappedSession):
self._coord.request_stop(e)
if self._coord.should_stop():
self._coord.join(self._coordinated_threads_to_join)
# TODO(touts): Add a close() method that also joins the coordinator
# but does not raise exceptions. This can only be done reliably when the
# Coordinator keeps a pointer to the coordinated threads, otherwise we do not
# know which threads to join.

View File

@ -23,7 +23,6 @@ import six
from tensorflow.contrib.learn.python.learn.wrapped_session import WrappedSession
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
class MonitoredSession(WrappedSession):
@ -73,7 +72,6 @@ class MonitoredSession(WrappedSession):
if self._last_step is None:
self._last_step = WrappedSession.run(self, self._global_step_tensor)
logging.info('Initialized step to: %d', self._last_step)
monitors_step = self._last_step + 1
monitor_fetches = []
@ -104,12 +102,16 @@ class MonitoredSession(WrappedSession):
induce_stop = monitor.step_end(monitors_step, monitor_outputs)
self._should_stop = self._should_stop or induce_stop
# Call the post_step methods.
for monitor in self._monitors:
monitor.post_step(monitors_step, self._sess)
return outputs['caller']
# TODO(ispir): Remove following logic after forcing monitors returns tensors.
def _as_graph_element(obj, graph):
"""Retrieves Graph element from tensors or tensor names."""
"""Retrieves Graph element."""
graph = graph or ops.get_default_graph()
if not isinstance(obj, six.string_types):
if not hasattr(obj, 'graph') or obj.graph != graph:

View File

@ -69,13 +69,19 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import numpy as np
import six
from tensorflow.contrib.learn.python.learn.summary_writer_cache import SummaryWriterCache
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.util.event_pb2 import SessionLog
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import summary_io
@ -92,6 +98,7 @@ class BaseMonitor(object):
self._current_epoch = None
self._current_step = None
self._max_steps = None
self._init_step = None
self._estimator = None
def set_estimator(self, estimator):
@ -108,11 +115,14 @@ class BaseMonitor(object):
# TODO(mdan): This should fail if called twice with the same estimator.
self._estimator = estimator
def begin(self, max_steps=None):
def begin(self, max_steps=None, init_step=None):
"""Called at the beginning of training.
When called, the default graph is the one we are executing.
Args:
max_steps: `int`, the maximum global step this training will run until.
init_step: `int`, step at which this training will start.
Raises:
ValueError: if we've already begun a run.
@ -120,14 +130,19 @@ class BaseMonitor(object):
if self._begun:
raise ValueError("begin called twice without end.")
self._max_steps = max_steps
self._init_step = init_step
self._begun = True
def end(self):
def end(self, session=None):
"""Callback at the end of training/evaluation.
Args:
session: A `tf.Session` object that can be used to run ops.
Raises:
ValueError: if we've not begun a run.
"""
_ = session
if not self._begun:
raise ValueError("end called without begin.")
self._max_steps = None
@ -178,8 +193,6 @@ class BaseMonitor(object):
ValueError: if we've already begun a step, or `step` < 0, or
`step` > `max_steps`.
"""
if self._current_step is not None:
raise ValueError("step_begin called twice without step_end.")
if (step < 0) or (
(self._max_steps is not None) and (step > self._max_steps)):
raise ValueError("Invalid step %s." % step)
@ -196,6 +209,9 @@ class BaseMonitor(object):
In addition, the callback has the opportunity to stop training by returning
`True`. This is useful for early stopping, for example.
Note that this method is not called if the call to `Session.run()` that
followed the last call to `step_begin()` failed.
Args:
step: `int`, the current value of the global step.
output: `dict` mapping `string` values representing tensor names to
@ -214,13 +230,26 @@ class BaseMonitor(object):
self._current_step = None
return False
def post_step(self, step, session): # pylint: disable=unused-argument
"""Callback after the step is finished.
Called after step_end and receives session to perform extra session.run
calls. If failure occurred in the process, will be called as well.
Args:
step: `int`, global step of the model.
session: `Session` object.
"""
_ = step, session
class EveryN(BaseMonitor):
"""Base class for monitors that execute callbacks every n steps.
"""Base class for monitors that execute callbacks every N steps.
This class adds two new callbacks:
This class adds three new callbacks:
- every_n_step_begin
- every_n_step_end
- every_n_pos_step
The callbacks are executed every n steps, or optionally every step for the
first m steps, where m and n can both be user-specified.
@ -234,11 +263,16 @@ class EveryN(BaseMonitor):
Failing to call the super implementation will cause unpredictible behavior.
The `every_n_post_step()` callback is also called after the last step if it
was not already called through the regular conditions. Note that
`every_n_step_begin()` and `every_n_step_end()` do not receive that special
treatment.
"""
# TODO(ipolosukhin): Add also every n seconds.
def __init__(self, every_n_steps=100, first_n_steps=1):
"""Initialized an `EveryN` monitor.
"""Initializes an `EveryN` monitor.
Args:
every_n_steps: `int`, the number of steps to allow between callbacks.
@ -249,18 +283,10 @@ class EveryN(BaseMonitor):
super(EveryN, self).__init__()
self._every_n_steps = every_n_steps
self._first_n_steps = first_n_steps
self._last_step = 0
self._active_step = None
def begin(self, max_steps=None):
"""Overrides `BaseMonitor.begin`.
When overriding this method, you must call the super implementation.
Args:
max_steps: `int`, the maximum global step value.
"""
super(EveryN, self).begin(max_steps)
# Last step in the model.
self._last_step = None
# Last step at which we called one of the every_n methods
self._last_active_step = None
def every_n_step_begin(self, step): # pylint: disable=unused-argument
"""Callback before every n'th step begins.
@ -294,6 +320,15 @@ class EveryN(BaseMonitor):
"""
return False
def every_n_post_step(self, step, session):
"""Callback after a step is finished or `end()` is called.
Args:
step: `int`, the current value of the global step.
session: `Session` object.
"""
pass
def step_begin(self, step):
"""Overrides `BaseMonitor.step_begin`.
@ -309,13 +344,13 @@ class EveryN(BaseMonitor):
ValueError: if called more than once during a step.
"""
super(EveryN, self).step_begin(step)
self._last_step = step
if self._last_active_step is None:
self._last_active_step = step - 1
if (step <= self._first_n_steps or
step >= (self._every_n_steps + self._last_step) or
step == self._max_steps):
if self._active_step is not None:
raise ValueError(
"Starting step %s, %s still active." % (step, self._active_step))
self._active_step = step
step >= (self._every_n_steps + self._last_active_step) or
step == self._max_steps): # Note: max_steps can be None here.
self._last_active_step = step
return self.every_n_step_begin(step)
return []
@ -334,12 +369,59 @@ class EveryN(BaseMonitor):
or `False` otherwise.
"""
super(EveryN, self).step_end(step, output)
to_stop = False
if (self._active_step is not None) and (self._active_step == step):
self._last_step = step
to_stop = self.every_n_step_end(step, output)
self._active_step = None
return to_stop
if self._last_active_step == step:
return self.every_n_step_end(step, output)
return False
def post_step(self, step, session):
super(EveryN, self).post_step(step, session)
if self._last_active_step == step:
self.every_n_post_step(step, session)
def end(self, session=None):
super(EveryN, self).end(session=session)
if self._last_step != self._last_active_step:
self.every_n_post_step(self._last_step, session)
class StopAtStep(BaseMonitor):
"""Monitor to request stop at a specified step."""
def __init__(self, num_steps=None, last_step=None):
"""Create a StopAtStep monitor.
This monitor requests stop after either a number of steps have been
executed or a last step has been reached. Only of the two options can be
specified.
if `num_steps` is specified, it indicates the number of steps to execute
after `begin()` is called. If instead `last_step` is specified, it
indicates the last step we want to execute, as passed to the `step_begin()`
call.
Args:
num_steps: Number of steps to execute.
last_step: Step after which to stop.
Raises:
ValueError: If one of the arguments is invalid.
"""
super(StopAtStep, self).__init__()
if num_steps is None and last_step is None:
raise ValueError("One of num_steps or last_step must be specified.")
if num_steps is not None and last_step is not None:
raise ValueError("Only one of num_steps or last_step can be specified.")
self._num_steps = num_steps
self._last_step = last_step
def begin(self, max_steps=None, init_step=None):
super(StopAtStep, self).begin(max_steps=max_steps, init_step=init_step)
if self._num_steps is not None:
self._last_step = init_step + self._num_steps
def step_end(self, step, output):
super(StopAtStep, self).step_end(step, output)
return step >= self._last_step
# TODO(ptucker): Rename to LoggingTensor since it's not writing to stdout.
@ -460,8 +542,8 @@ class SummarySaver(EveryN):
self._summary_writer.add_summary(summary_strs, step)
return False
def end(self):
super(SummarySaver, self).end()
def end(self, session=None):
super(SummarySaver, self).end(session=session)
if self._summary_writer:
self._summary_writer.flush()
@ -553,7 +635,7 @@ class ValidationMonitor(EveryN):
if self._estimator is None:
raise ValueError("Missing call to set_estimator.")
# Check that we are not running evaluation on the same checkpoint.
latest_path = saver.latest_checkpoint(self._estimator.model_dir)
latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
if latest_path == self._latest_path:
logging.info("Skipping evaluation due to same checkpoint %s for step %d "
"as for step %d.", latest_path, step, self._latest_path_step)
@ -677,8 +759,8 @@ class GraphDump(BaseMonitor):
self._ignore_ops = ignore_ops or GraphDump.IGNORE_OPS
self._data = {}
def begin(self, max_steps):
super(GraphDump, self).begin(max_steps)
def begin(self, max_steps=None, init_step=None):
super(GraphDump, self).begin(max_steps=max_steps, init_step=init_step)
self._tensors = []
graph = ops.get_default_graph()
graph_def = graph.as_graph_def()
@ -782,9 +864,121 @@ class ExportMonitor(EveryN):
logging.info("Skipping exporting for the same step. "
"Consider exporting less frequently.")
def end(self):
super(ExportMonitor, self).end()
def end(self, session=None):
super(ExportMonitor, self).end(session=session)
export.export_estimator(self._estimator,
self.export_dir,
exports_to_keep=self.exports_to_keep,
signature_fn=self.signature_fn)
class CheckpointSaver(EveryN):
"""Saves checkpoints every N steps."""
def __init__(self, every_n_steps, saver, checkpoint_dir,
checkpoint_basename="model.ckpt",
first_n_steps=-1):
"""Initialize CheckpointSaver monitor.
Args:
every_n_steps: `int`, save every N steps.
saver: `Saver` object, used for saving.
checkpoint_dir: `str`, base directory for the checkpoint files.
checkpoint_basename: `str`, base name for the checkpoint files.
first_n_steps: `int`, if positive, save every step during the
first `first_n_steps` steps.
"""
logging.info("Create CheckpointSaver")
super(CheckpointSaver, self).__init__(every_n_steps=every_n_steps,
first_n_steps=first_n_steps)
self._saver = saver
self._summary_writer = SummaryWriterCache.get(checkpoint_dir)
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
def every_n_post_step(self, step, session):
logging.info("Saving checkpoints for %d into %s." % (step, self._save_path))
self._saver.save(session, self._save_path, global_step=step)
if self._summary_writer:
self._summary_writer.add_session_log(
SessionLog(status=SessionLog.CHECKPOINT,
checkpoint_path=self._save_path),
step)
class StepCounter(EveryN):
"""Steps per second monitor."""
def __init__(self, every_n_steps=100, output_dir=None,
summary_writer=None):
super(StepCounter, self).__init__(every_n_steps=every_n_steps)
self._summary_tag = "global_step/sec"
self._last_reported_step = None
self._last_reported_time = None
self._summary_writer = None
if summary_writer is None and output_dir:
self._summary_writer = SummaryWriterCache.get(output_dir)
def begin(self, init_step):
super(StepCounter, self).begin(init_step)
self._last_reported_step = self._init_step
self._last_reported_time = time.time()
def set_estimator(self, estimator):
super(StepCounter, self).set_estimator(estimator)
if self._summary_writer is None:
self._summary_writer = SummaryWriterCache.get(estimator.model_dir)
def every_n_step_end(self, current_step, outputs):
current_time = time.time()
if self._last_reported_time is not None and self._summary_writer:
added_steps = current_step - self._last_reported_step
elapsed_time = current_time - self._last_reported_time
steps_per_sec = added_steps / elapsed_time
summary = Summary(value=[Summary.Value(tag=self._summary_tag,
simple_value=steps_per_sec)])
self._summary_writer.add_summary(summary, current_step)
self._last_reported_step = current_step
self._last_reported_time = current_time
class NanLossDuringTrainingError(RuntimeError):
def __str__(self):
return "NaN loss during training."
class NanLoss(EveryN):
"""NaN Loss monitor.
Monitors loss and stops training if loss is NaN.
Can either fail with exception or just stop training.
"""
def __init__(self, loss_tensor, every_n_steps=100, fail_on_nan_loss=True):
"""Initializes NanLoss monitor.
Args:
loss_tensor: `Tensor`, the loss tensor.
every_n_steps: `int`, run check every this many steps.
fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
"""
super(NanLoss, self).__init__(every_n_steps=every_n_steps)
self._loss_tensor = loss_tensor
self._fail_on_nan_loss = fail_on_nan_loss
def every_n_step_begin(self, step):
super(NanLoss, self).every_n_step_begin(step)
return self._loss_tensor
def every_n_step_end(self, step, outputs):
super(NanLoss, self).every_n_step_end(step, outputs)
if np.isnan(outputs):
failure_message = "Model diverged with loss = NaN."
if self._fail_on_nan_loss:
logging.error(failure_message)
raise NanLossDuringTrainingError
else:
logging.warning(failure_message)
# We don't raise an error but we return "should stop" so we stop, but
# without an exception.
return True

View File

@ -0,0 +1,65 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for a Session-like object that handles threads and recovery.
Based on an original design of Illia Polosukhin.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import threading
from tensorflow.python.framework import ops
from tensorflow.python.training import summary_io
class SummaryWriterCache(object):
"""Cache for summary writers.
This class caches summary writers, one per directory.
"""
# Cache, keyed by directory.
_cache = {}
# Lock protecting _SUMMARY_WRITERS.
_lock = threading.RLock()
@staticmethod
def clear():
"""Clear cached summary writers. Currently only used for unit tests."""
with SummaryWriterCache._lock:
SummaryWriterCache._cache = {}
@staticmethod
def get(logdir):
"""Returns the SummaryWriter for the specified directory.
Args:
logdir: str, name of the directory.
Returns:
A `SummaryWriter`.
"""
with SummaryWriterCache._lock:
if logdir not in SummaryWriterCache._cache:
SummaryWriterCache._cache[logdir] = summary_io.SummaryWriter(
logdir, graph=ops.get_default_graph())
return SummaryWriterCache._cache[logdir]
# Backward compatible interface. Remove?
clear_summary_writers = SummaryWriterCache.clear
get_summary_writer = SummaryWriterCache.get

View File

@ -0,0 +1,301 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wrapper for a Session-like object that handles threads and recovery.
Based on an original design of Illia Polosukhin.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import coordinated_session
from tensorflow.contrib.learn.python.learn import monitored_session
from tensorflow.contrib.learn.python.learn import recoverable_session
from tensorflow.contrib.learn.python.learn import summary_writer_cache
from tensorflow.core.util.event_pb2 import SessionLog
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import variables
from tensorflow.python.training import coordinator
from tensorflow.python.training import queue_runner
from tensorflow.python.training import saver as training_saver
from tensorflow.python.training import session_manager as sm
from tensorflow.python.training import training_util
# TODO(touts): Share that with the Supervisor.
class Scaffold(object):
"""Structure to create or gather pieces commonly needed to train a model.
When you build a model for training you usually need ops to initialize
variables, a `Saver` to checkpoint them, an op to collect summaries for
the visualizer, and so on.
Various libraries built on top of the core TensorFlow library take care of
creating some or all of these pieces and storing them in well known
collections in the graph. The `Scaffold` class helps pick these pieces from
the graph collections, creating and adding them to the collections if needed.
If you call the scaffold constructor without any arguments it will pick
pieces from the collections, creating default ones if needed. You can pass
arguments to the constructor to provide your own pieces. Pieces that you
pass to the constructor are not added to the graph collections.
The following pieces are directly accessible as attributes of the `Scaffold`
object:
* `saver`: A `tf.Saver` object taking care of saving the variables. Picked
from and stored into the `SAVERS` collection in the graph.
* `init_op`: An op to run to initialize the variables. Picked from and
stored into the `INIT_OP` collection in the graph.
* `ready_op`: An op to verify that the variables are initialized. Picked
from and stored into the `READY_OP` collection in the graph.
* `local_init_op`: An op to initialize the local variables. Picked
from and stored into the `LOCAL_INIT_OP` collection in the graph.
* `summary_op`: An op to run and merge the summaries in the graph. Picked
from and stored into the `SUMMARY_OP` collection in the graph.
* `global_step`: A tensor containing the global step counter. Picked
from and stored into the `GLOBAL_STEP` collection in the graph.
You can also pass the following additional pieces to the constructor:
* `init_feed_dict`: A sessionn feed dictionary that should be used when
running the init op.
* `init_fn`: A callable to run run after the init op to perform additional
initializations. The callable will be called as
`init_fn(scaffold, session)`.
"""
# TODO(touts): consider adding the output dir and summary writer (cached)?
# TODO(touts): consider finalizeing the graph? (If the graph is
# modified later, the cached parts could be wrong.)
# TODO(touts): I do not think we should pass keep_checkpoint_max here.
# TODO(touts): Add individual static functions for init_op(), etc. that
# implement the caching logic.
def __init__(self,
global_step_tensor=None,
init_op=None,
init_feed_dict=None,
init_fn=None,
ready_op=None,
local_init_op=None,
summary_op=None,
saver=None,
keep_checkpoint_max=5):
"""Create a scaffold.
Args:
global_step_tensor: Optional tensor to use as the global step counter.
init_op: Optional op for initializing variables.
init_feed_dict: Optional session feed dictionary to use when running the
init_op.
init_fn: Optional function to use to initialize the model after running
the init_op. Will be called as `init_fn(scaffold, session)`.
ready_op: Optional op to verify that the variables are initialized. Must
return an empty scalar string tensor when the variables are
initialized, or a non-empty one listing the names of the
non-initialized variables.
local_init_op: Optional op to initialize local variables.
summary_op: Optional op to gather all summaries. Must return a scalar
string tensor containing a serialized `Summary` proto.
saver: Optional `tf.Saver` object to use to save and restore variables.
keep_checkpoint_max: Optional parameter to use to construct a saver if
none is already there in the graph.
"""
if global_step_tensor is None:
global_step_tensor = contrib_variables.get_or_create_global_step()
self.global_step_tensor = global_step_tensor
if init_op is None:
init_op = Scaffold._get_or_default(
ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
self.init_op = init_op
self.init_feed_dict = init_feed_dict
# NOTE(touts): modifying the init function to be passed the scaffold is a
# hack to make it easy to find the saver. Is there a better way?
if init_fn:
self.init_fn = lambda sess: init_fn(self, sess)
else:
self.init_fn = None
if ready_op is None:
ready_op = Scaffold._get_or_default(
ops.GraphKeys.READY_OP, variables.report_uninitialized_variables)
self.ready_op = ready_op
if local_init_op is None:
local_init_op = Scaffold._get_or_default(
ops.GraphKeys.LOCAL_INIT_OP, Scaffold._default_local_init_op)
self.local_init_op = local_init_op
if summary_op is None:
summary_op = Scaffold._get_or_default(
ops.GraphKeys.SUMMARY_OP, logging_ops.merge_all_summaries)
# pylint: disable=g-long-lambda
if saver is None:
saver = Scaffold._get_or_default(
ops.GraphKeys.SAVERS,
lambda: training_saver.Saver(sharded=True,
max_to_keep=keep_checkpoint_max))
# pylint: enable=g-long-lambda
self.saver = saver
@staticmethod
def _get_or_default(collection_key, default_constructor):
elements = ops.get_collection(collection_key)
if elements:
return elements[0]
op = default_constructor()
if op is not None:
ops.add_to_collection(collection_key, op)
return op
@staticmethod
def _default_local_init_op():
return control_flow_ops.group(variables.initialize_local_variables(),
data_flow_ops.initialize_all_tables())
class SupervisedSession(object):
"""Session-like object that supports recovery and monitors.
"""
def __init__(self, master, is_chief=True, checkpoint_dir=None,
monitors=None, scaffold=None, config=None,
clean_stop_exception_types=None):
self._graph = ops.get_default_graph()
self._master = master
self._checkpoint_dir = checkpoint_dir
self._is_chief = is_chief
self._config = config
self._clean_stop_exception_types = clean_stop_exception_types
self._monitors = monitors or []
self._scaffold = scaffold or Scaffold()
# Finalize and write the graph.
self._graph.finalize()
# Create the session.
self._session_manager = sm.SessionManager(
local_init_op=self._scaffold.local_init_op,
ready_op=self._scaffold.ready_op,
graph=ops.get_default_graph())
self._sess = recoverable_session.RecoverableSession(self._create_session)
# Call the begin() method of monitors.
self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor)
for monitor in self._monitors:
monitor.begin(max_steps=None, init_step=self._init_step)
# Write the graph out, note: this uses self._init_step.
self.write_graph()
def _create_session(self):
"""Factory for the RecoverableSession.
Returns:
A session, initialized or recovered as needed.
"""
if self._is_chief:
tf_sess = self._session_manager.prepare_session(
self._master, saver=self._scaffold.saver,
checkpoint_dir=self._checkpoint_dir, config=self._config,
init_op=self._scaffold.init_op,
init_feed_dict=self._scaffold.init_feed_dict,
init_fn=self._scaffold.init_fn)
else:
tf_sess = self._session_manager.wait_for_session(
self._master, config=self._config)
# Keep the tf_sess for quick runs of global step when needed.
self._tf_sess = tf_sess
self._coord = coordinator.Coordinator(
clean_stop_exception_types=self._clean_stop_exception_types)
self._coordinated_threads_to_join = queue_runner.start_queue_runners(
sess=tf_sess, coord=self._coord)
return coordinated_session.CoordinatedSession(
monitored_session.MonitoredSession(tf_sess, self._monitors,
self._scaffold.global_step_tensor),
self._coord, self._coordinated_threads_to_join)
@property
def scaffold(self):
return self._scaffold
@property
def session(self):
return self._tf_sess
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
"""Run ops in the supervised session.
This method is completely compatible with the `tf.Session.run()` method.
Args:
fetches: Same as `tf.Session.run()`.
feed_dict: Same as `tf.Session.run()`.
options: Same as `tf.Session.run()`.
run_metadata: Same as `tf.Session.run()`.
Returns:
Same as `tf.Session.run()`.
"""
return self._sess.run(fetches, feed_dict=feed_dict, options=options,
run_metadata=run_metadata)
def should_stop(self):
if self._sess:
return self._sess.should_stop()
return True
def close(self):
# Run the Monitor.end() methods.
for monitor in self._monitors:
monitor.end(self._tf_sess)
self._sess.close()
self._sess = None
self._tf_sess = None
def _is_closed(self):
"""Return True if the supervised session is closed. For tests only.
Returns:
A boolean.
"""
return self._tf_sess is None
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, traceback):
if exception_type:
self._coord.request_stop((exception_type, exception_value, traceback))
else:
self._coord.request_stop()
try:
self._coord.join(self._coordinated_threads_to_join)
# If coord does not raise an exception, we return True to indicate
# "no exception to raise".
return True
finally:
self.close()
def write_graph(self):
"""Saves current graph."""
if self._checkpoint_dir is not None and self._is_chief:
summary_writer = summary_writer_cache.SummaryWriterCache.get(
self._checkpoint_dir)
training_util.write_graph(self._graph.as_graph_def(add_shapes=True),
self._checkpoint_dir, 'graph.pbtxt')
summary_writer.add_graph(self._graph)
summary_writer.add_session_log(SessionLog(status=SessionLog.START),
self._init_step)

View File

@ -24,6 +24,8 @@ import time
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import coordinated_session
def BusyWaitForCoordStop(coord):
while not coord.should_stop():
@ -37,7 +39,7 @@ class CoordinatedSessionTest(tf.test.TestCase):
with self.test_session() as sess:
tf.constant(0.0)
coord = tf.train.Coordinator()
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, [])
coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
self.assertEquals(sess.graph, coord_sess.graph)
self.assertEquals(sess.sess_str, coord_sess.sess_str)
@ -46,13 +48,13 @@ class CoordinatedSessionTest(tf.test.TestCase):
c = tf.constant(0)
v = tf.identity(c)
coord = tf.train.Coordinator()
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, [])
coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
def test_should_stop_on_close(self):
with self.test_session() as sess:
coord = tf.train.Coordinator()
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, [])
coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
self.assertFalse(coord_sess.should_stop())
coord_sess.close()
self.assertTrue(coord_sess.should_stop())
@ -60,7 +62,7 @@ class CoordinatedSessionTest(tf.test.TestCase):
def test_should_stop_on_coord_stop(self):
with self.test_session() as sess:
coord = tf.train.Coordinator()
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, [])
coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
self.assertFalse(coord_sess.should_stop())
coord.request_stop()
self.assertTrue(coord_sess.should_stop())
@ -70,7 +72,7 @@ class CoordinatedSessionTest(tf.test.TestCase):
c = tf.constant(0)
v = tf.identity(c)
coord = tf.train.Coordinator()
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, [])
coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
self.assertFalse(coord_sess.should_stop())
self.assertEqual(0, coord_sess.run(c))
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
@ -89,7 +91,7 @@ class CoordinatedSessionTest(tf.test.TestCase):
for _ in range(3)]
for t in threads:
t.start()
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, threads)
coord_sess = coordinated_session.CoordinatedSession(sess, coord, threads)
self.assertFalse(coord_sess.should_stop())
for t in threads:
self.assertTrue(t.is_alive())

View File

@ -47,18 +47,26 @@ class FakeMonitor(monitors.BaseMonitor):
self.should_stop = False
self.requested_tensors = []
self.call_counter = Counter()
self.last_begin_step = None
self.last_end_step = None
self.last_post_step = None
def step_begin(self, step):
self.call_counter['step_begin'] += 1
self.begin_step = step
self.last_begin_step = step
return self.requested_tensors
def step_end(self, step, output):
self.call_counter['step_end'] += 1
self.end_step = step
self.last_end_step = step
self.output = output
return self.should_stop
def post_step(self, step, session):
self.call_counter['post_step'] += 1
self.last_post_step = step
self.session = session
class MonitoredSessionTest(tf.test.TestCase):
@ -83,7 +91,7 @@ class MonitoredSessionTest(tf.test.TestCase):
'run_metadata': 'a_metadata'
})
def testCallsMonitorsBeginAndEnd(self):
def testCallsMonitorsBeginEndAndPost(self):
with tf.Graph().as_default(), tf.Session() as sess:
global_step_tensor = tf.contrib.framework.create_global_step()
mock_mon = FakeMonitor()
@ -99,10 +107,12 @@ class MonitoredSessionTest(tf.test.TestCase):
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.output, {})
self.assertEqual(mon.begin_step, 11)
self.assertEqual(mon.end_step, 11)
self.assertEqual(mon.last_begin_step, 11)
self.assertEqual(mon.last_end_step, 11)
self.assertEqual(mon.last_post_step, 11)
self.assertEqual(mon.call_counter['step_end'], 1)
self.assertEqual(mon.call_counter['step_begin'], 1)
self.assertEqual(mon.call_counter['post_step'], 1)
def testCallsMonitorsWithLastStep(self):
with tf.Graph().as_default(), tf.Session() as sess:
@ -119,18 +129,21 @@ class MonitoredSessionTest(tf.test.TestCase):
mon_sess.run(fetches=[inc_5])
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.begin_step, 1)
self.assertEqual(mon.end_step, 1)
self.assertEqual(mon.last_begin_step, 1)
self.assertEqual(mon.last_end_step, 1)
self.assertEqual(mon.last_post_step, 1)
mon_sess.run(fetches=[inc_5])
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.begin_step, 6)
self.assertEqual(mon.end_step, 6)
self.assertEqual(mon.last_begin_step, 6)
self.assertEqual(mon.last_end_step, 6)
self.assertEqual(mon.last_post_step, 6)
mon_sess.run(fetches=[inc_5])
for mon in [mock_mon, mock_mon2]:
self.assertEqual(mon.begin_step, 11)
self.assertEqual(mon.end_step, 11)
self.assertEqual(mon.last_begin_step, 11)
self.assertEqual(mon.last_end_step, 11)
self.assertEqual(mon.last_post_step, 11)
def testShouldStop(self):
with tf.Graph().as_default(), tf.Session() as sess:

View File

@ -34,6 +34,7 @@ class _MyEveryN(learn.monitors.EveryN):
every_n_steps=every_n_steps, first_n_steps=first_n_steps)
self._steps_begun = []
self._steps_ended = []
self._post_steps = []
@property
def steps_begun(self):
@ -43,6 +44,10 @@ class _MyEveryN(learn.monitors.EveryN):
def steps_ended(self):
return self._steps_ended
@property
def post_steps(self):
return self._post_steps
def every_n_step_begin(self, step):
super(_MyEveryN, self).every_n_step_begin(step)
self._steps_begun.append(step)
@ -53,6 +58,11 @@ class _MyEveryN(learn.monitors.EveryN):
self._steps_ended.append(step)
return False
def every_n_post_step(self, step, session):
super(_MyEveryN, self).every_n_post_step(step, session)
self._post_steps.append(step)
return False
class MonitorsTest(tf.test.TestCase):
"""Monitors tests."""
@ -71,8 +81,13 @@ class MonitorsTest(tf.test.TestCase):
def tearDown(self):
logging.info = self._actual_log
def _run_monitor(self, monitor, num_epochs=3, num_steps_per_epoch=10):
monitor.begin(max_steps=(num_epochs * num_steps_per_epoch) - 1)
def _run_monitor(self, monitor, num_epochs=3, num_steps_per_epoch=10,
pass_max_steps=True):
if pass_max_steps:
max_steps = num_epochs * num_steps_per_epoch - 1
else:
max_steps = None
monitor.begin(max_steps=max_steps, init_step=0)
for epoch in xrange(num_epochs):
monitor.epoch_begin(epoch)
should_stop = False
@ -85,6 +100,7 @@ class MonitorsTest(tf.test.TestCase):
[t.name if isinstance(t, tf.Tensor) else t for t in tensors],
output))
should_stop = monitor.step_end(step=step, output=output)
monitor.post_step(step=step, session=None)
step += 1
monitor.epoch_end(epoch)
monitor.end()
@ -100,6 +116,18 @@ class MonitorsTest(tf.test.TestCase):
expected_steps = [0, 1, 2, 10, 18, 26, 29]
self.assertEqual(expected_steps, monitor.steps_begun)
self.assertEqual(expected_steps, monitor.steps_ended)
self.assertEqual(expected_steps, monitor.post_steps)
def test_every_n_no_max_steps(self):
monitor = _MyEveryN(every_n_steps=8, first_n_steps=2)
with tf.Graph().as_default() as g, self.test_session(g):
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10,
pass_max_steps=False)
begin_end_steps = [0, 1, 2, 10, 18, 26]
post_steps = [0, 1, 2, 10, 18, 26, 29]
self.assertEqual(begin_end_steps, monitor.steps_begun)
self.assertEqual(begin_end_steps, monitor.steps_ended)
self.assertEqual(post_steps, monitor.post_steps)
def test_print(self):
with tf.Graph().as_default() as g, self.test_session(g):

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import recoverable_session
class AbortAtNSession(object):
"""A mock sessionthat aborts at the N-th run call."""
@ -42,7 +44,7 @@ class RecoverableSessionTest(tf.test.TestCase):
def test_properties(self):
with self.test_session() as sess:
tf.constant(0.0)
recoverable_sess = tf.contrib.learn.RecoverableSession(lambda: sess)
recoverable_sess = recoverable_session.RecoverableSession(lambda: sess)
self.assertEquals(sess.graph, recoverable_sess.graph)
self.assertEquals(sess.sess_str, recoverable_sess.sess_str)
@ -50,7 +52,7 @@ class RecoverableSessionTest(tf.test.TestCase):
with self.test_session() as sess:
c = tf.constant(0)
v = tf.identity(c)
recoverable_sess = tf.contrib.learn.RecoverableSession(lambda: sess)
recoverable_sess = recoverable_session.RecoverableSession(lambda: sess)
self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
def test_recovery(self):
@ -65,7 +67,7 @@ class RecoverableSessionTest(tf.test.TestCase):
self.assertEqual(3, len(sessions_to_use))
# Make the recoverable session uses these 3 sessions in sequence by
# passing a factory that pops from the session_to_use list.
recoverable_sess = tf.contrib.learn.RecoverableSession(
recoverable_sess = recoverable_session.RecoverableSession(
lambda: sessions_to_use.pop(0))
self.assertEqual(2, len(sessions_to_use)) # One session popped.
# Using first session.

View File

@ -0,0 +1,68 @@
# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Runner."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import glob
import os
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import summary_writer_cache
class SummaryWriterCacheTest(tf.test.TestCase):
"""SummaryWriterCache tests."""
def _test_dir(self, test_name):
test_dir = os.path.join(self.get_temp_dir(), test_name)
if os.path.isdir(test_dir):
os.removedirs(test_dir)
os.makedirs(test_dir)
return test_dir
def test_cache(self):
with tf.Graph().as_default():
dir1 = self._test_dir('test_cache_1')
dir2 = self._test_dir('test_cache_2')
sw1 = summary_writer_cache.SummaryWriterCache.get(dir1)
sw2 = summary_writer_cache.SummaryWriterCache.get(dir2)
sw3 = summary_writer_cache.SummaryWriterCache.get(dir1)
self.assertEqual(sw1, sw3)
self.assertFalse(sw1 == sw2)
sw1.close()
sw2.close()
events1 = glob.glob(os.path.join(dir1, 'event*'))
self.assertTrue(events1)
events2 = glob.glob(os.path.join(dir2, 'event*'))
self.assertTrue(events2)
events3 = glob.glob(os.path.join('nowriter', 'event*'))
self.assertFalse(events3)
def test_clear(self):
with tf.Graph().as_default():
dir1 = self._test_dir('test_clear')
sw1 = summary_writer_cache.SummaryWriterCache.get(dir1)
summary_writer_cache.SummaryWriterCache.clear()
sw2 = summary_writer_cache.SummaryWriterCache.get(dir1)
self.assertFalse(sw1 == sw2)
if __name__ == '__main__':
tf.test.main()

View File

@ -0,0 +1,404 @@
# pylint: disable=g-bad-file-header
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for SupervisedSession."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import supervised_session
class ScaffoldTest(tf.test.TestCase):
"""Scaffold tests."""
def test_defaults_empty_graph(self):
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
self.assertTrue(isinstance(scaffold.global_step_tensor, tf.Variable))
self.assertTrue(isinstance(scaffold.init_op, tf.Operation))
self.assertEqual(None, scaffold.init_feed_dict)
self.assertEqual(None, scaffold.init_fn)
self.assertTrue(isinstance(scaffold.ready_op, tf.Tensor))
self.assertTrue(isinstance(scaffold.local_init_op, tf.Operation))
self.assertTrue(isinstance(scaffold.saver, tf.train.Saver))
with self.test_session() as sess:
self.assertTrue(b'global_step' in sess.run(scaffold.ready_op))
sess.run([scaffold.init_op, scaffold.local_init_op])
self.assertEquals(0, len(sess.run(scaffold.ready_op)))
self.assertEquals(0, sess.run(scaffold.global_step_tensor))
def test_caches_values(self):
with tf.Graph().as_default():
scaffold1 = supervised_session.Scaffold()
scaffold2 = supervised_session.Scaffold()
self.assertEqual(scaffold1.global_step_tensor,
scaffold2.global_step_tensor)
self.assertEqual(scaffold1.init_op, scaffold2.init_op)
self.assertEqual(scaffold1.ready_op, scaffold2.ready_op)
self.assertEqual(scaffold1.local_init_op, scaffold2.local_init_op)
self.assertEqual(scaffold1.saver, scaffold2.saver)
def test_uses_passed_values(self):
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold(global_step_tensor=1,
init_op=2,
init_feed_dict=3,
init_fn=lambda scaffold, sess: 4,
ready_op=5,
local_init_op=6,
saver=7)
self.assertEqual(1, scaffold.global_step_tensor)
self.assertEqual(2, scaffold.init_op)
self.assertEqual(3, scaffold.init_feed_dict)
self.assertTrue(callable(scaffold.init_fn))
self.assertEqual(5, scaffold.ready_op)
self.assertEqual(6, scaffold.local_init_op)
self.assertEqual(7, scaffold.saver)
class RaiseOnceAtStepN(tf.contrib.learn.monitors.BaseMonitor):
"""Monitor that raises an Exception at step N."""
def __init__(self, n, ex):
super(RaiseOnceAtStepN, self).__init__()
self.n = n
self.ex = ex
self.raised = False
def step_begin(self, step):
super(RaiseOnceAtStepN, self).step_begin(step)
# Raise the first time we reach step N.
if step == self.n and not self.raised:
self.raised = True
raise self.ex
return []
class SupervisedSessionTest(tf.test.TestCase):
"""SupervisedSession tests."""
def _test_dir(self, test_name):
"""Create an empty dir to use for tests.
Args:
test_name: Name of the test.
Returns:
Absolute path to the test directory.
"""
test_dir = os.path.join(self.get_temp_dir(), test_name)
if os.path.isdir(test_dir):
os.removedirs(test_dir)
os.makedirs(test_dir)
return test_dir
def test_defaults(self):
with tf.Graph().as_default():
with supervised_session.SupervisedSession('') as session:
self.assertEqual(0, session.run(session.scaffold.global_step_tensor))
def test_last_step(self):
logdir = self._test_dir('test_last_step')
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
# Run till step 3 and save.
monitors = [tf.contrib.learn.monitors.StopAtStep(last_step=3)]
with supervised_session.SupervisedSession('', scaffold=scaffold,
monitors=monitors) as session:
self.assertEqual(0, session.run(gstep))
self.assertFalse(session.should_stop())
self.assertEqual(1, session.run(do_step))
self.assertFalse(session.should_stop())
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
self.assertEqual(3, session.run(do_step))
self.assertTrue(session.should_stop())
save_path = scaffold.saver.save(session.session,
os.path.join(logdir, 'step-3'))
# Run till step 5 and save.
def load_ckpt(scaffold, sess):
scaffold.saver.restore(sess, save_path)
scaffold = supervised_session.Scaffold(init_fn=load_ckpt)
monitors = [tf.contrib.learn.monitors.StopAtStep(last_step=5)]
with supervised_session.SupervisedSession('', scaffold=scaffold,
monitors=monitors) as session:
self.assertEqual(3, session.run(gstep))
self.assertFalse(session.should_stop())
self.assertEqual(4, session.run(do_step))
self.assertFalse(session.should_stop())
self.assertEqual(5, session.run(do_step))
self.assertTrue(session.should_stop())
def test_num_steps(self):
logdir = self._test_dir('test_num_steps')
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
# Do 3 steps and save.
monitors = [tf.contrib.learn.monitors.StopAtStep(num_steps=3)]
with supervised_session.SupervisedSession('', scaffold=scaffold,
monitors=monitors) as session:
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertTrue(session.should_stop())
save_path = scaffold.saver.save(session.session,
os.path.join(logdir, 'step-3'))
# Restore and do 4 steps.
def load_ckpt(scaffold, sess):
scaffold.saver.restore(sess, save_path)
scaffold = supervised_session.Scaffold(init_fn=load_ckpt)
monitors = [tf.contrib.learn.monitors.StopAtStep(num_steps=4)]
with supervised_session.SupervisedSession('', scaffold=scaffold,
monitors=monitors) as session:
self.assertEqual(3, session.run(gstep))
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertFalse(session.should_stop())
session.run(do_step)
self.assertTrue(session.should_stop())
# This set of tests, verifies the supervised session behavior when exceptions
# are raised next to the innermost session run() call.
def test_recovery(self):
logdir = self._test_dir('test_recovery')
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
# Use a monitor to save the model every 100 steps. It also saves it at
# the end.
monitors = [tf.contrib.learn.monitors.CheckpointSaver(
100, scaffold.saver, logdir)]
with supervised_session.SupervisedSession('', scaffold=scaffold,
checkpoint_dir=logdir,
monitors=monitors) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
# A restart will find the checkpoint and recover automatically.
with supervised_session.SupervisedSession(
'', scaffold=scaffold, checkpoint_dir=logdir) as session:
self.assertEqual(2, session.run(gstep))
def test_retry_on_aborted_error(self):
# Tests that we silently retry on abort. Note that this does not test
# recovery as we do not use a CheckpointSaver in this test.
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
monitor = RaiseOnceAtStepN(3, tf.errors.AbortedError(None, None, 'Abort'))
with supervised_session.SupervisedSession('', scaffold=scaffold,
monitors=[monitor]) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Here at step 3, the monitor triggers and raises AbortedError. The
# SupervisedSession automatically retries and restart from a freshly
# initialized session, so the step is back to 0 and running do_step
# moves it to 1.
self.assertEqual(1, session.run(do_step))
self.assertFalse(session.should_stop())
self.assertTrue(monitor.raised)
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
def test_recover_and_retry_on_aborted_error(self):
# Tests that we silently retry and recover on abort. This test uses
# a CheckpointSaver to have something to recover from.
logdir = self._test_dir('test_recover_and_retry_on_aborted_error')
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
abort_monitor = RaiseOnceAtStepN(
3, tf.errors.AbortedError(None, None, 'Abort'))
# Save after each step.
ckpt_monitor = tf.contrib.learn.monitors.CheckpointSaver(
1, scaffold.saver, logdir)
monitors = [abort_monitor, ckpt_monitor]
with supervised_session.SupervisedSession('', scaffold=scaffold,
checkpoint_dir=logdir,
monitors=monitors) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Here at step 3, the monitor triggers and raises AbortedError. The
# SupervisedSession automatically restores and retries.
self.assertEqual(3, session.run(do_step))
self.assertTrue(abort_monitor.raised)
self.assertFalse(session.should_stop())
self.assertEqual(4, session.run(do_step))
self.assertFalse(session.should_stop())
def test_stop_cleanly_on_out_of_range_exception(self):
# Tests that we stop cleanly when OutOfRange is raised.
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
monitor = RaiseOnceAtStepN(
3, tf.errors.OutOfRangeError(None, None, 'EOI'))
with supervised_session.SupervisedSession('', scaffold=scaffold,
monitors=[monitor]) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Here at step 3, the monitor triggers and raises OutOfRange. The
# session should go into should_stop() mode. We do not get a result
# in that case.
self.assertEqual(None, session.run(do_step))
self.assertTrue(monitor.raised)
self.assertTrue(session.should_stop())
def test_stop_cleanly_on_custom_exception(self):
# Tests that we stop cleanly when an exception type of
# our choice is raised (StopIteration here.)
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
monitor = RaiseOnceAtStepN(3, StopIteration('I choose you'))
exception_types = [tf.errors.OutOfRangeError, StopIteration]
with supervised_session.SupervisedSession(
'', scaffold=scaffold,
monitors=[monitor],
clean_stop_exception_types=exception_types) as session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Here at step 3, the monitor triggers and raises StopIteration. The
# session should go into should_stop() mode. We do not get a result
# in that case.
self.assertEqual(None, session.run(do_step))
self.assertTrue(monitor.raised)
self.assertTrue(session.should_stop())
# This set of tests, verifies the session behavior when exceptions are raised
# from code inside a "with SupervisedSession:" context.
def test_regular_exception_pass_through_in_with_body(self):
# Tests that regular exceptions just pass through a "with
# SupervisedSession" block and set the session in stop mode.
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
monitor = RaiseOnceAtStepN(3, RuntimeError('regular exception'))
session = supervised_session.SupervisedSession('', scaffold=scaffold,
monitors=[monitor])
with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
with session:
self.assertEqual(0, session.run(gstep))
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# This triggers the monitor and raises the exception
session.run(do_step)
# We should not hit this
self.assertFalse(True)
self.assertTrue(monitor.raised)
self.assertTrue(session.should_stop())
def test_stop_cleanly_when_no_exception_in_with_body(self):
# Tests that regular exceptions pass through
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
session = supervised_session.SupervisedSession('', scaffold=scaffold)
with session:
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Should have closed.
self.assertTrue(session.should_stop())
self.assertTrue(session._is_closed())
def test_stop_cleanly_on_out_of_range_exception_in_with_body(self):
# Tests that regular exceptions pass through
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
session = supervised_session.SupervisedSession('', scaffold=scaffold)
with session:
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
raise tf.errors.OutOfRangeError(None, None, 'EOI')
# Should have closed.
self.assertTrue(session.should_stop())
self.assertTrue(session._is_closed())
def test_raises_regular_exceptions_in_with_body(self):
# Tests that regular exceptions in "with body" are seen outside.
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
session = supervised_session.SupervisedSession('', scaffold=scaffold)
# We should see that exception.
with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
with session:
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
# Will be visible outside the "with body".
raise RuntimeError('regular exception')
# Should have closed.
self.assertTrue(session.should_stop())
self.assertTrue(session._is_closed())
def test_stop_cleanly_on_custom_exception_in_with_body(self):
with tf.Graph().as_default():
scaffold = supervised_session.Scaffold()
gstep = scaffold.global_step_tensor
do_step = tf.assign_add(gstep, 1)
exception_types = [tf.errors.OutOfRangeError, StopIteration]
session = supervised_session.SupervisedSession(
'', scaffold=scaffold, clean_stop_exception_types=exception_types)
with session:
self.assertEqual(1, session.run(do_step))
self.assertEqual(2, session.run(do_step))
self.assertFalse(session.should_stop())
raise StopIteration('EOI')
# Should have closed.
self.assertTrue(session.should_stop())
self.assertTrue(session._is_closed())
if __name__ == '__main__':
tf.test.main()

View File

@ -21,8 +21,10 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import wrapped_session
class StopAtNSession(tf.contrib.learn.WrappedSession):
class StopAtNSession(wrapped_session.WrappedSession):
"""A wrapped session that stops at the N-th call to _check_stop."""
def __init__(self, sess, n):
@ -42,13 +44,13 @@ class WrappedSessionTest(tf.test.TestCase):
def test_properties(self):
with self.test_session() as sess:
tf.constant(0.0)
wrapped_sess = tf.contrib.learn.WrappedSession(sess)
wrapped_sess = wrapped_session.WrappedSession(sess)
self.assertEquals(sess.graph, wrapped_sess.graph)
self.assertEquals(sess.sess_str, wrapped_sess.sess_str)
def test_should_stop_on_close(self):
with self.test_session() as sess:
wrapped_sess = tf.contrib.learn.WrappedSession(sess)
wrapped_sess = wrapped_session.WrappedSession(sess)
self.assertFalse(wrapped_sess.should_stop())
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
@ -64,7 +66,7 @@ class WrappedSessionTest(tf.test.TestCase):
def test_should_stop_delegates_to_wrapped_session(self):
with self.test_session() as sess:
wrapped_sess0 = StopAtNSession(sess, 4)
wrapped_sess1 = tf.contrib.learn.WrappedSession(wrapped_sess0)
wrapped_sess1 = wrapped_session.WrappedSession(wrapped_sess0)
self.assertFalse(wrapped_sess1.should_stop())
self.assertFalse(wrapped_sess1.should_stop())
self.assertFalse(wrapped_sess1.should_stop())
@ -73,7 +75,7 @@ class WrappedSessionTest(tf.test.TestCase):
def test_close_twice(self):
with self.test_session() as sess:
wrapped_sess = tf.contrib.learn.WrappedSession(sess)
wrapped_sess = wrapped_session.WrappedSession(sess)
wrapped_sess.close()
self.assertTrue(wrapped_sess.should_stop())
wrapped_sess.close()
@ -84,7 +86,7 @@ class WrappedSessionTest(tf.test.TestCase):
c = tf.constant(0)
v = tf.identity(c)
self.assertEqual(42, sess.run(v, feed_dict={c: 42}))
wrapped_sess = tf.contrib.learn.WrappedSession(sess)
wrapped_sess = wrapped_session.WrappedSession(sess)
self.assertEqual(51, wrapped_sess.run(v, feed_dict={c: 51}))

View File

@ -139,7 +139,7 @@ class Coordinator(object):
"""
if clean_stop_exception_types is None:
clean_stop_exception_types = (errors.OutOfRangeError,)
self._clean_stop_exception_types = clean_stop_exception_types
self._clean_stop_exception_types = tuple(clean_stop_exception_types)
# Protects all attributes.
self._lock = threading.Lock()
# Event set when threads must stop.