Make convolutional.py easier to run on GPUs with less RAM.

A few cosmetic fixes. Test error was already 0.7%, not 0.8%.
https://github.com/tensorflow/tensorflow/issues/609
Change: 111624237
This commit is contained in:
Vincent Vanhoucke 2016-01-07 13:26:26 -08:00 committed by Vijay Vasudevan
parent cb91829d58
commit ca47376b3c

View File

@ -15,7 +15,7 @@
"""Simple, end-to-end, LeNet-5-like convolutional MNIST model example.
This should achieve a test error of 0.8%. Please keep this model as simple and
This should achieve a test error of 0.7%. Please keep this model as simple and
linear as possible, it is meant as a tutorial for simple convolutional models.
Run with --self_test on the command line to exectute a short self-test.
"""
@ -45,6 +45,8 @@ VALIDATION_SIZE = 5000 # Size of the validation set.
SEED = 66478 # Set to None for random seed.
BATCH_SIZE = 64
NUM_EPOCHS = 10
EVAL_BATCH_SIZE = 64
EVAL_FREQUENCY = 100 # Number of steps between evaluations.
tf.app.flags.DEFINE_boolean("self_test", False, "True if running a self test.")
@ -114,8 +116,8 @@ def main(argv=None): # pylint: disable=unused-argument
if FLAGS.self_test:
print('Running self-test.')
train_data, train_labels = fake_data(256)
validation_data, validation_labels = fake_data(16)
test_data, test_labels = fake_data(256)
validation_data, validation_labels = fake_data(EVAL_BATCH_SIZE)
test_data, test_labels = fake_data(EVAL_BATCH_SIZE)
num_epochs = 1
else:
# Get the data.
@ -131,9 +133,9 @@ def main(argv=None): # pylint: disable=unused-argument
test_labels = extract_labels(test_labels_filename, 10000)
# Generate a validation set.
validation_data = train_data[:VALIDATION_SIZE, :, :, :]
validation_data = train_data[:VALIDATION_SIZE, ...]
validation_labels = train_labels[:VALIDATION_SIZE]
train_data = train_data[VALIDATION_SIZE:, :, :, :]
train_data = train_data[VALIDATION_SIZE:, ...]
train_labels = train_labels[VALIDATION_SIZE:]
num_epochs = NUM_EPOCHS
train_size = train_labels.shape[0]
@ -146,10 +148,9 @@ def main(argv=None): # pylint: disable=unused-argument
shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
train_labels_node = tf.placeholder(tf.float32,
shape=(BATCH_SIZE, NUM_LABELS))
# For the validation and test data, we'll just hold the entire dataset in
# one constant node.
validation_data_node = tf.constant(validation_data)
test_data_node = tf.constant(test_data)
eval_data = tf.placeholder(
tf.float32,
shape=(EVAL_BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
# The variables below hold all the trainable weights. They are passed an
# initial value which will be assigned when when we call:
@ -245,46 +246,68 @@ def main(argv=None): # pylint: disable=unused-argument
0.9).minimize(loss,
global_step=batch)
# Predictions for the minibatch, validation set and test set.
# Predictions for the current training minibatch.
train_prediction = tf.nn.softmax(logits)
# We'll compute them only once in a while by calling their {eval()} method.
validation_prediction = tf.nn.softmax(model(validation_data_node))
test_prediction = tf.nn.softmax(model(test_data_node))
# Create a local session to run this computation.
# Predictions for the test and validation, which we'll compute less often.
eval_prediction = tf.nn.softmax(model(eval_data))
# Small utility function to evaluate a dataset by feeding batches of data to
# {eval_data} and pulling the results from {eval_predictions}.
# Saves memory and enables this to run on smaller GPUs.
def eval_in_batches(data, sess):
"""Get all predictions for a dataset by running it in small batches."""
size = data.shape[0]
if size < EVAL_BATCH_SIZE:
raise ValueError("batch size for evals larger than dataset: %d" % size)
predictions = numpy.ndarray(shape=(size, NUM_LABELS), dtype=numpy.float32)
for begin in xrange(0, size, EVAL_BATCH_SIZE):
end = begin + EVAL_BATCH_SIZE
if end <= size:
predictions[begin:end, :] = sess.run(
eval_prediction,
feed_dict={eval_data: data[begin:end, ...]})
else:
batch_predictions = sess.run(
eval_prediction,
feed_dict={eval_data: data[-EVAL_BATCH_SIZE:, ...]})
predictions[begin:, :] = batch_predictions[begin - size:, :]
return predictions
# Create a local session to run the training.
start_time = time.time()
log_each_x_steps = 100
with tf.Session() as s:
with tf.Session() as sess:
# Run all the initializers to prepare the trainable parameters.
tf.initialize_all_variables().run()
print('Initialized!')
# Loop through training steps.
for step in xrange(num_epochs * train_size // BATCH_SIZE):
for step in xrange(int(num_epochs * train_size) // BATCH_SIZE):
# Compute the offset of the current minibatch in the data.
# Note that we could use better randomization across epochs.
offset = (step * BATCH_SIZE) % (train_size - BATCH_SIZE)
batch_data = train_data[offset:(offset + BATCH_SIZE), :, :, :]
batch_data = train_data[offset:(offset + BATCH_SIZE), ...]
batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
# This dictionary maps the batch data (as a numpy array) to the
# node in the graph is should be fed to.
feed_dict = {train_data_node: batch_data,
train_labels_node: batch_labels}
# Run the graph and fetch some of the nodes.
_, l, lr, predictions = s.run(
_, l, lr, predictions = sess.run(
[optimizer, loss, learning_rate, train_prediction],
feed_dict=feed_dict)
if step % log_each_x_steps == 0:
if step % EVAL_FREQUENCY == 0:
elapsed_time = time.time() - start_time
start_time = time.time()
print('Step %d, %.1f ms'%(step, 1000 * elapsed_time / log_each_x_steps))
print('Epoch %.2f' % (float(step) * BATCH_SIZE / train_size))
print('Step %d (epoch %.2f), %.1f ms' %
(step, float(step) * BATCH_SIZE / train_size,
1000 * elapsed_time / EVAL_FREQUENCY))
print('Minibatch loss: %.3f, learning rate: %.6f' % (l, lr))
print('Minibatch error: %.1f%%' % error_rate(predictions, batch_labels))
print('Validation error: %.1f%%' %
error_rate(validation_prediction.eval(), validation_labels))
print('Validation error: %.1f%%' % error_rate(
eval_in_batches(validation_data, sess), validation_labels))
sys.stdout.flush()
# Finally print the result!
test_error = error_rate(test_prediction.eval(), test_labels)
test_error = error_rate(eval_in_batches(test_data, sess), test_labels)
print('Test error: %.1f%%' % test_error)
if FLAGS.self_test:
print('test_error', test_error)