Redo the TensorBoard tutorial code (mnist_with_summaries.py).
Goals: - Have enough of each summary type that tag grouping is useful. (Wound up recording e.g. mean and stddev and min/max for each variable) - Use every summary type (adds images) - Write to multiple directories so there are several "runs" Change: 119585022
This commit is contained in:
parent
3c59c1ed08
commit
a77499c87d
tensorflow
@ -1,29 +1,25 @@
|
|||||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
# You may obtain a copy of the License at
|
# You may obtain a copy of the License at
|
||||||
#
|
#
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
#
|
#
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""A very simple MNIST classifier, modified to display data in TensorBoard.
|
"""A simple MNIST classifier which displays summaries in TensorBoard.
|
||||||
|
|
||||||
See extensive documentation for the original model at
|
This is an unimpressive MNIST model, but it is a good example of using
|
||||||
http://tensorflow.org/tutorials/mnist/beginners/index.md
|
tf.name_scope to make a graph legible in the TensorBoard graph explorer, and of
|
||||||
|
naming summary tags so that they are grouped meaningfully in TensorBoard.
|
||||||
See documentation on the TensorBoard specific pieces at
|
|
||||||
http://tensorflow.org/how_tos/summaries_and_tensorboard/index.md
|
|
||||||
|
|
||||||
If you modify this file, please update the excerpt in
|
|
||||||
how_tos/summaries_and_tensorboard/index.md.
|
|
||||||
|
|
||||||
|
It demonstrates the functionality of every TensorBoard dashboard.
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -39,72 +35,132 @@ FLAGS = flags.FLAGS
|
|||||||
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
|
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
|
||||||
'for unit testing.')
|
'for unit testing.')
|
||||||
flags.DEFINE_integer('max_steps', 1000, 'Number of steps to run trainer.')
|
flags.DEFINE_integer('max_steps', 1000, 'Number of steps to run trainer.')
|
||||||
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
|
flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')
|
||||||
|
flags.DEFINE_float('dropout', 0.9, 'Keep probability for training dropout.')
|
||||||
flags.DEFINE_string('data_dir', '/tmp/data', 'Directory for storing data')
|
flags.DEFINE_string('data_dir', '/tmp/data', 'Directory for storing data')
|
||||||
flags.DEFINE_string('summaries_dir', '/tmp/mnist_logs', 'Summaries directory')
|
flags.DEFINE_string('summaries_dir', '/tmp/mnist_logs', 'Summaries directory')
|
||||||
|
|
||||||
|
|
||||||
def main(_):
|
def train():
|
||||||
# Import data
|
# Import data
|
||||||
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True,
|
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True,
|
||||||
fake_data=FLAGS.fake_data)
|
fake_data=FLAGS.fake_data)
|
||||||
|
|
||||||
sess = tf.InteractiveSession()
|
sess = tf.InteractiveSession()
|
||||||
|
|
||||||
# Create the model
|
# Create a multilayer model.
|
||||||
x = tf.placeholder(tf.float32, [None, 784], name='x-input')
|
|
||||||
W = tf.Variable(tf.zeros([784, 10]), name='weights')
|
|
||||||
b = tf.Variable(tf.zeros([10]), name='bias')
|
|
||||||
|
|
||||||
# Use a name scope to organize nodes in the graph visualizer
|
# Input placehoolders
|
||||||
with tf.name_scope('Wx_b'):
|
with tf.name_scope('input'):
|
||||||
y = tf.nn.softmax(tf.matmul(x, W) + b)
|
x = tf.placeholder(tf.float32, [None, 784], name='x-input')
|
||||||
|
image_shaped_input = tf.reshape(x, [-1, 28, 28, 1])
|
||||||
|
tf.image_summary('input', image_shaped_input, 10)
|
||||||
|
y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
|
||||||
|
keep_prob = tf.placeholder(tf.float32)
|
||||||
|
tf.scalar_summary('dropout_keep_probability', keep_prob)
|
||||||
|
|
||||||
# Add summary ops to collect data
|
# We can't initialize these variables to 0 - the network will get stuck.
|
||||||
tf.histogram_summary('weights', W)
|
def weight_variable(shape):
|
||||||
tf.histogram_summary('biases', b)
|
"""Create a weight variable with appropriate initialization."""
|
||||||
tf.histogram_summary('y', y)
|
initial = tf.truncated_normal(shape, stddev=0.1)
|
||||||
|
return tf.Variable(initial)
|
||||||
|
|
||||||
# Define loss and optimizer
|
def bias_variable(shape):
|
||||||
y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')
|
"""Create a bias variable with appropriate initialization."""
|
||||||
# More name scopes will clean up the graph representation
|
initial = tf.constant(0.1, shape=shape)
|
||||||
with tf.name_scope('xent'):
|
return tf.Variable(initial)
|
||||||
cross_entropy = -tf.reduce_sum(y_ * tf.log(y))
|
|
||||||
|
def variable_summaries(var, name):
|
||||||
|
"""Attach a lot of summaries to a Tensor."""
|
||||||
|
with tf.name_scope('summaries'):
|
||||||
|
mean = tf.reduce_mean(var)
|
||||||
|
tf.scalar_summary('mean/' + name, mean)
|
||||||
|
with tf.name_scope('stddev'):
|
||||||
|
stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
|
||||||
|
tf.scalar_summary('sttdev/' + name, stddev)
|
||||||
|
tf.scalar_summary('max/' + name, tf.reduce_max(var))
|
||||||
|
tf.scalar_summary('min/' + name, tf.reduce_min(var))
|
||||||
|
tf.histogram_summary(name, var)
|
||||||
|
|
||||||
|
def nn_layer(input_tensor, input_dim, output_dim, layer_name, act=tf.nn.relu):
|
||||||
|
"""Reusable code for making a simple neural net layer.
|
||||||
|
|
||||||
|
It does a matrix multiply, bias add, and then uses relu to nonlinearize.
|
||||||
|
It also sets up name scoping so that the resultant graph is easy to read, and
|
||||||
|
adds a number of summary ops.
|
||||||
|
"""
|
||||||
|
# Adding a name scope ensures logical grouping of the layers in the graph.
|
||||||
|
with tf.name_scope(layer_name):
|
||||||
|
# This Variable will hold the state of the weights for the layer
|
||||||
|
with tf.name_scope('weights'):
|
||||||
|
weights = weight_variable([input_dim, output_dim])
|
||||||
|
variable_summaries(weights, layer_name + '/weights')
|
||||||
|
with tf.name_scope('biases'):
|
||||||
|
biases = bias_variable([output_dim])
|
||||||
|
variable_summaries(biases, layer_name + '/biases')
|
||||||
|
with tf.name_scope('Wx_plus_b'):
|
||||||
|
preactivate = tf.matmul(input_tensor, weights) + biases
|
||||||
|
tf.histogram_summary(layer_name + '/pre_activations', preactivate)
|
||||||
|
activations = act(preactivate, 'activation')
|
||||||
|
tf.histogram_summary(layer_name + '/activations', activations)
|
||||||
|
return activations
|
||||||
|
|
||||||
|
hidden1 = nn_layer(x, 784, 500, 'layer1')
|
||||||
|
dropped = tf.nn.dropout(hidden1, keep_prob)
|
||||||
|
y = nn_layer(dropped, 500, 10, 'layer2', act=tf.nn.softmax)
|
||||||
|
|
||||||
|
|
||||||
|
with tf.name_scope('cross_entropy'):
|
||||||
|
diff = y_ * tf.log(y)
|
||||||
|
with tf.name_scope('total'):
|
||||||
|
cross_entropy = -tf.reduce_mean(diff)
|
||||||
tf.scalar_summary('cross entropy', cross_entropy)
|
tf.scalar_summary('cross entropy', cross_entropy)
|
||||||
|
|
||||||
with tf.name_scope('train'):
|
with tf.name_scope('train'):
|
||||||
train_step = tf.train.GradientDescentOptimizer(
|
train_step = tf.train.AdamOptimizer(
|
||||||
FLAGS.learning_rate).minimize(cross_entropy)
|
FLAGS.learning_rate).minimize(cross_entropy)
|
||||||
|
|
||||||
with tf.name_scope('test'):
|
with tf.name_scope('accuracy'):
|
||||||
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
|
with tf.name_scope('correct_prediction'):
|
||||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
|
||||||
|
with tf.name_scope('accuracy'):
|
||||||
|
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
||||||
tf.scalar_summary('accuracy', accuracy)
|
tf.scalar_summary('accuracy', accuracy)
|
||||||
|
|
||||||
# Merge all the summaries and write them out to /tmp/mnist_logs (by default)
|
# Merge all the summaries and write them out to /tmp/mnist_logs (by default)
|
||||||
merged = tf.merge_all_summaries()
|
merged = tf.merge_all_summaries()
|
||||||
writer = tf.train.SummaryWriter(FLAGS.summaries_dir, sess.graph)
|
train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', sess.graph)
|
||||||
|
test_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/test')
|
||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
# Train the model, and feed in test data and record summaries every 10 steps
|
# Train the model, and also write summaries.
|
||||||
|
# Every 10th step, measure test-set accuracy, and write test summaries
|
||||||
|
# All other steps, run train_step on training data, & add training summaries
|
||||||
|
|
||||||
|
def feed_dict(train):
|
||||||
|
"""Make a TensorFlow feed_dict: maps data onto Tensor placeholders."""
|
||||||
|
if train or FLAGS.fake_data:
|
||||||
|
xs, ys = mnist.train.next_batch(100, fake_data=FLAGS.fake_data)
|
||||||
|
k = FLAGS.dropout
|
||||||
|
else:
|
||||||
|
xs, ys = mnist.test.images, mnist.test.labels
|
||||||
|
k = 1.0
|
||||||
|
return {x: xs, y_: ys, keep_prob: k}
|
||||||
|
|
||||||
for i in range(FLAGS.max_steps):
|
for i in range(FLAGS.max_steps):
|
||||||
if i % 10 == 0: # Record summary data and the accuracy
|
if i % 10 == 0: # Record summaries and test-set accuracy
|
||||||
if FLAGS.fake_data:
|
summary, acc = sess.run([merged, accuracy], feed_dict=feed_dict(False))
|
||||||
batch_xs, batch_ys = mnist.train.next_batch(
|
test_writer.add_summary(summary, i)
|
||||||
100, fake_data=FLAGS.fake_data)
|
|
||||||
feed = {x: batch_xs, y_: batch_ys}
|
|
||||||
else:
|
|
||||||
feed = {x: mnist.test.images, y_: mnist.test.labels}
|
|
||||||
result = sess.run([merged, accuracy], feed_dict=feed)
|
|
||||||
summary_str = result[0]
|
|
||||||
acc = result[1]
|
|
||||||
writer.add_summary(summary_str, i)
|
|
||||||
print('Accuracy at step %s: %s' % (i, acc))
|
print('Accuracy at step %s: %s' % (i, acc))
|
||||||
else:
|
else: # Record train set summarieis, and train
|
||||||
batch_xs, batch_ys = mnist.train.next_batch(
|
summary, _ = sess.run([merged, train_step], feed_dict=feed_dict(True))
|
||||||
100, fake_data=FLAGS.fake_data)
|
train_writer.add_summary(summary, i)
|
||||||
feed = {x: batch_xs, y_: batch_ys}
|
|
||||||
sess.run(train_step, feed_dict=feed)
|
def main(_):
|
||||||
|
if tf.gfile.Exists(FLAGS.summaries_dir):
|
||||||
|
tf.gfile.DeleteRecursively(FLAGS.summaries_dir)
|
||||||
|
tf.gfile.MakeDirs(FLAGS.summaries_dir)
|
||||||
|
train()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
tf.app.run()
|
tf.app.run()
|
||||||
|
@ -8,7 +8,8 @@ your TensorFlow graph, plot quantitative metrics about the execution of your
|
|||||||
graph, and show additional data like images that pass through it. When
|
graph, and show additional data like images that pass through it. When
|
||||||
TensorBoard is fully configured, it looks like this:
|
TensorBoard is fully configured, it looks like this:
|
||||||
|
|
||||||

|
[](http://tensorflow.org/tensorboard)
|
||||||
|
[*Click try a TensorBoard with data from this tutorial!*](http://tensorflow.org/tensorboard)
|
||||||
|
|
||||||
|
|
||||||
## Serializing the data
|
## Serializing the data
|
||||||
@ -75,56 +76,70 @@ statistics, such as how the weights or accuracy varied during training.
|
|||||||
The code below is an excerpt; full source is [here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py).
|
The code below is an excerpt; full source is [here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/mnist/mnist_with_summaries.py).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Create the model
|
def variable_summaries(var, name):
|
||||||
x = tf.placeholder(tf.float32, [None, 784], name="x-input")
|
with tf.name_scope("summaries"):
|
||||||
W = tf.Variable(tf.zeros([784,10]), name="weights")
|
mean = tf.reduce_mean(var)
|
||||||
b = tf.Variable(tf.zeros([10], name="bias"))
|
tf.scalar_summary('mean/' + name, mean)
|
||||||
|
with tf.name_scope('stddev'):
|
||||||
|
stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
|
||||||
|
tf.scalar_summary('sttdev/' + name, stddev)
|
||||||
|
tf.scalar_summary('max/' + name, tf.reduce_max(var))
|
||||||
|
tf.scalar_summary('min/' + name, tf.reduce_min(var))
|
||||||
|
tf.histogram_summary(name, var)
|
||||||
|
|
||||||
# use a name scope to organize nodes in the graph visualizer
|
def nn_layer(input_tensor, input_dim, output_dim, layer_name):
|
||||||
with tf.name_scope("Wx_b") as scope:
|
"""Reusable code for making a simple neural net layer.
|
||||||
y = tf.nn.softmax(tf.matmul(x,W) + b)
|
|
||||||
|
|
||||||
# Add summary ops to collect data
|
It does a matrix multiply, bias add, and then uses relu to nonlinearize.
|
||||||
tf.histogram_summary("weights", W)
|
It also sets up name scoping so that the resultant graph is easy to read, and
|
||||||
tf.histogram_summary("biases", b)
|
adds a number of summary ops.
|
||||||
tf.histogram_summary("y", y)
|
"""
|
||||||
|
# Adding a name scope ensures logical grouping of the layers in the graph.
|
||||||
|
with tf.name_scope(layer_name):
|
||||||
|
# This Variable will hold the state of the weights for the layer
|
||||||
|
with tf.name_scope("weights"):
|
||||||
|
weights = weight_variable([input_dim, output_dim])
|
||||||
|
variable_summaries(weights, layer_name + '/weights')
|
||||||
|
with tf.name_scope("biases"):
|
||||||
|
biases = bias_variable([output_dim])
|
||||||
|
variable_summaries(biases, layer_name + '/biases')
|
||||||
|
with tf.name_scope('Wx_plus_b'):
|
||||||
|
activations = tf.matmul(input_tensor, weights) + biases
|
||||||
|
tf.histogram_summary(layer_name + '/activations', activations)
|
||||||
|
relu = tf.nn.relu(activations, 'relu')
|
||||||
|
tf.histogram_summary(layer_name + '/activations_relu', relu)
|
||||||
|
return tf.nn.dropout(relu, keep_prob)
|
||||||
|
|
||||||
# Define loss and optimizer
|
layer1 = nn_layer(x, 784, 50, 'layer1')
|
||||||
y_ = tf.placeholder(tf.float32, [None,10], name="y-input")
|
layer2 = nn_layer(layer1, 50, 10, 'layer2')
|
||||||
# More name scopes will clean up the graph representation
|
y = tf.nn.softmax(layer2, 'predictions')
|
||||||
with tf.name_scope("xent") as scope:
|
|
||||||
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
|
|
||||||
tf.scalar_summary("cross entropy", cross_entropy)
|
|
||||||
with tf.name_scope("train") as scope:
|
|
||||||
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
|
|
||||||
|
|
||||||
with tf.name_scope("test") as scope:
|
|
||||||
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
|
|
||||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
|
||||||
tf.scalar_summary("accuracy", accuracy)
|
|
||||||
|
|
||||||
# Merge all the summaries and write them out to /tmp/mnist_logs
|
with tf.name_scope('cross_entropy'):
|
||||||
|
diff = y_ * tf.log(y)
|
||||||
|
with tf.name_scope('total'):
|
||||||
|
cross_entropy = -tf.reduce_sum(diff)
|
||||||
|
with tf.name_scope('normalized'):
|
||||||
|
normalized_cross_entropy = -tf.reduce_mean(diff)
|
||||||
|
tf.scalar_summary('cross entropy', normalized_cross_entropy)
|
||||||
|
|
||||||
|
with tf.name_scope('train'):
|
||||||
|
train_step = tf.train.AdamOptimizer(
|
||||||
|
FLAGS.learning_rate).minimize(cross_entropy)
|
||||||
|
|
||||||
|
with tf.name_scope('accuracy'):
|
||||||
|
with tf.name_scope('correct_prediction'):
|
||||||
|
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
|
||||||
|
with tf.name_scope('accuracy'):
|
||||||
|
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
||||||
|
tf.scalar_summary('accuracy', accuracy)
|
||||||
|
|
||||||
|
# Merge all the summaries and write them out to /tmp/mnist_logs (by default)
|
||||||
merged = tf.merge_all_summaries()
|
merged = tf.merge_all_summaries()
|
||||||
writer = tf.train.SummaryWriter("/tmp/mnist_logs", sess.graph)
|
train_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/train', sess.graph)
|
||||||
|
test_writer = tf.train.SummaryWriter(FLAGS.summaries_dir + '/test')
|
||||||
tf.initialize_all_variables().run()
|
tf.initialize_all_variables().run()
|
||||||
|
|
||||||
# Train the model, and feed in test data and record summaries every 10 steps
|
|
||||||
|
|
||||||
for i in range(1000):
|
|
||||||
if i % 10 == 0: # Record summary data, and the accuracy
|
|
||||||
feed = {x: mnist.test.images, y_: mnist.test.labels}
|
|
||||||
result = sess.run([merged, accuracy], feed_dict=feed)
|
|
||||||
summary_str = result[0]
|
|
||||||
acc = result[1]
|
|
||||||
writer.add_summary(summary_str, i)
|
|
||||||
print("Accuracy at step %s: %s" % (i, acc))
|
|
||||||
else:
|
|
||||||
batch_xs, batch_ys = mnist.train.next_batch(100)
|
|
||||||
feed = {x: batch_xs, y_: batch_ys}
|
|
||||||
sess.run(train_step, feed_dict=feed)
|
|
||||||
|
|
||||||
print(accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}))
|
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
You're now all set to visualize this data using TensorBoard.
|
You're now all set to visualize this data using TensorBoard.
|
||||||
@ -135,7 +150,7 @@ You're now all set to visualize this data using TensorBoard.
|
|||||||
To run TensorBoard, use the command
|
To run TensorBoard, use the command
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python tensorflow/tensorboard/tensorboard.py --logdir=path/to/log-directory
|
tensorboard --logdir=path/to/log-directory
|
||||||
```
|
```
|
||||||
|
|
||||||
where `logdir` points to the directory where the `SummaryWriter` serialized its
|
where `logdir` points to the directory where the `SummaryWriter` serialized its
|
||||||
@ -144,18 +159,8 @@ serialized data from separate runs, then TensorBoard will visualize the data
|
|||||||
from all of those runs. Once TensorBoard is running, navigate your web browser
|
from all of those runs. Once TensorBoard is running, navigate your web browser
|
||||||
to `localhost:6006` to view the TensorBoard.
|
to `localhost:6006` to view the TensorBoard.
|
||||||
|
|
||||||
If you have pip installed TensorFlow, `tensorboard` is installed into
|
|
||||||
the system path, so you can use the simpler command
|
|
||||||
|
|
||||||
```bash
|
|
||||||
tensorboard --logdir=/path/to/log-directory
|
|
||||||
```
|
|
||||||
|
|
||||||
When looking at TensorBoard, you will see the navigation tabs in the top right
|
When looking at TensorBoard, you will see the navigation tabs in the top right
|
||||||
corner. Each tab represents a set of serialized data that can be visualized.
|
corner. Each tab represents a set of serialized data that can be visualized.
|
||||||
For any tab you are looking at, if the logs being looked at by TensorBoard do
|
|
||||||
not contain any data relevant to that tab, a message will be displayed
|
|
||||||
indicating how to serialize data that is applicable to that tab.
|
|
||||||
|
|
||||||
For in depth information on how to use the *graph* tab to visualize your graph,
|
For in depth information on how to use the *graph* tab to visualize your graph,
|
||||||
see [TensorBoard: Graph Visualization](../../how_tos/graph_viz/index.md).
|
see [TensorBoard: Graph Visualization](../../how_tos/graph_viz/index.md).
|
||||||
|
Loading…
Reference in New Issue
Block a user