tfdbg-tflearn integration: add dedicated doc + a kwarg renaming
Renaming the newly added "evaluate_hooks" kwarg to Experiment.__init__() to "eval_hooks", to be 1) more succinct and 2) consistent with other eval-related kwargs in Experiment.__init__(). Change: 143841895
This commit is contained in:
parent
e7ce8471c3
commit
48ca775da6
@ -68,7 +68,7 @@ class Experiment(object):
|
||||
train_steps=None,
|
||||
eval_steps=100,
|
||||
train_monitors=None,
|
||||
evaluate_hooks=None,
|
||||
eval_hooks=None,
|
||||
local_eval_frequency=None,
|
||||
eval_delay_secs=120,
|
||||
continuous_eval_throttle_secs=60,
|
||||
@ -95,7 +95,7 @@ class Experiment(object):
|
||||
is raised), or for `eval_steps` steps, if specified.
|
||||
train_monitors: A list of monitors to pass to the `Estimator`'s `fit`
|
||||
function.
|
||||
evaluate_hooks: A list of `SessionRunHook` hooks to pass to the
|
||||
eval_hooks: A list of `SessionRunHook` hooks to pass to the
|
||||
`Estimator`'s `evaluate` function.
|
||||
local_eval_frequency: Frequency of running eval in steps,
|
||||
when running locally. If `None`, runs evaluation only at the end of
|
||||
@ -134,7 +134,7 @@ class Experiment(object):
|
||||
self._delay_workers_by_global_step = delay_workers_by_global_step
|
||||
# Mutable fields, using the setters.
|
||||
self.train_monitors = train_monitors
|
||||
self.evaluate_hooks = evaluate_hooks
|
||||
self.eval_hooks = eval_hooks
|
||||
self.export_strategies = export_strategies
|
||||
|
||||
@property
|
||||
@ -170,12 +170,12 @@ class Experiment(object):
|
||||
self._train_monitors = value or []
|
||||
|
||||
@property
|
||||
def evaluate_hooks(self):
|
||||
return self._evaluate_hooks
|
||||
def eval_hooks(self):
|
||||
return self._eval_hooks
|
||||
|
||||
@evaluate_hooks.setter
|
||||
def evaluate_hooks(self, value):
|
||||
self._evaluate_hooks = value or []
|
||||
@eval_hooks.setter
|
||||
def eval_hooks(self, value):
|
||||
self._eval_hooks = value or []
|
||||
|
||||
@property
|
||||
def local_eval_frequency(self):
|
||||
@ -288,7 +288,7 @@ class Experiment(object):
|
||||
steps=self._eval_steps,
|
||||
metrics=self._eval_metrics,
|
||||
name="one_pass",
|
||||
hooks=self._evaluate_hooks)
|
||||
hooks=self._eval_hooks)
|
||||
|
||||
@deprecated(
|
||||
"2016-10-23",
|
||||
|
@ -108,6 +108,8 @@ running the command `lt` after you executed `run`.) This is called the
|
||||
|
||||

|
||||
|
||||
### tfdbg CLI Frequently-Used Commands
|
||||
|
||||
Try the following commands at the `tfdbg>` prompt (referencing the code at
|
||||
`tensorflow/python/debug/examples/debug_mnist.py`):
|
||||
|
||||
@ -262,38 +264,12 @@ stuck. Success!
|
||||
|
||||
## Debugging tf-learn Estimators
|
||||
|
||||
In the tutorial above, we described how to use `tfdbg` if you are managing your
|
||||
own [`tf.Session`](https://tensorflow.org/api_docs/python/client.html#Session)
|
||||
objects. However, many users find
|
||||
[`tf.contrib.learn`](https://tensorflow.org/tutorials/tflearn/index.html)
|
||||
`Estimator`s to be a convenient higher level API for creating and using models
|
||||
in TensorFlow. Part of the convenience is that `Estimator`s manage Sessions
|
||||
internally. Fortunately, you can still use `tfdbg` with `Estimator`s by adding a
|
||||
special hook.
|
||||
For documentation on **tfdbg** to debug
|
||||
[tf.contrib.learn](https://tensorflow.org/tutorials/tflearn/index.html)
|
||||
`Estimator`s and `Experiment`s, please see
|
||||
[How to Use TensorFlow Debugger (tfdbg) with tf.contrib.learn](tfdbg-tflearn.md).
|
||||
|
||||
Currently, `tfdbg` can only debug the `fit()` method of tf-learn
|
||||
`Estimator`s. Support for debugging `evaluate()` will come soon. To debug
|
||||
`Estimator.fit()`, create a monitor and supply it as an argument. For example:
|
||||
|
||||
```python
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
# Create a local CLI debug hook and use it as a monitor when calling fit().
|
||||
classifier.fit(x=training_set.data,
|
||||
y=training_set.target,
|
||||
steps=1000,
|
||||
monitors=[tf_debug.LocalCLIDebugHook()])
|
||||
```
|
||||
|
||||
For a detailed [example](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_tflearn_iris.py) based on
|
||||
[tf-learn's iris tutorial](../../../g3doc/tutorials/tflearn/index.md),
|
||||
run:
|
||||
|
||||
```none
|
||||
python $(python -c "import tensorflow as tf; import os; print(os.path.dirname(tf.__file__));")/python/debug/examples/debug_tflearn_iris.py --debug
|
||||
```
|
||||
|
||||
## Offline Debugging of Remotely-running Sessions
|
||||
## Offline Debugging of Remotely-Running Sessions
|
||||
|
||||
Oftentimes, your model is running in a remote machine or process that you don't
|
||||
have terminal access to. To perform model debugging in such cases, you can use
|
||||
|
88
tensorflow/g3doc/how_tos/debugger/tfdbg-tflearn.md
Normal file
88
tensorflow/g3doc/how_tos/debugger/tfdbg-tflearn.md
Normal file
@ -0,0 +1,88 @@
|
||||
# How to Use TensorFlow Debugger (tfdbg) with tf.contrib.learn
|
||||
|
||||
[TOC]
|
||||
|
||||
In [a previous tutorial](index.md), we described how to use TensorFlow Debugger (**tfdbg**)
|
||||
to debug TensorFlow graphs running in
|
||||
[`tf.Session`](https://tensorflow.org/api_docs/python/client.html#Session)
|
||||
objects managed by yourself. However, many users find
|
||||
[`tf.contrib.learn`](https://tensorflow.org/tutorials/tflearn/index.html)
|
||||
[Estimator](https://tensorflow.org/api_docs/python/contrib.learn.html?cl=head#Estimator)s
|
||||
to be a convenient higher-level API for creating and using models
|
||||
in TensorFlow. Part of the convenience is that `Estimator`s manage `Session`s
|
||||
internally. Fortunately, you can still use `tfdbg` with `Estimator`s by adding
|
||||
special hooks.
|
||||
|
||||
## Debugging tf.contrib.learn Estimators
|
||||
|
||||
Currently, **tfdbg** can debug the
|
||||
[fit()](https://tensorflow.org/api_docs/python/contrib.learn.html#BaseEstimator.fit)
|
||||
and
|
||||
[evaluate()](https://tensorflow.org/api_docs/python/contrib.learn.html#BaseEstimator.evaluate)
|
||||
methods of tf-learn `Estimator`s. To debug `Estimator.fit()`,
|
||||
create a `LocalCLIDebugHook` and supply it as the `monitors` argument. For example:
|
||||
|
||||
```python
|
||||
# First, let your BUILD target depend on "//tensorflow/python/debug:debug_py"
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
hooks = [tf_debug.LocalCLIDebugHook()]
|
||||
|
||||
# Create a local CLI debug hook and use it as a monitor when calling fit().
|
||||
classifier.fit(x=training_set.data,
|
||||
y=training_set.target,
|
||||
steps=1000,
|
||||
monitors=hooks)
|
||||
```
|
||||
|
||||
To debug `Estimator.evaluate()`, you can follow the example below:
|
||||
|
||||
```python
|
||||
accuracy_score = classifier.evaluate(x=test_set.data,
|
||||
y=test_set.target,
|
||||
hooks=hooks)["accuracy"]
|
||||
```
|
||||
|
||||
|
||||
For a detailed [example](https://www.tensorflow.org/code/tensorflow/python/debug/examples/debug_tflearn_iris.py) based on
|
||||
[tf-learn's iris tutorial](../../../g3doc/tutorials/tflearn/index.md),
|
||||
run:
|
||||
|
||||
```none
|
||||
python $(python -c "import tensorflow as tf; import os; print(os.path.dirname(tf.__file__));")/python/debug/examples/debug_tflearn_iris.py --debug
|
||||
```
|
||||
|
||||
## Debugging tf.contrib.learn Experiments
|
||||
|
||||
`Experiment` is a construct in `tf.contrib.learn` at a higher level than
|
||||
`Estimator`.
|
||||
It provides a single interface for training and evaluating a model. To debug
|
||||
the `train()` and `evaluate()` calls to an `Experiment` object, you can
|
||||
use the keyword arguments `train_monitors` and `eval_hooks`, respectively, when
|
||||
calling its constructor. For example:
|
||||
|
||||
```python
|
||||
# First, let your BUILD target depend on "//tensorflow/python/debug:debug_py"
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
hooks = [tf_debug.LocalCLIDebugHook()]
|
||||
|
||||
ex = experiment.Experiment(classifier,
|
||||
train_input_fn=iris_input_fn,
|
||||
eval_input_fn=iris_input_fn,
|
||||
train_steps=FLAGS.train_steps,
|
||||
eval_delay_secs=0,
|
||||
eval_steps=1,
|
||||
train_monitors=hooks,
|
||||
eval_hooks=hooks)
|
||||
|
||||
ex.train()
|
||||
accuracy_score = ex.evaluate()["accuracy"]
|
||||
```
|
||||
|
||||
To see the `debug_tflearn_iris` example run in the `Experiment` mode, do:
|
||||
|
||||
```none
|
||||
python $(python -c "import tensorflow as tf; import os; print(os.path.dirname(tf.__file__));")/python/debug/examples/debug_tflearn_iris.py \
|
||||
--use_experiment --debug
|
||||
```
|
@ -26,6 +26,8 @@ import numpy as np
|
||||
from six.moves import urllib
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.contrib.learn.python.learn import experiment
|
||||
from tensorflow.contrib.learn.python.learn.datasets import base
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
|
||||
@ -67,6 +69,16 @@ def maybe_download_data(data_dir):
|
||||
return training_data_path, test_data_path
|
||||
|
||||
|
||||
_IRIS_INPUT_DIM = 4
|
||||
|
||||
|
||||
def iris_input_fn():
|
||||
iris = base.load_iris()
|
||||
features = tf.reshape(tf.constant(iris.data), [-1, _IRIS_INPUT_DIM])
|
||||
labels = tf.reshape(tf.constant(iris.target), [-1])
|
||||
return features, labels
|
||||
|
||||
|
||||
def main(_):
|
||||
training_data_path, test_data_path = maybe_download_data(FLAGS.data_dir)
|
||||
|
||||
@ -93,16 +105,28 @@ def main(_):
|
||||
hooks = ([tf_debug.LocalCLIDebugHook(ui_type=FLAGS.ui_type)] if FLAGS.debug
|
||||
else None)
|
||||
|
||||
# Fit model.
|
||||
classifier.fit(x=training_set.data,
|
||||
y=training_set.target,
|
||||
steps=FLAGS.train_steps,
|
||||
monitors=hooks)
|
||||
if not FLAGS.use_experiment:
|
||||
# Fit model.
|
||||
classifier.fit(x=training_set.data,
|
||||
y=training_set.target,
|
||||
steps=FLAGS.train_steps,
|
||||
monitors=hooks)
|
||||
|
||||
# Evaluate accuracy.
|
||||
accuracy_score = classifier.evaluate(x=test_set.data,
|
||||
y=test_set.target,
|
||||
hooks=hooks)["accuracy"]
|
||||
# Evaluate accuracy.
|
||||
accuracy_score = classifier.evaluate(x=test_set.data,
|
||||
y=test_set.target,
|
||||
hooks=hooks)["accuracy"]
|
||||
else:
|
||||
ex = experiment.Experiment(classifier,
|
||||
train_input_fn=iris_input_fn,
|
||||
eval_input_fn=iris_input_fn,
|
||||
train_steps=FLAGS.train_steps,
|
||||
eval_delay_secs=0,
|
||||
eval_steps=1,
|
||||
train_monitors=hooks,
|
||||
eval_hooks=hooks)
|
||||
ex.train()
|
||||
accuracy_score = ex.evaluate()["accuracy"]
|
||||
|
||||
print("After training %d steps, Accuracy = %f" %
|
||||
(FLAGS.train_steps, accuracy_score))
|
||||
@ -126,6 +150,13 @@ if __name__ == "__main__":
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of steps to run trainer.")
|
||||
parser.add_argument(
|
||||
"--use_experiment",
|
||||
type="bool",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="Use tf.contrib.learn Experiment to run training and evaluation")
|
||||
parser.add_argument(
|
||||
"--ui_type",
|
||||
type=str,
|
||||
|
Loading…
Reference in New Issue
Block a user