MNIST TF 2.0 integration test
PiperOrigin-RevId: 220545522
This commit is contained in:
parent
665bd7a2ce
commit
42476f730e
32
tensorflow/examples/tf2_showcase/BUILD
Normal file
32
tensorflow/examples/tf2_showcase/BUILD
Normal file
@ -0,0 +1,32 @@
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:private"],
|
||||
)
|
||||
|
||||
test_suite(
|
||||
name = "all_tests",
|
||||
tags = [
|
||||
"manual",
|
||||
"no_oss",
|
||||
"notap",
|
||||
],
|
||||
tests = [
|
||||
":mnist",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "mnist",
|
||||
srcs = ["mnist.py"],
|
||||
tags = [
|
||||
"manual",
|
||||
"no_oss",
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//third_party/py/absl:app",
|
||||
"//third_party/py/absl/flags",
|
||||
],
|
||||
)
|
25
tensorflow/examples/tf2_showcase/README.md
Normal file
25
tensorflow/examples/tf2_showcase/README.md
Normal file
@ -0,0 +1,25 @@
|
||||
# TF 2.0 Showcase
|
||||
|
||||
The code here shows idiomatic ways to write TensorFlow 2.0 code. It doubles as
|
||||
an integration test.
|
||||
|
||||
## General guidelines for showcase code:
|
||||
|
||||
- Code should minimize dependencies and be self-contained in one file. A user
|
||||
should be able to copy-paste the example code into their project and have it
|
||||
just work.
|
||||
- Code should emphasize simplicity over performance, as long as it performs
|
||||
within a factor of 2-3x of the optimized implementation.
|
||||
- Code should work on CPU and single GPU.
|
||||
- Code should run in Python 3.
|
||||
- Code should conform to the [Google Python Style Guide](https://github.com/google/styleguide/blob/gh-pages/pyguide.md)
|
||||
|
||||
|
||||
- Code should follow these guidelines:
|
||||
- Prefer Keras.
|
||||
- Split code into separate input pipeline and model code segments.
|
||||
- Don't use tf.cond or tf.while_loop; instead, make use of AutoGraph's
|
||||
functionality to compile Python `for`, `while`, and `if` statements.
|
||||
- Prefer a simple training loop over Estimator
|
||||
- Save and restore a SavedModel.
|
||||
- Write basic TensorBoard metrics - loss, accuracy,
|
262
tensorflow/examples/tf2_showcase/mnist.py
Normal file
262
tensorflow/examples/tf2_showcase/mnist.py
Normal file
@ -0,0 +1,262 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""MNIST model training with TensorFlow eager execution.
|
||||
|
||||
See:
|
||||
https://research.googleblog.com/2017/10/eager-execution-imperative-define-by.html
|
||||
|
||||
This program demonstrates training, export, and inference of a convolutional
|
||||
neural network model with eager execution enabled.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
tfe = tf.contrib.eager
|
||||
|
||||
flags.DEFINE_integer(
|
||||
name='log_interval',
|
||||
default=10,
|
||||
help='batches between logging training status')
|
||||
|
||||
flags.DEFINE_float(name='learning_rate', default=0.01, help='Learning rate.')
|
||||
|
||||
flags.DEFINE_float(
|
||||
name='momentum', short_name='m', default=0.5, help='SGD momentum.')
|
||||
|
||||
flags.DEFINE_integer(
|
||||
name='batch_size',
|
||||
default=100,
|
||||
help='Batch size to use during training / eval')
|
||||
|
||||
flags.DEFINE_integer(
|
||||
name='train_epochs', default=10, help='Number of epochs to train')
|
||||
|
||||
flags.DEFINE_string(
|
||||
name='model_dir',
|
||||
default='/tmp/tensorflow/mnist',
|
||||
help='Where to save checkpoints, tensorboard summaries, etc.')
|
||||
|
||||
flags.DEFINE_bool(
|
||||
name='clean',
|
||||
default=False,
|
||||
help='Whether to clear model directory before training')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def create_model():
|
||||
"""Model to recognize digits in the MNIST dataset.
|
||||
|
||||
Network structure is equivalent to:
|
||||
https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
|
||||
and
|
||||
https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
|
||||
But uses the tf.keras API.
|
||||
Returns:
|
||||
A tf.keras.Model.
|
||||
"""
|
||||
# Assumes data_format == 'channel_last'.
|
||||
# See https://www.tensorflow.org/performance/performance_guide#data_formats
|
||||
|
||||
input_shape = [28, 28, 1]
|
||||
|
||||
l = tf.keras.layers
|
||||
max_pool = l.MaxPooling2D((2, 2), (2, 2), padding='same')
|
||||
# The model consists of a sequential chain of layers, so tf.keras.Sequential
|
||||
# (a subclass of tf.keras.Model) makes for a compact description.
|
||||
model = tf.keras.Sequential(
|
||||
[
|
||||
l.Reshape(
|
||||
target_shape=input_shape,
|
||||
input_shape=(28 * 28,)),
|
||||
l.Conv2D(2, 5, padding='same', activation=tf.nn.relu),
|
||||
max_pool,
|
||||
l.Conv2D(4, 5, padding='same', activation=tf.nn.relu),
|
||||
max_pool,
|
||||
l.Flatten(),
|
||||
l.Dense(32, activation=tf.nn.relu),
|
||||
l.Dropout(0.4),
|
||||
l.Dense(10)
|
||||
])
|
||||
# TODO(brianklee): Remove when @kaftan makes this happen by default.
|
||||
# TODO(brianklee): remove `autograph=True` when kwarg default is flipped.
|
||||
model.call = tfe.function(model.call, autograph=True)
|
||||
# Needs to have input_signature specified in order to be exported
|
||||
# since model.predict() is never called before saved_model.export()
|
||||
# TODO(brianklee): Update with input signature, depending on how the impl of
|
||||
# saved_model.restore() pans out.
|
||||
model.predict = tfe.function(model.predict, autograph=True)
|
||||
# ,input_signature=(tensor_spec.TensorSpec(shape=[28, 28, None], dtype=tf.float32),) # pylint: disable=line-too-long
|
||||
return model
|
||||
|
||||
|
||||
def mnist_datasets():
|
||||
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
||||
# Numpy defaults to dtype=float64; TF defaults to float32. Stick with float32.
|
||||
x_train, x_test = x_train / np.float32(255), x_test / np.float32(255)
|
||||
y_train, y_test = y_train.astype(np.int64), y_test.astype(np.int64)
|
||||
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
|
||||
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
|
||||
return train_dataset, test_dataset
|
||||
|
||||
|
||||
def loss(logits, labels):
|
||||
return tf.reduce_mean(
|
||||
tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits=logits, labels=labels))
|
||||
|
||||
|
||||
def compute_accuracy(logits, labels):
|
||||
predictions = tf.argmax(logits, axis=1, output_type=tf.int64)
|
||||
labels = tf.cast(labels, tf.int64)
|
||||
return tf.reduce_mean(
|
||||
tf.cast(tf.equal(predictions, labels), dtype=tf.float32))
|
||||
|
||||
|
||||
# TODO(brianklee): Enable @tf.function on the training loop when zip, enumerate
|
||||
# are supported by autograph.
|
||||
def train(model, optimizer, dataset, step_counter, log_interval=None,
|
||||
num_steps=None):
|
||||
"""Trains model on `dataset` using `optimizer`."""
|
||||
start = time.time()
|
||||
for (batch, (images, labels)) in enumerate(dataset):
|
||||
if num_steps is not None and batch > num_steps:
|
||||
break
|
||||
with tf.contrib.summary.record_summaries_every_n_global_steps(
|
||||
10, global_step=step_counter):
|
||||
# Record the operations used to compute the loss given the input,
|
||||
# so that the gradient of the loss with respect to the variables
|
||||
# can be computed.
|
||||
with tf.GradientTape() as tape:
|
||||
logits = model(images, training=True)
|
||||
loss_value = loss(logits, labels)
|
||||
tf.contrib.summary.scalar('loss', loss_value)
|
||||
tf.contrib.summary.scalar('accuracy', compute_accuracy(logits, labels))
|
||||
grads = tape.gradient(loss_value, model.variables)
|
||||
optimizer.apply_gradients(
|
||||
zip(grads, model.variables), global_step=step_counter)
|
||||
if log_interval and batch % log_interval == 0:
|
||||
rate = log_interval / (time.time() - start)
|
||||
print('Step #%d\tLoss: %.6f (%d steps/sec)' % (batch, loss_value, rate))
|
||||
start = time.time()
|
||||
|
||||
|
||||
def test(model, dataset):
|
||||
"""Perform an evaluation of `model` on the examples from `dataset`."""
|
||||
avg_loss = tfe.metrics.Mean('loss', dtype=tf.float32)
|
||||
accuracy = tfe.metrics.Accuracy('accuracy', dtype=tf.float32)
|
||||
|
||||
for (images, labels) in dataset:
|
||||
logits = model(images, training=False)
|
||||
avg_loss(loss(logits, labels))
|
||||
accuracy(
|
||||
tf.argmax(logits, axis=1, output_type=tf.int64),
|
||||
tf.cast(labels, tf.int64))
|
||||
print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' %
|
||||
(avg_loss.result(), 100 * accuracy.result()))
|
||||
with tf.contrib.summary.always_record_summaries():
|
||||
tf.contrib.summary.scalar('loss', avg_loss.result())
|
||||
tf.contrib.summary.scalar('accuracy', accuracy.result())
|
||||
|
||||
|
||||
def train_and_export(flags_obj):
|
||||
"""Run MNIST training and eval loop in eager mode.
|
||||
|
||||
Args:
|
||||
flags_obj: An object containing parsed flag values.
|
||||
"""
|
||||
# Load the datasets
|
||||
train_ds, test_ds = mnist_datasets()
|
||||
train_ds = train_ds.shuffle(60000).batch(flags_obj.batch_size)
|
||||
test_ds = test_ds.batch(flags_obj.batch_size)
|
||||
|
||||
# Create the model and optimizer
|
||||
model = create_model()
|
||||
optimizer = tf.train.MomentumOptimizer(
|
||||
flags_obj.learning_rate, flags_obj.momentum)
|
||||
|
||||
# See summaries with `tensorboard --logdir=<model_dir>`
|
||||
train_dir = os.path.join(flags_obj.model_dir, 'summaries', 'train')
|
||||
test_dir = os.path.join(flags_obj.model_dir, 'summaries', 'eval')
|
||||
summary_writer = tf.contrib.summary.create_file_writer(
|
||||
train_dir, flush_millis=10000)
|
||||
test_summary_writer = tf.contrib.summary.create_file_writer(
|
||||
test_dir, flush_millis=10000, name='test')
|
||||
|
||||
# Create and restore checkpoint (if one exists on the path)
|
||||
checkpoint_dir = os.path.join(flags_obj.model_dir, 'checkpoints')
|
||||
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
|
||||
step_counter = tf.train.get_or_create_global_step()
|
||||
checkpoint = tf.train.Checkpoint(
|
||||
model=model, optimizer=optimizer, step_counter=step_counter)
|
||||
# Restore variables on creation if a checkpoint exists.
|
||||
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))
|
||||
|
||||
# Train and evaluate for a set number of epochs.
|
||||
for _ in range(flags_obj.train_epochs):
|
||||
start = time.time()
|
||||
with summary_writer.as_default():
|
||||
train(model, optimizer, train_ds, step_counter,
|
||||
flags_obj.log_interval, num_steps=1)
|
||||
end = time.time()
|
||||
print('\nTrain time for epoch #%d (%d total steps): %f' %
|
||||
(checkpoint.save_counter.numpy() + 1,
|
||||
step_counter.numpy(),
|
||||
end - start))
|
||||
with test_summary_writer.as_default():
|
||||
test(model, test_ds)
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
|
||||
# TODO(brianklee): Enable this functionality after @allenl implements this.
|
||||
# export_path = os.path.join(flags_obj.model_dir, 'export')
|
||||
# tf.saved_model.save(export_path, model)
|
||||
|
||||
|
||||
def import_and_eval(flags_obj):
|
||||
export_path = os.path.join(flags_obj.model_dir, 'export')
|
||||
model = tf.saved_model.restore(export_path)
|
||||
_, (x_test, y_test) = tf.keras.datasets.mnist.load_data()
|
||||
x_test = x_test / np.float32(255)
|
||||
y_predict = model(x_test)
|
||||
accuracy = compute_accuracy(y_predict, y_test)
|
||||
print('Model accuracy: {:0.2f}%'.format(accuracy.numpy() * 100))
|
||||
|
||||
|
||||
def apply_clean(flags_obj):
|
||||
if flags_obj.clean and tf.gfile.Exists(flags_obj.model_dir):
|
||||
tf.logging.info('--clean flag set. Removing existing model dir: {}'.format(
|
||||
flags_obj.model_dir))
|
||||
tf.gfile.DeleteRecursively(flags_obj.model_dir)
|
||||
|
||||
|
||||
def main(_):
|
||||
apply_clean(flags.FLAGS)
|
||||
train_and_export(flags.FLAGS)
|
||||
# TODO(brianklee): Enable this functionality after @allenl implements this.
|
||||
# import_and_eval(flags.FLAGS)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
Loading…
Reference in New Issue
Block a user