Redo the TensorBoard tutorial code (mnist_with_summaries.py).

Goals:
- Have enough of each summary type that tag grouping is useful.
(Wound up recording e.g. mean and stddev and min/max for each variable)
- Use every summary type (adds images)
- Write to multiple directories so there are several "runs"
Change: 119585022
This commit is contained in:
Dan Mané 2016-04-11 15:40:04 -08:00 committed by TensorFlower Gardener
parent 3c59c1ed08
commit a77499c87d
2 changed files with 166 additions and 105 deletions
tensorflow
examples/tutorials/mnist
g3doc/how_tos/summaries_and_tensorboard

View File

@ -1,29 +1,25 @@
# Copyright 2015 Google Inc. All Rights Reserved. # Copyright 2015 Google Inc. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A very simple MNIST classifier, modified to display data in TensorBoard. """A simple MNIST classifier which displays summaries in TensorBoard.
See extensive documentation for the original model at This is an unimpressive MNIST model, but it is a good example of using
http://tensorflow.org/tutorials/mnist/beginners/index.md tf.name_scope to make a graph legible in the TensorBoard graph explorer, and of
naming summary tags so that they are grouped meaningfully in TensorBoard.
See documentation on the TensorBoard specific pieces at
http://tensorflow.org/how_tos/summaries_and_tensorboard/index.md
If you modify this file, please update the excerpt in
how_tos/summaries_and_tensorboard/index.md.
It demonstrates the functionality of every TensorBoard dashboard.
""" """
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -39,72 +35,132 @@ FLAGS = flags.FLAGS
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data ' flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
'for unit testing.') 'for unit testing.')
flags.DEFINE_integer('max_steps', 1000, 'Number of steps to run trainer.') flags.DEFINE_integer('max_steps', 1000, 'Number of steps to run trainer.')
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.') flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')
flags.DEFINE_float('dropout', 0.9, 'Keep probability for training dropout.')
flags.DEFINE_string('data_dir', '/tmp/data', 'Directory for storing data') flags.DEFINE_string('data_dir', '/tmp/data', 'Directory for storing data')
flags.DEFINE_string('summaries_dir', '/tmp/mnist_logs', 'Summaries directory') flags.DEFINE_string('summaries_dir', '/tmp/mnist_logs', 'Summaries directory')
def main(_): def train():
# Import data # Import data
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True, mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True,
fake_data=FLAGS.fake_data) fake_data=FLAGS.fake_data)
sess = tf.InteractiveSession() sess = tf.InteractiveSession()
# Create the model # Create a multilayer model.
x = tf.placeholder(tf.float32, [None, 784], name='x-input')
W = tf.Variable(tf.zeros([784, 10]), name='weights')
b = tf.Variable(tf.zeros([10]), name='bias')
# Use a name scope to organize nodes in the graph visualizer # Input placehoolders
with tf.name_scope('Wx_b'): with tf.name_scope('input'):
y = tf.nn.softmax(tf.matmul(x, W) + b) x = tf.placeholder(tf.float32, [None, 784], name='x-input')
image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
tf.image_summary('input', image_shaped_input, 10)
y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
keep_prob = tf.placeholder(tf.float32)
tf.scalar_summary('dropout_keep_probability', keep_prob)
# Add summary ops to collect data # We can't initialize these variables to 0 - the network will get stuck.
tf.histogram_summary('weights', W) def weight_variable(shape):
tf.histogram_summary('biases', b) """Create a weight variable with appropriate initialization."""
tf.histogram_summary('y', y) initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
# Define loss and optimizer def bias_variable(shape):
y_ = tf.placeholder(tf.float32, [None, 10], name='y-input') """Create a bias variable with appropriate initialization."""
# More name scopes will clean up the graph representation initial = tf.constant(0.1, shape=shape)
with tf.name_scope('xent'): return tf.Variable(initial)
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
def variable_summaries(var, name):
"""Attach a lot of summaries to a Tensor."""
with tf.name_scope('summaries'):
mean = tf.reduce_mean(var)
tf.scalar_summary('mean/' + name, mean)
with tf.name_scope('stddev'):
stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
tf.scalar_summary('sttdev/' + name, stddev)
tf.scalar_summary('max/' + name, tf.reduce_max(var))
tf.scalar_summary('min/' + name, tf.reduce_min(var))
tf.histogram_summary(name, var)
def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
"""Reusable code for making a simple neural net layer.
It does a matrix multiply, bias add, and then uses relu to nonlinearize.
It also sets up name scoping so that the resultant graph is easy to read, and
adds a number of summary ops.
"""
# Adding a name scope ensures logical grouping of the layers in the graph.
with tf.name_scope(layer_name):
# This Variable will hold the state of the weights for the layer
with tf.name_scope('weights'):
weights = weight_variable([input_dim, output_dim])
variable_summaries(weights, layer_name + '/weights')
with tf.name_scope('biases'):
biases = bias_variable([output_dim])
variable_summaries(biases, layer_name + '/biases')
with tf.name_scope('Wx_plus_b'):
preactivate = tf.matmul(input_tensor, weights) + biases
tf.histogram_summary(layer_name + '/pre_activations', preactivate)
activations = act(preactivate, 'activation')
tf.histogram_summary(layer_name + '/activations', activations)
return activations
hidden1 = nn_layer(x, 784, 500, 'layer1')
dropped = tf.nn.dropout(hidden1, keep_prob)
y = nn_layer(dropped, 500, 10, 'layer2', act=tf.nn.softmax)
with tf.name_scope('cross_entropy'):
diff = y_ * tf.log(y)
with tf.name_scope('total'):
cross_entropy = -tf.reduce_mean(diff)
tf.scalar_summary('cross entropy', cross_entropy) tf.scalar_summary('cross entropy', cross_entropy)
with tf.name_scope('train'): with tf.name_scope('train'):
train_step = tf.train.GradientDescentOptimizer( train_step = tf.train.AdamOptimizer(
FLAGS.learning_rate).minimize(cross_entropy) FLAGS.learning_rate).minimize(cross_entropy)
with tf.name_scope('test'): with tf.name_scope('accuracy'):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) with tf.name_scope('correct_prediction'):
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
with tf.name_scope('accuracy'):
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.scalar_summary('accuracy', accuracy) tf.scalar_summary('accuracy', accuracy)
# Merge all the summaries and write them out to /tmp/mnist_logs (by default) # Merge all the summaries and write them out to /tmp/mnist_logs (by default)
merged = tf.merge_all_summaries() merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter(FLAGS.summaries_dir, sess.graph) train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', sess.graph)
test_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/test')
tf.initialize_all_variables().run() tf.initialize_all_variables().run()
# Train the model, and feed in test data and record summaries every 10 steps # Train the model, and also write summaries.
# Every 10th step, measure test-set accuracy, and write test summaries
# All other steps, run train_step on training data, & add training summaries
def feed_dict(train):
"""Make a TensorFlow feed_dict: maps data onto Tensor placeholders."""
if train or FLAGS.fake_data:
xs, ys = mnist.train.next_batch(100, fake_data=FLAGS.fake_data)
k = FLAGS.dropout
else:
xs, ys = mnist.test.images, mnist.test.labels
k = 1.0
return {x: xs, y_: ys, keep_prob: k}
for i in range(FLAGS.max_steps): for i in range(FLAGS.max_steps):
if i % 10 == 0: # Record summary data and the accuracy if i % 10 == 0: # Record summaries and test-set accuracy
if FLAGS.fake_data: summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))
batch_xs, batch_ys = mnist.train.next_batch( test_writer.add_summary(summary, i)
100, fake_data=FLAGS.fake_data)
feed = {x: batch_xs, y_: batch_ys}
else:
feed = {x: mnist.test.images, y_: mnist.test.labels}
result = sess.run([merged, accuracy], feed_dict=feed)
summary_str = result[0]
acc = result[1]
writer.add_summary(summary_str, i)
print('Accuracy at step %s: %s' % (i, acc)) print('Accuracy at step %s: %s' % (i, acc))
else: else: # Record train set summarieis, and train
batch_xs, batch_ys = mnist.train.next_batch( summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
100, fake_data=FLAGS.fake_data) train_writer.add_summary(summary, i)
feed = {x: batch_xs, y_: batch_ys}
sess.run(train_step, feed_dict=feed) def main(_):
if tf.gfile.Exists(FLAGS.summaries_dir):
tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
tf.gfile.MakeDirs(FLAGS.summaries_dir)
train()
if __name__ == '__main__': if __name__ == '__main__':
tf.app.run() tf.app.run()

View File

@ -8,7 +8,8 @@ your TensorFlow graph, plot quantitative metrics about the execution of your
graph, and show additional data like images that pass through it. When graph, and show additional data like images that pass through it. When
TensorBoard is fully configured, it looks like this: TensorBoard is fully configured, it looks like this:
![MNIST TensorBoard](../../images/mnist_tensorboard.png "MNIST TensorBoard") [![MNIST TensorBoard](../../images/mnist_tensorboard.png "MNIST TensorBoard")](http://tensorflow.org/tensorboard)
[*Click try a TensorBoard with data from this tutorial!*](http://tensorflow.org/tensorboard)
## Serializing the data ## Serializing the data
@ -75,56 +76,70 @@ statistics, such as how the weights or accuracy varied during training.
The code below is an excerpt; full source is [here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py). The code below is an excerpt; full source is [here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py).
```python ```python
# Create the model def variable_summaries(var, name):
x = tf.placeholder(tf.float32, [None, 784], name="x-input") with tf.name_scope("summaries"):
W = tf.Variable(tf.zeros([784,10]), name="weights") mean = tf.reduce_mean(var)
b = tf.Variable(tf.zeros([10], name="bias")) tf.scalar_summary('mean/' + name, mean)
with tf.name_scope('stddev'):
stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
tf.scalar_summary('sttdev/' + name, stddev)
tf.scalar_summary('max/' + name, tf.reduce_max(var))
tf.scalar_summary('min/' + name, tf.reduce_min(var))
tf.histogram_summary(name, var)
# use a name scope to organize nodes in the graph visualizer def nn_layer(input_tensor, input_dim, output_dim, layer_name):
with tf.name_scope("Wx_b") as scope: """Reusable code for making a simple neural net layer.
y = tf.nn.softmax(tf.matmul(x,W) + b)
# Add summary ops to collect data It does a matrix multiply, bias add, and then uses relu to nonlinearize.
tf.histogram_summary("weights", W) It also sets up name scoping so that the resultant graph is easy to read, and
tf.histogram_summary("biases", b) adds a number of summary ops.
tf.histogram_summary("y", y) """
# Adding a name scope ensures logical grouping of the layers in the graph.
with tf.name_scope(layer_name):
# This Variable will hold the state of the weights for the layer
with tf.name_scope("weights"):
weights = weight_variable([input_dim, output_dim])
variable_summaries(weights, layer_name + '/weights')
with tf.name_scope("biases"):
biases = bias_variable([output_dim])
variable_summaries(biases, layer_name + '/biases')
with tf.name_scope('Wx_plus_b'):
activations = tf.matmul(input_tensor, weights) + biases
tf.histogram_summary(layer_name + '/activations', activations)
relu = tf.nn.relu(activations, 'relu')
tf.histogram_summary(layer_name + '/activations_relu', relu)
return tf.nn.dropout(relu, keep_prob)
# Define loss and optimizer layer1 = nn_layer(x, 784, 50, 'layer1')
y_ = tf.placeholder(tf.float32, [None,10], name="y-input") layer2 = nn_layer(layer1, 50, 10, 'layer2')
# More name scopes will clean up the graph representation y = tf.nn.softmax(layer2, 'predictions')
with tf.name_scope("xent") as scope:
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
tf.scalar_summary("cross entropy", cross_entropy)
with tf.name_scope("train") as scope:
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
with tf.name_scope("test") as scope:
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.scalar_summary("accuracy", accuracy)
# Merge all the summaries and write them out to /tmp/mnist_logs with tf.name_scope('cross_entropy'):
diff = y_ * tf.log(y)
with tf.name_scope('total'):
cross_entropy = -tf.reduce_sum(diff)
with tf.name_scope('normalized'):
normalized_cross_entropy = -tf.reduce_mean(diff)
tf.scalar_summary('cross entropy', normalized_cross_entropy)
with tf.name_scope('train'):
train_step = tf.train.AdamOptimizer(
FLAGS.learning_rate).minimize(cross_entropy)
with tf.name_scope('accuracy'):
with tf.name_scope('correct_prediction'):
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
with tf.name_scope('accuracy'):
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.scalar_summary('accuracy', accuracy)
# Merge all the summaries and write them out to /tmp/mnist_logs (by default)
merged = tf.merge_all_summaries() merged = tf.merge_all_summaries()
writer = tf.train.SummaryWriter("/tmp/mnist_logs", sess.graph) train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', sess.graph)
test_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/test')
tf.initialize_all_variables().run() tf.initialize_all_variables().run()
# Train the model, and feed in test data and record summaries every 10 steps
for i in range(1000):
if i % 10 == 0: # Record summary data, and the accuracy
feed = {x: mnist.test.images, y_: mnist.test.labels}
result = sess.run([merged, accuracy], feed_dict=feed)
summary_str = result[0]
acc = result[1]
writer.add_summary(summary_str, i)
print("Accuracy at step %s: %s" % (i, acc))
else:
batch_xs, batch_ys = mnist.train.next_batch(100)
feed = {x: batch_xs, y_: batch_ys}
sess.run(train_step, feed_dict=feed)
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
``` ```
You're now all set to visualize this data using TensorBoard. You're now all set to visualize this data using TensorBoard.
@ -135,7 +150,7 @@ You're now all set to visualize this data using TensorBoard.
To run TensorBoard, use the command To run TensorBoard, use the command
```bash ```bash
python tensorflow/tensorboard/tensorboard.py --logdir=path/to/log-directory tensorboard --logdir=path/to/log-directory
``` ```
where `logdir` points to the directory where the `SummaryWriter` serialized its where `logdir` points to the directory where the `SummaryWriter` serialized its
@ -144,18 +159,8 @@ serialized data from separate runs, then TensorBoard will visualize the data
from all of those runs. Once TensorBoard is running, navigate your web browser from all of those runs. Once TensorBoard is running, navigate your web browser
to `localhost:6006` to view the TensorBoard. to `localhost:6006` to view the TensorBoard.
If you have pip installed TensorFlow, `tensorboard` is installed into
the system path, so you can use the simpler command
```bash
tensorboard --logdir=/path/to/log-directory
```
When looking at TensorBoard, you will see the navigation tabs in the top right When looking at TensorBoard, you will see the navigation tabs in the top right
corner. Each tab represents a set of serialized data that can be visualized. corner. Each tab represents a set of serialized data that can be visualized.
For any tab you are looking at, if the logs being looked at by TensorBoard do
not contain any data relevant to that tab, a message will be displayed
indicating how to serialize data that is applicable to that tab.
For in depth information on how to use the *graph* tab to visualize your graph, For in depth information on how to use the *graph* tab to visualize your graph,
see [TensorBoard: Graph Visualization](../../how_tos/graph_viz/index.md). see [TensorBoard: Graph Visualization](../../how_tos/graph_viz/index.md).