From 4f6e9efb40b1fca70a4fdf547401eafcffda47fa Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 13 Jul 2016 08:10:45 -0800 Subject: [PATCH] Add SupervisedSession class. Extend the Monitor interface to support a post_step() callback which is passed the Session. Add a Scaffold object to hold onto the commonly needed graph pieces to run training. Based on a design by @ilblackdragon . Change: 127323605 --- tensorflow/contrib/learn/BUILD | 20 + .../contrib/learn/python/learn/__init__.py | 3 - .../learn/python/learn/coordinated_session.py | 5 + .../learn/python/learn/monitored_session.py | 8 +- .../contrib/learn/python/learn/monitors.py | 272 ++++++++++-- .../python/learn/summary_writer_cache.py | 65 +++ .../learn/python/learn/supervised_session.py | 301 +++++++++++++ .../learn/tests/coordinated_session_test.py | 14 +- .../learn/tests/monitored_session_test.py | 35 +- .../learn/python/learn/tests/monitors_test.py | 32 +- .../learn/tests/recoverable_session_test.py | 8 +- .../learn/tests/summary_writer_cache_test.py | 68 +++ .../learn/tests/supervised_session_test.py | 404 ++++++++++++++++++ .../learn/tests/wrapped_session_test.py | 14 +- tensorflow/python/training/coordinator.py | 2 +- 15 files changed, 1177 insertions(+), 74 deletions(-) create mode 100644 tensorflow/contrib/learn/python/learn/summary_writer_cache.py create mode 100644 tensorflow/contrib/learn/python/learn/supervised_session.py create mode 100644 tensorflow/contrib/learn/python/learn/tests/summary_writer_cache_test.py create mode 100644 tensorflow/contrib/learn/python/learn/tests/supervised_session_test.py diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index bad00b79526..7a4523e5755 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -684,6 +684,26 @@ py_test( ], ) +py_test( + name = "supervised_session_test", + size = "small", + srcs = ["python/learn/tests/supervised_session_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "summary_writer_cache_test", + size = "small", + srcs = ["python/learn/tests/summary_writer_cache_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + py_binary( name = "inspect_checkpoint", srcs = [ diff --git a/tensorflow/contrib/learn/python/learn/__init__.py b/tensorflow/contrib/learn/python/learn/__init__.py index 9f635891871..375d90960d7 100644 --- a/tensorflow/contrib/learn/python/learn/__init__.py +++ b/tensorflow/contrib/learn/python/learn/__init__.py @@ -31,7 +31,6 @@ from tensorflow.contrib.learn.python.learn import monitors from tensorflow.contrib.learn.python.learn import ops from tensorflow.contrib.learn.python.learn import preprocessing from tensorflow.contrib.learn.python.learn import utils -from tensorflow.contrib.learn.python.learn.coordinated_session import * from tensorflow.contrib.learn.python.learn.dataframe import * from tensorflow.contrib.learn.python.learn.estimators import * from tensorflow.contrib.learn.python.learn.experiment import Experiment @@ -42,6 +41,4 @@ from tensorflow.contrib.learn.python.learn.graph_actions import run_feeds from tensorflow.contrib.learn.python.learn.graph_actions import run_n from tensorflow.contrib.learn.python.learn.graph_actions import train from tensorflow.contrib.learn.python.learn.io import * -from tensorflow.contrib.learn.python.learn.recoverable_session import * -from tensorflow.contrib.learn.python.learn.wrapped_session import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/learn/python/learn/coordinated_session.py b/tensorflow/contrib/learn/python/learn/coordinated_session.py index 9f216e1fccd..e2ad2c1afff 100644 --- a/tensorflow/contrib/learn/python/learn/coordinated_session.py +++ b/tensorflow/contrib/learn/python/learn/coordinated_session.py @@ -57,3 +57,8 @@ class CoordinatedSession(WrappedSession): self._coord.request_stop(e) if self._coord.should_stop(): self._coord.join(self._coordinated_threads_to_join) + + # TODO(touts): Add a close() method that also joins the coordinator + # but does not raise exceptions. This can only be done reliably when the + # Coordinator keeps a pointer to the coordinated threads, otherwise we do not + # know which threads to join. diff --git a/tensorflow/contrib/learn/python/learn/monitored_session.py b/tensorflow/contrib/learn/python/learn/monitored_session.py index bcaec58ab5e..602c8a2db04 100644 --- a/tensorflow/contrib/learn/python/learn/monitored_session.py +++ b/tensorflow/contrib/learn/python/learn/monitored_session.py @@ -23,7 +23,6 @@ import six from tensorflow.contrib.learn.python.learn.wrapped_session import WrappedSession from tensorflow.python.framework import ops -from tensorflow.python.platform import tf_logging as logging class MonitoredSession(WrappedSession): @@ -73,7 +72,6 @@ class MonitoredSession(WrappedSession): if self._last_step is None: self._last_step = WrappedSession.run(self, self._global_step_tensor) - logging.info('Initialized step to: %d', self._last_step) monitors_step = self._last_step + 1 monitor_fetches = [] @@ -104,12 +102,16 @@ class MonitoredSession(WrappedSession): induce_stop = monitor.step_end(monitors_step, monitor_outputs) self._should_stop = self._should_stop or induce_stop + # Call the post_step methods. + for monitor in self._monitors: + monitor.post_step(monitors_step, self._sess) + return outputs['caller'] # TODO(ispir): Remove following logic after forcing monitors returns tensors. def _as_graph_element(obj, graph): - """Retrieves Graph element from tensors or tensor names.""" + """Retrieves Graph element.""" graph = graph or ops.get_default_graph() if not isinstance(obj, six.string_types): if not hasattr(obj, 'graph') or obj.graph != graph: diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index f6af7a6f8c1..e50e9136de3 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -69,13 +69,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os +import time + import numpy as np import six +from tensorflow.contrib.learn.python.learn.summary_writer_cache import SummaryWriterCache from tensorflow.contrib.learn.python.learn.utils import export +from tensorflow.core.framework.summary_pb2 import Summary +from tensorflow.core.util.event_pb2 import SessionLog from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training import saver +from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import summary_io @@ -92,6 +98,7 @@ class BaseMonitor(object): self._current_epoch = None self._current_step = None self._max_steps = None + self._init_step = None self._estimator = None def set_estimator(self, estimator): @@ -108,11 +115,14 @@ class BaseMonitor(object): # TODO(mdan): This should fail if called twice with the same estimator. self._estimator = estimator - def begin(self, max_steps=None): + def begin(self, max_steps=None, init_step=None): """Called at the beginning of training. + When called, the default graph is the one we are executing. + Args: max_steps: `int`, the maximum global step this training will run until. + init_step: `int`, step at which this training will start. Raises: ValueError: if we've already begun a run. @@ -120,14 +130,19 @@ class BaseMonitor(object): if self._begun: raise ValueError("begin called twice without end.") self._max_steps = max_steps + self._init_step = init_step self._begun = True - def end(self): + def end(self, session=None): """Callback at the end of training/evaluation. + Args: + session: A `tf.Session` object that can be used to run ops. + Raises: ValueError: if we've not begun a run. """ + _ = session if not self._begun: raise ValueError("end called without begin.") self._max_steps = None @@ -178,8 +193,6 @@ class BaseMonitor(object): ValueError: if we've already begun a step, or `step` < 0, or `step` > `max_steps`. """ - if self._current_step is not None: - raise ValueError("step_begin called twice without step_end.") if (step < 0) or ( (self._max_steps is not None) and (step > self._max_steps)): raise ValueError("Invalid step %s." % step) @@ -196,6 +209,9 @@ class BaseMonitor(object): In addition, the callback has the opportunity to stop training by returning `True`. This is useful for early stopping, for example. + Note that this method is not called if the call to `Session.run()` that + followed the last call to `step_begin()` failed. + Args: step: `int`, the current value of the global step. output: `dict` mapping `string` values representing tensor names to @@ -214,13 +230,26 @@ class BaseMonitor(object): self._current_step = None return False + def post_step(self, step, session): # pylint: disable=unused-argument + """Callback after the step is finished. + + Called after step_end and receives session to perform extra session.run + calls. If failure occurred in the process, will be called as well. + + Args: + step: `int`, global step of the model. + session: `Session` object. + """ + _ = step, session + class EveryN(BaseMonitor): - """Base class for monitors that execute callbacks every n steps. + """Base class for monitors that execute callbacks every N steps. - This class adds two new callbacks: + This class adds three new callbacks: - every_n_step_begin - every_n_step_end + - every_n_pos_step The callbacks are executed every n steps, or optionally every step for the first m steps, where m and n can both be user-specified. @@ -234,11 +263,16 @@ class EveryN(BaseMonitor): Failing to call the super implementation will cause unpredictible behavior. + The `every_n_post_step()` callback is also called after the last step if it + was not already called through the regular conditions. Note that + `every_n_step_begin()` and `every_n_step_end()` do not receive that special + treatment. + """ # TODO(ipolosukhin): Add also every n seconds. def __init__(self, every_n_steps=100, first_n_steps=1): - """Initialized an `EveryN` monitor. + """Initializes an `EveryN` monitor. Args: every_n_steps: `int`, the number of steps to allow between callbacks. @@ -249,18 +283,10 @@ class EveryN(BaseMonitor): super(EveryN, self).__init__() self._every_n_steps = every_n_steps self._first_n_steps = first_n_steps - self._last_step = 0 - self._active_step = None - - def begin(self, max_steps=None): - """Overrides `BaseMonitor.begin`. - - When overriding this method, you must call the super implementation. - - Args: - max_steps: `int`, the maximum global step value. - """ - super(EveryN, self).begin(max_steps) + # Last step in the model. + self._last_step = None + # Last step at which we called one of the every_n methods + self._last_active_step = None def every_n_step_begin(self, step): # pylint: disable=unused-argument """Callback before every n'th step begins. @@ -294,6 +320,15 @@ class EveryN(BaseMonitor): """ return False + def every_n_post_step(self, step, session): + """Callback after a step is finished or `end()` is called. + + Args: + step: `int`, the current value of the global step. + session: `Session` object. + """ + pass + def step_begin(self, step): """Overrides `BaseMonitor.step_begin`. @@ -309,13 +344,13 @@ class EveryN(BaseMonitor): ValueError: if called more than once during a step. """ super(EveryN, self).step_begin(step) + self._last_step = step + if self._last_active_step is None: + self._last_active_step = step - 1 if (step <= self._first_n_steps or - step >= (self._every_n_steps + self._last_step) or - step == self._max_steps): - if self._active_step is not None: - raise ValueError( - "Starting step %s, %s still active." % (step, self._active_step)) - self._active_step = step + step >= (self._every_n_steps + self._last_active_step) or + step == self._max_steps): # Note: max_steps can be None here. + self._last_active_step = step return self.every_n_step_begin(step) return [] @@ -334,12 +369,59 @@ class EveryN(BaseMonitor): or `False` otherwise. """ super(EveryN, self).step_end(step, output) - to_stop = False - if (self._active_step is not None) and (self._active_step == step): - self._last_step = step - to_stop = self.every_n_step_end(step, output) - self._active_step = None - return to_stop + if self._last_active_step == step: + return self.every_n_step_end(step, output) + return False + + def post_step(self, step, session): + super(EveryN, self).post_step(step, session) + if self._last_active_step == step: + self.every_n_post_step(step, session) + + def end(self, session=None): + super(EveryN, self).end(session=session) + if self._last_step != self._last_active_step: + self.every_n_post_step(self._last_step, session) + + +class StopAtStep(BaseMonitor): + """Monitor to request stop at a specified step.""" + + def __init__(self, num_steps=None, last_step=None): + """Create a StopAtStep monitor. + + This monitor requests stop after either a number of steps have been + executed or a last step has been reached. Only of the two options can be + specified. + + if `num_steps` is specified, it indicates the number of steps to execute + after `begin()` is called. If instead `last_step` is specified, it + indicates the last step we want to execute, as passed to the `step_begin()` + call. + + Args: + num_steps: Number of steps to execute. + last_step: Step after which to stop. + + Raises: + ValueError: If one of the arguments is invalid. + """ + super(StopAtStep, self).__init__() + if num_steps is None and last_step is None: + raise ValueError("One of num_steps or last_step must be specified.") + if num_steps is not None and last_step is not None: + raise ValueError("Only one of num_steps or last_step can be specified.") + self._num_steps = num_steps + self._last_step = last_step + + def begin(self, max_steps=None, init_step=None): + super(StopAtStep, self).begin(max_steps=max_steps, init_step=init_step) + if self._num_steps is not None: + self._last_step = init_step + self._num_steps + + def step_end(self, step, output): + super(StopAtStep, self).step_end(step, output) + return step >= self._last_step # TODO(ptucker): Rename to LoggingTensor since it's not writing to stdout. @@ -460,8 +542,8 @@ class SummarySaver(EveryN): self._summary_writer.add_summary(summary_strs, step) return False - def end(self): - super(SummarySaver, self).end() + def end(self, session=None): + super(SummarySaver, self).end(session=session) if self._summary_writer: self._summary_writer.flush() @@ -553,7 +635,7 @@ class ValidationMonitor(EveryN): if self._estimator is None: raise ValueError("Missing call to set_estimator.") # Check that we are not running evaluation on the same checkpoint. - latest_path = saver.latest_checkpoint(self._estimator.model_dir) + latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) if latest_path == self._latest_path: logging.info("Skipping evaluation due to same checkpoint %s for step %d " "as for step %d.", latest_path, step, self._latest_path_step) @@ -677,8 +759,8 @@ class GraphDump(BaseMonitor): self._ignore_ops = ignore_ops or GraphDump.IGNORE_OPS self._data = {} - def begin(self, max_steps): - super(GraphDump, self).begin(max_steps) + def begin(self, max_steps=None, init_step=None): + super(GraphDump, self).begin(max_steps=max_steps, init_step=init_step) self._tensors = [] graph = ops.get_default_graph() graph_def = graph.as_graph_def() @@ -782,9 +864,121 @@ class ExportMonitor(EveryN): logging.info("Skipping exporting for the same step. " "Consider exporting less frequently.") - def end(self): - super(ExportMonitor, self).end() + def end(self, session=None): + super(ExportMonitor, self).end(session=session) export.export_estimator(self._estimator, self.export_dir, exports_to_keep=self.exports_to_keep, signature_fn=self.signature_fn) + + +class CheckpointSaver(EveryN): + """Saves checkpoints every N steps.""" + + def __init__(self, every_n_steps, saver, checkpoint_dir, + checkpoint_basename="model.ckpt", + first_n_steps=-1): + """Initialize CheckpointSaver monitor. + + Args: + every_n_steps: `int`, save every N steps. + saver: `Saver` object, used for saving. + checkpoint_dir: `str`, base directory for the checkpoint files. + checkpoint_basename: `str`, base name for the checkpoint files. + first_n_steps: `int`, if positive, save every step during the + first `first_n_steps` steps. + """ + logging.info("Create CheckpointSaver") + super(CheckpointSaver, self).__init__(every_n_steps=every_n_steps, + first_n_steps=first_n_steps) + self._saver = saver + self._summary_writer = SummaryWriterCache.get(checkpoint_dir) + self._save_path = os.path.join(checkpoint_dir, checkpoint_basename) + + def every_n_post_step(self, step, session): + logging.info("Saving checkpoints for %d into %s." % (step, self._save_path)) + self._saver.save(session, self._save_path, global_step=step) + if self._summary_writer: + self._summary_writer.add_session_log( + SessionLog(status=SessionLog.CHECKPOINT, + checkpoint_path=self._save_path), + step) + + +class StepCounter(EveryN): + """Steps per second monitor.""" + + def __init__(self, every_n_steps=100, output_dir=None, + summary_writer=None): + super(StepCounter, self).__init__(every_n_steps=every_n_steps) + self._summary_tag = "global_step/sec" + self._last_reported_step = None + self._last_reported_time = None + self._summary_writer = None + if summary_writer is None and output_dir: + self._summary_writer = SummaryWriterCache.get(output_dir) + + def begin(self, init_step): + super(StepCounter, self).begin(init_step) + self._last_reported_step = self._init_step + self._last_reported_time = time.time() + + def set_estimator(self, estimator): + super(StepCounter, self).set_estimator(estimator) + if self._summary_writer is None: + self._summary_writer = SummaryWriterCache.get(estimator.model_dir) + + def every_n_step_end(self, current_step, outputs): + current_time = time.time() + if self._last_reported_time is not None and self._summary_writer: + added_steps = current_step - self._last_reported_step + elapsed_time = current_time - self._last_reported_time + steps_per_sec = added_steps / elapsed_time + summary = Summary(value=[Summary.Value(tag=self._summary_tag, + simple_value=steps_per_sec)]) + self._summary_writer.add_summary(summary, current_step) + self._last_reported_step = current_step + self._last_reported_time = current_time + + +class NanLossDuringTrainingError(RuntimeError): + + def __str__(self): + return "NaN loss during training." + + +class NanLoss(EveryN): + """NaN Loss monitor. + + Monitors loss and stops training if loss is NaN. + Can either fail with exception or just stop training. + """ + + def __init__(self, loss_tensor, every_n_steps=100, fail_on_nan_loss=True): + """Initializes NanLoss monitor. + + Args: + loss_tensor: `Tensor`, the loss tensor. + every_n_steps: `int`, run check every this many steps. + fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN. + """ + super(NanLoss, self).__init__(every_n_steps=every_n_steps) + self._loss_tensor = loss_tensor + self._fail_on_nan_loss = fail_on_nan_loss + + def every_n_step_begin(self, step): + super(NanLoss, self).every_n_step_begin(step) + return self._loss_tensor + + def every_n_step_end(self, step, outputs): + super(NanLoss, self).every_n_step_end(step, outputs) + if np.isnan(outputs): + failure_message = "Model diverged with loss = NaN." + if self._fail_on_nan_loss: + logging.error(failure_message) + raise NanLossDuringTrainingError + else: + logging.warning(failure_message) + # We don't raise an error but we return "should stop" so we stop, but + # without an exception. + return True diff --git a/tensorflow/contrib/learn/python/learn/summary_writer_cache.py b/tensorflow/contrib/learn/python/learn/summary_writer_cache.py new file mode 100644 index 00000000000..5ef3046a578 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/summary_writer_cache.py @@ -0,0 +1,65 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for a Session-like object that handles threads and recovery. + +Based on an original design of Illia Polosukhin. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.python.framework import ops +from tensorflow.python.training import summary_io + + +class SummaryWriterCache(object): + """Cache for summary writers. + + This class caches summary writers, one per directory. + """ + # Cache, keyed by directory. + _cache = {} + + # Lock protecting _SUMMARY_WRITERS. + _lock = threading.RLock() + + @staticmethod + def clear(): + """Clear cached summary writers. Currently only used for unit tests.""" + with SummaryWriterCache._lock: + SummaryWriterCache._cache = {} + + @staticmethod + def get(logdir): + """Returns the SummaryWriter for the specified directory. + + Args: + logdir: str, name of the directory. + + Returns: + A `SummaryWriter`. + """ + with SummaryWriterCache._lock: + if logdir not in SummaryWriterCache._cache: + SummaryWriterCache._cache[logdir] = summary_io.SummaryWriter( + logdir, graph=ops.get_default_graph()) + return SummaryWriterCache._cache[logdir] + + +# Backward compatible interface. Remove? +clear_summary_writers = SummaryWriterCache.clear +get_summary_writer = SummaryWriterCache.get diff --git a/tensorflow/contrib/learn/python/learn/supervised_session.py b/tensorflow/contrib/learn/python/learn/supervised_session.py new file mode 100644 index 00000000000..6f9318a3351 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/supervised_session.py @@ -0,0 +1,301 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper for a Session-like object that handles threads and recovery. + +Based on an original design of Illia Polosukhin. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.contrib.learn.python.learn import coordinated_session +from tensorflow.contrib.learn.python.learn import monitored_session +from tensorflow.contrib.learn.python.learn import recoverable_session +from tensorflow.contrib.learn.python.learn import summary_writer_cache +from tensorflow.core.util.event_pb2 import SessionLog +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import variables +from tensorflow.python.training import coordinator +from tensorflow.python.training import queue_runner +from tensorflow.python.training import saver as training_saver +from tensorflow.python.training import session_manager as sm +from tensorflow.python.training import training_util + + +# TODO(touts): Share that with the Supervisor. +class Scaffold(object): + """Structure to create or gather pieces commonly needed to train a model. + + When you build a model for training you usually need ops to initialize + variables, a `Saver` to checkpoint them, an op to collect summaries for + the visualizer, and so on. + + Various libraries built on top of the core TensorFlow library take care of + creating some or all of these pieces and storing them in well known + collections in the graph. The `Scaffold` class helps pick these pieces from + the graph collections, creating and adding them to the collections if needed. + + If you call the scaffold constructor without any arguments it will pick + pieces from the collections, creating default ones if needed. You can pass + arguments to the constructor to provide your own pieces. Pieces that you + pass to the constructor are not added to the graph collections. + + The following pieces are directly accessible as attributes of the `Scaffold` + object: + + * `saver`: A `tf.Saver` object taking care of saving the variables. Picked + from and stored into the `SAVERS` collection in the graph. + * `init_op`: An op to run to initialize the variables. Picked from and + stored into the `INIT_OP` collection in the graph. + * `ready_op`: An op to verify that the variables are initialized. Picked + from and stored into the `READY_OP` collection in the graph. + * `local_init_op`: An op to initialize the local variables. Picked + from and stored into the `LOCAL_INIT_OP` collection in the graph. + * `summary_op`: An op to run and merge the summaries in the graph. Picked + from and stored into the `SUMMARY_OP` collection in the graph. + * `global_step`: A tensor containing the global step counter. Picked + from and stored into the `GLOBAL_STEP` collection in the graph. + + You can also pass the following additional pieces to the constructor: + + * `init_feed_dict`: A sessionn feed dictionary that should be used when + running the init op. + * `init_fn`: A callable to run run after the init op to perform additional + initializations. The callable will be called as + `init_fn(scaffold, session)`. + + """ + # TODO(touts): consider adding the output dir and summary writer (cached)? + # TODO(touts): consider finalizeing the graph? (If the graph is + # modified later, the cached parts could be wrong.) + # TODO(touts): I do not think we should pass keep_checkpoint_max here. + # TODO(touts): Add individual static functions for init_op(), etc. that + # implement the caching logic. + + def __init__(self, + global_step_tensor=None, + init_op=None, + init_feed_dict=None, + init_fn=None, + ready_op=None, + local_init_op=None, + summary_op=None, + saver=None, + keep_checkpoint_max=5): + """Create a scaffold. + + Args: + global_step_tensor: Optional tensor to use as the global step counter. + init_op: Optional op for initializing variables. + init_feed_dict: Optional session feed dictionary to use when running the + init_op. + init_fn: Optional function to use to initialize the model after running + the init_op. Will be called as `init_fn(scaffold, session)`. + ready_op: Optional op to verify that the variables are initialized. Must + return an empty scalar string tensor when the variables are + initialized, or a non-empty one listing the names of the + non-initialized variables. + local_init_op: Optional op to initialize local variables. + summary_op: Optional op to gather all summaries. Must return a scalar + string tensor containing a serialized `Summary` proto. + saver: Optional `tf.Saver` object to use to save and restore variables. + keep_checkpoint_max: Optional parameter to use to construct a saver if + none is already there in the graph. + """ + if global_step_tensor is None: + global_step_tensor = contrib_variables.get_or_create_global_step() + self.global_step_tensor = global_step_tensor + if init_op is None: + init_op = Scaffold._get_or_default( + ops.GraphKeys.INIT_OP, variables.initialize_all_variables) + self.init_op = init_op + self.init_feed_dict = init_feed_dict + # NOTE(touts): modifying the init function to be passed the scaffold is a + # hack to make it easy to find the saver. Is there a better way? + if init_fn: + self.init_fn = lambda sess: init_fn(self, sess) + else: + self.init_fn = None + if ready_op is None: + ready_op = Scaffold._get_or_default( + ops.GraphKeys.READY_OP, variables.report_uninitialized_variables) + self.ready_op = ready_op + if local_init_op is None: + local_init_op = Scaffold._get_or_default( + ops.GraphKeys.LOCAL_INIT_OP, Scaffold._default_local_init_op) + self.local_init_op = local_init_op + if summary_op is None: + summary_op = Scaffold._get_or_default( + ops.GraphKeys.SUMMARY_OP, logging_ops.merge_all_summaries) + # pylint: disable=g-long-lambda + if saver is None: + saver = Scaffold._get_or_default( + ops.GraphKeys.SAVERS, + lambda: training_saver.Saver(sharded=True, + max_to_keep=keep_checkpoint_max)) + # pylint: enable=g-long-lambda + self.saver = saver + + @staticmethod + def _get_or_default(collection_key, default_constructor): + elements = ops.get_collection(collection_key) + if elements: + return elements[0] + op = default_constructor() + if op is not None: + ops.add_to_collection(collection_key, op) + return op + + @staticmethod + def _default_local_init_op(): + return control_flow_ops.group(variables.initialize_local_variables(), + data_flow_ops.initialize_all_tables()) + + +class SupervisedSession(object): + """Session-like object that supports recovery and monitors. + + + """ + + def __init__(self, master, is_chief=True, checkpoint_dir=None, + monitors=None, scaffold=None, config=None, + clean_stop_exception_types=None): + self._graph = ops.get_default_graph() + self._master = master + self._checkpoint_dir = checkpoint_dir + self._is_chief = is_chief + self._config = config + self._clean_stop_exception_types = clean_stop_exception_types + self._monitors = monitors or [] + self._scaffold = scaffold or Scaffold() + # Finalize and write the graph. + self._graph.finalize() + # Create the session. + self._session_manager = sm.SessionManager( + local_init_op=self._scaffold.local_init_op, + ready_op=self._scaffold.ready_op, + graph=ops.get_default_graph()) + self._sess = recoverable_session.RecoverableSession(self._create_session) + # Call the begin() method of monitors. + self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor) + for monitor in self._monitors: + monitor.begin(max_steps=None, init_step=self._init_step) + # Write the graph out, note: this uses self._init_step. + self.write_graph() + + def _create_session(self): + """Factory for the RecoverableSession. + + Returns: + A session, initialized or recovered as needed. + """ + if self._is_chief: + tf_sess = self._session_manager.prepare_session( + self._master, saver=self._scaffold.saver, + checkpoint_dir=self._checkpoint_dir, config=self._config, + init_op=self._scaffold.init_op, + init_feed_dict=self._scaffold.init_feed_dict, + init_fn=self._scaffold.init_fn) + else: + tf_sess = self._session_manager.wait_for_session( + self._master, config=self._config) + # Keep the tf_sess for quick runs of global step when needed. + self._tf_sess = tf_sess + self._coord = coordinator.Coordinator( + clean_stop_exception_types=self._clean_stop_exception_types) + self._coordinated_threads_to_join = queue_runner.start_queue_runners( + sess=tf_sess, coord=self._coord) + return coordinated_session.CoordinatedSession( + monitored_session.MonitoredSession(tf_sess, self._monitors, + self._scaffold.global_step_tensor), + self._coord, self._coordinated_threads_to_join) + + @property + def scaffold(self): + return self._scaffold + + @property + def session(self): + return self._tf_sess + + def run(self, fetches, feed_dict=None, options=None, run_metadata=None): + """Run ops in the supervised session. + + This method is completely compatible with the `tf.Session.run()` method. + + Args: + fetches: Same as `tf.Session.run()`. + feed_dict: Same as `tf.Session.run()`. + options: Same as `tf.Session.run()`. + run_metadata: Same as `tf.Session.run()`. + + Returns: + Same as `tf.Session.run()`. + """ + return self._sess.run(fetches, feed_dict=feed_dict, options=options, + run_metadata=run_metadata) + + def should_stop(self): + if self._sess: + return self._sess.should_stop() + return True + + def close(self): + # Run the Monitor.end() methods. + for monitor in self._monitors: + monitor.end(self._tf_sess) + self._sess.close() + self._sess = None + self._tf_sess = None + + def _is_closed(self): + """Return True if the supervised session is closed. For tests only. + + Returns: + A boolean. + """ + return self._tf_sess is None + + def __enter__(self): + return self + + def __exit__(self, exception_type, exception_value, traceback): + if exception_type: + self._coord.request_stop((exception_type, exception_value, traceback)) + else: + self._coord.request_stop() + try: + self._coord.join(self._coordinated_threads_to_join) + # If coord does not raise an exception, we return True to indicate + # "no exception to raise". + return True + finally: + self.close() + + def write_graph(self): + """Saves current graph.""" + if self._checkpoint_dir is not None and self._is_chief: + summary_writer = summary_writer_cache.SummaryWriterCache.get( + self._checkpoint_dir) + training_util.write_graph(self._graph.as_graph_def(add_shapes=True), + self._checkpoint_dir, 'graph.pbtxt') + summary_writer.add_graph(self._graph) + summary_writer.add_session_log(SessionLog(status=SessionLog.START), + self._init_step) diff --git a/tensorflow/contrib/learn/python/learn/tests/coordinated_session_test.py b/tensorflow/contrib/learn/python/learn/tests/coordinated_session_test.py index 89fa3f9e3d1..1584d6d760a 100644 --- a/tensorflow/contrib/learn/python/learn/tests/coordinated_session_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/coordinated_session_test.py @@ -24,6 +24,8 @@ import time import tensorflow as tf +from tensorflow.contrib.learn.python.learn import coordinated_session + def BusyWaitForCoordStop(coord): while not coord.should_stop(): @@ -37,7 +39,7 @@ class CoordinatedSessionTest(tf.test.TestCase): with self.test_session() as sess: tf.constant(0.0) coord = tf.train.Coordinator() - coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, []) + coord_sess = coordinated_session.CoordinatedSession(sess, coord, []) self.assertEquals(sess.graph, coord_sess.graph) self.assertEquals(sess.sess_str, coord_sess.sess_str) @@ -46,13 +48,13 @@ class CoordinatedSessionTest(tf.test.TestCase): c = tf.constant(0) v = tf.identity(c) coord = tf.train.Coordinator() - coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, []) + coord_sess = coordinated_session.CoordinatedSession(sess, coord, []) self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42})) def test_should_stop_on_close(self): with self.test_session() as sess: coord = tf.train.Coordinator() - coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, []) + coord_sess = coordinated_session.CoordinatedSession(sess, coord, []) self.assertFalse(coord_sess.should_stop()) coord_sess.close() self.assertTrue(coord_sess.should_stop()) @@ -60,7 +62,7 @@ class CoordinatedSessionTest(tf.test.TestCase): def test_should_stop_on_coord_stop(self): with self.test_session() as sess: coord = tf.train.Coordinator() - coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, []) + coord_sess = coordinated_session.CoordinatedSession(sess, coord, []) self.assertFalse(coord_sess.should_stop()) coord.request_stop() self.assertTrue(coord_sess.should_stop()) @@ -70,7 +72,7 @@ class CoordinatedSessionTest(tf.test.TestCase): c = tf.constant(0) v = tf.identity(c) coord = tf.train.Coordinator() - coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, []) + coord_sess = coordinated_session.CoordinatedSession(sess, coord, []) self.assertFalse(coord_sess.should_stop()) self.assertEqual(0, coord_sess.run(c)) self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1})) @@ -89,7 +91,7 @@ class CoordinatedSessionTest(tf.test.TestCase): for _ in range(3)] for t in threads: t.start() - coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, threads) + coord_sess = coordinated_session.CoordinatedSession(sess, coord, threads) self.assertFalse(coord_sess.should_stop()) for t in threads: self.assertTrue(t.is_alive()) diff --git a/tensorflow/contrib/learn/python/learn/tests/monitored_session_test.py b/tensorflow/contrib/learn/python/learn/tests/monitored_session_test.py index 55646921fc0..cb9d3c7f3ee 100644 --- a/tensorflow/contrib/learn/python/learn/tests/monitored_session_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/monitored_session_test.py @@ -47,18 +47,26 @@ class FakeMonitor(monitors.BaseMonitor): self.should_stop = False self.requested_tensors = [] self.call_counter = Counter() + self.last_begin_step = None + self.last_end_step = None + self.last_post_step = None def step_begin(self, step): self.call_counter['step_begin'] += 1 - self.begin_step = step + self.last_begin_step = step return self.requested_tensors def step_end(self, step, output): self.call_counter['step_end'] += 1 - self.end_step = step + self.last_end_step = step self.output = output return self.should_stop + def post_step(self, step, session): + self.call_counter['post_step'] += 1 + self.last_post_step = step + self.session = session + class MonitoredSessionTest(tf.test.TestCase): @@ -83,7 +91,7 @@ class MonitoredSessionTest(tf.test.TestCase): 'run_metadata': 'a_metadata' }) - def testCallsMonitorsBeginAndEnd(self): + def testCallsMonitorsBeginEndAndPost(self): with tf.Graph().as_default(), tf.Session() as sess: global_step_tensor = tf.contrib.framework.create_global_step() mock_mon = FakeMonitor() @@ -99,10 +107,12 @@ class MonitoredSessionTest(tf.test.TestCase): for mon in [mock_mon, mock_mon2]: self.assertEqual(mon.output, {}) - self.assertEqual(mon.begin_step, 11) - self.assertEqual(mon.end_step, 11) + self.assertEqual(mon.last_begin_step, 11) + self.assertEqual(mon.last_end_step, 11) + self.assertEqual(mon.last_post_step, 11) self.assertEqual(mon.call_counter['step_end'], 1) self.assertEqual(mon.call_counter['step_begin'], 1) + self.assertEqual(mon.call_counter['post_step'], 1) def testCallsMonitorsWithLastStep(self): with tf.Graph().as_default(), tf.Session() as sess: @@ -119,18 +129,21 @@ class MonitoredSessionTest(tf.test.TestCase): mon_sess.run(fetches=[inc_5]) for mon in [mock_mon, mock_mon2]: - self.assertEqual(mon.begin_step, 1) - self.assertEqual(mon.end_step, 1) + self.assertEqual(mon.last_begin_step, 1) + self.assertEqual(mon.last_end_step, 1) + self.assertEqual(mon.last_post_step, 1) mon_sess.run(fetches=[inc_5]) for mon in [mock_mon, mock_mon2]: - self.assertEqual(mon.begin_step, 6) - self.assertEqual(mon.end_step, 6) + self.assertEqual(mon.last_begin_step, 6) + self.assertEqual(mon.last_end_step, 6) + self.assertEqual(mon.last_post_step, 6) mon_sess.run(fetches=[inc_5]) for mon in [mock_mon, mock_mon2]: - self.assertEqual(mon.begin_step, 11) - self.assertEqual(mon.end_step, 11) + self.assertEqual(mon.last_begin_step, 11) + self.assertEqual(mon.last_end_step, 11) + self.assertEqual(mon.last_post_step, 11) def testShouldStop(self): with tf.Graph().as_default(), tf.Session() as sess: diff --git a/tensorflow/contrib/learn/python/learn/tests/monitors_test.py b/tensorflow/contrib/learn/python/learn/tests/monitors_test.py index 09555f11cbe..953ba2369ff 100644 --- a/tensorflow/contrib/learn/python/learn/tests/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/monitors_test.py @@ -34,6 +34,7 @@ class _MyEveryN(learn.monitors.EveryN): every_n_steps=every_n_steps, first_n_steps=first_n_steps) self._steps_begun = [] self._steps_ended = [] + self._post_steps = [] @property def steps_begun(self): @@ -43,6 +44,10 @@ class _MyEveryN(learn.monitors.EveryN): def steps_ended(self): return self._steps_ended + @property + def post_steps(self): + return self._post_steps + def every_n_step_begin(self, step): super(_MyEveryN, self).every_n_step_begin(step) self._steps_begun.append(step) @@ -53,6 +58,11 @@ class _MyEveryN(learn.monitors.EveryN): self._steps_ended.append(step) return False + def every_n_post_step(self, step, session): + super(_MyEveryN, self).every_n_post_step(step, session) + self._post_steps.append(step) + return False + class MonitorsTest(tf.test.TestCase): """Monitors tests.""" @@ -71,8 +81,13 @@ class MonitorsTest(tf.test.TestCase): def tearDown(self): logging.info = self._actual_log - def _run_monitor(self, monitor, num_epochs=3, num_steps_per_epoch=10): - monitor.begin(max_steps=(num_epochs * num_steps_per_epoch) - 1) + def _run_monitor(self, monitor, num_epochs=3, num_steps_per_epoch=10, + pass_max_steps=True): + if pass_max_steps: + max_steps = num_epochs * num_steps_per_epoch - 1 + else: + max_steps = None + monitor.begin(max_steps=max_steps, init_step=0) for epoch in xrange(num_epochs): monitor.epoch_begin(epoch) should_stop = False @@ -85,6 +100,7 @@ class MonitorsTest(tf.test.TestCase): [t.name if isinstance(t, tf.Tensor) else t for t in tensors], output)) should_stop = monitor.step_end(step=step, output=output) + monitor.post_step(step=step, session=None) step += 1 monitor.epoch_end(epoch) monitor.end() @@ -100,6 +116,18 @@ class MonitorsTest(tf.test.TestCase): expected_steps = [0, 1, 2, 10, 18, 26, 29] self.assertEqual(expected_steps, monitor.steps_begun) self.assertEqual(expected_steps, monitor.steps_ended) + self.assertEqual(expected_steps, monitor.post_steps) + + def test_every_n_no_max_steps(self): + monitor = _MyEveryN(every_n_steps=8, first_n_steps=2) + with tf.Graph().as_default() as g, self.test_session(g): + self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10, + pass_max_steps=False) + begin_end_steps = [0, 1, 2, 10, 18, 26] + post_steps = [0, 1, 2, 10, 18, 26, 29] + self.assertEqual(begin_end_steps, monitor.steps_begun) + self.assertEqual(begin_end_steps, monitor.steps_ended) + self.assertEqual(post_steps, monitor.post_steps) def test_print(self): with tf.Graph().as_default() as g, self.test_session(g): diff --git a/tensorflow/contrib/learn/python/learn/tests/recoverable_session_test.py b/tensorflow/contrib/learn/python/learn/tests/recoverable_session_test.py index d62953414c1..678e96d8b2b 100644 --- a/tensorflow/contrib/learn/python/learn/tests/recoverable_session_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/recoverable_session_test.py @@ -21,6 +21,8 @@ from __future__ import print_function import tensorflow as tf +from tensorflow.contrib.learn.python.learn import recoverable_session + class AbortAtNSession(object): """A mock sessionthat aborts at the N-th run call.""" @@ -42,7 +44,7 @@ class RecoverableSessionTest(tf.test.TestCase): def test_properties(self): with self.test_session() as sess: tf.constant(0.0) - recoverable_sess = tf.contrib.learn.RecoverableSession(lambda: sess) + recoverable_sess = recoverable_session.RecoverableSession(lambda: sess) self.assertEquals(sess.graph, recoverable_sess.graph) self.assertEquals(sess.sess_str, recoverable_sess.sess_str) @@ -50,7 +52,7 @@ class RecoverableSessionTest(tf.test.TestCase): with self.test_session() as sess: c = tf.constant(0) v = tf.identity(c) - recoverable_sess = tf.contrib.learn.RecoverableSession(lambda: sess) + recoverable_sess = recoverable_session.RecoverableSession(lambda: sess) self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51})) def test_recovery(self): @@ -65,7 +67,7 @@ class RecoverableSessionTest(tf.test.TestCase): self.assertEqual(3, len(sessions_to_use)) # Make the recoverable session uses these 3 sessions in sequence by # passing a factory that pops from the session_to_use list. - recoverable_sess = tf.contrib.learn.RecoverableSession( + recoverable_sess = recoverable_session.RecoverableSession( lambda: sessions_to_use.pop(0)) self.assertEqual(2, len(sessions_to_use)) # One session popped. # Using first session. diff --git a/tensorflow/contrib/learn/python/learn/tests/summary_writer_cache_test.py b/tensorflow/contrib/learn/python/learn/tests/summary_writer_cache_test.py new file mode 100644 index 00000000000..f37e179c219 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/tests/summary_writer_cache_test.py @@ -0,0 +1,68 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Runner.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import os + +import tensorflow as tf + +from tensorflow.contrib.learn.python.learn import summary_writer_cache + + +class SummaryWriterCacheTest(tf.test.TestCase): + """SummaryWriterCache tests.""" + + def _test_dir(self, test_name): + test_dir = os.path.join(self.get_temp_dir(), test_name) + if os.path.isdir(test_dir): + os.removedirs(test_dir) + os.makedirs(test_dir) + return test_dir + + def test_cache(self): + with tf.Graph().as_default(): + dir1 = self._test_dir('test_cache_1') + dir2 = self._test_dir('test_cache_2') + sw1 = summary_writer_cache.SummaryWriterCache.get(dir1) + sw2 = summary_writer_cache.SummaryWriterCache.get(dir2) + sw3 = summary_writer_cache.SummaryWriterCache.get(dir1) + self.assertEqual(sw1, sw3) + self.assertFalse(sw1 == sw2) + sw1.close() + sw2.close() + events1 = glob.glob(os.path.join(dir1, 'event*')) + self.assertTrue(events1) + events2 = glob.glob(os.path.join(dir2, 'event*')) + self.assertTrue(events2) + events3 = glob.glob(os.path.join('nowriter', 'event*')) + self.assertFalse(events3) + + def test_clear(self): + with tf.Graph().as_default(): + dir1 = self._test_dir('test_clear') + sw1 = summary_writer_cache.SummaryWriterCache.get(dir1) + summary_writer_cache.SummaryWriterCache.clear() + sw2 = summary_writer_cache.SummaryWriterCache.get(dir1) + self.assertFalse(sw1 == sw2) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/learn/python/learn/tests/supervised_session_test.py b/tensorflow/contrib/learn/python/learn/tests/supervised_session_test.py new file mode 100644 index 00000000000..d096b68366e --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/tests/supervised_session_test.py @@ -0,0 +1,404 @@ +# pylint: disable=g-bad-file-header +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SupervisedSession.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import tensorflow as tf + +from tensorflow.contrib.learn.python.learn import supervised_session + + +class ScaffoldTest(tf.test.TestCase): + """Scaffold tests.""" + + def test_defaults_empty_graph(self): + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + self.assertTrue(isinstance(scaffold.global_step_tensor, tf.Variable)) + self.assertTrue(isinstance(scaffold.init_op, tf.Operation)) + self.assertEqual(None, scaffold.init_feed_dict) + self.assertEqual(None, scaffold.init_fn) + self.assertTrue(isinstance(scaffold.ready_op, tf.Tensor)) + self.assertTrue(isinstance(scaffold.local_init_op, tf.Operation)) + self.assertTrue(isinstance(scaffold.saver, tf.train.Saver)) + with self.test_session() as sess: + self.assertTrue(b'global_step' in sess.run(scaffold.ready_op)) + sess.run([scaffold.init_op, scaffold.local_init_op]) + self.assertEquals(0, len(sess.run(scaffold.ready_op))) + self.assertEquals(0, sess.run(scaffold.global_step_tensor)) + + def test_caches_values(self): + with tf.Graph().as_default(): + scaffold1 = supervised_session.Scaffold() + scaffold2 = supervised_session.Scaffold() + self.assertEqual(scaffold1.global_step_tensor, + scaffold2.global_step_tensor) + self.assertEqual(scaffold1.init_op, scaffold2.init_op) + self.assertEqual(scaffold1.ready_op, scaffold2.ready_op) + self.assertEqual(scaffold1.local_init_op, scaffold2.local_init_op) + self.assertEqual(scaffold1.saver, scaffold2.saver) + + def test_uses_passed_values(self): + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold(global_step_tensor=1, + init_op=2, + init_feed_dict=3, + init_fn=lambda scaffold, sess: 4, + ready_op=5, + local_init_op=6, + saver=7) + self.assertEqual(1, scaffold.global_step_tensor) + self.assertEqual(2, scaffold.init_op) + self.assertEqual(3, scaffold.init_feed_dict) + self.assertTrue(callable(scaffold.init_fn)) + self.assertEqual(5, scaffold.ready_op) + self.assertEqual(6, scaffold.local_init_op) + self.assertEqual(7, scaffold.saver) + + +class RaiseOnceAtStepN(tf.contrib.learn.monitors.BaseMonitor): + """Monitor that raises an Exception at step N.""" + + def __init__(self, n, ex): + super(RaiseOnceAtStepN, self).__init__() + self.n = n + self.ex = ex + self.raised = False + + def step_begin(self, step): + super(RaiseOnceAtStepN, self).step_begin(step) + # Raise the first time we reach step N. + if step == self.n and not self.raised: + self.raised = True + raise self.ex + return [] + + +class SupervisedSessionTest(tf.test.TestCase): + """SupervisedSession tests.""" + + def _test_dir(self, test_name): + """Create an empty dir to use for tests. + + Args: + test_name: Name of the test. + + Returns: + Absolute path to the test directory. + """ + test_dir = os.path.join(self.get_temp_dir(), test_name) + if os.path.isdir(test_dir): + os.removedirs(test_dir) + os.makedirs(test_dir) + return test_dir + + def test_defaults(self): + with tf.Graph().as_default(): + with supervised_session.SupervisedSession('') as session: + self.assertEqual(0, session.run(session.scaffold.global_step_tensor)) + + def test_last_step(self): + logdir = self._test_dir('test_last_step') + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + # Run till step 3 and save. + monitors = [tf.contrib.learn.monitors.StopAtStep(last_step=3)] + with supervised_session.SupervisedSession('', scaffold=scaffold, + monitors=monitors) as session: + self.assertEqual(0, session.run(gstep)) + self.assertFalse(session.should_stop()) + self.assertEqual(1, session.run(do_step)) + self.assertFalse(session.should_stop()) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + self.assertEqual(3, session.run(do_step)) + self.assertTrue(session.should_stop()) + save_path = scaffold.saver.save(session.session, + os.path.join(logdir, 'step-3')) + # Run till step 5 and save. + def load_ckpt(scaffold, sess): + scaffold.saver.restore(sess, save_path) + scaffold = supervised_session.Scaffold(init_fn=load_ckpt) + monitors = [tf.contrib.learn.monitors.StopAtStep(last_step=5)] + with supervised_session.SupervisedSession('', scaffold=scaffold, + monitors=monitors) as session: + self.assertEqual(3, session.run(gstep)) + self.assertFalse(session.should_stop()) + self.assertEqual(4, session.run(do_step)) + self.assertFalse(session.should_stop()) + self.assertEqual(5, session.run(do_step)) + self.assertTrue(session.should_stop()) + + def test_num_steps(self): + logdir = self._test_dir('test_num_steps') + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + # Do 3 steps and save. + monitors = [tf.contrib.learn.monitors.StopAtStep(num_steps=3)] + with supervised_session.SupervisedSession('', scaffold=scaffold, + monitors=monitors) as session: + session.run(do_step) + self.assertFalse(session.should_stop()) + session.run(do_step) + self.assertFalse(session.should_stop()) + session.run(do_step) + self.assertTrue(session.should_stop()) + save_path = scaffold.saver.save(session.session, + os.path.join(logdir, 'step-3')) + # Restore and do 4 steps. + def load_ckpt(scaffold, sess): + scaffold.saver.restore(sess, save_path) + scaffold = supervised_session.Scaffold(init_fn=load_ckpt) + monitors = [tf.contrib.learn.monitors.StopAtStep(num_steps=4)] + with supervised_session.SupervisedSession('', scaffold=scaffold, + monitors=monitors) as session: + self.assertEqual(3, session.run(gstep)) + session.run(do_step) + self.assertFalse(session.should_stop()) + session.run(do_step) + self.assertFalse(session.should_stop()) + session.run(do_step) + self.assertFalse(session.should_stop()) + session.run(do_step) + self.assertTrue(session.should_stop()) + + # This set of tests, verifies the supervised session behavior when exceptions + # are raised next to the innermost session run() call. + + def test_recovery(self): + logdir = self._test_dir('test_recovery') + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + # Use a monitor to save the model every 100 steps. It also saves it at + # the end. + monitors = [tf.contrib.learn.monitors.CheckpointSaver( + 100, scaffold.saver, logdir)] + with supervised_session.SupervisedSession('', scaffold=scaffold, + checkpoint_dir=logdir, + monitors=monitors) as session: + self.assertEqual(0, session.run(gstep)) + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + # A restart will find the checkpoint and recover automatically. + with supervised_session.SupervisedSession( + '', scaffold=scaffold, checkpoint_dir=logdir) as session: + self.assertEqual(2, session.run(gstep)) + + def test_retry_on_aborted_error(self): + # Tests that we silently retry on abort. Note that this does not test + # recovery as we do not use a CheckpointSaver in this test. + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + monitor = RaiseOnceAtStepN(3, tf.errors.AbortedError(None, None, 'Abort')) + with supervised_session.SupervisedSession('', scaffold=scaffold, + monitors=[monitor]) as session: + self.assertEqual(0, session.run(gstep)) + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + # Here at step 3, the monitor triggers and raises AbortedError. The + # SupervisedSession automatically retries and restart from a freshly + # initialized session, so the step is back to 0 and running do_step + # moves it to 1. + self.assertEqual(1, session.run(do_step)) + self.assertFalse(session.should_stop()) + self.assertTrue(monitor.raised) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + + def test_recover_and_retry_on_aborted_error(self): + # Tests that we silently retry and recover on abort. This test uses + # a CheckpointSaver to have something to recover from. + logdir = self._test_dir('test_recover_and_retry_on_aborted_error') + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + abort_monitor = RaiseOnceAtStepN( + 3, tf.errors.AbortedError(None, None, 'Abort')) + # Save after each step. + ckpt_monitor = tf.contrib.learn.monitors.CheckpointSaver( + 1, scaffold.saver, logdir) + monitors = [abort_monitor, ckpt_monitor] + with supervised_session.SupervisedSession('', scaffold=scaffold, + checkpoint_dir=logdir, + monitors=monitors) as session: + self.assertEqual(0, session.run(gstep)) + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + # Here at step 3, the monitor triggers and raises AbortedError. The + # SupervisedSession automatically restores and retries. + self.assertEqual(3, session.run(do_step)) + self.assertTrue(abort_monitor.raised) + self.assertFalse(session.should_stop()) + self.assertEqual(4, session.run(do_step)) + self.assertFalse(session.should_stop()) + + def test_stop_cleanly_on_out_of_range_exception(self): + # Tests that we stop cleanly when OutOfRange is raised. + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + monitor = RaiseOnceAtStepN( + 3, tf.errors.OutOfRangeError(None, None, 'EOI')) + with supervised_session.SupervisedSession('', scaffold=scaffold, + monitors=[monitor]) as session: + self.assertEqual(0, session.run(gstep)) + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + # Here at step 3, the monitor triggers and raises OutOfRange. The + # session should go into should_stop() mode. We do not get a result + # in that case. + self.assertEqual(None, session.run(do_step)) + self.assertTrue(monitor.raised) + self.assertTrue(session.should_stop()) + + def test_stop_cleanly_on_custom_exception(self): + # Tests that we stop cleanly when an exception type of + # our choice is raised (StopIteration here.) + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + monitor = RaiseOnceAtStepN(3, StopIteration('I choose you')) + exception_types = [tf.errors.OutOfRangeError, StopIteration] + with supervised_session.SupervisedSession( + '', scaffold=scaffold, + monitors=[monitor], + clean_stop_exception_types=exception_types) as session: + self.assertEqual(0, session.run(gstep)) + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + # Here at step 3, the monitor triggers and raises StopIteration. The + # session should go into should_stop() mode. We do not get a result + # in that case. + self.assertEqual(None, session.run(do_step)) + self.assertTrue(monitor.raised) + self.assertTrue(session.should_stop()) + + # This set of tests, verifies the session behavior when exceptions are raised + # from code inside a "with SupervisedSession:" context. + + def test_regular_exception_pass_through_in_with_body(self): + # Tests that regular exceptions just pass through a "with + # SupervisedSession" block and set the session in stop mode. + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + monitor = RaiseOnceAtStepN(3, RuntimeError('regular exception')) + session = supervised_session.SupervisedSession('', scaffold=scaffold, + monitors=[monitor]) + with self.assertRaisesRegexp(RuntimeError, 'regular exception'): + with session: + self.assertEqual(0, session.run(gstep)) + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + # This triggers the monitor and raises the exception + session.run(do_step) + # We should not hit this + self.assertFalse(True) + self.assertTrue(monitor.raised) + self.assertTrue(session.should_stop()) + + def test_stop_cleanly_when_no_exception_in_with_body(self): + # Tests that regular exceptions pass through + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + session = supervised_session.SupervisedSession('', scaffold=scaffold) + with session: + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + # Should have closed. + self.assertTrue(session.should_stop()) + self.assertTrue(session._is_closed()) + + def test_stop_cleanly_on_out_of_range_exception_in_with_body(self): + # Tests that regular exceptions pass through + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + session = supervised_session.SupervisedSession('', scaffold=scaffold) + with session: + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + raise tf.errors.OutOfRangeError(None, None, 'EOI') + # Should have closed. + self.assertTrue(session.should_stop()) + self.assertTrue(session._is_closed()) + + def test_raises_regular_exceptions_in_with_body(self): + # Tests that regular exceptions in "with body" are seen outside. + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + session = supervised_session.SupervisedSession('', scaffold=scaffold) + # We should see that exception. + with self.assertRaisesRegexp(RuntimeError, 'regular exception'): + with session: + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + # Will be visible outside the "with body". + raise RuntimeError('regular exception') + # Should have closed. + self.assertTrue(session.should_stop()) + self.assertTrue(session._is_closed()) + + def test_stop_cleanly_on_custom_exception_in_with_body(self): + with tf.Graph().as_default(): + scaffold = supervised_session.Scaffold() + gstep = scaffold.global_step_tensor + do_step = tf.assign_add(gstep, 1) + exception_types = [tf.errors.OutOfRangeError, StopIteration] + session = supervised_session.SupervisedSession( + '', scaffold=scaffold, clean_stop_exception_types=exception_types) + with session: + self.assertEqual(1, session.run(do_step)) + self.assertEqual(2, session.run(do_step)) + self.assertFalse(session.should_stop()) + raise StopIteration('EOI') + # Should have closed. + self.assertTrue(session.should_stop()) + self.assertTrue(session._is_closed()) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/learn/python/learn/tests/wrapped_session_test.py b/tensorflow/contrib/learn/python/learn/tests/wrapped_session_test.py index 5b31bfdc23f..526f8413f9a 100644 --- a/tensorflow/contrib/learn/python/learn/tests/wrapped_session_test.py +++ b/tensorflow/contrib/learn/python/learn/tests/wrapped_session_test.py @@ -21,8 +21,10 @@ from __future__ import print_function import tensorflow as tf +from tensorflow.contrib.learn.python.learn import wrapped_session -class StopAtNSession(tf.contrib.learn.WrappedSession): + +class StopAtNSession(wrapped_session.WrappedSession): """A wrapped session that stops at the N-th call to _check_stop.""" def __init__(self, sess, n): @@ -42,13 +44,13 @@ class WrappedSessionTest(tf.test.TestCase): def test_properties(self): with self.test_session() as sess: tf.constant(0.0) - wrapped_sess = tf.contrib.learn.WrappedSession(sess) + wrapped_sess = wrapped_session.WrappedSession(sess) self.assertEquals(sess.graph, wrapped_sess.graph) self.assertEquals(sess.sess_str, wrapped_sess.sess_str) def test_should_stop_on_close(self): with self.test_session() as sess: - wrapped_sess = tf.contrib.learn.WrappedSession(sess) + wrapped_sess = wrapped_session.WrappedSession(sess) self.assertFalse(wrapped_sess.should_stop()) wrapped_sess.close() self.assertTrue(wrapped_sess.should_stop()) @@ -64,7 +66,7 @@ class WrappedSessionTest(tf.test.TestCase): def test_should_stop_delegates_to_wrapped_session(self): with self.test_session() as sess: wrapped_sess0 = StopAtNSession(sess, 4) - wrapped_sess1 = tf.contrib.learn.WrappedSession(wrapped_sess0) + wrapped_sess1 = wrapped_session.WrappedSession(wrapped_sess0) self.assertFalse(wrapped_sess1.should_stop()) self.assertFalse(wrapped_sess1.should_stop()) self.assertFalse(wrapped_sess1.should_stop()) @@ -73,7 +75,7 @@ class WrappedSessionTest(tf.test.TestCase): def test_close_twice(self): with self.test_session() as sess: - wrapped_sess = tf.contrib.learn.WrappedSession(sess) + wrapped_sess = wrapped_session.WrappedSession(sess) wrapped_sess.close() self.assertTrue(wrapped_sess.should_stop()) wrapped_sess.close() @@ -84,7 +86,7 @@ class WrappedSessionTest(tf.test.TestCase): c = tf.constant(0) v = tf.identity(c) self.assertEqual(42, sess.run(v, feed_dict={c: 42})) - wrapped_sess = tf.contrib.learn.WrappedSession(sess) + wrapped_sess = wrapped_session.WrappedSession(sess) self.assertEqual(51, wrapped_sess.run(v, feed_dict={c: 51})) diff --git a/tensorflow/python/training/coordinator.py b/tensorflow/python/training/coordinator.py index 1785ff372b6..96a1bdb52b7 100644 --- a/tensorflow/python/training/coordinator.py +++ b/tensorflow/python/training/coordinator.py @@ -139,7 +139,7 @@ class Coordinator(object): """ if clean_stop_exception_types is None: clean_stop_exception_types = (errors.OutOfRangeError,) - self._clean_stop_exception_types = clean_stop_exception_types + self._clean_stop_exception_types = tuple(clean_stop_exception_types) # Protects all attributes. self._lock = threading.Lock() # Event set when threads must stop.