Add SupervisedSession class.
Extend the Monitor interface to support a post_step() callback which is passed the Session. Add a Scaffold object to hold onto the commonly needed graph pieces to run training. Based on a design by @ilblackdragon . Change: 127323605
This commit is contained in:
parent
3219a1e939
commit
4f6e9efb40
@ -684,6 +684,26 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "supervised_session_test",
|
||||
size = "small",
|
||||
srcs = ["python/learn/tests/supervised_session_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "summary_writer_cache_test",
|
||||
size = "small",
|
||||
srcs = ["python/learn/tests/summary_writer_cache_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "inspect_checkpoint",
|
||||
srcs = [
|
||||
|
@ -31,7 +31,6 @@ from tensorflow.contrib.learn.python.learn import monitors
|
||||
from tensorflow.contrib.learn.python.learn import ops
|
||||
from tensorflow.contrib.learn.python.learn import preprocessing
|
||||
from tensorflow.contrib.learn.python.learn import utils
|
||||
from tensorflow.contrib.learn.python.learn.coordinated_session import *
|
||||
from tensorflow.contrib.learn.python.learn.dataframe import *
|
||||
from tensorflow.contrib.learn.python.learn.estimators import *
|
||||
from tensorflow.contrib.learn.python.learn.experiment import Experiment
|
||||
@ -42,6 +41,4 @@ from tensorflow.contrib.learn.python.learn.graph_actions import run_feeds
|
||||
from tensorflow.contrib.learn.python.learn.graph_actions import run_n
|
||||
from tensorflow.contrib.learn.python.learn.graph_actions import train
|
||||
from tensorflow.contrib.learn.python.learn.io import *
|
||||
from tensorflow.contrib.learn.python.learn.recoverable_session import *
|
||||
from tensorflow.contrib.learn.python.learn.wrapped_session import *
|
||||
# pylint: enable=wildcard-import
|
||||
|
@ -57,3 +57,8 @@ class CoordinatedSession(WrappedSession):
|
||||
self._coord.request_stop(e)
|
||||
if self._coord.should_stop():
|
||||
self._coord.join(self._coordinated_threads_to_join)
|
||||
|
||||
# TODO(touts): Add a close() method that also joins the coordinator
|
||||
# but does not raise exceptions. This can only be done reliably when the
|
||||
# Coordinator keeps a pointer to the coordinated threads, otherwise we do not
|
||||
# know which threads to join.
|
||||
|
@ -23,7 +23,6 @@ import six
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.wrapped_session import WrappedSession
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
|
||||
|
||||
class MonitoredSession(WrappedSession):
|
||||
@ -73,7 +72,6 @@ class MonitoredSession(WrappedSession):
|
||||
|
||||
if self._last_step is None:
|
||||
self._last_step = WrappedSession.run(self, self._global_step_tensor)
|
||||
logging.info('Initialized step to: %d', self._last_step)
|
||||
|
||||
monitors_step = self._last_step + 1
|
||||
monitor_fetches = []
|
||||
@ -104,12 +102,16 @@ class MonitoredSession(WrappedSession):
|
||||
induce_stop = monitor.step_end(monitors_step, monitor_outputs)
|
||||
self._should_stop = self._should_stop or induce_stop
|
||||
|
||||
# Call the post_step methods.
|
||||
for monitor in self._monitors:
|
||||
monitor.post_step(monitors_step, self._sess)
|
||||
|
||||
return outputs['caller']
|
||||
|
||||
|
||||
# TODO(ispir): Remove following logic after forcing monitors returns tensors.
|
||||
def _as_graph_element(obj, graph):
|
||||
"""Retrieves Graph element from tensors or tensor names."""
|
||||
"""Retrieves Graph element."""
|
||||
graph = graph or ops.get_default_graph()
|
||||
if not isinstance(obj, six.string_types):
|
||||
if not hasattr(obj, 'graph') or obj.graph != graph:
|
||||
|
@ -69,13 +69,19 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.summary_writer_cache import SummaryWriterCache
|
||||
from tensorflow.contrib.learn.python.learn.utils import export
|
||||
from tensorflow.core.framework.summary_pb2 import Summary
|
||||
from tensorflow.core.util.event_pb2 import SessionLog
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import saver
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
from tensorflow.python.training import summary_io
|
||||
|
||||
|
||||
@ -92,6 +98,7 @@ class BaseMonitor(object):
|
||||
self._current_epoch = None
|
||||
self._current_step = None
|
||||
self._max_steps = None
|
||||
self._init_step = None
|
||||
self._estimator = None
|
||||
|
||||
def set_estimator(self, estimator):
|
||||
@ -108,11 +115,14 @@ class BaseMonitor(object):
|
||||
# TODO(mdan): This should fail if called twice with the same estimator.
|
||||
self._estimator = estimator
|
||||
|
||||
def begin(self, max_steps=None):
|
||||
def begin(self, max_steps=None, init_step=None):
|
||||
"""Called at the beginning of training.
|
||||
|
||||
When called, the default graph is the one we are executing.
|
||||
|
||||
Args:
|
||||
max_steps: `int`, the maximum global step this training will run until.
|
||||
init_step: `int`, step at which this training will start.
|
||||
|
||||
Raises:
|
||||
ValueError: if we've already begun a run.
|
||||
@ -120,14 +130,19 @@ class BaseMonitor(object):
|
||||
if self._begun:
|
||||
raise ValueError("begin called twice without end.")
|
||||
self._max_steps = max_steps
|
||||
self._init_step = init_step
|
||||
self._begun = True
|
||||
|
||||
def end(self):
|
||||
def end(self, session=None):
|
||||
"""Callback at the end of training/evaluation.
|
||||
|
||||
Args:
|
||||
session: A `tf.Session` object that can be used to run ops.
|
||||
|
||||
Raises:
|
||||
ValueError: if we've not begun a run.
|
||||
"""
|
||||
_ = session
|
||||
if not self._begun:
|
||||
raise ValueError("end called without begin.")
|
||||
self._max_steps = None
|
||||
@ -178,8 +193,6 @@ class BaseMonitor(object):
|
||||
ValueError: if we've already begun a step, or `step` < 0, or
|
||||
`step` > `max_steps`.
|
||||
"""
|
||||
if self._current_step is not None:
|
||||
raise ValueError("step_begin called twice without step_end.")
|
||||
if (step < 0) or (
|
||||
(self._max_steps is not None) and (step > self._max_steps)):
|
||||
raise ValueError("Invalid step %s." % step)
|
||||
@ -196,6 +209,9 @@ class BaseMonitor(object):
|
||||
In addition, the callback has the opportunity to stop training by returning
|
||||
`True`. This is useful for early stopping, for example.
|
||||
|
||||
Note that this method is not called if the call to `Session.run()` that
|
||||
followed the last call to `step_begin()` failed.
|
||||
|
||||
Args:
|
||||
step: `int`, the current value of the global step.
|
||||
output: `dict` mapping `string` values representing tensor names to
|
||||
@ -214,13 +230,26 @@ class BaseMonitor(object):
|
||||
self._current_step = None
|
||||
return False
|
||||
|
||||
def post_step(self, step, session): # pylint: disable=unused-argument
|
||||
"""Callback after the step is finished.
|
||||
|
||||
Called after step_end and receives session to perform extra session.run
|
||||
calls. If failure occurred in the process, will be called as well.
|
||||
|
||||
Args:
|
||||
step: `int`, global step of the model.
|
||||
session: `Session` object.
|
||||
"""
|
||||
_ = step, session
|
||||
|
||||
|
||||
class EveryN(BaseMonitor):
|
||||
"""Base class for monitors that execute callbacks every n steps.
|
||||
"""Base class for monitors that execute callbacks every N steps.
|
||||
|
||||
This class adds two new callbacks:
|
||||
This class adds three new callbacks:
|
||||
- every_n_step_begin
|
||||
- every_n_step_end
|
||||
- every_n_pos_step
|
||||
|
||||
The callbacks are executed every n steps, or optionally every step for the
|
||||
first m steps, where m and n can both be user-specified.
|
||||
@ -234,11 +263,16 @@ class EveryN(BaseMonitor):
|
||||
|
||||
Failing to call the super implementation will cause unpredictible behavior.
|
||||
|
||||
The `every_n_post_step()` callback is also called after the last step if it
|
||||
was not already called through the regular conditions. Note that
|
||||
`every_n_step_begin()` and `every_n_step_end()` do not receive that special
|
||||
treatment.
|
||||
|
||||
"""
|
||||
# TODO(ipolosukhin): Add also every n seconds.
|
||||
|
||||
def __init__(self, every_n_steps=100, first_n_steps=1):
|
||||
"""Initialized an `EveryN` monitor.
|
||||
"""Initializes an `EveryN` monitor.
|
||||
|
||||
Args:
|
||||
every_n_steps: `int`, the number of steps to allow between callbacks.
|
||||
@ -249,18 +283,10 @@ class EveryN(BaseMonitor):
|
||||
super(EveryN, self).__init__()
|
||||
self._every_n_steps = every_n_steps
|
||||
self._first_n_steps = first_n_steps
|
||||
self._last_step = 0
|
||||
self._active_step = None
|
||||
|
||||
def begin(self, max_steps=None):
|
||||
"""Overrides `BaseMonitor.begin`.
|
||||
|
||||
When overriding this method, you must call the super implementation.
|
||||
|
||||
Args:
|
||||
max_steps: `int`, the maximum global step value.
|
||||
"""
|
||||
super(EveryN, self).begin(max_steps)
|
||||
# Last step in the model.
|
||||
self._last_step = None
|
||||
# Last step at which we called one of the every_n methods
|
||||
self._last_active_step = None
|
||||
|
||||
def every_n_step_begin(self, step): # pylint: disable=unused-argument
|
||||
"""Callback before every n'th step begins.
|
||||
@ -294,6 +320,15 @@ class EveryN(BaseMonitor):
|
||||
"""
|
||||
return False
|
||||
|
||||
def every_n_post_step(self, step, session):
|
||||
"""Callback after a step is finished or `end()` is called.
|
||||
|
||||
Args:
|
||||
step: `int`, the current value of the global step.
|
||||
session: `Session` object.
|
||||
"""
|
||||
pass
|
||||
|
||||
def step_begin(self, step):
|
||||
"""Overrides `BaseMonitor.step_begin`.
|
||||
|
||||
@ -309,13 +344,13 @@ class EveryN(BaseMonitor):
|
||||
ValueError: if called more than once during a step.
|
||||
"""
|
||||
super(EveryN, self).step_begin(step)
|
||||
self._last_step = step
|
||||
if self._last_active_step is None:
|
||||
self._last_active_step = step - 1
|
||||
if (step <= self._first_n_steps or
|
||||
step >= (self._every_n_steps + self._last_step) or
|
||||
step == self._max_steps):
|
||||
if self._active_step is not None:
|
||||
raise ValueError(
|
||||
"Starting step %s, %s still active." % (step, self._active_step))
|
||||
self._active_step = step
|
||||
step >= (self._every_n_steps + self._last_active_step) or
|
||||
step == self._max_steps): # Note: max_steps can be None here.
|
||||
self._last_active_step = step
|
||||
return self.every_n_step_begin(step)
|
||||
return []
|
||||
|
||||
@ -334,12 +369,59 @@ class EveryN(BaseMonitor):
|
||||
or `False` otherwise.
|
||||
"""
|
||||
super(EveryN, self).step_end(step, output)
|
||||
to_stop = False
|
||||
if (self._active_step is not None) and (self._active_step == step):
|
||||
self._last_step = step
|
||||
to_stop = self.every_n_step_end(step, output)
|
||||
self._active_step = None
|
||||
return to_stop
|
||||
if self._last_active_step == step:
|
||||
return self.every_n_step_end(step, output)
|
||||
return False
|
||||
|
||||
def post_step(self, step, session):
|
||||
super(EveryN, self).post_step(step, session)
|
||||
if self._last_active_step == step:
|
||||
self.every_n_post_step(step, session)
|
||||
|
||||
def end(self, session=None):
|
||||
super(EveryN, self).end(session=session)
|
||||
if self._last_step != self._last_active_step:
|
||||
self.every_n_post_step(self._last_step, session)
|
||||
|
||||
|
||||
class StopAtStep(BaseMonitor):
|
||||
"""Monitor to request stop at a specified step."""
|
||||
|
||||
def __init__(self, num_steps=None, last_step=None):
|
||||
"""Create a StopAtStep monitor.
|
||||
|
||||
This monitor requests stop after either a number of steps have been
|
||||
executed or a last step has been reached. Only of the two options can be
|
||||
specified.
|
||||
|
||||
if `num_steps` is specified, it indicates the number of steps to execute
|
||||
after `begin()` is called. If instead `last_step` is specified, it
|
||||
indicates the last step we want to execute, as passed to the `step_begin()`
|
||||
call.
|
||||
|
||||
Args:
|
||||
num_steps: Number of steps to execute.
|
||||
last_step: Step after which to stop.
|
||||
|
||||
Raises:
|
||||
ValueError: If one of the arguments is invalid.
|
||||
"""
|
||||
super(StopAtStep, self).__init__()
|
||||
if num_steps is None and last_step is None:
|
||||
raise ValueError("One of num_steps or last_step must be specified.")
|
||||
if num_steps is not None and last_step is not None:
|
||||
raise ValueError("Only one of num_steps or last_step can be specified.")
|
||||
self._num_steps = num_steps
|
||||
self._last_step = last_step
|
||||
|
||||
def begin(self, max_steps=None, init_step=None):
|
||||
super(StopAtStep, self).begin(max_steps=max_steps, init_step=init_step)
|
||||
if self._num_steps is not None:
|
||||
self._last_step = init_step + self._num_steps
|
||||
|
||||
def step_end(self, step, output):
|
||||
super(StopAtStep, self).step_end(step, output)
|
||||
return step >= self._last_step
|
||||
|
||||
|
||||
# TODO(ptucker): Rename to LoggingTensor since it's not writing to stdout.
|
||||
@ -460,8 +542,8 @@ class SummarySaver(EveryN):
|
||||
self._summary_writer.add_summary(summary_strs, step)
|
||||
return False
|
||||
|
||||
def end(self):
|
||||
super(SummarySaver, self).end()
|
||||
def end(self, session=None):
|
||||
super(SummarySaver, self).end(session=session)
|
||||
if self._summary_writer:
|
||||
self._summary_writer.flush()
|
||||
|
||||
@ -553,7 +635,7 @@ class ValidationMonitor(EveryN):
|
||||
if self._estimator is None:
|
||||
raise ValueError("Missing call to set_estimator.")
|
||||
# Check that we are not running evaluation on the same checkpoint.
|
||||
latest_path = saver.latest_checkpoint(self._estimator.model_dir)
|
||||
latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
|
||||
if latest_path == self._latest_path:
|
||||
logging.info("Skipping evaluation due to same checkpoint %s for step %d "
|
||||
"as for step %d.", latest_path, step, self._latest_path_step)
|
||||
@ -677,8 +759,8 @@ class GraphDump(BaseMonitor):
|
||||
self._ignore_ops = ignore_ops or GraphDump.IGNORE_OPS
|
||||
self._data = {}
|
||||
|
||||
def begin(self, max_steps):
|
||||
super(GraphDump, self).begin(max_steps)
|
||||
def begin(self, max_steps=None, init_step=None):
|
||||
super(GraphDump, self).begin(max_steps=max_steps, init_step=init_step)
|
||||
self._tensors = []
|
||||
graph = ops.get_default_graph()
|
||||
graph_def = graph.as_graph_def()
|
||||
@ -782,9 +864,121 @@ class ExportMonitor(EveryN):
|
||||
logging.info("Skipping exporting for the same step. "
|
||||
"Consider exporting less frequently.")
|
||||
|
||||
def end(self):
|
||||
super(ExportMonitor, self).end()
|
||||
def end(self, session=None):
|
||||
super(ExportMonitor, self).end(session=session)
|
||||
export.export_estimator(self._estimator,
|
||||
self.export_dir,
|
||||
exports_to_keep=self.exports_to_keep,
|
||||
signature_fn=self.signature_fn)
|
||||
|
||||
|
||||
class CheckpointSaver(EveryN):
|
||||
"""Saves checkpoints every N steps."""
|
||||
|
||||
def __init__(self, every_n_steps, saver, checkpoint_dir,
|
||||
checkpoint_basename="model.ckpt",
|
||||
first_n_steps=-1):
|
||||
"""Initialize CheckpointSaver monitor.
|
||||
|
||||
Args:
|
||||
every_n_steps: `int`, save every N steps.
|
||||
saver: `Saver` object, used for saving.
|
||||
checkpoint_dir: `str`, base directory for the checkpoint files.
|
||||
checkpoint_basename: `str`, base name for the checkpoint files.
|
||||
first_n_steps: `int`, if positive, save every step during the
|
||||
first `first_n_steps` steps.
|
||||
"""
|
||||
logging.info("Create CheckpointSaver")
|
||||
super(CheckpointSaver, self).__init__(every_n_steps=every_n_steps,
|
||||
first_n_steps=first_n_steps)
|
||||
self._saver = saver
|
||||
self._summary_writer = SummaryWriterCache.get(checkpoint_dir)
|
||||
self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
|
||||
|
||||
def every_n_post_step(self, step, session):
|
||||
logging.info("Saving checkpoints for %d into %s." % (step, self._save_path))
|
||||
self._saver.save(session, self._save_path, global_step=step)
|
||||
if self._summary_writer:
|
||||
self._summary_writer.add_session_log(
|
||||
SessionLog(status=SessionLog.CHECKPOINT,
|
||||
checkpoint_path=self._save_path),
|
||||
step)
|
||||
|
||||
|
||||
class StepCounter(EveryN):
|
||||
"""Steps per second monitor."""
|
||||
|
||||
def __init__(self, every_n_steps=100, output_dir=None,
|
||||
summary_writer=None):
|
||||
super(StepCounter, self).__init__(every_n_steps=every_n_steps)
|
||||
self._summary_tag = "global_step/sec"
|
||||
self._last_reported_step = None
|
||||
self._last_reported_time = None
|
||||
self._summary_writer = None
|
||||
if summary_writer is None and output_dir:
|
||||
self._summary_writer = SummaryWriterCache.get(output_dir)
|
||||
|
||||
def begin(self, init_step):
|
||||
super(StepCounter, self).begin(init_step)
|
||||
self._last_reported_step = self._init_step
|
||||
self._last_reported_time = time.time()
|
||||
|
||||
def set_estimator(self, estimator):
|
||||
super(StepCounter, self).set_estimator(estimator)
|
||||
if self._summary_writer is None:
|
||||
self._summary_writer = SummaryWriterCache.get(estimator.model_dir)
|
||||
|
||||
def every_n_step_end(self, current_step, outputs):
|
||||
current_time = time.time()
|
||||
if self._last_reported_time is not None and self._summary_writer:
|
||||
added_steps = current_step - self._last_reported_step
|
||||
elapsed_time = current_time - self._last_reported_time
|
||||
steps_per_sec = added_steps / elapsed_time
|
||||
summary = Summary(value=[Summary.Value(tag=self._summary_tag,
|
||||
simple_value=steps_per_sec)])
|
||||
self._summary_writer.add_summary(summary, current_step)
|
||||
self._last_reported_step = current_step
|
||||
self._last_reported_time = current_time
|
||||
|
||||
|
||||
class NanLossDuringTrainingError(RuntimeError):
|
||||
|
||||
def __str__(self):
|
||||
return "NaN loss during training."
|
||||
|
||||
|
||||
class NanLoss(EveryN):
|
||||
"""NaN Loss monitor.
|
||||
|
||||
Monitors loss and stops training if loss is NaN.
|
||||
Can either fail with exception or just stop training.
|
||||
"""
|
||||
|
||||
def __init__(self, loss_tensor, every_n_steps=100, fail_on_nan_loss=True):
|
||||
"""Initializes NanLoss monitor.
|
||||
|
||||
Args:
|
||||
loss_tensor: `Tensor`, the loss tensor.
|
||||
every_n_steps: `int`, run check every this many steps.
|
||||
fail_on_nan_loss: `bool`, whether to raise exception when loss is NaN.
|
||||
"""
|
||||
super(NanLoss, self).__init__(every_n_steps=every_n_steps)
|
||||
self._loss_tensor = loss_tensor
|
||||
self._fail_on_nan_loss = fail_on_nan_loss
|
||||
|
||||
def every_n_step_begin(self, step):
|
||||
super(NanLoss, self).every_n_step_begin(step)
|
||||
return self._loss_tensor
|
||||
|
||||
def every_n_step_end(self, step, outputs):
|
||||
super(NanLoss, self).every_n_step_end(step, outputs)
|
||||
if np.isnan(outputs):
|
||||
failure_message = "Model diverged with loss = NaN."
|
||||
if self._fail_on_nan_loss:
|
||||
logging.error(failure_message)
|
||||
raise NanLossDuringTrainingError
|
||||
else:
|
||||
logging.warning(failure_message)
|
||||
# We don't raise an error but we return "should stop" so we stop, but
|
||||
# without an exception.
|
||||
return True
|
||||
|
@ -0,0 +1,65 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Wrapper for a Session-like object that handles threads and recovery.
|
||||
|
||||
Based on an original design of Illia Polosukhin.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import threading
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.training import summary_io
|
||||
|
||||
|
||||
class SummaryWriterCache(object):
|
||||
"""Cache for summary writers.
|
||||
|
||||
This class caches summary writers, one per directory.
|
||||
"""
|
||||
# Cache, keyed by directory.
|
||||
_cache = {}
|
||||
|
||||
# Lock protecting _SUMMARY_WRITERS.
|
||||
_lock = threading.RLock()
|
||||
|
||||
@staticmethod
|
||||
def clear():
|
||||
"""Clear cached summary writers. Currently only used for unit tests."""
|
||||
with SummaryWriterCache._lock:
|
||||
SummaryWriterCache._cache = {}
|
||||
|
||||
@staticmethod
|
||||
def get(logdir):
|
||||
"""Returns the SummaryWriter for the specified directory.
|
||||
|
||||
Args:
|
||||
logdir: str, name of the directory.
|
||||
|
||||
Returns:
|
||||
A `SummaryWriter`.
|
||||
"""
|
||||
with SummaryWriterCache._lock:
|
||||
if logdir not in SummaryWriterCache._cache:
|
||||
SummaryWriterCache._cache[logdir] = summary_io.SummaryWriter(
|
||||
logdir, graph=ops.get_default_graph())
|
||||
return SummaryWriterCache._cache[logdir]
|
||||
|
||||
|
||||
# Backward compatible interface. Remove?
|
||||
clear_summary_writers = SummaryWriterCache.clear
|
||||
get_summary_writer = SummaryWriterCache.get
|
301
tensorflow/contrib/learn/python/learn/supervised_session.py
Normal file
301
tensorflow/contrib/learn/python/learn/supervised_session.py
Normal file
@ -0,0 +1,301 @@
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Wrapper for a Session-like object that handles threads and recovery.
|
||||
|
||||
Based on an original design of Illia Polosukhin.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||
from tensorflow.contrib.learn.python.learn import coordinated_session
|
||||
from tensorflow.contrib.learn.python.learn import monitored_session
|
||||
from tensorflow.contrib.learn.python.learn import recoverable_session
|
||||
from tensorflow.contrib.learn.python.learn import summary_writer_cache
|
||||
from tensorflow.core.util.event_pb2 import SessionLog
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import logging_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import queue_runner
|
||||
from tensorflow.python.training import saver as training_saver
|
||||
from tensorflow.python.training import session_manager as sm
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
# TODO(touts): Share that with the Supervisor.
|
||||
class Scaffold(object):
|
||||
"""Structure to create or gather pieces commonly needed to train a model.
|
||||
|
||||
When you build a model for training you usually need ops to initialize
|
||||
variables, a `Saver` to checkpoint them, an op to collect summaries for
|
||||
the visualizer, and so on.
|
||||
|
||||
Various libraries built on top of the core TensorFlow library take care of
|
||||
creating some or all of these pieces and storing them in well known
|
||||
collections in the graph. The `Scaffold` class helps pick these pieces from
|
||||
the graph collections, creating and adding them to the collections if needed.
|
||||
|
||||
If you call the scaffold constructor without any arguments it will pick
|
||||
pieces from the collections, creating default ones if needed. You can pass
|
||||
arguments to the constructor to provide your own pieces. Pieces that you
|
||||
pass to the constructor are not added to the graph collections.
|
||||
|
||||
The following pieces are directly accessible as attributes of the `Scaffold`
|
||||
object:
|
||||
|
||||
* `saver`: A `tf.Saver` object taking care of saving the variables. Picked
|
||||
from and stored into the `SAVERS` collection in the graph.
|
||||
* `init_op`: An op to run to initialize the variables. Picked from and
|
||||
stored into the `INIT_OP` collection in the graph.
|
||||
* `ready_op`: An op to verify that the variables are initialized. Picked
|
||||
from and stored into the `READY_OP` collection in the graph.
|
||||
* `local_init_op`: An op to initialize the local variables. Picked
|
||||
from and stored into the `LOCAL_INIT_OP` collection in the graph.
|
||||
* `summary_op`: An op to run and merge the summaries in the graph. Picked
|
||||
from and stored into the `SUMMARY_OP` collection in the graph.
|
||||
* `global_step`: A tensor containing the global step counter. Picked
|
||||
from and stored into the `GLOBAL_STEP` collection in the graph.
|
||||
|
||||
You can also pass the following additional pieces to the constructor:
|
||||
|
||||
* `init_feed_dict`: A sessionn feed dictionary that should be used when
|
||||
running the init op.
|
||||
* `init_fn`: A callable to run run after the init op to perform additional
|
||||
initializations. The callable will be called as
|
||||
`init_fn(scaffold, session)`.
|
||||
|
||||
"""
|
||||
# TODO(touts): consider adding the output dir and summary writer (cached)?
|
||||
# TODO(touts): consider finalizeing the graph? (If the graph is
|
||||
# modified later, the cached parts could be wrong.)
|
||||
# TODO(touts): I do not think we should pass keep_checkpoint_max here.
|
||||
# TODO(touts): Add individual static functions for init_op(), etc. that
|
||||
# implement the caching logic.
|
||||
|
||||
def __init__(self,
|
||||
global_step_tensor=None,
|
||||
init_op=None,
|
||||
init_feed_dict=None,
|
||||
init_fn=None,
|
||||
ready_op=None,
|
||||
local_init_op=None,
|
||||
summary_op=None,
|
||||
saver=None,
|
||||
keep_checkpoint_max=5):
|
||||
"""Create a scaffold.
|
||||
|
||||
Args:
|
||||
global_step_tensor: Optional tensor to use as the global step counter.
|
||||
init_op: Optional op for initializing variables.
|
||||
init_feed_dict: Optional session feed dictionary to use when running the
|
||||
init_op.
|
||||
init_fn: Optional function to use to initialize the model after running
|
||||
the init_op. Will be called as `init_fn(scaffold, session)`.
|
||||
ready_op: Optional op to verify that the variables are initialized. Must
|
||||
return an empty scalar string tensor when the variables are
|
||||
initialized, or a non-empty one listing the names of the
|
||||
non-initialized variables.
|
||||
local_init_op: Optional op to initialize local variables.
|
||||
summary_op: Optional op to gather all summaries. Must return a scalar
|
||||
string tensor containing a serialized `Summary` proto.
|
||||
saver: Optional `tf.Saver` object to use to save and restore variables.
|
||||
keep_checkpoint_max: Optional parameter to use to construct a saver if
|
||||
none is already there in the graph.
|
||||
"""
|
||||
if global_step_tensor is None:
|
||||
global_step_tensor = contrib_variables.get_or_create_global_step()
|
||||
self.global_step_tensor = global_step_tensor
|
||||
if init_op is None:
|
||||
init_op = Scaffold._get_or_default(
|
||||
ops.GraphKeys.INIT_OP, variables.initialize_all_variables)
|
||||
self.init_op = init_op
|
||||
self.init_feed_dict = init_feed_dict
|
||||
# NOTE(touts): modifying the init function to be passed the scaffold is a
|
||||
# hack to make it easy to find the saver. Is there a better way?
|
||||
if init_fn:
|
||||
self.init_fn = lambda sess: init_fn(self, sess)
|
||||
else:
|
||||
self.init_fn = None
|
||||
if ready_op is None:
|
||||
ready_op = Scaffold._get_or_default(
|
||||
ops.GraphKeys.READY_OP, variables.report_uninitialized_variables)
|
||||
self.ready_op = ready_op
|
||||
if local_init_op is None:
|
||||
local_init_op = Scaffold._get_or_default(
|
||||
ops.GraphKeys.LOCAL_INIT_OP, Scaffold._default_local_init_op)
|
||||
self.local_init_op = local_init_op
|
||||
if summary_op is None:
|
||||
summary_op = Scaffold._get_or_default(
|
||||
ops.GraphKeys.SUMMARY_OP, logging_ops.merge_all_summaries)
|
||||
# pylint: disable=g-long-lambda
|
||||
if saver is None:
|
||||
saver = Scaffold._get_or_default(
|
||||
ops.GraphKeys.SAVERS,
|
||||
lambda: training_saver.Saver(sharded=True,
|
||||
max_to_keep=keep_checkpoint_max))
|
||||
# pylint: enable=g-long-lambda
|
||||
self.saver = saver
|
||||
|
||||
@staticmethod
|
||||
def _get_or_default(collection_key, default_constructor):
|
||||
elements = ops.get_collection(collection_key)
|
||||
if elements:
|
||||
return elements[0]
|
||||
op = default_constructor()
|
||||
if op is not None:
|
||||
ops.add_to_collection(collection_key, op)
|
||||
return op
|
||||
|
||||
@staticmethod
|
||||
def _default_local_init_op():
|
||||
return control_flow_ops.group(variables.initialize_local_variables(),
|
||||
data_flow_ops.initialize_all_tables())
|
||||
|
||||
|
||||
class SupervisedSession(object):
|
||||
"""Session-like object that supports recovery and monitors.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, master, is_chief=True, checkpoint_dir=None,
|
||||
monitors=None, scaffold=None, config=None,
|
||||
clean_stop_exception_types=None):
|
||||
self._graph = ops.get_default_graph()
|
||||
self._master = master
|
||||
self._checkpoint_dir = checkpoint_dir
|
||||
self._is_chief = is_chief
|
||||
self._config = config
|
||||
self._clean_stop_exception_types = clean_stop_exception_types
|
||||
self._monitors = monitors or []
|
||||
self._scaffold = scaffold or Scaffold()
|
||||
# Finalize and write the graph.
|
||||
self._graph.finalize()
|
||||
# Create the session.
|
||||
self._session_manager = sm.SessionManager(
|
||||
local_init_op=self._scaffold.local_init_op,
|
||||
ready_op=self._scaffold.ready_op,
|
||||
graph=ops.get_default_graph())
|
||||
self._sess = recoverable_session.RecoverableSession(self._create_session)
|
||||
# Call the begin() method of monitors.
|
||||
self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor)
|
||||
for monitor in self._monitors:
|
||||
monitor.begin(max_steps=None, init_step=self._init_step)
|
||||
# Write the graph out, note: this uses self._init_step.
|
||||
self.write_graph()
|
||||
|
||||
def _create_session(self):
|
||||
"""Factory for the RecoverableSession.
|
||||
|
||||
Returns:
|
||||
A session, initialized or recovered as needed.
|
||||
"""
|
||||
if self._is_chief:
|
||||
tf_sess = self._session_manager.prepare_session(
|
||||
self._master, saver=self._scaffold.saver,
|
||||
checkpoint_dir=self._checkpoint_dir, config=self._config,
|
||||
init_op=self._scaffold.init_op,
|
||||
init_feed_dict=self._scaffold.init_feed_dict,
|
||||
init_fn=self._scaffold.init_fn)
|
||||
else:
|
||||
tf_sess = self._session_manager.wait_for_session(
|
||||
self._master, config=self._config)
|
||||
# Keep the tf_sess for quick runs of global step when needed.
|
||||
self._tf_sess = tf_sess
|
||||
self._coord = coordinator.Coordinator(
|
||||
clean_stop_exception_types=self._clean_stop_exception_types)
|
||||
self._coordinated_threads_to_join = queue_runner.start_queue_runners(
|
||||
sess=tf_sess, coord=self._coord)
|
||||
return coordinated_session.CoordinatedSession(
|
||||
monitored_session.MonitoredSession(tf_sess, self._monitors,
|
||||
self._scaffold.global_step_tensor),
|
||||
self._coord, self._coordinated_threads_to_join)
|
||||
|
||||
@property
|
||||
def scaffold(self):
|
||||
return self._scaffold
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self._tf_sess
|
||||
|
||||
def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
|
||||
"""Run ops in the supervised session.
|
||||
|
||||
This method is completely compatible with the `tf.Session.run()` method.
|
||||
|
||||
Args:
|
||||
fetches: Same as `tf.Session.run()`.
|
||||
feed_dict: Same as `tf.Session.run()`.
|
||||
options: Same as `tf.Session.run()`.
|
||||
run_metadata: Same as `tf.Session.run()`.
|
||||
|
||||
Returns:
|
||||
Same as `tf.Session.run()`.
|
||||
"""
|
||||
return self._sess.run(fetches, feed_dict=feed_dict, options=options,
|
||||
run_metadata=run_metadata)
|
||||
|
||||
def should_stop(self):
|
||||
if self._sess:
|
||||
return self._sess.should_stop()
|
||||
return True
|
||||
|
||||
def close(self):
|
||||
# Run the Monitor.end() methods.
|
||||
for monitor in self._monitors:
|
||||
monitor.end(self._tf_sess)
|
||||
self._sess.close()
|
||||
self._sess = None
|
||||
self._tf_sess = None
|
||||
|
||||
def _is_closed(self):
|
||||
"""Return True if the supervised session is closed. For tests only.
|
||||
|
||||
Returns:
|
||||
A boolean.
|
||||
"""
|
||||
return self._tf_sess is None
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exception_type, exception_value, traceback):
|
||||
if exception_type:
|
||||
self._coord.request_stop((exception_type, exception_value, traceback))
|
||||
else:
|
||||
self._coord.request_stop()
|
||||
try:
|
||||
self._coord.join(self._coordinated_threads_to_join)
|
||||
# If coord does not raise an exception, we return True to indicate
|
||||
# "no exception to raise".
|
||||
return True
|
||||
finally:
|
||||
self.close()
|
||||
|
||||
def write_graph(self):
|
||||
"""Saves current graph."""
|
||||
if self._checkpoint_dir is not None and self._is_chief:
|
||||
summary_writer = summary_writer_cache.SummaryWriterCache.get(
|
||||
self._checkpoint_dir)
|
||||
training_util.write_graph(self._graph.as_graph_def(add_shapes=True),
|
||||
self._checkpoint_dir, 'graph.pbtxt')
|
||||
summary_writer.add_graph(self._graph)
|
||||
summary_writer.add_session_log(SessionLog(status=SessionLog.START),
|
||||
self._init_step)
|
@ -24,6 +24,8 @@ import time
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import coordinated_session
|
||||
|
||||
|
||||
def BusyWaitForCoordStop(coord):
|
||||
while not coord.should_stop():
|
||||
@ -37,7 +39,7 @@ class CoordinatedSessionTest(tf.test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
tf.constant(0.0)
|
||||
coord = tf.train.Coordinator()
|
||||
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, [])
|
||||
coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
|
||||
self.assertEquals(sess.graph, coord_sess.graph)
|
||||
self.assertEquals(sess.sess_str, coord_sess.sess_str)
|
||||
|
||||
@ -46,13 +48,13 @@ class CoordinatedSessionTest(tf.test.TestCase):
|
||||
c = tf.constant(0)
|
||||
v = tf.identity(c)
|
||||
coord = tf.train.Coordinator()
|
||||
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, [])
|
||||
coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
|
||||
self.assertEqual(42, coord_sess.run(v, feed_dict={c: 42}))
|
||||
|
||||
def test_should_stop_on_close(self):
|
||||
with self.test_session() as sess:
|
||||
coord = tf.train.Coordinator()
|
||||
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, [])
|
||||
coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
|
||||
self.assertFalse(coord_sess.should_stop())
|
||||
coord_sess.close()
|
||||
self.assertTrue(coord_sess.should_stop())
|
||||
@ -60,7 +62,7 @@ class CoordinatedSessionTest(tf.test.TestCase):
|
||||
def test_should_stop_on_coord_stop(self):
|
||||
with self.test_session() as sess:
|
||||
coord = tf.train.Coordinator()
|
||||
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, [])
|
||||
coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
|
||||
self.assertFalse(coord_sess.should_stop())
|
||||
coord.request_stop()
|
||||
self.assertTrue(coord_sess.should_stop())
|
||||
@ -70,7 +72,7 @@ class CoordinatedSessionTest(tf.test.TestCase):
|
||||
c = tf.constant(0)
|
||||
v = tf.identity(c)
|
||||
coord = tf.train.Coordinator()
|
||||
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, [])
|
||||
coord_sess = coordinated_session.CoordinatedSession(sess, coord, [])
|
||||
self.assertFalse(coord_sess.should_stop())
|
||||
self.assertEqual(0, coord_sess.run(c))
|
||||
self.assertEqual(1, coord_sess.run(v, feed_dict={c: 1}))
|
||||
@ -89,7 +91,7 @@ class CoordinatedSessionTest(tf.test.TestCase):
|
||||
for _ in range(3)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
coord_sess = tf.contrib.learn.CoordinatedSession(sess, coord, threads)
|
||||
coord_sess = coordinated_session.CoordinatedSession(sess, coord, threads)
|
||||
self.assertFalse(coord_sess.should_stop())
|
||||
for t in threads:
|
||||
self.assertTrue(t.is_alive())
|
||||
|
@ -47,18 +47,26 @@ class FakeMonitor(monitors.BaseMonitor):
|
||||
self.should_stop = False
|
||||
self.requested_tensors = []
|
||||
self.call_counter = Counter()
|
||||
self.last_begin_step = None
|
||||
self.last_end_step = None
|
||||
self.last_post_step = None
|
||||
|
||||
def step_begin(self, step):
|
||||
self.call_counter['step_begin'] += 1
|
||||
self.begin_step = step
|
||||
self.last_begin_step = step
|
||||
return self.requested_tensors
|
||||
|
||||
def step_end(self, step, output):
|
||||
self.call_counter['step_end'] += 1
|
||||
self.end_step = step
|
||||
self.last_end_step = step
|
||||
self.output = output
|
||||
return self.should_stop
|
||||
|
||||
def post_step(self, step, session):
|
||||
self.call_counter['post_step'] += 1
|
||||
self.last_post_step = step
|
||||
self.session = session
|
||||
|
||||
|
||||
class MonitoredSessionTest(tf.test.TestCase):
|
||||
|
||||
@ -83,7 +91,7 @@ class MonitoredSessionTest(tf.test.TestCase):
|
||||
'run_metadata': 'a_metadata'
|
||||
})
|
||||
|
||||
def testCallsMonitorsBeginAndEnd(self):
|
||||
def testCallsMonitorsBeginEndAndPost(self):
|
||||
with tf.Graph().as_default(), tf.Session() as sess:
|
||||
global_step_tensor = tf.contrib.framework.create_global_step()
|
||||
mock_mon = FakeMonitor()
|
||||
@ -99,10 +107,12 @@ class MonitoredSessionTest(tf.test.TestCase):
|
||||
|
||||
for mon in [mock_mon, mock_mon2]:
|
||||
self.assertEqual(mon.output, {})
|
||||
self.assertEqual(mon.begin_step, 11)
|
||||
self.assertEqual(mon.end_step, 11)
|
||||
self.assertEqual(mon.last_begin_step, 11)
|
||||
self.assertEqual(mon.last_end_step, 11)
|
||||
self.assertEqual(mon.last_post_step, 11)
|
||||
self.assertEqual(mon.call_counter['step_end'], 1)
|
||||
self.assertEqual(mon.call_counter['step_begin'], 1)
|
||||
self.assertEqual(mon.call_counter['post_step'], 1)
|
||||
|
||||
def testCallsMonitorsWithLastStep(self):
|
||||
with tf.Graph().as_default(), tf.Session() as sess:
|
||||
@ -119,18 +129,21 @@ class MonitoredSessionTest(tf.test.TestCase):
|
||||
|
||||
mon_sess.run(fetches=[inc_5])
|
||||
for mon in [mock_mon, mock_mon2]:
|
||||
self.assertEqual(mon.begin_step, 1)
|
||||
self.assertEqual(mon.end_step, 1)
|
||||
self.assertEqual(mon.last_begin_step, 1)
|
||||
self.assertEqual(mon.last_end_step, 1)
|
||||
self.assertEqual(mon.last_post_step, 1)
|
||||
|
||||
mon_sess.run(fetches=[inc_5])
|
||||
for mon in [mock_mon, mock_mon2]:
|
||||
self.assertEqual(mon.begin_step, 6)
|
||||
self.assertEqual(mon.end_step, 6)
|
||||
self.assertEqual(mon.last_begin_step, 6)
|
||||
self.assertEqual(mon.last_end_step, 6)
|
||||
self.assertEqual(mon.last_post_step, 6)
|
||||
|
||||
mon_sess.run(fetches=[inc_5])
|
||||
for mon in [mock_mon, mock_mon2]:
|
||||
self.assertEqual(mon.begin_step, 11)
|
||||
self.assertEqual(mon.end_step, 11)
|
||||
self.assertEqual(mon.last_begin_step, 11)
|
||||
self.assertEqual(mon.last_end_step, 11)
|
||||
self.assertEqual(mon.last_post_step, 11)
|
||||
|
||||
def testShouldStop(self):
|
||||
with tf.Graph().as_default(), tf.Session() as sess:
|
||||
|
@ -34,6 +34,7 @@ class _MyEveryN(learn.monitors.EveryN):
|
||||
every_n_steps=every_n_steps, first_n_steps=first_n_steps)
|
||||
self._steps_begun = []
|
||||
self._steps_ended = []
|
||||
self._post_steps = []
|
||||
|
||||
@property
|
||||
def steps_begun(self):
|
||||
@ -43,6 +44,10 @@ class _MyEveryN(learn.monitors.EveryN):
|
||||
def steps_ended(self):
|
||||
return self._steps_ended
|
||||
|
||||
@property
|
||||
def post_steps(self):
|
||||
return self._post_steps
|
||||
|
||||
def every_n_step_begin(self, step):
|
||||
super(_MyEveryN, self).every_n_step_begin(step)
|
||||
self._steps_begun.append(step)
|
||||
@ -53,6 +58,11 @@ class _MyEveryN(learn.monitors.EveryN):
|
||||
self._steps_ended.append(step)
|
||||
return False
|
||||
|
||||
def every_n_post_step(self, step, session):
|
||||
super(_MyEveryN, self).every_n_post_step(step, session)
|
||||
self._post_steps.append(step)
|
||||
return False
|
||||
|
||||
|
||||
class MonitorsTest(tf.test.TestCase):
|
||||
"""Monitors tests."""
|
||||
@ -71,8 +81,13 @@ class MonitorsTest(tf.test.TestCase):
|
||||
def tearDown(self):
|
||||
logging.info = self._actual_log
|
||||
|
||||
def _run_monitor(self, monitor, num_epochs=3, num_steps_per_epoch=10):
|
||||
monitor.begin(max_steps=(num_epochs * num_steps_per_epoch) - 1)
|
||||
def _run_monitor(self, monitor, num_epochs=3, num_steps_per_epoch=10,
|
||||
pass_max_steps=True):
|
||||
if pass_max_steps:
|
||||
max_steps = num_epochs * num_steps_per_epoch - 1
|
||||
else:
|
||||
max_steps = None
|
||||
monitor.begin(max_steps=max_steps, init_step=0)
|
||||
for epoch in xrange(num_epochs):
|
||||
monitor.epoch_begin(epoch)
|
||||
should_stop = False
|
||||
@ -85,6 +100,7 @@ class MonitorsTest(tf.test.TestCase):
|
||||
[t.name if isinstance(t, tf.Tensor) else t for t in tensors],
|
||||
output))
|
||||
should_stop = monitor.step_end(step=step, output=output)
|
||||
monitor.post_step(step=step, session=None)
|
||||
step += 1
|
||||
monitor.epoch_end(epoch)
|
||||
monitor.end()
|
||||
@ -100,6 +116,18 @@ class MonitorsTest(tf.test.TestCase):
|
||||
expected_steps = [0, 1, 2, 10, 18, 26, 29]
|
||||
self.assertEqual(expected_steps, monitor.steps_begun)
|
||||
self.assertEqual(expected_steps, monitor.steps_ended)
|
||||
self.assertEqual(expected_steps, monitor.post_steps)
|
||||
|
||||
def test_every_n_no_max_steps(self):
|
||||
monitor = _MyEveryN(every_n_steps=8, first_n_steps=2)
|
||||
with tf.Graph().as_default() as g, self.test_session(g):
|
||||
self._run_monitor(monitor, num_epochs=3, num_steps_per_epoch=10,
|
||||
pass_max_steps=False)
|
||||
begin_end_steps = [0, 1, 2, 10, 18, 26]
|
||||
post_steps = [0, 1, 2, 10, 18, 26, 29]
|
||||
self.assertEqual(begin_end_steps, monitor.steps_begun)
|
||||
self.assertEqual(begin_end_steps, monitor.steps_ended)
|
||||
self.assertEqual(post_steps, monitor.post_steps)
|
||||
|
||||
def test_print(self):
|
||||
with tf.Graph().as_default() as g, self.test_session(g):
|
||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import recoverable_session
|
||||
|
||||
|
||||
class AbortAtNSession(object):
|
||||
"""A mock sessionthat aborts at the N-th run call."""
|
||||
@ -42,7 +44,7 @@ class RecoverableSessionTest(tf.test.TestCase):
|
||||
def test_properties(self):
|
||||
with self.test_session() as sess:
|
||||
tf.constant(0.0)
|
||||
recoverable_sess = tf.contrib.learn.RecoverableSession(lambda: sess)
|
||||
recoverable_sess = recoverable_session.RecoverableSession(lambda: sess)
|
||||
self.assertEquals(sess.graph, recoverable_sess.graph)
|
||||
self.assertEquals(sess.sess_str, recoverable_sess.sess_str)
|
||||
|
||||
@ -50,7 +52,7 @@ class RecoverableSessionTest(tf.test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
c = tf.constant(0)
|
||||
v = tf.identity(c)
|
||||
recoverable_sess = tf.contrib.learn.RecoverableSession(lambda: sess)
|
||||
recoverable_sess = recoverable_session.RecoverableSession(lambda: sess)
|
||||
self.assertEqual(51, recoverable_sess.run(v, feed_dict={c: 51}))
|
||||
|
||||
def test_recovery(self):
|
||||
@ -65,7 +67,7 @@ class RecoverableSessionTest(tf.test.TestCase):
|
||||
self.assertEqual(3, len(sessions_to_use))
|
||||
# Make the recoverable session uses these 3 sessions in sequence by
|
||||
# passing a factory that pops from the session_to_use list.
|
||||
recoverable_sess = tf.contrib.learn.RecoverableSession(
|
||||
recoverable_sess = recoverable_session.RecoverableSession(
|
||||
lambda: sessions_to_use.pop(0))
|
||||
self.assertEqual(2, len(sessions_to_use)) # One session popped.
|
||||
# Using first session.
|
||||
|
@ -0,0 +1,68 @@
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for Runner."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import glob
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import summary_writer_cache
|
||||
|
||||
|
||||
class SummaryWriterCacheTest(tf.test.TestCase):
|
||||
"""SummaryWriterCache tests."""
|
||||
|
||||
def _test_dir(self, test_name):
|
||||
test_dir = os.path.join(self.get_temp_dir(), test_name)
|
||||
if os.path.isdir(test_dir):
|
||||
os.removedirs(test_dir)
|
||||
os.makedirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def test_cache(self):
|
||||
with tf.Graph().as_default():
|
||||
dir1 = self._test_dir('test_cache_1')
|
||||
dir2 = self._test_dir('test_cache_2')
|
||||
sw1 = summary_writer_cache.SummaryWriterCache.get(dir1)
|
||||
sw2 = summary_writer_cache.SummaryWriterCache.get(dir2)
|
||||
sw3 = summary_writer_cache.SummaryWriterCache.get(dir1)
|
||||
self.assertEqual(sw1, sw3)
|
||||
self.assertFalse(sw1 == sw2)
|
||||
sw1.close()
|
||||
sw2.close()
|
||||
events1 = glob.glob(os.path.join(dir1, 'event*'))
|
||||
self.assertTrue(events1)
|
||||
events2 = glob.glob(os.path.join(dir2, 'event*'))
|
||||
self.assertTrue(events2)
|
||||
events3 = glob.glob(os.path.join('nowriter', 'event*'))
|
||||
self.assertFalse(events3)
|
||||
|
||||
def test_clear(self):
|
||||
with tf.Graph().as_default():
|
||||
dir1 = self._test_dir('test_clear')
|
||||
sw1 = summary_writer_cache.SummaryWriterCache.get(dir1)
|
||||
summary_writer_cache.SummaryWriterCache.clear()
|
||||
sw2 = summary_writer_cache.SummaryWriterCache.get(dir1)
|
||||
self.assertFalse(sw1 == sw2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -0,0 +1,404 @@
|
||||
# pylint: disable=g-bad-file-header
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for SupervisedSession."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import supervised_session
|
||||
|
||||
|
||||
class ScaffoldTest(tf.test.TestCase):
|
||||
"""Scaffold tests."""
|
||||
|
||||
def test_defaults_empty_graph(self):
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
self.assertTrue(isinstance(scaffold.global_step_tensor, tf.Variable))
|
||||
self.assertTrue(isinstance(scaffold.init_op, tf.Operation))
|
||||
self.assertEqual(None, scaffold.init_feed_dict)
|
||||
self.assertEqual(None, scaffold.init_fn)
|
||||
self.assertTrue(isinstance(scaffold.ready_op, tf.Tensor))
|
||||
self.assertTrue(isinstance(scaffold.local_init_op, tf.Operation))
|
||||
self.assertTrue(isinstance(scaffold.saver, tf.train.Saver))
|
||||
with self.test_session() as sess:
|
||||
self.assertTrue(b'global_step' in sess.run(scaffold.ready_op))
|
||||
sess.run([scaffold.init_op, scaffold.local_init_op])
|
||||
self.assertEquals(0, len(sess.run(scaffold.ready_op)))
|
||||
self.assertEquals(0, sess.run(scaffold.global_step_tensor))
|
||||
|
||||
def test_caches_values(self):
|
||||
with tf.Graph().as_default():
|
||||
scaffold1 = supervised_session.Scaffold()
|
||||
scaffold2 = supervised_session.Scaffold()
|
||||
self.assertEqual(scaffold1.global_step_tensor,
|
||||
scaffold2.global_step_tensor)
|
||||
self.assertEqual(scaffold1.init_op, scaffold2.init_op)
|
||||
self.assertEqual(scaffold1.ready_op, scaffold2.ready_op)
|
||||
self.assertEqual(scaffold1.local_init_op, scaffold2.local_init_op)
|
||||
self.assertEqual(scaffold1.saver, scaffold2.saver)
|
||||
|
||||
def test_uses_passed_values(self):
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold(global_step_tensor=1,
|
||||
init_op=2,
|
||||
init_feed_dict=3,
|
||||
init_fn=lambda scaffold, sess: 4,
|
||||
ready_op=5,
|
||||
local_init_op=6,
|
||||
saver=7)
|
||||
self.assertEqual(1, scaffold.global_step_tensor)
|
||||
self.assertEqual(2, scaffold.init_op)
|
||||
self.assertEqual(3, scaffold.init_feed_dict)
|
||||
self.assertTrue(callable(scaffold.init_fn))
|
||||
self.assertEqual(5, scaffold.ready_op)
|
||||
self.assertEqual(6, scaffold.local_init_op)
|
||||
self.assertEqual(7, scaffold.saver)
|
||||
|
||||
|
||||
class RaiseOnceAtStepN(tf.contrib.learn.monitors.BaseMonitor):
|
||||
"""Monitor that raises an Exception at step N."""
|
||||
|
||||
def __init__(self, n, ex):
|
||||
super(RaiseOnceAtStepN, self).__init__()
|
||||
self.n = n
|
||||
self.ex = ex
|
||||
self.raised = False
|
||||
|
||||
def step_begin(self, step):
|
||||
super(RaiseOnceAtStepN, self).step_begin(step)
|
||||
# Raise the first time we reach step N.
|
||||
if step == self.n and not self.raised:
|
||||
self.raised = True
|
||||
raise self.ex
|
||||
return []
|
||||
|
||||
|
||||
class SupervisedSessionTest(tf.test.TestCase):
|
||||
"""SupervisedSession tests."""
|
||||
|
||||
def _test_dir(self, test_name):
|
||||
"""Create an empty dir to use for tests.
|
||||
|
||||
Args:
|
||||
test_name: Name of the test.
|
||||
|
||||
Returns:
|
||||
Absolute path to the test directory.
|
||||
"""
|
||||
test_dir = os.path.join(self.get_temp_dir(), test_name)
|
||||
if os.path.isdir(test_dir):
|
||||
os.removedirs(test_dir)
|
||||
os.makedirs(test_dir)
|
||||
return test_dir
|
||||
|
||||
def test_defaults(self):
|
||||
with tf.Graph().as_default():
|
||||
with supervised_session.SupervisedSession('') as session:
|
||||
self.assertEqual(0, session.run(session.scaffold.global_step_tensor))
|
||||
|
||||
def test_last_step(self):
|
||||
logdir = self._test_dir('test_last_step')
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
# Run till step 3 and save.
|
||||
monitors = [tf.contrib.learn.monitors.StopAtStep(last_step=3)]
|
||||
with supervised_session.SupervisedSession('', scaffold=scaffold,
|
||||
monitors=monitors) as session:
|
||||
self.assertEqual(0, session.run(gstep))
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(3, session.run(do_step))
|
||||
self.assertTrue(session.should_stop())
|
||||
save_path = scaffold.saver.save(session.session,
|
||||
os.path.join(logdir, 'step-3'))
|
||||
# Run till step 5 and save.
|
||||
def load_ckpt(scaffold, sess):
|
||||
scaffold.saver.restore(sess, save_path)
|
||||
scaffold = supervised_session.Scaffold(init_fn=load_ckpt)
|
||||
monitors = [tf.contrib.learn.monitors.StopAtStep(last_step=5)]
|
||||
with supervised_session.SupervisedSession('', scaffold=scaffold,
|
||||
monitors=monitors) as session:
|
||||
self.assertEqual(3, session.run(gstep))
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(4, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(5, session.run(do_step))
|
||||
self.assertTrue(session.should_stop())
|
||||
|
||||
def test_num_steps(self):
|
||||
logdir = self._test_dir('test_num_steps')
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
# Do 3 steps and save.
|
||||
monitors = [tf.contrib.learn.monitors.StopAtStep(num_steps=3)]
|
||||
with supervised_session.SupervisedSession('', scaffold=scaffold,
|
||||
monitors=monitors) as session:
|
||||
session.run(do_step)
|
||||
self.assertFalse(session.should_stop())
|
||||
session.run(do_step)
|
||||
self.assertFalse(session.should_stop())
|
||||
session.run(do_step)
|
||||
self.assertTrue(session.should_stop())
|
||||
save_path = scaffold.saver.save(session.session,
|
||||
os.path.join(logdir, 'step-3'))
|
||||
# Restore and do 4 steps.
|
||||
def load_ckpt(scaffold, sess):
|
||||
scaffold.saver.restore(sess, save_path)
|
||||
scaffold = supervised_session.Scaffold(init_fn=load_ckpt)
|
||||
monitors = [tf.contrib.learn.monitors.StopAtStep(num_steps=4)]
|
||||
with supervised_session.SupervisedSession('', scaffold=scaffold,
|
||||
monitors=monitors) as session:
|
||||
self.assertEqual(3, session.run(gstep))
|
||||
session.run(do_step)
|
||||
self.assertFalse(session.should_stop())
|
||||
session.run(do_step)
|
||||
self.assertFalse(session.should_stop())
|
||||
session.run(do_step)
|
||||
self.assertFalse(session.should_stop())
|
||||
session.run(do_step)
|
||||
self.assertTrue(session.should_stop())
|
||||
|
||||
# This set of tests, verifies the supervised session behavior when exceptions
|
||||
# are raised next to the innermost session run() call.
|
||||
|
||||
def test_recovery(self):
|
||||
logdir = self._test_dir('test_recovery')
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
# Use a monitor to save the model every 100 steps. It also saves it at
|
||||
# the end.
|
||||
monitors = [tf.contrib.learn.monitors.CheckpointSaver(
|
||||
100, scaffold.saver, logdir)]
|
||||
with supervised_session.SupervisedSession('', scaffold=scaffold,
|
||||
checkpoint_dir=logdir,
|
||||
monitors=monitors) as session:
|
||||
self.assertEqual(0, session.run(gstep))
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
# A restart will find the checkpoint and recover automatically.
|
||||
with supervised_session.SupervisedSession(
|
||||
'', scaffold=scaffold, checkpoint_dir=logdir) as session:
|
||||
self.assertEqual(2, session.run(gstep))
|
||||
|
||||
def test_retry_on_aborted_error(self):
|
||||
# Tests that we silently retry on abort. Note that this does not test
|
||||
# recovery as we do not use a CheckpointSaver in this test.
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
monitor = RaiseOnceAtStepN(3, tf.errors.AbortedError(None, None, 'Abort'))
|
||||
with supervised_session.SupervisedSession('', scaffold=scaffold,
|
||||
monitors=[monitor]) as session:
|
||||
self.assertEqual(0, session.run(gstep))
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
# Here at step 3, the monitor triggers and raises AbortedError. The
|
||||
# SupervisedSession automatically retries and restart from a freshly
|
||||
# initialized session, so the step is back to 0 and running do_step
|
||||
# moves it to 1.
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertTrue(monitor.raised)
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
|
||||
def test_recover_and_retry_on_aborted_error(self):
|
||||
# Tests that we silently retry and recover on abort. This test uses
|
||||
# a CheckpointSaver to have something to recover from.
|
||||
logdir = self._test_dir('test_recover_and_retry_on_aborted_error')
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
abort_monitor = RaiseOnceAtStepN(
|
||||
3, tf.errors.AbortedError(None, None, 'Abort'))
|
||||
# Save after each step.
|
||||
ckpt_monitor = tf.contrib.learn.monitors.CheckpointSaver(
|
||||
1, scaffold.saver, logdir)
|
||||
monitors = [abort_monitor, ckpt_monitor]
|
||||
with supervised_session.SupervisedSession('', scaffold=scaffold,
|
||||
checkpoint_dir=logdir,
|
||||
monitors=monitors) as session:
|
||||
self.assertEqual(0, session.run(gstep))
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
# Here at step 3, the monitor triggers and raises AbortedError. The
|
||||
# SupervisedSession automatically restores and retries.
|
||||
self.assertEqual(3, session.run(do_step))
|
||||
self.assertTrue(abort_monitor.raised)
|
||||
self.assertFalse(session.should_stop())
|
||||
self.assertEqual(4, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
|
||||
def test_stop_cleanly_on_out_of_range_exception(self):
|
||||
# Tests that we stop cleanly when OutOfRange is raised.
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
monitor = RaiseOnceAtStepN(
|
||||
3, tf.errors.OutOfRangeError(None, None, 'EOI'))
|
||||
with supervised_session.SupervisedSession('', scaffold=scaffold,
|
||||
monitors=[monitor]) as session:
|
||||
self.assertEqual(0, session.run(gstep))
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
# Here at step 3, the monitor triggers and raises OutOfRange. The
|
||||
# session should go into should_stop() mode. We do not get a result
|
||||
# in that case.
|
||||
self.assertEqual(None, session.run(do_step))
|
||||
self.assertTrue(monitor.raised)
|
||||
self.assertTrue(session.should_stop())
|
||||
|
||||
def test_stop_cleanly_on_custom_exception(self):
|
||||
# Tests that we stop cleanly when an exception type of
|
||||
# our choice is raised (StopIteration here.)
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
monitor = RaiseOnceAtStepN(3, StopIteration('I choose you'))
|
||||
exception_types = [tf.errors.OutOfRangeError, StopIteration]
|
||||
with supervised_session.SupervisedSession(
|
||||
'', scaffold=scaffold,
|
||||
monitors=[monitor],
|
||||
clean_stop_exception_types=exception_types) as session:
|
||||
self.assertEqual(0, session.run(gstep))
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
# Here at step 3, the monitor triggers and raises StopIteration. The
|
||||
# session should go into should_stop() mode. We do not get a result
|
||||
# in that case.
|
||||
self.assertEqual(None, session.run(do_step))
|
||||
self.assertTrue(monitor.raised)
|
||||
self.assertTrue(session.should_stop())
|
||||
|
||||
# This set of tests, verifies the session behavior when exceptions are raised
|
||||
# from code inside a "with SupervisedSession:" context.
|
||||
|
||||
def test_regular_exception_pass_through_in_with_body(self):
|
||||
# Tests that regular exceptions just pass through a "with
|
||||
# SupervisedSession" block and set the session in stop mode.
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
monitor = RaiseOnceAtStepN(3, RuntimeError('regular exception'))
|
||||
session = supervised_session.SupervisedSession('', scaffold=scaffold,
|
||||
monitors=[monitor])
|
||||
with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
|
||||
with session:
|
||||
self.assertEqual(0, session.run(gstep))
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
# This triggers the monitor and raises the exception
|
||||
session.run(do_step)
|
||||
# We should not hit this
|
||||
self.assertFalse(True)
|
||||
self.assertTrue(monitor.raised)
|
||||
self.assertTrue(session.should_stop())
|
||||
|
||||
def test_stop_cleanly_when_no_exception_in_with_body(self):
|
||||
# Tests that regular exceptions pass through
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
session = supervised_session.SupervisedSession('', scaffold=scaffold)
|
||||
with session:
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
# Should have closed.
|
||||
self.assertTrue(session.should_stop())
|
||||
self.assertTrue(session._is_closed())
|
||||
|
||||
def test_stop_cleanly_on_out_of_range_exception_in_with_body(self):
|
||||
# Tests that regular exceptions pass through
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
session = supervised_session.SupervisedSession('', scaffold=scaffold)
|
||||
with session:
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
raise tf.errors.OutOfRangeError(None, None, 'EOI')
|
||||
# Should have closed.
|
||||
self.assertTrue(session.should_stop())
|
||||
self.assertTrue(session._is_closed())
|
||||
|
||||
def test_raises_regular_exceptions_in_with_body(self):
|
||||
# Tests that regular exceptions in "with body" are seen outside.
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
session = supervised_session.SupervisedSession('', scaffold=scaffold)
|
||||
# We should see that exception.
|
||||
with self.assertRaisesRegexp(RuntimeError, 'regular exception'):
|
||||
with session:
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
# Will be visible outside the "with body".
|
||||
raise RuntimeError('regular exception')
|
||||
# Should have closed.
|
||||
self.assertTrue(session.should_stop())
|
||||
self.assertTrue(session._is_closed())
|
||||
|
||||
def test_stop_cleanly_on_custom_exception_in_with_body(self):
|
||||
with tf.Graph().as_default():
|
||||
scaffold = supervised_session.Scaffold()
|
||||
gstep = scaffold.global_step_tensor
|
||||
do_step = tf.assign_add(gstep, 1)
|
||||
exception_types = [tf.errors.OutOfRangeError, StopIteration]
|
||||
session = supervised_session.SupervisedSession(
|
||||
'', scaffold=scaffold, clean_stop_exception_types=exception_types)
|
||||
with session:
|
||||
self.assertEqual(1, session.run(do_step))
|
||||
self.assertEqual(2, session.run(do_step))
|
||||
self.assertFalse(session.should_stop())
|
||||
raise StopIteration('EOI')
|
||||
# Should have closed.
|
||||
self.assertTrue(session.should_stop())
|
||||
self.assertTrue(session._is_closed())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.test.main()
|
@ -21,8 +21,10 @@ from __future__ import print_function
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import wrapped_session
|
||||
|
||||
class StopAtNSession(tf.contrib.learn.WrappedSession):
|
||||
|
||||
class StopAtNSession(wrapped_session.WrappedSession):
|
||||
"""A wrapped session that stops at the N-th call to _check_stop."""
|
||||
|
||||
def __init__(self, sess, n):
|
||||
@ -42,13 +44,13 @@ class WrappedSessionTest(tf.test.TestCase):
|
||||
def test_properties(self):
|
||||
with self.test_session() as sess:
|
||||
tf.constant(0.0)
|
||||
wrapped_sess = tf.contrib.learn.WrappedSession(sess)
|
||||
wrapped_sess = wrapped_session.WrappedSession(sess)
|
||||
self.assertEquals(sess.graph, wrapped_sess.graph)
|
||||
self.assertEquals(sess.sess_str, wrapped_sess.sess_str)
|
||||
|
||||
def test_should_stop_on_close(self):
|
||||
with self.test_session() as sess:
|
||||
wrapped_sess = tf.contrib.learn.WrappedSession(sess)
|
||||
wrapped_sess = wrapped_session.WrappedSession(sess)
|
||||
self.assertFalse(wrapped_sess.should_stop())
|
||||
wrapped_sess.close()
|
||||
self.assertTrue(wrapped_sess.should_stop())
|
||||
@ -64,7 +66,7 @@ class WrappedSessionTest(tf.test.TestCase):
|
||||
def test_should_stop_delegates_to_wrapped_session(self):
|
||||
with self.test_session() as sess:
|
||||
wrapped_sess0 = StopAtNSession(sess, 4)
|
||||
wrapped_sess1 = tf.contrib.learn.WrappedSession(wrapped_sess0)
|
||||
wrapped_sess1 = wrapped_session.WrappedSession(wrapped_sess0)
|
||||
self.assertFalse(wrapped_sess1.should_stop())
|
||||
self.assertFalse(wrapped_sess1.should_stop())
|
||||
self.assertFalse(wrapped_sess1.should_stop())
|
||||
@ -73,7 +75,7 @@ class WrappedSessionTest(tf.test.TestCase):
|
||||
|
||||
def test_close_twice(self):
|
||||
with self.test_session() as sess:
|
||||
wrapped_sess = tf.contrib.learn.WrappedSession(sess)
|
||||
wrapped_sess = wrapped_session.WrappedSession(sess)
|
||||
wrapped_sess.close()
|
||||
self.assertTrue(wrapped_sess.should_stop())
|
||||
wrapped_sess.close()
|
||||
@ -84,7 +86,7 @@ class WrappedSessionTest(tf.test.TestCase):
|
||||
c = tf.constant(0)
|
||||
v = tf.identity(c)
|
||||
self.assertEqual(42, sess.run(v, feed_dict={c: 42}))
|
||||
wrapped_sess = tf.contrib.learn.WrappedSession(sess)
|
||||
wrapped_sess = wrapped_session.WrappedSession(sess)
|
||||
self.assertEqual(51, wrapped_sess.run(v, feed_dict={c: 51}))
|
||||
|
||||
|
||||
|
@ -139,7 +139,7 @@ class Coordinator(object):
|
||||
"""
|
||||
if clean_stop_exception_types is None:
|
||||
clean_stop_exception_types = (errors.OutOfRangeError,)
|
||||
self._clean_stop_exception_types = clean_stop_exception_types
|
||||
self._clean_stop_exception_types = tuple(clean_stop_exception_types)
|
||||
# Protects all attributes.
|
||||
self._lock = threading.Lock()
|
||||
# Event set when threads must stop.
|
||||
|
Loading…
x
Reference in New Issue
Block a user