MNIST TF 2.0 integration test

PiperOrigin-RevId: 220545522
This commit is contained in:
A. Unique TensorFlower 2018-11-07 15:59:13 -08:00 committed by TensorFlower Gardener
parent 665bd7a2ce
commit 42476f730e
3 changed files with 319 additions and 0 deletions

View File

@ -0,0 +1,32 @@
licenses(["notice"]) # Apache 2.0
package(
default_visibility = ["//visibility:private"],
)
test_suite(
name = "all_tests",
tags = [
"manual",
"no_oss",
"notap",
],
tests = [
":mnist",
],
)
py_test(
name = "mnist",
srcs = ["mnist.py"],
tags = [
"manual",
"no_oss",
"notap",
],
deps = [
"//tensorflow:tensorflow_py",
"//third_party/py/absl:app",
"//third_party/py/absl/flags",
],
)

View File

@ -0,0 +1,25 @@
# TF 2.0 Showcase
The code here shows idiomatic ways to write TensorFlow 2.0 code. It doubles as
an integration test.
## General guidelines for showcase code:
- Code should minimize dependencies and be self-contained in one file. A user
should be able to copy-paste the example code into their project and have it
just work.
- Code should emphasize simplicity over performance, as long as it performs
within a factor of 2-3x of the optimized implementation.
- Code should work on CPU and single GPU.
- Code should run in Python 3.
- Code should conform to the [Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md)
- Code should follow these guidelines:
- Prefer Keras.
- Split code into separate input pipeline and model code segments.
- Don't use tf.cond or tf.while_loop; instead, make use of AutoGraph's
functionality to compile Python `for`, `while`, and `if` statements.
- Prefer a simple training loop over Estimator
- Save and restore a SavedModel.
- Write basic TensorBoard metrics - loss, accuracy,

View File

@ -0,0 +1,262 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""MNIST model training with TensorFlow eager execution.
See:
https://research.googleblog.com/2017/10/eager-execution-imperative-define-by.html
This program demonstrates training, export, and inference of a convolutional
neural network model with eager execution enabled.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
tfe = tf.contrib.eager
flags.DEFINE_integer(
name='log_interval',
default=10,
help='batches between logging training status')
flags.DEFINE_float(name='learning_rate', default=0.01, help='Learning rate.')
flags.DEFINE_float(
name='momentum', short_name='m', default=0.5, help='SGD momentum.')
flags.DEFINE_integer(
name='batch_size',
default=100,
help='Batch size to use during training / eval')
flags.DEFINE_integer(
name='train_epochs', default=10, help='Number of epochs to train')
flags.DEFINE_string(
name='model_dir',
default='/tmp/tensorflow/mnist',
help='Where to save checkpoints, tensorboard summaries, etc.')
flags.DEFINE_bool(
name='clean',
default=False,
help='Whether to clear model directory before training')
FLAGS = flags.FLAGS
def create_model():
"""Model to recognize digits in the MNIST dataset.
Network structure is equivalent to:
https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
and
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
But uses the tf.keras API.
Returns:
A tf.keras.Model.
"""
# Assumes data_format == 'channel_last'.
# See https://www.tensorflow.org/performance/performance_guide#data_formats
input_shape = [28, 28, 1]
l = tf.keras.layers
max_pool = l.MaxPooling2D((2, 2), (2, 2), padding='same')
# The model consists of a sequential chain of layers, so tf.keras.Sequential
# (a subclass of tf.keras.Model) makes for a compact description.
model = tf.keras.Sequential(
[
l.Reshape(
target_shape=input_shape,
input_shape=(28 * 28,)),
l.Conv2D(2, 5, padding='same', activation=tf.nn.relu),
max_pool,
l.Conv2D(4, 5, padding='same', activation=tf.nn.relu),
max_pool,
l.Flatten(),
l.Dense(32, activation=tf.nn.relu),
l.Dropout(0.4),
l.Dense(10)
])
# TODO(brianklee): Remove when @kaftan makes this happen by default.
# TODO(brianklee): remove `autograph=True` when kwarg default is flipped.
model.call = tfe.function(model.call, autograph=True)
# Needs to have input_signature specified in order to be exported
# since model.predict() is never called before saved_model.export()
# TODO(brianklee): Update with input signature, depending on how the impl of
# saved_model.restore() pans out.
model.predict = tfe.function(model.predict, autograph=True)
# ,input_signature=(tensor_spec.TensorSpec(shape=[28, 28, None], dtype=tf.float32),) # pylint: disable=line-too-long
return model
def mnist_datasets():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
# Numpy defaults to dtype=float64; TF defaults to float32. Stick with float32.
x_train, x_test = x_train / np.float32(255), x_test / np.float32(255)
y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
return train_dataset, test_dataset
def loss(logits, labels):
return tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=labels))
def compute_accuracy(logits, labels):
predictions = tf.argmax(logits, axis=1, output_type=tf.int64)
labels = tf.cast(labels, tf.int64)
return tf.reduce_mean(
tf.cast(tf.equal(predictions, labels), dtype=tf.float32))
# TODO(brianklee): Enable @tf.function on the training loop when zip, enumerate
# are supported by autograph.
def train(model, optimizer, dataset, step_counter, log_interval=None,
num_steps=None):
"""Trains model on `dataset` using `optimizer`."""
start = time.time()
for (batch, (images, labels)) in enumerate(dataset):
if num_steps is not None and batch > num_steps:
break
with tf.contrib.summary.record_summaries_every_n_global_steps(
10, global_step=step_counter):
# Record the operations used to compute the loss given the input,
# so that the gradient of the loss with respect to the variables
# can be computed.
with tf.GradientTape() as tape:
logits = model(images, training=True)
loss_value = loss(logits, labels)
tf.contrib.summary.scalar('loss', loss_value)
tf.contrib.summary.scalar('accuracy', compute_accuracy(logits, labels))
grads = tape.gradient(loss_value, model.variables)
optimizer.apply_gradients(
zip(grads, model.variables), global_step=step_counter)
if log_interval and batch % log_interval == 0:
rate = log_interval / (time.time() - start)
print('Step #%d\tLoss: %.6f (%d steps/sec)' % (batch, loss_value, rate))
start = time.time()
def test(model, dataset):
"""Perform an evaluation of `model` on the examples from `dataset`."""
avg_loss = tfe.metrics.Mean('loss', dtype=tf.float32)
accuracy = tfe.metrics.Accuracy('accuracy', dtype=tf.float32)
for (images, labels) in dataset:
logits = model(images, training=False)
avg_loss(loss(logits, labels))
accuracy(
tf.argmax(logits, axis=1, output_type=tf.int64),
tf.cast(labels, tf.int64))
print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' %
(avg_loss.result(), 100 * accuracy.result()))
with tf.contrib.summary.always_record_summaries():
tf.contrib.summary.scalar('loss', avg_loss.result())
tf.contrib.summary.scalar('accuracy', accuracy.result())
def train_and_export(flags_obj):
"""Run MNIST training and eval loop in eager mode.
Args:
flags_obj: An object containing parsed flag values.
"""
# Load the datasets
train_ds, test_ds = mnist_datasets()
train_ds = train_ds.shuffle(60000).batch(flags_obj.batch_size)
test_ds = test_ds.batch(flags_obj.batch_size)
# Create the model and optimizer
model = create_model()
optimizer = tf.train.MomentumOptimizer(
flags_obj.learning_rate, flags_obj.momentum)
# See summaries with `tensorboard --logdir=<model_dir>`
train_dir = os.path.join(flags_obj.model_dir, 'summaries', 'train')
test_dir = os.path.join(flags_obj.model_dir, 'summaries', 'eval')
summary_writer = tf.contrib.summary.create_file_writer(
train_dir, flush_millis=10000)
test_summary_writer = tf.contrib.summary.create_file_writer(
test_dir, flush_millis=10000, name='test')
# Create and restore checkpoint (if one exists on the path)
checkpoint_dir = os.path.join(flags_obj.model_dir, 'checkpoints')
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
step_counter = tf.train.get_or_create_global_step()
checkpoint = tf.train.Checkpoint(
model=model, optimizer=optimizer, step_counter=step_counter)
# Restore variables on creation if a checkpoint exists.
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
# Train and evaluate for a set number of epochs.
for _ in range(flags_obj.train_epochs):
start = time.time()
with summary_writer.as_default():
train(model, optimizer, train_ds, step_counter,
flags_obj.log_interval, num_steps=1)
end = time.time()
print('\nTrain time for epoch #%d (%d total steps): %f' %
(checkpoint.save_counter.numpy() + 1,
step_counter.numpy(),
end - start))
with test_summary_writer.as_default():
test(model, test_ds)
checkpoint.save(checkpoint_prefix)
# TODO(brianklee): Enable this functionality after @allenl implements this.
# export_path = os.path.join(flags_obj.model_dir, 'export')
# tf.saved_model.save(export_path, model)
def import_and_eval(flags_obj):
export_path = os.path.join(flags_obj.model_dir, 'export')
model = tf.saved_model.restore(export_path)
_, (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_test = x_test / np.float32(255)
y_predict = model(x_test)
accuracy = compute_accuracy(y_predict, y_test)
print('Model accuracy: {:0.2f}%'.format(accuracy.numpy() * 100))
def apply_clean(flags_obj):
if flags_obj.clean and tf.gfile.Exists(flags_obj.model_dir):
tf.logging.info('--clean flag set. Removing existing model dir: {}'.format(
flags_obj.model_dir))
tf.gfile.DeleteRecursively(flags_obj.model_dir)
def main(_):
apply_clean(flags.FLAGS)
train_and_export(flags.FLAGS)
# TODO(brianklee): Enable this functionality after @allenl implements this.
# import_and_eval(flags.FLAGS)
if __name__ == '__main__':
app.run(main)