tfdbg-tflearn integration: add dedicated doc + a kwarg renaming

Renaming the newly added "evaluate_hooks" kwarg to Experiment.__init__() to "eval_hooks", to be 1) more succinct and 2) consistent with other eval-related kwargs in Experiment.__init__().
Change: 143841895
This commit is contained in:
Shanqing Cai 2017-01-06 20:17:51 -08:00 committed by TensorFlower Gardener
parent e7ce8471c3
commit 48ca775da6
4 changed files with 144 additions and 49 deletions

View File

@ -68,7 +68,7 @@ class Experiment(object):
train_steps=None,
eval_steps=100,
train_monitors=None,
evaluate_hooks=None,
eval_hooks=None,
local_eval_frequency=None,
eval_delay_secs=120,
continuous_eval_throttle_secs=60,
@ -95,7 +95,7 @@ class Experiment(object):
is raised), or for `eval_steps` steps, if specified.
train_monitors: A list of monitors to pass to the `Estimator`'s `fit`
function.
evaluate_hooks: A list of `SessionRunHook` hooks to pass to the
eval_hooks: A list of `SessionRunHook` hooks to pass to the
`Estimator`'s `evaluate` function.
local_eval_frequency: Frequency of running eval in steps,
when running locally. If `None`, runs evaluation only at the end of
@ -134,7 +134,7 @@ class Experiment(object):
self._delay_workers_by_global_step = delay_workers_by_global_step
# Mutable fields, using the setters.
self.train_monitors = train_monitors
self.evaluate_hooks = evaluate_hooks
self.eval_hooks = eval_hooks
self.export_strategies = export_strategies
@property
@ -170,12 +170,12 @@ class Experiment(object):
self._train_monitors = value or []
@property
def evaluate_hooks(self):
return self._evaluate_hooks
def eval_hooks(self):
return self._eval_hooks
@evaluate_hooks.setter
def evaluate_hooks(self, value):
self._evaluate_hooks = value or []
@eval_hooks.setter
def eval_hooks(self, value):
self._eval_hooks = value or []
@property
def local_eval_frequency(self):
@ -288,7 +288,7 @@ class Experiment(object):
steps=self._eval_steps,
metrics=self._eval_metrics,
name="one_pass",
hooks=self._evaluate_hooks)
hooks=self._eval_hooks)
@deprecated(
"2016-10-23",

View File

@ -108,6 +108,8 @@ running the command `lt` after you executed `run`.) This is called the
![tfdbg run-end UI: accuracy](tfdbg_screenshot_run_end_accuracy.png)
### tfdbg CLI Frequently-Used Commands
Try the following commands at the `tfdbg>` prompt (referencing the code at
`tensorflow/python/debug/examples/debug_mnist.py`):
@ -262,38 +264,12 @@ stuck. Success!
## Debugging tf-learn Estimators
In the tutorial above, we described how to use `tfdbg` if you are managing your
own [`tf.Session`](https://tensorflow.org/api_docs/python/client.html#Session)
objects. However, many users find
[`tf.contrib.learn`](https://tensorflow.org/tutorials/tflearn/index.html)
`Estimator`s to be a convenient higher level API for creating and using models
in TensorFlow. Part of the convenience is that `Estimator`s manage Sessions
internally. Fortunately, you can still use `tfdbg` with `Estimator`s by adding a
special hook.
For documentation on **tfdbg** to debug
[tf.contrib.learn](https://tensorflow.org/tutorials/tflearn/index.html)
`Estimator`s and `Experiment`s, please see
[How to Use TensorFlow Debugger (tfdbg) with tf.contrib.learn](tfdbg-tflearn.md).
Currently, `tfdbg` can only debug the `fit()` method of tf-learn
`Estimator`s. Support for debugging `evaluate()` will come soon. To debug
`Estimator.fit()`, create a monitor and supply it as an argument. For example:
```python
from tensorflow.python import debug as tf_debug
# Create a local CLI debug hook and use it as a monitor when calling fit().
classifier.fit(x=training_set.data,
y=training_set.target,
steps=1000,
monitors=[tf_debug.LocalCLIDebugHook()])
```
For a detailed [example](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_tflearn_iris.py) based on
[tf-learn's iris tutorial](../../../g3doc/tutorials/tflearn/index.md),
run:
```none
python $(python -c "import tensorflow as tf; import os; print(os.path.dirname(tf.__file__));")/python/debug/examples/debug_tflearn_iris.py --debug
```
## Offline Debugging of Remotely-running Sessions
## Offline Debugging of Remotely-Running Sessions
Oftentimes, your model is running in a remote machine or process that you don't
have terminal access to. To perform model debugging in such cases, you can use

View File

@ -0,0 +1,88 @@
# How to Use TensorFlow Debugger (tfdbg) with tf.contrib.learn
[TOC]
In [a previous tutorial](index.md), we described how to use TensorFlow Debugger (**tfdbg**)
to debug TensorFlow graphs running in
[`tf.Session`](https://tensorflow.org/api_docs/python/client.html#Session)
objects managed by yourself. However, many users find
[`tf.contrib.learn`](https://tensorflow.org/tutorials/tflearn/index.html)
[Estimator](https://tensorflow.org/api_docs/python/contrib.learn.html?cl=head#Estimator)s
to be a convenient higher-level API for creating and using models
in TensorFlow. Part of the convenience is that `Estimator`s manage `Session`s
internally. Fortunately, you can still use `tfdbg` with `Estimator`s by adding
special hooks.
## Debugging tf.contrib.learn Estimators
Currently, **tfdbg** can debug the
[fit()](https://tensorflow.org/api_docs/python/contrib.learn.html#BaseEstimator.fit)
and
[evaluate()](https://tensorflow.org/api_docs/python/contrib.learn.html#BaseEstimator.evaluate)
methods of tf-learn `Estimator`s. To debug `Estimator.fit()`,
create a `LocalCLIDebugHook` and supply it as the `monitors` argument. For example:
```python
# First, let your BUILD target depend on "//tensorflow/python/debug:debug_py"
from tensorflow.python import debug as tf_debug
hooks = [tf_debug.LocalCLIDebugHook()]
# Create a local CLI debug hook and use it as a monitor when calling fit().
classifier.fit(x=training_set.data,
y=training_set.target,
steps=1000,
monitors=hooks)
```
To debug `Estimator.evaluate()`, you can follow the example below:
```python
accuracy_score = classifier.evaluate(x=test_set.data,
y=test_set.target,
hooks=hooks)["accuracy"]
```
For a detailed [example](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_tflearn_iris.py) based on
[tf-learn's iris tutorial](../../../g3doc/tutorials/tflearn/index.md),
run:
```none
python $(python -c "import tensorflow as tf; import os; print(os.path.dirname(tf.__file__));")/python/debug/examples/debug_tflearn_iris.py --debug
```
## Debugging tf.contrib.learn Experiments
`Experiment` is a construct in `tf.contrib.learn` at a higher level than
`Estimator`.
It provides a single interface for training and evaluating a model. To debug
the `train()` and `evaluate()` calls to an `Experiment` object, you can
use the keyword arguments `train_monitors` and `eval_hooks`, respectively, when
calling its constructor. For example:
```python
# First, let your BUILD target depend on "//tensorflow/python/debug:debug_py"
from tensorflow.python import debug as tf_debug
hooks = [tf_debug.LocalCLIDebugHook()]
ex = experiment.Experiment(classifier,
train_input_fn=iris_input_fn,
eval_input_fn=iris_input_fn,
train_steps=FLAGS.train_steps,
eval_delay_secs=0,
eval_steps=1,
train_monitors=hooks,
eval_hooks=hooks)
ex.train()
accuracy_score = ex.evaluate()["accuracy"]
```
To see the `debug_tflearn_iris` example run in the `Experiment` mode, do:
```none
python $(python -c "import tensorflow as tf; import os; print(os.path.dirname(tf.__file__));")/python/debug/examples/debug_tflearn_iris.py \
--use_experiment --debug
```

View File

@ -26,6 +26,8 @@ import numpy as np
from six.moves import urllib
import tensorflow as tf
from tensorflow.contrib.learn.python.learn import experiment
from tensorflow.contrib.learn.python.learn.datasets import base
from tensorflow.python import debug as tf_debug
@ -67,6 +69,16 @@ def maybe_download_data(data_dir):
return training_data_path, test_data_path
_IRIS_INPUT_DIM = 4
def iris_input_fn():
iris = base.load_iris()
features = tf.reshape(tf.constant(iris.data), [-1, _IRIS_INPUT_DIM])
labels = tf.reshape(tf.constant(iris.target), [-1])
return features, labels
def main(_):
training_data_path, test_data_path = maybe_download_data(FLAGS.data_dir)
@ -93,16 +105,28 @@ def main(_):
hooks = ([tf_debug.LocalCLIDebugHook(ui_type=FLAGS.ui_type)] if FLAGS.debug
else None)
# Fit model.
classifier.fit(x=training_set.data,
y=training_set.target,
steps=FLAGS.train_steps,
monitors=hooks)
if not FLAGS.use_experiment:
# Fit model.
classifier.fit(x=training_set.data,
y=training_set.target,
steps=FLAGS.train_steps,
monitors=hooks)
# Evaluate accuracy.
accuracy_score = classifier.evaluate(x=test_set.data,
y=test_set.target,
hooks=hooks)["accuracy"]
# Evaluate accuracy.
accuracy_score = classifier.evaluate(x=test_set.data,
y=test_set.target,
hooks=hooks)["accuracy"]
else:
ex = experiment.Experiment(classifier,
train_input_fn=iris_input_fn,
eval_input_fn=iris_input_fn,
train_steps=FLAGS.train_steps,
eval_delay_secs=0,
eval_steps=1,
train_monitors=hooks,
eval_hooks=hooks)
ex.train()
accuracy_score = ex.evaluate()["accuracy"]
print("After training %d steps, Accuracy = %f" %
(FLAGS.train_steps, accuracy_score))
@ -126,6 +150,13 @@ if __name__ == "__main__":
type=int,
default=10,
help="Number of steps to run trainer.")
parser.add_argument(
"--use_experiment",
type="bool",
nargs="?",
const=True,
default=False,
help="Use tf.contrib.learn Experiment to run training and evaluation")
parser.add_argument(
"--ui_type",
type=str,