Moved tensorflow/models to models/tutorials and replaced all tutorial references to tensorflow/models
Change: 141503531
This commit is contained in:
parent
2d00e6f17d
commit
e3f8d2a085
tensorflow
BUILD
contrib/cmake
core
models
__init__.py
embedding
image
__init__.py
alexnet
cifar10
BUILDREADME.md__init__.pycifar10.pycifar10_eval.pycifar10_input.pycifar10_input_test.pycifar10_multi_gpu_train.pycifar10_train.py
imagenet
mnist
rnn
python/ops
tools
@ -164,14 +164,6 @@ filegroup(
|
||||
"//tensorflow/java:all_files",
|
||||
"//tensorflow/java/src/main/java/org/tensorflow/examples:all_files",
|
||||
"//tensorflow/java/src/main/native:all_files",
|
||||
"//tensorflow/models/embedding:all_files",
|
||||
"//tensorflow/models/image/alexnet:all_files",
|
||||
"//tensorflow/models/image/cifar10:all_files",
|
||||
"//tensorflow/models/image/imagenet:all_files",
|
||||
"//tensorflow/models/image/mnist:all_files",
|
||||
"//tensorflow/models/rnn:all_files",
|
||||
"//tensorflow/models/rnn/ptb:all_files",
|
||||
"//tensorflow/models/rnn/translate:all_files",
|
||||
"//tensorflow/python:all_files",
|
||||
"//tensorflow/python/debug:all_files",
|
||||
"//tensorflow/python/kernel_tests:all_files",
|
||||
|
@ -2,7 +2,7 @@
|
||||
# tf_models_word2vec_ops library
|
||||
########################################################
|
||||
file(GLOB tf_models_word2vec_ops_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/models/embedding/word2vec_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow_models/tutorials/embedding/word2vec_ops.cc"
|
||||
)
|
||||
|
||||
add_library(tf_models_word2vec_ops OBJECT ${tf_models_word2vec_ops_srcs})
|
||||
@ -13,7 +13,7 @@ add_dependencies(tf_models_word2vec_ops tf_core_framework)
|
||||
# tf_models_word2vec_kernels library
|
||||
########################################################
|
||||
file(GLOB tf_models_word2vec_kernels_srcs
|
||||
"${tensorflow_source_dir}/tensorflow/models/embedding/word2vec_kernels.cc"
|
||||
"${tensorflow_source_dir}/tensorflow_models/tutorials/embedding/word2vec_kernels.cc"
|
||||
)
|
||||
|
||||
add_library(tf_models_word2vec_kernels OBJECT ${tf_models_word2vec_kernels_srcs})
|
||||
|
@ -440,6 +440,15 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "word2vec_ops",
|
||||
srcs = ["ops/word2vec_ops.cc"],
|
||||
linkstatic = 1,
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "ops",
|
||||
visibility = ["//visibility:public"],
|
||||
@ -469,7 +478,7 @@ cc_library(
|
||||
":string_ops_op_lib",
|
||||
":training_ops_op_lib",
|
||||
":user_ops_op_lib",
|
||||
"//tensorflow/models/embedding:word2vec_ops",
|
||||
":word2vec_ops",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -591,7 +600,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels:state",
|
||||
"//tensorflow/core/kernels:string",
|
||||
"//tensorflow/core/kernels:training_ops",
|
||||
"//tensorflow/models/embedding:word2vec_kernels",
|
||||
"//tensorflow/core/kernels:word2vec_kernels",
|
||||
] + if_not_windows([
|
||||
"//tensorflow/core/kernels:fact_op",
|
||||
"//tensorflow/core/kernels:array_not_windows",
|
||||
|
@ -3240,6 +3240,17 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "word2vec_kernels",
|
||||
prefix = "word2vec_kernels",
|
||||
deps = [
|
||||
"//tensorflow/core",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
# Android libraries -----------------------------------------------------------
|
||||
|
||||
# Changes to the Android srcs here should be replicated in
|
||||
|
@ -18,6 +18,9 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
REGISTER_OP("Skipgram")
|
||||
.Deprecated(19,
|
||||
"Moving word2vec into tensorflow_models/tutorials and "
|
||||
"deprecating its ops here as a result")
|
||||
.Output("vocab_word: string")
|
||||
.Output("vocab_freq: int32")
|
||||
.Output("words_per_epoch: int64")
|
||||
@ -51,6 +54,9 @@ subsample: Threshold for word occurrence. Words that appear with higher
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("NegTrain")
|
||||
.Deprecated(19,
|
||||
"Moving word2vec into tensorflow_models/tutorials and "
|
||||
"deprecating its ops here as a result")
|
||||
.Input("w_in: Ref(float)")
|
||||
.Input("w_out: Ref(float)")
|
||||
.Input("examples: int32")
|
@ -76,7 +76,7 @@ limitations under the License.
|
||||
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||
#define TF_GRAPH_DEF_VERSION 18
|
||||
#define TF_GRAPH_DEF_VERSION 19
|
||||
|
||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||
//
|
||||
|
@ -1,123 +0,0 @@
|
||||
# Description:
|
||||
# TensorFlow model for word2vec
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
|
||||
|
||||
py_library(
|
||||
name = "package",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":gen_word2vec",
|
||||
":word2vec",
|
||||
":word2vec_optimized",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "word2vec",
|
||||
srcs = [
|
||||
"word2vec.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_word2vec",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "word2vec_optimized",
|
||||
srcs = [
|
||||
"word2vec_optimized.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gen_word2vec",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:platform",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "word2vec_test",
|
||||
size = "small",
|
||||
srcs = ["word2vec_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"notsan", # b/25864127
|
||||
],
|
||||
deps = [
|
||||
":word2vec",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "word2vec_optimized_test",
|
||||
size = "small",
|
||||
srcs = ["word2vec_optimized_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"notsan",
|
||||
],
|
||||
deps = [
|
||||
":word2vec_optimized",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "word2vec_ops",
|
||||
srcs = [
|
||||
"word2vec_ops.cc",
|
||||
],
|
||||
linkstatic = 1,
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "word2vec_kernels",
|
||||
srcs = [
|
||||
"word2vec_kernels.cc",
|
||||
],
|
||||
linkstatic = 1,
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":word2vec_ops",
|
||||
"//tensorflow/core",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_gen_op_wrapper_py(
|
||||
name = "gen_word2vec",
|
||||
out = "gen_word2vec.py",
|
||||
deps = [":word2vec_ops"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,51 +0,0 @@
|
||||
This directory contains models for unsupervised training of word embeddings
|
||||
using the model described in:
|
||||
|
||||
(Mikolov, et. al.) [Efficient Estimation of Word Representations in Vector Space](http://arxiv.org/abs/1301.3781),
|
||||
ICLR 2013.
|
||||
|
||||
Detailed instructions on how to get started and use them are available in the
|
||||
tutorials. Brief instructions are below.
|
||||
|
||||
* [Word2Vec Tutorial](http://tensorflow.org/tutorials/word2vec/index.md)
|
||||
|
||||
To download the example text and evaluation data:
|
||||
|
||||
```shell
|
||||
wget http://mattmahoney.net/dc/text8.zip -O text8.zip
|
||||
unzip text8.zip
|
||||
wget https://storage.googleapis.com/google-code-archive-source/v2/code.google.com/word2vec/source-archive.zip
|
||||
unzip -p source-archive.zip word2vec/trunk/questions-words.txt > questions-words.txt
|
||||
rm source-archive.zip
|
||||
```
|
||||
|
||||
Assuming you are using the pip package install and have cloned the git
|
||||
repository, navigate into this directory and run using:
|
||||
|
||||
```shell
|
||||
cd tensorflow/models/embedding
|
||||
python word2vec_optimized.py \
|
||||
--train_data=text8 \
|
||||
--eval_data=questions-words.txt \
|
||||
--save_path=/tmp/
|
||||
```
|
||||
|
||||
To run the code from sources using bazel:
|
||||
|
||||
```shell
|
||||
bazel run -c opt tensorflow/models/embedding/word2vec_optimized -- \
|
||||
--train_data=text8 \
|
||||
--eval_data=questions-words.txt \
|
||||
--save_path=/tmp/
|
||||
```
|
||||
|
||||
Here is a short overview of what is in this directory.
|
||||
|
||||
File | What's in it?
|
||||
--- | ---
|
||||
`word2vec.py` | A version of word2vec implemented using TensorFlow ops and minibatching.
|
||||
`word2vec_test.py` | Integration test for word2vec.
|
||||
`word2vec_optimized.py` | A version of word2vec implemented using C ops that does no minibatching.
|
||||
`word2vec_optimized_test.py` | Integration test for word2vec_optimized.
|
||||
`word2vec_kernels.cc` | Kernels for the custom input and training ops.
|
||||
`word2vec_ops.cc` | The declarations of the custom ops.
|
@ -1,21 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Import generated word2vec optimized ops into embedding package."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.models.embedding import gen_word2vec
|
@ -1,534 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Multi-threaded word2vec mini-batched skip-gram model.
|
||||
|
||||
Trains the model described in:
|
||||
(Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space
|
||||
ICLR 2013.
|
||||
http://arxiv.org/abs/1301.3781
|
||||
This model does traditional minibatching.
|
||||
|
||||
The key ops used are:
|
||||
* placeholder for feeding in tensors for each example.
|
||||
* embedding_lookup for fetching rows from the embedding matrix.
|
||||
* sigmoid_cross_entropy_with_logits to calculate the loss.
|
||||
* GradientDescentOptimizer for optimizing the loss.
|
||||
* skipgram custom op that does input processing.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.embedding import gen_word2vec as word2vec
|
||||
|
||||
flags = tf.app.flags
|
||||
|
||||
flags.DEFINE_string("save_path", None, "Directory to write the model and "
|
||||
"training summaries.")
|
||||
flags.DEFINE_string("train_data", None, "Training text file. "
|
||||
"E.g., unzipped file http://mattmahoney.net/dc/text8.zip.")
|
||||
flags.DEFINE_string(
|
||||
"eval_data", None, "File consisting of analogies of four tokens."
|
||||
"embedding 2 - embedding 1 + embedding 3 should be close "
|
||||
"to embedding 4."
|
||||
"See README.md for how to get 'questions-words.txt'.")
|
||||
flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.")
|
||||
flags.DEFINE_integer(
|
||||
"epochs_to_train", 15,
|
||||
"Number of epochs to train. Each epoch processes the training data once "
|
||||
"completely.")
|
||||
flags.DEFINE_float("learning_rate", 0.2, "Initial learning rate.")
|
||||
flags.DEFINE_integer("num_neg_samples", 100,
|
||||
"Negative samples per training example.")
|
||||
flags.DEFINE_integer("batch_size", 16,
|
||||
"Number of training examples processed per step "
|
||||
"(size of a minibatch).")
|
||||
flags.DEFINE_integer("concurrent_steps", 12,
|
||||
"The number of concurrent training steps.")
|
||||
flags.DEFINE_integer("window_size", 5,
|
||||
"The number of words to predict to the left and right "
|
||||
"of the target word.")
|
||||
flags.DEFINE_integer("min_count", 5,
|
||||
"The minimum number of word occurrences for it to be "
|
||||
"included in the vocabulary.")
|
||||
flags.DEFINE_float("subsample", 1e-3,
|
||||
"Subsample threshold for word occurrence. Words that appear "
|
||||
"with higher frequency will be randomly down-sampled. Set "
|
||||
"to 0 to disable.")
|
||||
flags.DEFINE_boolean(
|
||||
"interactive", False,
|
||||
"If true, enters an IPython interactive session to play with the trained "
|
||||
"model. E.g., try model.analogy(b'france', b'paris', b'russia') and "
|
||||
"model.nearby([b'proton', b'elephant', b'maxwell'])")
|
||||
flags.DEFINE_integer("statistics_interval", 5,
|
||||
"Print statistics every n seconds.")
|
||||
flags.DEFINE_integer("summary_interval", 5,
|
||||
"Save training summary to file every n seconds (rounded "
|
||||
"up to statistics interval).")
|
||||
flags.DEFINE_integer("checkpoint_interval", 600,
|
||||
"Checkpoint the model (i.e. save the parameters) every n "
|
||||
"seconds (rounded up to statistics interval).")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class Options(object):
|
||||
"""Options used by our word2vec model."""
|
||||
|
||||
def __init__(self):
|
||||
# Model options.
|
||||
|
||||
# Embedding dimension.
|
||||
self.emb_dim = FLAGS.embedding_size
|
||||
|
||||
# Training options.
|
||||
# The training text file.
|
||||
self.train_data = FLAGS.train_data
|
||||
|
||||
# Number of negative samples per example.
|
||||
self.num_samples = FLAGS.num_neg_samples
|
||||
|
||||
# The initial learning rate.
|
||||
self.learning_rate = FLAGS.learning_rate
|
||||
|
||||
# Number of epochs to train. After these many epochs, the learning
|
||||
# rate decays linearly to zero and the training stops.
|
||||
self.epochs_to_train = FLAGS.epochs_to_train
|
||||
|
||||
# Concurrent training steps.
|
||||
self.concurrent_steps = FLAGS.concurrent_steps
|
||||
|
||||
# Number of examples for one training step.
|
||||
self.batch_size = FLAGS.batch_size
|
||||
|
||||
# The number of words to predict to the left and right of the target word.
|
||||
self.window_size = FLAGS.window_size
|
||||
|
||||
# The minimum number of word occurrences for it to be included in the
|
||||
# vocabulary.
|
||||
self.min_count = FLAGS.min_count
|
||||
|
||||
# Subsampling threshold for word occurrence.
|
||||
self.subsample = FLAGS.subsample
|
||||
|
||||
# How often to print statistics.
|
||||
self.statistics_interval = FLAGS.statistics_interval
|
||||
|
||||
# How often to write to the summary file (rounds up to the nearest
|
||||
# statistics_interval).
|
||||
self.summary_interval = FLAGS.summary_interval
|
||||
|
||||
# How often to write checkpoints (rounds up to the nearest statistics
|
||||
# interval).
|
||||
self.checkpoint_interval = FLAGS.checkpoint_interval
|
||||
|
||||
# Where to write out summaries.
|
||||
self.save_path = FLAGS.save_path
|
||||
if not os.path.exists(self.save_path):
|
||||
os.makedirs(self.save_path)
|
||||
|
||||
# Eval options.
|
||||
# The text file for eval.
|
||||
self.eval_data = FLAGS.eval_data
|
||||
|
||||
|
||||
class Word2Vec(object):
|
||||
"""Word2Vec model (Skipgram)."""
|
||||
|
||||
def __init__(self, options, session):
|
||||
self._options = options
|
||||
self._session = session
|
||||
self._word2id = {}
|
||||
self._id2word = []
|
||||
self.build_graph()
|
||||
self.build_eval_graph()
|
||||
self.save_vocab()
|
||||
|
||||
def read_analogies(self):
|
||||
"""Reads through the analogy question file.
|
||||
|
||||
Returns:
|
||||
questions: a [n, 4] numpy array containing the analogy question's
|
||||
word ids.
|
||||
questions_skipped: questions skipped due to unknown words.
|
||||
"""
|
||||
questions = []
|
||||
questions_skipped = 0
|
||||
with open(self._options.eval_data, "rb") as analogy_f:
|
||||
for line in analogy_f:
|
||||
if line.startswith(b":"): # Skip comments.
|
||||
continue
|
||||
words = line.strip().lower().split(b" ")
|
||||
ids = [self._word2id.get(w.strip()) for w in words]
|
||||
if None in ids or len(ids) != 4:
|
||||
questions_skipped += 1
|
||||
else:
|
||||
questions.append(np.array(ids))
|
||||
print("Eval analogy file: ", self._options.eval_data)
|
||||
print("Questions: ", len(questions))
|
||||
print("Skipped: ", questions_skipped)
|
||||
self._analogy_questions = np.array(questions, dtype=np.int32)
|
||||
|
||||
def forward(self, examples, labels):
|
||||
"""Build the graph for the forward pass."""
|
||||
opts = self._options
|
||||
|
||||
# Declare all variables we need.
|
||||
# Embedding: [vocab_size, emb_dim]
|
||||
init_width = 0.5 / opts.emb_dim
|
||||
emb = tf.Variable(
|
||||
tf.random_uniform(
|
||||
[opts.vocab_size, opts.emb_dim], -init_width, init_width),
|
||||
name="emb")
|
||||
self._emb = emb
|
||||
|
||||
# Softmax weight: [vocab_size, emb_dim]. Transposed.
|
||||
sm_w_t = tf.Variable(
|
||||
tf.zeros([opts.vocab_size, opts.emb_dim]),
|
||||
name="sm_w_t")
|
||||
|
||||
# Softmax bias: [emb_dim].
|
||||
sm_b = tf.Variable(tf.zeros([opts.vocab_size]), name="sm_b")
|
||||
|
||||
# Global step: scalar, i.e., shape [].
|
||||
self.global_step = tf.Variable(0, name="global_step")
|
||||
|
||||
# Nodes to compute the nce loss w/ candidate sampling.
|
||||
labels_matrix = tf.reshape(
|
||||
tf.cast(labels,
|
||||
dtype=tf.int64),
|
||||
[opts.batch_size, 1])
|
||||
|
||||
# Negative sampling.
|
||||
sampled_ids, _, _ = (tf.nn.fixed_unigram_candidate_sampler(
|
||||
true_classes=labels_matrix,
|
||||
num_true=1,
|
||||
num_sampled=opts.num_samples,
|
||||
unique=True,
|
||||
range_max=opts.vocab_size,
|
||||
distortion=0.75,
|
||||
unigrams=opts.vocab_counts.tolist()))
|
||||
|
||||
# Embeddings for examples: [batch_size, emb_dim]
|
||||
example_emb = tf.nn.embedding_lookup(emb, examples)
|
||||
|
||||
# Weights for labels: [batch_size, emb_dim]
|
||||
true_w = tf.nn.embedding_lookup(sm_w_t, labels)
|
||||
# Biases for labels: [batch_size, 1]
|
||||
true_b = tf.nn.embedding_lookup(sm_b, labels)
|
||||
|
||||
# Weights for sampled ids: [num_sampled, emb_dim]
|
||||
sampled_w = tf.nn.embedding_lookup(sm_w_t, sampled_ids)
|
||||
# Biases for sampled ids: [num_sampled, 1]
|
||||
sampled_b = tf.nn.embedding_lookup(sm_b, sampled_ids)
|
||||
|
||||
# True logits: [batch_size, 1]
|
||||
true_logits = tf.reduce_sum(tf.mul(example_emb, true_w), 1) + true_b
|
||||
|
||||
# Sampled logits: [batch_size, num_sampled]
|
||||
# We replicate sampled noise labels for all examples in the batch
|
||||
# using the matmul.
|
||||
sampled_b_vec = tf.reshape(sampled_b, [opts.num_samples])
|
||||
sampled_logits = tf.matmul(example_emb,
|
||||
sampled_w,
|
||||
transpose_b=True) + sampled_b_vec
|
||||
return true_logits, sampled_logits
|
||||
|
||||
def nce_loss(self, true_logits, sampled_logits):
|
||||
"""Build the graph for the NCE loss."""
|
||||
|
||||
# cross-entropy(logits, labels)
|
||||
opts = self._options
|
||||
true_xent = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
true_logits, tf.ones_like(true_logits))
|
||||
sampled_xent = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
sampled_logits, tf.zeros_like(sampled_logits))
|
||||
|
||||
# NCE-loss is the sum of the true and noise (sampled words)
|
||||
# contributions, averaged over the batch.
|
||||
nce_loss_tensor = (tf.reduce_sum(true_xent) +
|
||||
tf.reduce_sum(sampled_xent)) / opts.batch_size
|
||||
return nce_loss_tensor
|
||||
|
||||
def optimize(self, loss):
|
||||
"""Build the graph to optimize the loss function."""
|
||||
|
||||
# Optimizer nodes.
|
||||
# Linear learning rate decay.
|
||||
opts = self._options
|
||||
words_to_train = float(opts.words_per_epoch * opts.epochs_to_train)
|
||||
lr = opts.learning_rate * tf.maximum(
|
||||
0.0001, 1.0 - tf.cast(self._words, tf.float32) / words_to_train)
|
||||
self._lr = lr
|
||||
optimizer = tf.train.GradientDescentOptimizer(lr)
|
||||
train = optimizer.minimize(loss,
|
||||
global_step=self.global_step,
|
||||
gate_gradients=optimizer.GATE_NONE)
|
||||
self._train = train
|
||||
|
||||
def build_eval_graph(self):
|
||||
"""Build the eval graph."""
|
||||
# Eval graph
|
||||
|
||||
# Each analogy task is to predict the 4th word (d) given three
|
||||
# words: a, b, c. E.g., a=italy, b=rome, c=france, we should
|
||||
# predict d=paris.
|
||||
|
||||
# The eval feeds three vectors of word ids for a, b, c, each of
|
||||
# which is of size N, where N is the number of analogies we want to
|
||||
# evaluate in one batch.
|
||||
analogy_a = tf.placeholder(dtype=tf.int32) # [N]
|
||||
analogy_b = tf.placeholder(dtype=tf.int32) # [N]
|
||||
analogy_c = tf.placeholder(dtype=tf.int32) # [N]
|
||||
|
||||
# Normalized word embeddings of shape [vocab_size, emb_dim].
|
||||
nemb = tf.nn.l2_normalize(self._emb, 1)
|
||||
|
||||
# Each row of a_emb, b_emb, c_emb is a word's embedding vector.
|
||||
# They all have the shape [N, emb_dim]
|
||||
a_emb = tf.gather(nemb, analogy_a) # a's embs
|
||||
b_emb = tf.gather(nemb, analogy_b) # b's embs
|
||||
c_emb = tf.gather(nemb, analogy_c) # c's embs
|
||||
|
||||
# We expect that d's embedding vectors on the unit hyper-sphere is
|
||||
# near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim].
|
||||
target = c_emb + (b_emb - a_emb)
|
||||
|
||||
# Compute cosine distance between each pair of target and vocab.
|
||||
# dist has shape [N, vocab_size].
|
||||
dist = tf.matmul(target, nemb, transpose_b=True)
|
||||
|
||||
# For each question (row in dist), find the top 4 words.
|
||||
_, pred_idx = tf.nn.top_k(dist, 4)
|
||||
|
||||
# Nodes for computing neighbors for a given word according to
|
||||
# their cosine distance.
|
||||
nearby_word = tf.placeholder(dtype=tf.int32) # word id
|
||||
nearby_emb = tf.gather(nemb, nearby_word)
|
||||
nearby_dist = tf.matmul(nearby_emb, nemb, transpose_b=True)
|
||||
nearby_val, nearby_idx = tf.nn.top_k(nearby_dist,
|
||||
min(1000, self._options.vocab_size))
|
||||
|
||||
# Nodes in the construct graph which are used by training and
|
||||
# evaluation to run/feed/fetch.
|
||||
self._analogy_a = analogy_a
|
||||
self._analogy_b = analogy_b
|
||||
self._analogy_c = analogy_c
|
||||
self._analogy_pred_idx = pred_idx
|
||||
self._nearby_word = nearby_word
|
||||
self._nearby_val = nearby_val
|
||||
self._nearby_idx = nearby_idx
|
||||
|
||||
def build_graph(self):
|
||||
"""Build the graph for the full model."""
|
||||
opts = self._options
|
||||
# The training data. A text file.
|
||||
(words, counts, words_per_epoch, self._epoch, self._words, examples,
|
||||
labels) = word2vec.skipgram(filename=opts.train_data,
|
||||
batch_size=opts.batch_size,
|
||||
window_size=opts.window_size,
|
||||
min_count=opts.min_count,
|
||||
subsample=opts.subsample)
|
||||
(opts.vocab_words, opts.vocab_counts,
|
||||
opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch])
|
||||
opts.vocab_size = len(opts.vocab_words)
|
||||
print("Data file: ", opts.train_data)
|
||||
print("Vocab size: ", opts.vocab_size - 1, " + UNK")
|
||||
print("Words per epoch: ", opts.words_per_epoch)
|
||||
self._examples = examples
|
||||
self._labels = labels
|
||||
self._id2word = opts.vocab_words
|
||||
for i, w in enumerate(self._id2word):
|
||||
self._word2id[w] = i
|
||||
true_logits, sampled_logits = self.forward(examples, labels)
|
||||
loss = self.nce_loss(true_logits, sampled_logits)
|
||||
tf.contrib.deprecated.scalar_summary("NCE loss", loss)
|
||||
self._loss = loss
|
||||
self.optimize(loss)
|
||||
|
||||
# Properly initialize all variables.
|
||||
tf.global_variables_initializer().run()
|
||||
|
||||
self.saver = tf.train.Saver()
|
||||
|
||||
def save_vocab(self):
|
||||
"""Save the vocabulary to a file so the model can be reloaded."""
|
||||
opts = self._options
|
||||
with open(os.path.join(opts.save_path, "vocab.txt"), "w") as f:
|
||||
for i in xrange(opts.vocab_size):
|
||||
vocab_word = tf.compat.as_text(opts.vocab_words[i]).encode("utf-8")
|
||||
f.write("%s %d\n" % (vocab_word,
|
||||
opts.vocab_counts[i]))
|
||||
|
||||
def _train_thread_body(self):
|
||||
initial_epoch, = self._session.run([self._epoch])
|
||||
while True:
|
||||
_, epoch = self._session.run([self._train, self._epoch])
|
||||
if epoch != initial_epoch:
|
||||
break
|
||||
|
||||
def train(self):
|
||||
"""Train the model."""
|
||||
opts = self._options
|
||||
|
||||
initial_epoch, initial_words = self._session.run([self._epoch, self._words])
|
||||
|
||||
summary_op = tf.summary.merge_all()
|
||||
summary_writer = tf.summary.FileWriter(opts.save_path, self._session.graph)
|
||||
workers = []
|
||||
for _ in xrange(opts.concurrent_steps):
|
||||
t = threading.Thread(target=self._train_thread_body)
|
||||
t.start()
|
||||
workers.append(t)
|
||||
|
||||
last_words, last_time, last_summary_time = initial_words, time.time(), 0
|
||||
last_checkpoint_time = 0
|
||||
while True:
|
||||
time.sleep(opts.statistics_interval) # Reports our progress once a while.
|
||||
(epoch, step, loss, words, lr) = self._session.run(
|
||||
[self._epoch, self.global_step, self._loss, self._words, self._lr])
|
||||
now = time.time()
|
||||
last_words, last_time, rate = words, now, (words - last_words) / (
|
||||
now - last_time)
|
||||
print("Epoch %4d Step %8d: lr = %5.3f loss = %6.2f words/sec = %8.0f\r" %
|
||||
(epoch, step, lr, loss, rate), end="")
|
||||
sys.stdout.flush()
|
||||
if now - last_summary_time > opts.summary_interval:
|
||||
summary_str = self._session.run(summary_op)
|
||||
summary_writer.add_summary(summary_str, step)
|
||||
last_summary_time = now
|
||||
if now - last_checkpoint_time > opts.checkpoint_interval:
|
||||
self.saver.save(self._session,
|
||||
os.path.join(opts.save_path, "model.ckpt"),
|
||||
global_step=step.astype(int))
|
||||
last_checkpoint_time = now
|
||||
if epoch != initial_epoch:
|
||||
break
|
||||
|
||||
for t in workers:
|
||||
t.join()
|
||||
|
||||
return epoch
|
||||
|
||||
def _predict(self, analogy):
|
||||
"""Predict the top 4 answers for analogy questions."""
|
||||
idx, = self._session.run([self._analogy_pred_idx], {
|
||||
self._analogy_a: analogy[:, 0],
|
||||
self._analogy_b: analogy[:, 1],
|
||||
self._analogy_c: analogy[:, 2]
|
||||
})
|
||||
return idx
|
||||
|
||||
def eval(self):
|
||||
"""Evaluate analogy questions and reports accuracy."""
|
||||
|
||||
# How many questions we get right at precision@1.
|
||||
correct = 0
|
||||
|
||||
try:
|
||||
total = self._analogy_questions.shape[0]
|
||||
except AttributeError as e:
|
||||
raise AttributeError("Need to read analogy questions.")
|
||||
|
||||
start = 0
|
||||
while start < total:
|
||||
limit = start + 2500
|
||||
sub = self._analogy_questions[start:limit, :]
|
||||
idx = self._predict(sub)
|
||||
start = limit
|
||||
for question in xrange(sub.shape[0]):
|
||||
for j in xrange(4):
|
||||
if idx[question, j] == sub[question, 3]:
|
||||
# Bingo! We predicted correctly. E.g., [italy, rome, france, paris].
|
||||
correct += 1
|
||||
break
|
||||
elif idx[question, j] in sub[question, :3]:
|
||||
# We need to skip words already in the question.
|
||||
continue
|
||||
else:
|
||||
# The correct label is not the precision@1
|
||||
break
|
||||
print()
|
||||
print("Eval %4d/%d accuracy = %4.1f%%" % (correct, total,
|
||||
correct * 100.0 / total))
|
||||
|
||||
def analogy(self, w0, w1, w2):
|
||||
"""Predict word w3 as in w0:w1 vs w2:w3."""
|
||||
wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]])
|
||||
idx = self._predict(wid)
|
||||
for c in [self._id2word[i] for i in idx[0, :]]:
|
||||
if c not in [w0, w1, w2]:
|
||||
print(c)
|
||||
break
|
||||
print("unknown")
|
||||
|
||||
def nearby(self, words, num=20):
|
||||
"""Prints out nearby words given a list of words."""
|
||||
ids = np.array([self._word2id.get(x, 0) for x in words])
|
||||
vals, idx = self._session.run(
|
||||
[self._nearby_val, self._nearby_idx], {self._nearby_word: ids})
|
||||
for i in xrange(len(words)):
|
||||
print("\n%s\n=====================================" % (words[i]))
|
||||
for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]):
|
||||
print("%-20s %6.4f" % (self._id2word[neighbor], distance))
|
||||
|
||||
|
||||
def _start_shell(local_ns=None):
|
||||
# An interactive shell is useful for debugging/development.
|
||||
import IPython
|
||||
user_ns = {}
|
||||
if local_ns:
|
||||
user_ns.update(local_ns)
|
||||
user_ns.update(globals())
|
||||
IPython.start_ipython(argv=[], user_ns=user_ns)
|
||||
|
||||
|
||||
def main(_):
|
||||
"""Train a word2vec model."""
|
||||
if not FLAGS.train_data or not FLAGS.eval_data or not FLAGS.save_path:
|
||||
print("--train_data --eval_data and --save_path must be specified.")
|
||||
sys.exit(1)
|
||||
opts = Options()
|
||||
with tf.Graph().as_default(), tf.Session() as session:
|
||||
with tf.device("/cpu:0"):
|
||||
model = Word2Vec(opts, session)
|
||||
model.read_analogies() # Read analogy questions
|
||||
for _ in xrange(opts.epochs_to_train):
|
||||
model.train() # Process one epoch
|
||||
model.eval() # Eval analogies.
|
||||
# Perform a final save.
|
||||
model.saver.save(session,
|
||||
os.path.join(opts.save_path, "model.ckpt"),
|
||||
global_step=model.global_step)
|
||||
if FLAGS.interactive:
|
||||
# E.g.,
|
||||
# [0]: model.analogy(b'france', b'paris', b'russia')
|
||||
# [1]: model.nearby([b'proton', b'elephant', b'maxwell'])
|
||||
_start_shell(locals())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
@ -1,439 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Multi-threaded word2vec unbatched skip-gram model.
|
||||
|
||||
Trains the model described in:
|
||||
(Mikolov, et. al.) Efficient Estimation of Word Representations in Vector Space
|
||||
ICLR 2013.
|
||||
http://arxiv.org/abs/1301.3781
|
||||
This model does true SGD (i.e. no minibatching). To do this efficiently, custom
|
||||
ops are used to sequentially process data within a 'batch'.
|
||||
|
||||
The key ops used are:
|
||||
* skipgram custom op that does input processing.
|
||||
* neg_train custom op that efficiently calculates and applies the gradient using
|
||||
true SGD.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.embedding import gen_word2vec as word2vec
|
||||
|
||||
flags = tf.app.flags
|
||||
|
||||
flags.DEFINE_string("save_path", None, "Directory to write the model.")
|
||||
flags.DEFINE_string(
|
||||
"train_data", None,
|
||||
"Training data. E.g., unzipped file http://mattmahoney.net/dc/text8.zip.")
|
||||
flags.DEFINE_string(
|
||||
"eval_data", None, "Analogy questions. "
|
||||
"See README.md for how to get 'questions-words.txt'.")
|
||||
flags.DEFINE_integer("embedding_size", 200, "The embedding dimension size.")
|
||||
flags.DEFINE_integer(
|
||||
"epochs_to_train", 15,
|
||||
"Number of epochs to train. Each epoch processes the training data once "
|
||||
"completely.")
|
||||
flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.")
|
||||
flags.DEFINE_integer("num_neg_samples", 25,
|
||||
"Negative samples per training example.")
|
||||
flags.DEFINE_integer("batch_size", 500,
|
||||
"Numbers of training examples each step processes "
|
||||
"(no minibatching).")
|
||||
flags.DEFINE_integer("concurrent_steps", 12,
|
||||
"The number of concurrent training steps.")
|
||||
flags.DEFINE_integer("window_size", 5,
|
||||
"The number of words to predict to the left and right "
|
||||
"of the target word.")
|
||||
flags.DEFINE_integer("min_count", 5,
|
||||
"The minimum number of word occurrences for it to be "
|
||||
"included in the vocabulary.")
|
||||
flags.DEFINE_float("subsample", 1e-3,
|
||||
"Subsample threshold for word occurrence. Words that appear "
|
||||
"with higher frequency will be randomly down-sampled. Set "
|
||||
"to 0 to disable.")
|
||||
flags.DEFINE_boolean(
|
||||
"interactive", False,
|
||||
"If true, enters an IPython interactive session to play with the trained "
|
||||
"model. E.g., try model.analogy(b'france', b'paris', b'russia') and "
|
||||
"model.nearby([b'proton', b'elephant', b'maxwell'])")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class Options(object):
|
||||
"""Options used by our word2vec model."""
|
||||
|
||||
def __init__(self):
|
||||
# Model options.
|
||||
|
||||
# Embedding dimension.
|
||||
self.emb_dim = FLAGS.embedding_size
|
||||
|
||||
# Training options.
|
||||
|
||||
# The training text file.
|
||||
self.train_data = FLAGS.train_data
|
||||
|
||||
# Number of negative samples per example.
|
||||
self.num_samples = FLAGS.num_neg_samples
|
||||
|
||||
# The initial learning rate.
|
||||
self.learning_rate = FLAGS.learning_rate
|
||||
|
||||
# Number of epochs to train. After these many epochs, the learning
|
||||
# rate decays linearly to zero and the training stops.
|
||||
self.epochs_to_train = FLAGS.epochs_to_train
|
||||
|
||||
# Concurrent training steps.
|
||||
self.concurrent_steps = FLAGS.concurrent_steps
|
||||
|
||||
# Number of examples for one training step.
|
||||
self.batch_size = FLAGS.batch_size
|
||||
|
||||
# The number of words to predict to the left and right of the target word.
|
||||
self.window_size = FLAGS.window_size
|
||||
|
||||
# The minimum number of word occurrences for it to be included in the
|
||||
# vocabulary.
|
||||
self.min_count = FLAGS.min_count
|
||||
|
||||
# Subsampling threshold for word occurrence.
|
||||
self.subsample = FLAGS.subsample
|
||||
|
||||
# Where to write out summaries.
|
||||
self.save_path = FLAGS.save_path
|
||||
if not os.path.exists(self.save_path):
|
||||
os.makedirs(self.save_path)
|
||||
|
||||
# Eval options.
|
||||
|
||||
# The text file for eval.
|
||||
self.eval_data = FLAGS.eval_data
|
||||
|
||||
|
||||
class Word2Vec(object):
|
||||
"""Word2Vec model (Skipgram)."""
|
||||
|
||||
def __init__(self, options, session):
|
||||
self._options = options
|
||||
self._session = session
|
||||
self._word2id = {}
|
||||
self._id2word = []
|
||||
self.build_graph()
|
||||
self.build_eval_graph()
|
||||
self.save_vocab()
|
||||
|
||||
def read_analogies(self):
|
||||
"""Reads through the analogy question file.
|
||||
|
||||
Returns:
|
||||
questions: a [n, 4] numpy array containing the analogy question's
|
||||
word ids.
|
||||
questions_skipped: questions skipped due to unknown words.
|
||||
"""
|
||||
questions = []
|
||||
questions_skipped = 0
|
||||
with open(self._options.eval_data, "rb") as analogy_f:
|
||||
for line in analogy_f:
|
||||
if line.startswith(b":"): # Skip comments.
|
||||
continue
|
||||
words = line.strip().lower().split(b" ")
|
||||
ids = [self._word2id.get(w.strip()) for w in words]
|
||||
if None in ids or len(ids) != 4:
|
||||
questions_skipped += 1
|
||||
else:
|
||||
questions.append(np.array(ids))
|
||||
print("Eval analogy file: ", self._options.eval_data)
|
||||
print("Questions: ", len(questions))
|
||||
print("Skipped: ", questions_skipped)
|
||||
self._analogy_questions = np.array(questions, dtype=np.int32)
|
||||
|
||||
def build_graph(self):
|
||||
"""Build the model graph."""
|
||||
opts = self._options
|
||||
|
||||
# The training data. A text file.
|
||||
(words, counts, words_per_epoch, current_epoch, total_words_processed,
|
||||
examples, labels) = word2vec.skipgram(filename=opts.train_data,
|
||||
batch_size=opts.batch_size,
|
||||
window_size=opts.window_size,
|
||||
min_count=opts.min_count,
|
||||
subsample=opts.subsample)
|
||||
(opts.vocab_words, opts.vocab_counts,
|
||||
opts.words_per_epoch) = self._session.run([words, counts, words_per_epoch])
|
||||
opts.vocab_size = len(opts.vocab_words)
|
||||
print("Data file: ", opts.train_data)
|
||||
print("Vocab size: ", opts.vocab_size - 1, " + UNK")
|
||||
print("Words per epoch: ", opts.words_per_epoch)
|
||||
|
||||
self._id2word = opts.vocab_words
|
||||
for i, w in enumerate(self._id2word):
|
||||
self._word2id[w] = i
|
||||
|
||||
# Declare all variables we need.
|
||||
# Input words embedding: [vocab_size, emb_dim]
|
||||
w_in = tf.Variable(
|
||||
tf.random_uniform(
|
||||
[opts.vocab_size,
|
||||
opts.emb_dim], -0.5 / opts.emb_dim, 0.5 / opts.emb_dim),
|
||||
name="w_in")
|
||||
|
||||
# Global step: scalar, i.e., shape [].
|
||||
w_out = tf.Variable(tf.zeros([opts.vocab_size, opts.emb_dim]), name="w_out")
|
||||
|
||||
# Global step: []
|
||||
global_step = tf.Variable(0, name="global_step")
|
||||
|
||||
# Linear learning rate decay.
|
||||
words_to_train = float(opts.words_per_epoch * opts.epochs_to_train)
|
||||
lr = opts.learning_rate * tf.maximum(
|
||||
0.0001,
|
||||
1.0 - tf.cast(total_words_processed, tf.float32) / words_to_train)
|
||||
|
||||
# Training nodes.
|
||||
inc = global_step.assign_add(1)
|
||||
with tf.control_dependencies([inc]):
|
||||
train = word2vec.neg_train(w_in,
|
||||
w_out,
|
||||
examples,
|
||||
labels,
|
||||
lr,
|
||||
vocab_count=opts.vocab_counts.tolist(),
|
||||
num_negative_samples=opts.num_samples)
|
||||
|
||||
self._w_in = w_in
|
||||
self._examples = examples
|
||||
self._labels = labels
|
||||
self._lr = lr
|
||||
self._train = train
|
||||
self.global_step = global_step
|
||||
self._epoch = current_epoch
|
||||
self._words = total_words_processed
|
||||
|
||||
def save_vocab(self):
|
||||
"""Save the vocabulary to a file so the model can be reloaded."""
|
||||
opts = self._options
|
||||
with open(os.path.join(opts.save_path, "vocab.txt"), "w") as f:
|
||||
for i in xrange(opts.vocab_size):
|
||||
vocab_word = tf.compat.as_text(opts.vocab_words[i]).encode("utf-8")
|
||||
f.write("%s %d\n" % (vocab_word,
|
||||
opts.vocab_counts[i]))
|
||||
|
||||
def build_eval_graph(self):
|
||||
"""Build the evaluation graph."""
|
||||
# Eval graph
|
||||
opts = self._options
|
||||
|
||||
# Each analogy task is to predict the 4th word (d) given three
|
||||
# words: a, b, c. E.g., a=italy, b=rome, c=france, we should
|
||||
# predict d=paris.
|
||||
|
||||
# The eval feeds three vectors of word ids for a, b, c, each of
|
||||
# which is of size N, where N is the number of analogies we want to
|
||||
# evaluate in one batch.
|
||||
analogy_a = tf.placeholder(dtype=tf.int32) # [N]
|
||||
analogy_b = tf.placeholder(dtype=tf.int32) # [N]
|
||||
analogy_c = tf.placeholder(dtype=tf.int32) # [N]
|
||||
|
||||
# Normalized word embeddings of shape [vocab_size, emb_dim].
|
||||
nemb = tf.nn.l2_normalize(self._w_in, 1)
|
||||
|
||||
# Each row of a_emb, b_emb, c_emb is a word's embedding vector.
|
||||
# They all have the shape [N, emb_dim]
|
||||
a_emb = tf.gather(nemb, analogy_a) # a's embs
|
||||
b_emb = tf.gather(nemb, analogy_b) # b's embs
|
||||
c_emb = tf.gather(nemb, analogy_c) # c's embs
|
||||
|
||||
# We expect that d's embedding vectors on the unit hyper-sphere is
|
||||
# near: c_emb + (b_emb - a_emb), which has the shape [N, emb_dim].
|
||||
target = c_emb + (b_emb - a_emb)
|
||||
|
||||
# Compute cosine distance between each pair of target and vocab.
|
||||
# dist has shape [N, vocab_size].
|
||||
dist = tf.matmul(target, nemb, transpose_b=True)
|
||||
|
||||
# For each question (row in dist), find the top 4 words.
|
||||
_, pred_idx = tf.nn.top_k(dist, 4)
|
||||
|
||||
# Nodes for computing neighbors for a given word according to
|
||||
# their cosine distance.
|
||||
nearby_word = tf.placeholder(dtype=tf.int32) # word id
|
||||
nearby_emb = tf.gather(nemb, nearby_word)
|
||||
nearby_dist = tf.matmul(nearby_emb, nemb, transpose_b=True)
|
||||
nearby_val, nearby_idx = tf.nn.top_k(nearby_dist,
|
||||
min(1000, opts.vocab_size))
|
||||
|
||||
# Nodes in the construct graph which are used by training and
|
||||
# evaluation to run/feed/fetch.
|
||||
self._analogy_a = analogy_a
|
||||
self._analogy_b = analogy_b
|
||||
self._analogy_c = analogy_c
|
||||
self._analogy_pred_idx = pred_idx
|
||||
self._nearby_word = nearby_word
|
||||
self._nearby_val = nearby_val
|
||||
self._nearby_idx = nearby_idx
|
||||
|
||||
# Properly initialize all variables.
|
||||
tf.global_variables_initializer().run()
|
||||
|
||||
self.saver = tf.train.Saver()
|
||||
|
||||
def _train_thread_body(self):
|
||||
initial_epoch, = self._session.run([self._epoch])
|
||||
while True:
|
||||
_, epoch = self._session.run([self._train, self._epoch])
|
||||
if epoch != initial_epoch:
|
||||
break
|
||||
|
||||
def train(self):
|
||||
"""Train the model."""
|
||||
opts = self._options
|
||||
|
||||
initial_epoch, initial_words = self._session.run([self._epoch, self._words])
|
||||
|
||||
workers = []
|
||||
for _ in xrange(opts.concurrent_steps):
|
||||
t = threading.Thread(target=self._train_thread_body)
|
||||
t.start()
|
||||
workers.append(t)
|
||||
|
||||
last_words, last_time = initial_words, time.time()
|
||||
while True:
|
||||
time.sleep(5) # Reports our progress once a while.
|
||||
(epoch, step, words, lr) = self._session.run(
|
||||
[self._epoch, self.global_step, self._words, self._lr])
|
||||
now = time.time()
|
||||
last_words, last_time, rate = words, now, (words - last_words) / (
|
||||
now - last_time)
|
||||
print("Epoch %4d Step %8d: lr = %5.3f words/sec = %8.0f\r" % (epoch, step,
|
||||
lr, rate),
|
||||
end="")
|
||||
sys.stdout.flush()
|
||||
if epoch != initial_epoch:
|
||||
break
|
||||
|
||||
for t in workers:
|
||||
t.join()
|
||||
|
||||
def _predict(self, analogy):
|
||||
"""Predict the top 4 answers for analogy questions."""
|
||||
idx, = self._session.run([self._analogy_pred_idx], {
|
||||
self._analogy_a: analogy[:, 0],
|
||||
self._analogy_b: analogy[:, 1],
|
||||
self._analogy_c: analogy[:, 2]
|
||||
})
|
||||
return idx
|
||||
|
||||
def eval(self):
|
||||
"""Evaluate analogy questions and reports accuracy."""
|
||||
|
||||
# How many questions we get right at precision@1.
|
||||
correct = 0
|
||||
|
||||
try:
|
||||
total = self._analogy_questions.shape[0]
|
||||
except AttributeError as e:
|
||||
raise AttributeError("Need to read analogy questions.")
|
||||
|
||||
start = 0
|
||||
while start < total:
|
||||
limit = start + 2500
|
||||
sub = self._analogy_questions[start:limit, :]
|
||||
idx = self._predict(sub)
|
||||
start = limit
|
||||
for question in xrange(sub.shape[0]):
|
||||
for j in xrange(4):
|
||||
if idx[question, j] == sub[question, 3]:
|
||||
# Bingo! We predicted correctly. E.g., [italy, rome, france, paris].
|
||||
correct += 1
|
||||
break
|
||||
elif idx[question, j] in sub[question, :3]:
|
||||
# We need to skip words already in the question.
|
||||
continue
|
||||
else:
|
||||
# The correct label is not the precision@1
|
||||
break
|
||||
print()
|
||||
print("Eval %4d/%d accuracy = %4.1f%%" % (correct, total,
|
||||
correct * 100.0 / total))
|
||||
|
||||
def analogy(self, w0, w1, w2):
|
||||
"""Predict word w3 as in w0:w1 vs w2:w3."""
|
||||
wid = np.array([[self._word2id.get(w, 0) for w in [w0, w1, w2]]])
|
||||
idx = self._predict(wid)
|
||||
for c in [self._id2word[i] for i in idx[0, :]]:
|
||||
if c not in [w0, w1, w2]:
|
||||
print(c)
|
||||
break
|
||||
print("unknown")
|
||||
|
||||
def nearby(self, words, num=20):
|
||||
"""Prints out nearby words given a list of words."""
|
||||
ids = np.array([self._word2id.get(x, 0) for x in words])
|
||||
vals, idx = self._session.run(
|
||||
[self._nearby_val, self._nearby_idx], {self._nearby_word: ids})
|
||||
for i in xrange(len(words)):
|
||||
print("\n%s\n=====================================" % (words[i]))
|
||||
for (neighbor, distance) in zip(idx[i, :num], vals[i, :num]):
|
||||
print("%-20s %6.4f" % (self._id2word[neighbor], distance))
|
||||
|
||||
|
||||
def _start_shell(local_ns=None):
|
||||
# An interactive shell is useful for debugging/development.
|
||||
import IPython
|
||||
user_ns = {}
|
||||
if local_ns:
|
||||
user_ns.update(local_ns)
|
||||
user_ns.update(globals())
|
||||
IPython.start_ipython(argv=[], user_ns=user_ns)
|
||||
|
||||
|
||||
def main(_):
|
||||
"""Train a word2vec model."""
|
||||
if not FLAGS.train_data or not FLAGS.eval_data or not FLAGS.save_path:
|
||||
print("--train_data --eval_data and --save_path must be specified.")
|
||||
sys.exit(1)
|
||||
opts = Options()
|
||||
with tf.Graph().as_default(), tf.Session() as session:
|
||||
with tf.device("/cpu:0"):
|
||||
model = Word2Vec(opts, session)
|
||||
model.read_analogies() # Read analogy questions
|
||||
for _ in xrange(opts.epochs_to_train):
|
||||
model.train() # Process one epoch
|
||||
model.eval() # Eval analogies.
|
||||
# Perform a final save.
|
||||
model.saver.save(session, os.path.join(opts.save_path, "model.ckpt"),
|
||||
global_step=model.global_step)
|
||||
if FLAGS.interactive:
|
||||
# E.g.,
|
||||
# [0]: model.analogy(b'france', b'paris', b'russia')
|
||||
# [1]: model.nearby([b'proton', b'elephant', b'maxwell'])
|
||||
_start_shell(locals())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
@ -1,62 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for word2vec_optimized module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.embedding import word2vec_optimized
|
||||
|
||||
flags = tf.app.flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class Word2VecTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
FLAGS.train_data = os.path.join(self.get_temp_dir() + "test-text.txt")
|
||||
FLAGS.eval_data = os.path.join(self.get_temp_dir() + "eval-text.txt")
|
||||
FLAGS.save_path = self.get_temp_dir()
|
||||
with open(FLAGS.train_data, "w") as f:
|
||||
f.write(
|
||||
"""alice was beginning to get very tired of sitting by her sister on
|
||||
the bank, and of having nothing to do: once or twice she had peeped
|
||||
into the book her sister was reading, but it had no pictures or
|
||||
conversations in it, 'and what is the use of a book,' thought alice
|
||||
'without pictures or conversations?' So she was considering in her own
|
||||
mind (as well as she could, for the hot day made her feel very sleepy
|
||||
and stupid), whether the pleasure of making a daisy-chain would be
|
||||
worth the trouble of getting up and picking the daisies, when suddenly
|
||||
a White rabbit with pink eyes ran close by her.\n""")
|
||||
with open(FLAGS.eval_data, "w") as f:
|
||||
f.write("alice she rabbit once\n")
|
||||
|
||||
def testWord2VecOptimized(self):
|
||||
FLAGS.batch_size = 5
|
||||
FLAGS.num_neg_samples = 10
|
||||
FLAGS.epochs_to_train = 1
|
||||
FLAGS.min_count = 0
|
||||
word2vec_optimized.main([])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -1,62 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for word2vec module."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.embedding import word2vec
|
||||
|
||||
flags = tf.app.flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class Word2VecTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
FLAGS.train_data = os.path.join(self.get_temp_dir(), "test-text.txt")
|
||||
FLAGS.eval_data = os.path.join(self.get_temp_dir(), "eval-text.txt")
|
||||
FLAGS.save_path = self.get_temp_dir()
|
||||
with open(FLAGS.train_data, "w") as f:
|
||||
f.write(
|
||||
"""alice was beginning to get very tired of sitting by her sister on
|
||||
the bank, and of having nothing to do: once or twice she had peeped
|
||||
into the book her sister was reading, but it had no pictures or
|
||||
conversations in it, 'and what is the use of a book,' thought alice
|
||||
'without pictures or conversations?' So she was considering in her own
|
||||
mind (as well as she could, for the hot day made her feel very sleepy
|
||||
and stupid), whether the pleasure of making a daisy-chain would be
|
||||
worth the trouble of getting up and picking the daisies, when suddenly
|
||||
a White rabbit with pink eyes ran close by her.\n""")
|
||||
with open(FLAGS.eval_data, "w") as f:
|
||||
f.write("alice she rabbit once\n")
|
||||
|
||||
def testWord2Vec(self):
|
||||
FLAGS.batch_size = 5
|
||||
FLAGS.num_neg_samples = 10
|
||||
FLAGS.epochs_to_train = 1
|
||||
FLAGS.min_count = 0
|
||||
word2vec.main([])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -1,29 +0,0 @@
|
||||
# Description:
|
||||
# Benchmark for AlexNet.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_binary(
|
||||
name = "alexnet_benchmark",
|
||||
srcs = [
|
||||
"alexnet_benchmark.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,246 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Timing benchmark for AlexNet inference.
|
||||
|
||||
To run, use:
|
||||
bazel run -c opt --config=cuda \
|
||||
third_party/tensorflow/models/image/alexnet:alexnet_benchmark
|
||||
|
||||
Across 100 steps on batch size = 128.
|
||||
|
||||
Forward pass:
|
||||
Run on Tesla K40c: 145 +/- 1.5 ms / batch
|
||||
Run on Titan X: 70 +/- 0.1 ms / batch
|
||||
|
||||
Forward-backward pass:
|
||||
Run on Tesla K40c: 480 +/- 48 ms / batch
|
||||
Run on Titan X: 244 +/- 30 ms / batch
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def print_activations(t):
|
||||
print(t.op.name, ' ', t.get_shape().as_list())
|
||||
|
||||
|
||||
def inference(images):
|
||||
"""Build the AlexNet model.
|
||||
|
||||
Args:
|
||||
images: Images Tensor
|
||||
|
||||
Returns:
|
||||
pool5: the last Tensor in the convolutional component of AlexNet.
|
||||
parameters: a list of Tensors corresponding to the weights and biases of the
|
||||
AlexNet model.
|
||||
"""
|
||||
parameters = []
|
||||
# conv1
|
||||
with tf.name_scope('conv1') as scope:
|
||||
kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype=tf.float32,
|
||||
stddev=1e-1), name='weights')
|
||||
conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding='SAME')
|
||||
biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32),
|
||||
trainable=True, name='biases')
|
||||
bias = tf.nn.bias_add(conv, biases)
|
||||
conv1 = tf.nn.relu(bias, name=scope)
|
||||
print_activations(conv1)
|
||||
parameters += [kernel, biases]
|
||||
|
||||
# lrn1
|
||||
# TODO(shlens, jiayq): Add a GPU version of local response normalization.
|
||||
|
||||
# pool1
|
||||
pool1 = tf.nn.max_pool(conv1,
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding='VALID',
|
||||
name='pool1')
|
||||
print_activations(pool1)
|
||||
|
||||
# conv2
|
||||
with tf.name_scope('conv2') as scope:
|
||||
kernel = tf.Variable(tf.truncated_normal([5, 5, 64, 192], dtype=tf.float32,
|
||||
stddev=1e-1), name='weights')
|
||||
conv = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding='SAME')
|
||||
biases = tf.Variable(tf.constant(0.0, shape=[192], dtype=tf.float32),
|
||||
trainable=True, name='biases')
|
||||
bias = tf.nn.bias_add(conv, biases)
|
||||
conv2 = tf.nn.relu(bias, name=scope)
|
||||
parameters += [kernel, biases]
|
||||
print_activations(conv2)
|
||||
|
||||
# pool2
|
||||
pool2 = tf.nn.max_pool(conv2,
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding='VALID',
|
||||
name='pool2')
|
||||
print_activations(pool2)
|
||||
|
||||
# conv3
|
||||
with tf.name_scope('conv3') as scope:
|
||||
kernel = tf.Variable(tf.truncated_normal([3, 3, 192, 384],
|
||||
dtype=tf.float32,
|
||||
stddev=1e-1), name='weights')
|
||||
conv = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding='SAME')
|
||||
biases = tf.Variable(tf.constant(0.0, shape=[384], dtype=tf.float32),
|
||||
trainable=True, name='biases')
|
||||
bias = tf.nn.bias_add(conv, biases)
|
||||
conv3 = tf.nn.relu(bias, name=scope)
|
||||
parameters += [kernel, biases]
|
||||
print_activations(conv3)
|
||||
|
||||
# conv4
|
||||
with tf.name_scope('conv4') as scope:
|
||||
kernel = tf.Variable(tf.truncated_normal([3, 3, 384, 256],
|
||||
dtype=tf.float32,
|
||||
stddev=1e-1), name='weights')
|
||||
conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding='SAME')
|
||||
biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32),
|
||||
trainable=True, name='biases')
|
||||
bias = tf.nn.bias_add(conv, biases)
|
||||
conv4 = tf.nn.relu(bias, name=scope)
|
||||
parameters += [kernel, biases]
|
||||
print_activations(conv4)
|
||||
|
||||
# conv5
|
||||
with tf.name_scope('conv5') as scope:
|
||||
kernel = tf.Variable(tf.truncated_normal([3, 3, 256, 256],
|
||||
dtype=tf.float32,
|
||||
stddev=1e-1), name='weights')
|
||||
conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding='SAME')
|
||||
biases = tf.Variable(tf.constant(0.0, shape=[256], dtype=tf.float32),
|
||||
trainable=True, name='biases')
|
||||
bias = tf.nn.bias_add(conv, biases)
|
||||
conv5 = tf.nn.relu(bias, name=scope)
|
||||
parameters += [kernel, biases]
|
||||
print_activations(conv5)
|
||||
|
||||
# pool5
|
||||
pool5 = tf.nn.max_pool(conv5,
|
||||
ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding='VALID',
|
||||
name='pool5')
|
||||
print_activations(pool5)
|
||||
|
||||
return pool5, parameters
|
||||
|
||||
|
||||
def time_tensorflow_run(session, target, info_string):
|
||||
"""Run the computation to obtain the target tensor and print timing stats.
|
||||
|
||||
Args:
|
||||
session: the TensorFlow session to run the computation under.
|
||||
target: the target Tensor that is passed to the session's run() function.
|
||||
info_string: a string summarizing this run, to be printed with the stats.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
num_steps_burn_in = 10
|
||||
total_duration = 0.0
|
||||
total_duration_squared = 0.0
|
||||
for i in xrange(FLAGS.num_batches + num_steps_burn_in):
|
||||
start_time = time.time()
|
||||
_ = session.run(target)
|
||||
duration = time.time() - start_time
|
||||
if i >= num_steps_burn_in:
|
||||
if not i % 10:
|
||||
print ('%s: step %d, duration = %.3f' %
|
||||
(datetime.now(), i - num_steps_burn_in, duration))
|
||||
total_duration += duration
|
||||
total_duration_squared += duration * duration
|
||||
mn = total_duration / FLAGS.num_batches
|
||||
vr = total_duration_squared / FLAGS.num_batches - mn * mn
|
||||
sd = math.sqrt(vr)
|
||||
print ('%s: %s across %d steps, %.3f +/- %.3f sec / batch' %
|
||||
(datetime.now(), info_string, FLAGS.num_batches, mn, sd))
|
||||
|
||||
|
||||
|
||||
def run_benchmark():
|
||||
"""Run the benchmark on AlexNet."""
|
||||
with tf.Graph().as_default():
|
||||
# Generate some dummy images.
|
||||
image_size = 224
|
||||
# Note that our padding definition is slightly different the cuda-convnet.
|
||||
# In order to force the model to start with the same activations sizes,
|
||||
# we add 3 to the image_size and employ VALID padding above.
|
||||
images = tf.Variable(tf.random_normal([FLAGS.batch_size,
|
||||
image_size,
|
||||
image_size, 3],
|
||||
dtype=tf.float32,
|
||||
stddev=1e-1))
|
||||
|
||||
# Build a Graph that computes the logits predictions from the
|
||||
# inference model.
|
||||
pool5, parameters = inference(images)
|
||||
|
||||
# Build an initialization operation.
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
# Start running operations on the Graph.
|
||||
config = tf.ConfigProto()
|
||||
config.gpu_options.allocator_type = 'BFC'
|
||||
sess = tf.Session(config=config)
|
||||
sess.run(init)
|
||||
|
||||
# Run the forward benchmark.
|
||||
time_tensorflow_run(sess, pool5, "Forward")
|
||||
|
||||
# Add a simple objective so we can calculate the backward pass.
|
||||
objective = tf.nn.l2_loss(pool5)
|
||||
# Compute the gradient with respect to all the parameters.
|
||||
grad = tf.gradients(objective, parameters)
|
||||
# Run the backward benchmark.
|
||||
time_tensorflow_run(sess, grad, "Forward-backward")
|
||||
|
||||
|
||||
def main(_):
|
||||
run_benchmark()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Batch size.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_batches',
|
||||
type=int,
|
||||
default=100,
|
||||
help='Number of batches to run.'
|
||||
)
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1,87 +0,0 @@
|
||||
# Description:
|
||||
# Example TensorFlow models for CIFAR-10
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "cifar10_input",
|
||||
srcs = ["cifar10_input.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "cifar10_input_test",
|
||||
size = "small",
|
||||
srcs = ["cifar10_input_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cifar10_input",
|
||||
"//tensorflow:tensorflow_py",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "cifar10",
|
||||
srcs = ["cifar10.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":cifar10_input",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "cifar10_eval",
|
||||
srcs = [
|
||||
"cifar10_eval.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":cifar10",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "cifar10_train",
|
||||
srcs = [
|
||||
"cifar10_train.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":cifar10",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "cifar10_multi_gpu_train",
|
||||
srcs = [
|
||||
"cifar10_multi_gpu_train.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
":cifar10",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,10 +0,0 @@
|
||||
CIFAR-10 is a common benchmark in machine learning for image recognition.
|
||||
|
||||
http://www.cs.toronto.edu/~kriz/cifar.html
|
||||
|
||||
Code in this directory demonstrates how to use TensorFlow to train and evaluate a convolutional neural network (CNN) on both CPU and GPU. We also demonstrate how to train a CNN over multiple GPUs.
|
||||
|
||||
Detailed instructions on how to get started available at:
|
||||
|
||||
http://tensorflow.org/tutorials/deep_cnn/
|
||||
|
@ -1,22 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Makes helper libraries available in the cifar10 package."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.models.image.cifar10 import cifar10
|
||||
from tensorflow.models.image.cifar10 import cifar10_input
|
@ -1,399 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Builds the CIFAR-10 network.
|
||||
|
||||
Summary of available functions:
|
||||
|
||||
# Compute input images and labels for training. If you would like to run
|
||||
# evaluations, use inputs() instead.
|
||||
inputs, labels = distorted_inputs()
|
||||
|
||||
# Compute inference on the model inputs to make a prediction.
|
||||
predictions = inference(inputs)
|
||||
|
||||
# Compute the total loss of the prediction with respect to the labels.
|
||||
loss = loss(predictions, labels)
|
||||
|
||||
# Create a graph to run one step of training with respect to the loss.
|
||||
train_op = train(loss, global_step)
|
||||
"""
|
||||
# pylint: disable=missing-docstring
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import tarfile
|
||||
|
||||
from six.moves import urllib
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.image.cifar10 import cifar10_input
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
# Basic model parameters.
|
||||
tf.app.flags.DEFINE_integer('batch_size', 128,
|
||||
"""Number of images to process in a batch.""")
|
||||
tf.app.flags.DEFINE_string('data_dir', '/tmp/cifar10_data',
|
||||
"""Path to the CIFAR-10 data directory.""")
|
||||
tf.app.flags.DEFINE_boolean('use_fp16', False,
|
||||
"""Train the model using fp16.""")
|
||||
|
||||
# Global constants describing the CIFAR-10 data set.
|
||||
IMAGE_SIZE = cifar10_input.IMAGE_SIZE
|
||||
NUM_CLASSES = cifar10_input.NUM_CLASSES
|
||||
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
|
||||
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
|
||||
|
||||
|
||||
# Constants describing the training process.
|
||||
MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
|
||||
NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays.
|
||||
LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
|
||||
INITIAL_LEARNING_RATE = 0.1 # Initial learning rate.
|
||||
|
||||
# If a model is trained with multiple GPUs, prefix all Op names with tower_name
|
||||
# to differentiate the operations. Note that this prefix is removed from the
|
||||
# names of the summaries when visualizing a model.
|
||||
TOWER_NAME = 'tower'
|
||||
|
||||
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz'
|
||||
|
||||
|
||||
def _activation_summary(x):
|
||||
"""Helper to create summaries for activations.
|
||||
|
||||
Creates a summary that provides a histogram of activations.
|
||||
Creates a summary that measures the sparsity of activations.
|
||||
|
||||
Args:
|
||||
x: Tensor
|
||||
Returns:
|
||||
nothing
|
||||
"""
|
||||
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
|
||||
# session. This helps the clarity of presentation on tensorboard.
|
||||
tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
|
||||
tf.contrib.deprecated.histogram_summary(tensor_name + '/activations', x)
|
||||
tf.contrib.deprecated.scalar_summary(tensor_name + '/sparsity',
|
||||
tf.nn.zero_fraction(x))
|
||||
|
||||
|
||||
def _variable_on_cpu(name, shape, initializer):
|
||||
"""Helper to create a Variable stored on CPU memory.
|
||||
|
||||
Args:
|
||||
name: name of the variable
|
||||
shape: list of ints
|
||||
initializer: initializer for Variable
|
||||
|
||||
Returns:
|
||||
Variable Tensor
|
||||
"""
|
||||
with tf.device('/cpu:0'):
|
||||
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
|
||||
var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype)
|
||||
return var
|
||||
|
||||
|
||||
def _variable_with_weight_decay(name, shape, stddev, wd):
|
||||
"""Helper to create an initialized Variable with weight decay.
|
||||
|
||||
Note that the Variable is initialized with a truncated normal distribution.
|
||||
A weight decay is added only if one is specified.
|
||||
|
||||
Args:
|
||||
name: name of the variable
|
||||
shape: list of ints
|
||||
stddev: standard deviation of a truncated Gaussian
|
||||
wd: add L2Loss weight decay multiplied by this float. If None, weight
|
||||
decay is not added for this Variable.
|
||||
|
||||
Returns:
|
||||
Variable Tensor
|
||||
"""
|
||||
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
|
||||
var = _variable_on_cpu(
|
||||
name,
|
||||
shape,
|
||||
tf.truncated_normal_initializer(stddev=stddev, dtype=dtype))
|
||||
if wd is not None:
|
||||
weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
|
||||
tf.add_to_collection('losses', weight_decay)
|
||||
return var
|
||||
|
||||
|
||||
def distorted_inputs():
|
||||
"""Construct distorted input for CIFAR training using the Reader ops.
|
||||
|
||||
Returns:
|
||||
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
||||
labels: Labels. 1D tensor of [batch_size] size.
|
||||
|
||||
Raises:
|
||||
ValueError: If no data_dir
|
||||
"""
|
||||
if not FLAGS.data_dir:
|
||||
raise ValueError('Please supply a data_dir')
|
||||
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
|
||||
images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,
|
||||
batch_size=FLAGS.batch_size)
|
||||
if FLAGS.use_fp16:
|
||||
images = tf.cast(images, tf.float16)
|
||||
labels = tf.cast(labels, tf.float16)
|
||||
return images, labels
|
||||
|
||||
|
||||
def inputs(eval_data):
|
||||
"""Construct input for CIFAR evaluation using the Reader ops.
|
||||
|
||||
Args:
|
||||
eval_data: bool, indicating if one should use the train or eval data set.
|
||||
|
||||
Returns:
|
||||
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
||||
labels: Labels. 1D tensor of [batch_size] size.
|
||||
|
||||
Raises:
|
||||
ValueError: If no data_dir
|
||||
"""
|
||||
if not FLAGS.data_dir:
|
||||
raise ValueError('Please supply a data_dir')
|
||||
data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
|
||||
images, labels = cifar10_input.inputs(eval_data=eval_data,
|
||||
data_dir=data_dir,
|
||||
batch_size=FLAGS.batch_size)
|
||||
if FLAGS.use_fp16:
|
||||
images = tf.cast(images, tf.float16)
|
||||
labels = tf.cast(labels, tf.float16)
|
||||
return images, labels
|
||||
|
||||
|
||||
def inference(images):
|
||||
"""Build the CIFAR-10 model.
|
||||
|
||||
Args:
|
||||
images: Images returned from distorted_inputs() or inputs().
|
||||
|
||||
Returns:
|
||||
Logits.
|
||||
"""
|
||||
# We instantiate all variables using tf.get_variable() instead of
|
||||
# tf.Variable() in order to share variables across multiple GPU training runs.
|
||||
# If we only ran this model on a single GPU, we could simplify this function
|
||||
# by replacing all instances of tf.get_variable() with tf.Variable().
|
||||
#
|
||||
# conv1
|
||||
with tf.variable_scope('conv1') as scope:
|
||||
kernel = _variable_with_weight_decay('weights',
|
||||
shape=[5, 5, 3, 64],
|
||||
stddev=5e-2,
|
||||
wd=0.0)
|
||||
conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
|
||||
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
|
||||
pre_activation = tf.nn.bias_add(conv, biases)
|
||||
conv1 = tf.nn.relu(pre_activation, name=scope.name)
|
||||
_activation_summary(conv1)
|
||||
|
||||
# pool1
|
||||
pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
|
||||
padding='SAME', name='pool1')
|
||||
# norm1
|
||||
norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
|
||||
name='norm1')
|
||||
|
||||
# conv2
|
||||
with tf.variable_scope('conv2') as scope:
|
||||
kernel = _variable_with_weight_decay('weights',
|
||||
shape=[5, 5, 64, 64],
|
||||
stddev=5e-2,
|
||||
wd=0.0)
|
||||
conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
|
||||
biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))
|
||||
pre_activation = tf.nn.bias_add(conv, biases)
|
||||
conv2 = tf.nn.relu(pre_activation, name=scope.name)
|
||||
_activation_summary(conv2)
|
||||
|
||||
# norm2
|
||||
norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
|
||||
name='norm2')
|
||||
# pool2
|
||||
pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
|
||||
strides=[1, 2, 2, 1], padding='SAME', name='pool2')
|
||||
|
||||
# local3
|
||||
with tf.variable_scope('local3') as scope:
|
||||
# Move everything into depth so we can perform a single matrix multiply.
|
||||
reshape = tf.reshape(pool2, [FLAGS.batch_size, -1])
|
||||
dim = reshape.get_shape()[1].value
|
||||
weights = _variable_with_weight_decay('weights', shape=[dim, 384],
|
||||
stddev=0.04, wd=0.004)
|
||||
biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
|
||||
local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
|
||||
_activation_summary(local3)
|
||||
|
||||
# local4
|
||||
with tf.variable_scope('local4') as scope:
|
||||
weights = _variable_with_weight_decay('weights', shape=[384, 192],
|
||||
stddev=0.04, wd=0.004)
|
||||
biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
|
||||
local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)
|
||||
_activation_summary(local4)
|
||||
|
||||
# linear layer(WX + b),
|
||||
# We don't apply softmax here because
|
||||
# tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits
|
||||
# and performs the softmax internally for efficiency.
|
||||
with tf.variable_scope('softmax_linear') as scope:
|
||||
weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],
|
||||
stddev=1/192.0, wd=0.0)
|
||||
biases = _variable_on_cpu('biases', [NUM_CLASSES],
|
||||
tf.constant_initializer(0.0))
|
||||
softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name)
|
||||
_activation_summary(softmax_linear)
|
||||
|
||||
return softmax_linear
|
||||
|
||||
|
||||
def loss(logits, labels):
|
||||
"""Add L2Loss to all the trainable variables.
|
||||
|
||||
Add summary for "Loss" and "Loss/avg".
|
||||
Args:
|
||||
logits: Logits from inference().
|
||||
labels: Labels from distorted_inputs or inputs(). 1-D tensor
|
||||
of shape [batch_size]
|
||||
|
||||
Returns:
|
||||
Loss tensor of type float.
|
||||
"""
|
||||
# Calculate the average cross entropy loss across the batch.
|
||||
labels = tf.cast(labels, tf.int64)
|
||||
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits, labels, name='cross_entropy_per_example')
|
||||
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
|
||||
tf.add_to_collection('losses', cross_entropy_mean)
|
||||
|
||||
# The total loss is defined as the cross entropy loss plus all of the weight
|
||||
# decay terms (L2 loss).
|
||||
return tf.add_n(tf.get_collection('losses'), name='total_loss')
|
||||
|
||||
|
||||
def _add_loss_summaries(total_loss):
|
||||
"""Add summaries for losses in CIFAR-10 model.
|
||||
|
||||
Generates moving average for all losses and associated summaries for
|
||||
visualizing the performance of the network.
|
||||
|
||||
Args:
|
||||
total_loss: Total loss from loss().
|
||||
Returns:
|
||||
loss_averages_op: op for generating moving averages of losses.
|
||||
"""
|
||||
# Compute the moving average of all individual losses and the total loss.
|
||||
loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
|
||||
losses = tf.get_collection('losses')
|
||||
loss_averages_op = loss_averages.apply(losses + [total_loss])
|
||||
|
||||
# Attach a scalar summary to all individual losses and the total loss; do the
|
||||
# same for the averaged version of the losses.
|
||||
for l in losses + [total_loss]:
|
||||
# Name each loss as '(raw)' and name the moving average version of the loss
|
||||
# as the original loss name.
|
||||
tf.contrib.deprecated.scalar_summary(l.op.name + ' (raw)', l)
|
||||
tf.contrib.deprecated.scalar_summary(l.op.name, loss_averages.average(l))
|
||||
|
||||
return loss_averages_op
|
||||
|
||||
|
||||
def train(total_loss, global_step):
|
||||
"""Train CIFAR-10 model.
|
||||
|
||||
Create an optimizer and apply to all trainable variables. Add moving
|
||||
average for all trainable variables.
|
||||
|
||||
Args:
|
||||
total_loss: Total loss from loss().
|
||||
global_step: Integer Variable counting the number of training steps
|
||||
processed.
|
||||
Returns:
|
||||
train_op: op for training.
|
||||
"""
|
||||
# Variables that affect learning rate.
|
||||
num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_size
|
||||
decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)
|
||||
|
||||
# Decay the learning rate exponentially based on the number of steps.
|
||||
lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,
|
||||
global_step,
|
||||
decay_steps,
|
||||
LEARNING_RATE_DECAY_FACTOR,
|
||||
staircase=True)
|
||||
tf.contrib.deprecated.scalar_summary('learning_rate', lr)
|
||||
|
||||
# Generate moving averages of all losses and associated summaries.
|
||||
loss_averages_op = _add_loss_summaries(total_loss)
|
||||
|
||||
# Compute gradients.
|
||||
with tf.control_dependencies([loss_averages_op]):
|
||||
opt = tf.train.GradientDescentOptimizer(lr)
|
||||
grads = opt.compute_gradients(total_loss)
|
||||
|
||||
# Apply gradients.
|
||||
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
|
||||
|
||||
# Add histograms for trainable variables.
|
||||
for var in tf.trainable_variables():
|
||||
tf.contrib.deprecated.histogram_summary(var.op.name, var)
|
||||
|
||||
# Add histograms for gradients.
|
||||
for grad, var in grads:
|
||||
if grad is not None:
|
||||
tf.contrib.deprecated.histogram_summary(var.op.name + '/gradients', grad)
|
||||
|
||||
# Track the moving averages of all trainable variables.
|
||||
variable_averages = tf.train.ExponentialMovingAverage(
|
||||
MOVING_AVERAGE_DECAY, global_step)
|
||||
variables_averages_op = variable_averages.apply(tf.trainable_variables())
|
||||
|
||||
with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
|
||||
train_op = tf.no_op(name='train')
|
||||
|
||||
return train_op
|
||||
|
||||
|
||||
def maybe_download_and_extract():
|
||||
"""Download and extract the tarball from Alex's website."""
|
||||
dest_directory = FLAGS.data_dir
|
||||
if not os.path.exists(dest_directory):
|
||||
os.makedirs(dest_directory)
|
||||
filename = DATA_URL.split('/')[-1]
|
||||
filepath = os.path.join(dest_directory, filename)
|
||||
if not os.path.exists(filepath):
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
|
||||
float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
|
||||
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
|
@ -1,157 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Evaluation for CIFAR-10.
|
||||
|
||||
Accuracy:
|
||||
cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs
|
||||
of data) as judged by cifar10_eval.py.
|
||||
|
||||
Speed:
|
||||
On a single Tesla K40, cifar10_train.py processes a single batch of 128 images
|
||||
in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86%
|
||||
accuracy after 100K steps in 8 hours of training time.
|
||||
|
||||
Usage:
|
||||
Please see the tutorial and website for how to download the CIFAR-10
|
||||
data set, compile the program and train the model.
|
||||
|
||||
http://tensorflow.org/tutorials/deep_cnn/
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import datetime
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.image.cifar10 import cifar10
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
tf.app.flags.DEFINE_string('eval_dir', '/tmp/cifar10_eval',
|
||||
"""Directory where to write event logs.""")
|
||||
tf.app.flags.DEFINE_string('eval_data', 'test',
|
||||
"""Either 'test' or 'train_eval'.""")
|
||||
tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/cifar10_train',
|
||||
"""Directory where to read model checkpoints.""")
|
||||
tf.app.flags.DEFINE_integer('eval_interval_secs', 60 * 5,
|
||||
"""How often to run the eval.""")
|
||||
tf.app.flags.DEFINE_integer('num_examples', 10000,
|
||||
"""Number of examples to run.""")
|
||||
tf.app.flags.DEFINE_boolean('run_once', False,
|
||||
"""Whether to run eval only once.""")
|
||||
|
||||
|
||||
def eval_once(saver, summary_writer, top_k_op, summary_op):
|
||||
"""Run Eval once.
|
||||
|
||||
Args:
|
||||
saver: Saver.
|
||||
summary_writer: Summary writer.
|
||||
top_k_op: Top K op.
|
||||
summary_op: Summary op.
|
||||
"""
|
||||
with tf.Session() as sess:
|
||||
ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
|
||||
if ckpt and ckpt.model_checkpoint_path:
|
||||
# Restores from checkpoint
|
||||
saver.restore(sess, ckpt.model_checkpoint_path)
|
||||
# Assuming model_checkpoint_path looks something like:
|
||||
# /my-favorite-path/cifar10_train/model.ckpt-0,
|
||||
# extract global_step from it.
|
||||
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
|
||||
else:
|
||||
print('No checkpoint file found')
|
||||
return
|
||||
|
||||
# Start the queue runners.
|
||||
coord = tf.train.Coordinator()
|
||||
try:
|
||||
threads = []
|
||||
for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
|
||||
threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
|
||||
start=True))
|
||||
|
||||
num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
|
||||
true_count = 0 # Counts the number of correct predictions.
|
||||
total_sample_count = num_iter * FLAGS.batch_size
|
||||
step = 0
|
||||
while step < num_iter and not coord.should_stop():
|
||||
predictions = sess.run([top_k_op])
|
||||
true_count += np.sum(predictions)
|
||||
step += 1
|
||||
|
||||
# Compute precision @ 1.
|
||||
precision = true_count / total_sample_count
|
||||
print('%s: precision @ 1 = %.3f' % (datetime.now(), precision))
|
||||
|
||||
summary = tf.Summary()
|
||||
summary.ParseFromString(sess.run(summary_op))
|
||||
summary.value.add(tag='Precision @ 1', simple_value=precision)
|
||||
summary_writer.add_summary(summary, global_step)
|
||||
except Exception as e: # pylint: disable=broad-except
|
||||
coord.request_stop(e)
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads, stop_grace_period_secs=10)
|
||||
|
||||
|
||||
def evaluate():
|
||||
"""Eval CIFAR-10 for a number of steps."""
|
||||
with tf.Graph().as_default() as g:
|
||||
# Get images and labels for CIFAR-10.
|
||||
eval_data = FLAGS.eval_data == 'test'
|
||||
images, labels = cifar10.inputs(eval_data=eval_data)
|
||||
|
||||
# Build a Graph that computes the logits predictions from the
|
||||
# inference model.
|
||||
logits = cifar10.inference(images)
|
||||
|
||||
# Calculate predictions.
|
||||
top_k_op = tf.nn.in_top_k(logits, labels, 1)
|
||||
|
||||
# Restore the moving average version of the learned variables for eval.
|
||||
variable_averages = tf.train.ExponentialMovingAverage(
|
||||
cifar10.MOVING_AVERAGE_DECAY)
|
||||
variables_to_restore = variable_averages.variables_to_restore()
|
||||
saver = tf.train.Saver(variables_to_restore)
|
||||
|
||||
# Build the summary operation based on the TF collection of Summaries.
|
||||
summary_op = tf.summary.merge_all()
|
||||
|
||||
summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g)
|
||||
|
||||
while True:
|
||||
eval_once(saver, summary_writer, top_k_op, summary_op)
|
||||
if FLAGS.run_once:
|
||||
break
|
||||
time.sleep(FLAGS.eval_interval_secs)
|
||||
|
||||
|
||||
def main(argv=None): # pylint: disable=unused-argument
|
||||
cifar10.maybe_download_and_extract()
|
||||
if tf.gfile.Exists(FLAGS.eval_dir):
|
||||
tf.gfile.DeleteRecursively(FLAGS.eval_dir)
|
||||
tf.gfile.MakeDirs(FLAGS.eval_dir)
|
||||
evaluate()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.app.run()
|
@ -1,249 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Routine for decoding the CIFAR-10 binary file format."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
# Process images of this size. Note that this differs from the original CIFAR
|
||||
# image size of 32 x 32. If one alters this number, then the entire model
|
||||
# architecture will change and any model would need to be retrained.
|
||||
IMAGE_SIZE = 24
|
||||
|
||||
# Global constants describing the CIFAR-10 data set.
|
||||
NUM_CLASSES = 10
|
||||
NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
|
||||
NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
|
||||
|
||||
|
||||
def read_cifar10(filename_queue):
|
||||
"""Reads and parses examples from CIFAR10 data files.
|
||||
|
||||
Recommendation: if you want N-way read parallelism, call this function
|
||||
N times. This will give you N independent Readers reading different
|
||||
files & positions within those files, which will give better mixing of
|
||||
examples.
|
||||
|
||||
Args:
|
||||
filename_queue: A queue of strings with the filenames to read from.
|
||||
|
||||
Returns:
|
||||
An object representing a single example, with the following fields:
|
||||
height: number of rows in the result (32)
|
||||
width: number of columns in the result (32)
|
||||
depth: number of color channels in the result (3)
|
||||
key: a scalar string Tensor describing the filename & record number
|
||||
for this example.
|
||||
label: an int32 Tensor with the label in the range 0..9.
|
||||
uint8image: a [height, width, depth] uint8 Tensor with the image data
|
||||
"""
|
||||
|
||||
class CIFAR10Record(object):
|
||||
pass
|
||||
result = CIFAR10Record()
|
||||
|
||||
# Dimensions of the images in the CIFAR-10 dataset.
|
||||
# See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
|
||||
# input format.
|
||||
label_bytes = 1 # 2 for CIFAR-100
|
||||
result.height = 32
|
||||
result.width = 32
|
||||
result.depth = 3
|
||||
image_bytes = result.height * result.width * result.depth
|
||||
# Every record consists of a label followed by the image, with a
|
||||
# fixed number of bytes for each.
|
||||
record_bytes = label_bytes + image_bytes
|
||||
|
||||
# Read a record, getting filenames from the filename_queue. No
|
||||
# header or footer in the CIFAR-10 format, so we leave header_bytes
|
||||
# and footer_bytes at their default of 0.
|
||||
reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
|
||||
result.key, value = reader.read(filename_queue)
|
||||
|
||||
# Convert from a string to a vector of uint8 that is record_bytes long.
|
||||
record_bytes = tf.decode_raw(value, tf.uint8)
|
||||
|
||||
# The first bytes represent the label, which we convert from uint8->int32.
|
||||
result.label = tf.cast(
|
||||
tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
|
||||
|
||||
# The remaining bytes after the label represent the image, which we reshape
|
||||
# from [depth * height * width] to [depth, height, width].
|
||||
depth_major = tf.reshape(
|
||||
tf.strided_slice(record_bytes, [label_bytes],
|
||||
[label_bytes + image_bytes]),
|
||||
[result.depth, result.height, result.width])
|
||||
# Convert from [depth, height, width] to [height, width, depth].
|
||||
result.uint8image = tf.transpose(depth_major, [1, 2, 0])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _generate_image_and_label_batch(image, label, min_queue_examples,
|
||||
batch_size, shuffle):
|
||||
"""Construct a queued batch of images and labels.
|
||||
|
||||
Args:
|
||||
image: 3-D Tensor of [height, width, 3] of type.float32.
|
||||
label: 1-D Tensor of type.int32
|
||||
min_queue_examples: int32, minimum number of samples to retain
|
||||
in the queue that provides of batches of examples.
|
||||
batch_size: Number of images per batch.
|
||||
shuffle: boolean indicating whether to use a shuffling queue.
|
||||
|
||||
Returns:
|
||||
images: Images. 4D tensor of [batch_size, height, width, 3] size.
|
||||
labels: Labels. 1D tensor of [batch_size] size.
|
||||
"""
|
||||
# Create a queue that shuffles the examples, and then
|
||||
# read 'batch_size' images + labels from the example queue.
|
||||
num_preprocess_threads = 16
|
||||
if shuffle:
|
||||
images, label_batch = tf.train.shuffle_batch(
|
||||
[image, label],
|
||||
batch_size=batch_size,
|
||||
num_threads=num_preprocess_threads,
|
||||
capacity=min_queue_examples + 3 * batch_size,
|
||||
min_after_dequeue=min_queue_examples)
|
||||
else:
|
||||
images, label_batch = tf.train.batch(
|
||||
[image, label],
|
||||
batch_size=batch_size,
|
||||
num_threads=num_preprocess_threads,
|
||||
capacity=min_queue_examples + 3 * batch_size)
|
||||
|
||||
# Display the training images in the visualizer.
|
||||
tf.contrib.deprecated.image_summary('images', images)
|
||||
|
||||
return images, tf.reshape(label_batch, [batch_size])
|
||||
|
||||
|
||||
def distorted_inputs(data_dir, batch_size):
|
||||
"""Construct distorted input for CIFAR training using the Reader ops.
|
||||
|
||||
Args:
|
||||
data_dir: Path to the CIFAR-10 data directory.
|
||||
batch_size: Number of images per batch.
|
||||
|
||||
Returns:
|
||||
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
||||
labels: Labels. 1D tensor of [batch_size] size.
|
||||
"""
|
||||
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
|
||||
for i in xrange(1, 6)]
|
||||
for f in filenames:
|
||||
if not tf.gfile.Exists(f):
|
||||
raise ValueError('Failed to find file: ' + f)
|
||||
|
||||
# Create a queue that produces the filenames to read.
|
||||
filename_queue = tf.train.string_input_producer(filenames)
|
||||
|
||||
# Read examples from files in the filename queue.
|
||||
read_input = read_cifar10(filename_queue)
|
||||
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
|
||||
|
||||
height = IMAGE_SIZE
|
||||
width = IMAGE_SIZE
|
||||
|
||||
# Image processing for training the network. Note the many random
|
||||
# distortions applied to the image.
|
||||
|
||||
# Randomly crop a [height, width] section of the image.
|
||||
distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
|
||||
|
||||
# Randomly flip the image horizontally.
|
||||
distorted_image = tf.image.random_flip_left_right(distorted_image)
|
||||
|
||||
# Because these operations are not commutative, consider randomizing
|
||||
# the order their operation.
|
||||
distorted_image = tf.image.random_brightness(distorted_image,
|
||||
max_delta=63)
|
||||
distorted_image = tf.image.random_contrast(distorted_image,
|
||||
lower=0.2, upper=1.8)
|
||||
|
||||
# Subtract off the mean and divide by the variance of the pixels.
|
||||
float_image = tf.image.per_image_standardization(distorted_image)
|
||||
|
||||
# Ensure that the random shuffling has good mixing properties.
|
||||
min_fraction_of_examples_in_queue = 0.4
|
||||
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
|
||||
min_fraction_of_examples_in_queue)
|
||||
print ('Filling queue with %d CIFAR images before starting to train. '
|
||||
'This will take a few minutes.' % min_queue_examples)
|
||||
|
||||
# Generate a batch of images and labels by building up a queue of examples.
|
||||
return _generate_image_and_label_batch(float_image, read_input.label,
|
||||
min_queue_examples, batch_size,
|
||||
shuffle=True)
|
||||
|
||||
|
||||
def inputs(eval_data, data_dir, batch_size):
|
||||
"""Construct input for CIFAR evaluation using the Reader ops.
|
||||
|
||||
Args:
|
||||
eval_data: bool, indicating if one should use the train or eval data set.
|
||||
data_dir: Path to the CIFAR-10 data directory.
|
||||
batch_size: Number of images per batch.
|
||||
|
||||
Returns:
|
||||
images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
|
||||
labels: Labels. 1D tensor of [batch_size] size.
|
||||
"""
|
||||
if not eval_data:
|
||||
filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
|
||||
for i in xrange(1, 6)]
|
||||
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
|
||||
else:
|
||||
filenames = [os.path.join(data_dir, 'test_batch.bin')]
|
||||
num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
|
||||
|
||||
for f in filenames:
|
||||
if not tf.gfile.Exists(f):
|
||||
raise ValueError('Failed to find file: ' + f)
|
||||
|
||||
# Create a queue that produces the filenames to read.
|
||||
filename_queue = tf.train.string_input_producer(filenames)
|
||||
|
||||
# Read examples from files in the filename queue.
|
||||
read_input = read_cifar10(filename_queue)
|
||||
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
|
||||
|
||||
height = IMAGE_SIZE
|
||||
width = IMAGE_SIZE
|
||||
|
||||
# Image processing for evaluation.
|
||||
# Crop the central [height, width] of the image.
|
||||
resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
|
||||
width, height)
|
||||
|
||||
# Subtract off the mean and divide by the variance of the pixels.
|
||||
float_image = tf.image.per_image_standardization(resized_image)
|
||||
|
||||
# Ensure that the random shuffling has good mixing properties.
|
||||
min_fraction_of_examples_in_queue = 0.4
|
||||
min_queue_examples = int(num_examples_per_epoch *
|
||||
min_fraction_of_examples_in_queue)
|
||||
|
||||
# Generate a batch of images and labels by building up a queue of examples.
|
||||
return _generate_image_and_label_batch(float_image, read_input.label,
|
||||
min_queue_examples, batch_size,
|
||||
shuffle=False)
|
@ -1,66 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for cifar10 input."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.image.cifar10 import cifar10_input
|
||||
|
||||
|
||||
class CIFAR10InputTest(tf.test.TestCase):
|
||||
|
||||
def _record(self, label, red, green, blue):
|
||||
image_size = 32 * 32
|
||||
record = bytes(bytearray([label] + [red] * image_size +
|
||||
[green] * image_size + [blue] * image_size))
|
||||
expected = [[[red, green, blue]] * 32] * 32
|
||||
return record, expected
|
||||
|
||||
def testSimple(self):
|
||||
labels = [9, 3, 0]
|
||||
records = [self._record(labels[0], 0, 128, 255),
|
||||
self._record(labels[1], 255, 0, 1),
|
||||
self._record(labels[2], 254, 255, 0)]
|
||||
contents = b"".join([record for record, _ in records])
|
||||
expected = [expected for _, expected in records]
|
||||
filename = os.path.join(self.get_temp_dir(), "cifar")
|
||||
open(filename, "wb").write(contents)
|
||||
|
||||
with self.test_session() as sess:
|
||||
q = tf.FIFOQueue(99, [tf.string], shapes=())
|
||||
q.enqueue([filename]).run()
|
||||
q.close().run()
|
||||
result = cifar10_input.read_cifar10(q)
|
||||
|
||||
for i in range(3):
|
||||
key, label, uint8image = sess.run([
|
||||
result.key, result.label, result.uint8image])
|
||||
self.assertEqual("%s:%d" % (filename, i), tf.compat.as_text(key))
|
||||
self.assertEqual(labels[i], label)
|
||||
self.assertAllEqual(expected[i], uint8image)
|
||||
|
||||
with self.assertRaises(tf.errors.OutOfRangeError):
|
||||
sess.run([result.key, result.uint8image])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -1,273 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""A binary to train CIFAR-10 using multiple GPU's with synchronous updates.
|
||||
|
||||
Accuracy:
|
||||
cifar10_multi_gpu_train.py achieves ~86% accuracy after 100K steps (256
|
||||
epochs of data) as judged by cifar10_eval.py.
|
||||
|
||||
Speed: With batch_size 128.
|
||||
|
||||
System | Step Time (sec/batch) | Accuracy
|
||||
--------------------------------------------------------------------
|
||||
1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
|
||||
1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
|
||||
2 Tesla K20m | 0.13-0.20 | ~84% at 30K steps (2.5 hours)
|
||||
3 Tesla K20m | 0.13-0.18 | ~84% at 30K steps
|
||||
4 Tesla K20m | ~0.10 | ~84% at 30K steps
|
||||
|
||||
Usage:
|
||||
Please see the tutorial and website for how to download the CIFAR-10
|
||||
data set, compile the program and train the model.
|
||||
|
||||
http://tensorflow.org/tutorials/deep_cnn/
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import datetime
|
||||
import os.path
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
from tensorflow.models.image.cifar10 import cifar10
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
|
||||
"""Directory where to write event logs """
|
||||
"""and checkpoint.""")
|
||||
tf.app.flags.DEFINE_integer('max_steps', 1000000,
|
||||
"""Number of batches to run.""")
|
||||
tf.app.flags.DEFINE_integer('num_gpus', 1,
|
||||
"""How many GPUs to use.""")
|
||||
tf.app.flags.DEFINE_boolean('log_device_placement', False,
|
||||
"""Whether to log device placement.""")
|
||||
|
||||
|
||||
def tower_loss(scope):
|
||||
"""Calculate the total loss on a single tower running the CIFAR model.
|
||||
|
||||
Args:
|
||||
scope: unique prefix string identifying the CIFAR tower, e.g. 'tower_0'
|
||||
|
||||
Returns:
|
||||
Tensor of shape [] containing the total loss for a batch of data
|
||||
"""
|
||||
# Get images and labels for CIFAR-10.
|
||||
images, labels = cifar10.distorted_inputs()
|
||||
|
||||
# Build inference Graph.
|
||||
logits = cifar10.inference(images)
|
||||
|
||||
# Build the portion of the Graph calculating the losses. Note that we will
|
||||
# assemble the total_loss using a custom function below.
|
||||
_ = cifar10.loss(logits, labels)
|
||||
|
||||
# Assemble all of the losses for the current tower only.
|
||||
losses = tf.get_collection('losses', scope)
|
||||
|
||||
# Calculate the total loss for the current tower.
|
||||
total_loss = tf.add_n(losses, name='total_loss')
|
||||
|
||||
# Attach a scalar summary to all individual losses and the total loss; do the
|
||||
# same for the averaged version of the losses.
|
||||
for l in losses + [total_loss]:
|
||||
# Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
|
||||
# session. This helps the clarity of presentation on tensorboard.
|
||||
loss_name = re.sub('%s_[0-9]*/' % cifar10.TOWER_NAME, '', l.op.name)
|
||||
tf.contrib.deprecated.scalar_summary(loss_name, l)
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
def average_gradients(tower_grads):
|
||||
"""Calculate the average gradient for each shared variable across all towers.
|
||||
|
||||
Note that this function provides a synchronization point across all towers.
|
||||
|
||||
Args:
|
||||
tower_grads: List of lists of (gradient, variable) tuples. The outer list
|
||||
is over individual gradients. The inner list is over the gradient
|
||||
calculation for each tower.
|
||||
Returns:
|
||||
List of pairs of (gradient, variable) where the gradient has been averaged
|
||||
across all towers.
|
||||
"""
|
||||
average_grads = []
|
||||
for grad_and_vars in zip(*tower_grads):
|
||||
# Note that each grad_and_vars looks like the following:
|
||||
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
|
||||
grads = []
|
||||
for g, _ in grad_and_vars:
|
||||
# Add 0 dimension to the gradients to represent the tower.
|
||||
expanded_g = tf.expand_dims(g, 0)
|
||||
|
||||
# Append on a 'tower' dimension which we will average over below.
|
||||
grads.append(expanded_g)
|
||||
|
||||
# Average over the 'tower' dimension.
|
||||
grad = tf.concat_v2(grads, 0)
|
||||
grad = tf.reduce_mean(grad, 0)
|
||||
|
||||
# Keep in mind that the Variables are redundant because they are shared
|
||||
# across towers. So .. we will just return the first tower's pointer to
|
||||
# the Variable.
|
||||
v = grad_and_vars[0][1]
|
||||
grad_and_var = (grad, v)
|
||||
average_grads.append(grad_and_var)
|
||||
return average_grads
|
||||
|
||||
|
||||
def train():
|
||||
"""Train CIFAR-10 for a number of steps."""
|
||||
with tf.Graph().as_default(), tf.device('/cpu:0'):
|
||||
# Create a variable to count the number of train() calls. This equals the
|
||||
# number of batches processed * FLAGS.num_gpus.
|
||||
global_step = tf.get_variable(
|
||||
'global_step', [],
|
||||
initializer=tf.constant_initializer(0), trainable=False)
|
||||
|
||||
# Calculate the learning rate schedule.
|
||||
num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
|
||||
FLAGS.batch_size)
|
||||
decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY)
|
||||
|
||||
# Decay the learning rate exponentially based on the number of steps.
|
||||
lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE,
|
||||
global_step,
|
||||
decay_steps,
|
||||
cifar10.LEARNING_RATE_DECAY_FACTOR,
|
||||
staircase=True)
|
||||
|
||||
# Create an optimizer that performs gradient descent.
|
||||
opt = tf.train.GradientDescentOptimizer(lr)
|
||||
|
||||
# Calculate the gradients for each model tower.
|
||||
tower_grads = []
|
||||
for i in xrange(FLAGS.num_gpus):
|
||||
with tf.device('/gpu:%d' % i):
|
||||
with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
|
||||
# Calculate the loss for one tower of the CIFAR model. This function
|
||||
# constructs the entire CIFAR model but shares the variables across
|
||||
# all towers.
|
||||
loss = tower_loss(scope)
|
||||
|
||||
# Reuse variables for the next tower.
|
||||
tf.get_variable_scope().reuse_variables()
|
||||
|
||||
# Retain the summaries from the final tower.
|
||||
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
|
||||
|
||||
# Calculate the gradients for the batch of data on this CIFAR tower.
|
||||
grads = opt.compute_gradients(loss)
|
||||
|
||||
# Keep track of the gradients across all towers.
|
||||
tower_grads.append(grads)
|
||||
|
||||
# We must calculate the mean of each gradient. Note that this is the
|
||||
# synchronization point across all towers.
|
||||
grads = average_gradients(tower_grads)
|
||||
|
||||
# Add a summary to track the learning rate.
|
||||
summaries.append(tf.contrib.deprecated.scalar_summary('learning_rate', lr))
|
||||
|
||||
# Add histograms for gradients.
|
||||
for grad, var in grads:
|
||||
if grad is not None:
|
||||
summaries.append(
|
||||
tf.contrib.deprecated.histogram_summary(var.op.name + '/gradients',
|
||||
grad))
|
||||
|
||||
# Apply the gradients to adjust the shared variables.
|
||||
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
|
||||
|
||||
# Add histograms for trainable variables.
|
||||
for var in tf.trainable_variables():
|
||||
summaries.append(
|
||||
tf.contrib.deprecated.histogram_summary(var.op.name, var))
|
||||
|
||||
# Track the moving averages of all trainable variables.
|
||||
variable_averages = tf.train.ExponentialMovingAverage(
|
||||
cifar10.MOVING_AVERAGE_DECAY, global_step)
|
||||
variables_averages_op = variable_averages.apply(tf.trainable_variables())
|
||||
|
||||
# Group all updates to into a single train op.
|
||||
train_op = tf.group(apply_gradient_op, variables_averages_op)
|
||||
|
||||
# Create a saver.
|
||||
saver = tf.train.Saver(tf.all_variables())
|
||||
|
||||
# Build the summary operation from the last tower summaries.
|
||||
summary_op = tf.contrib.deprecated.merge_summary(summaries)
|
||||
|
||||
# Build an initialization operation to run below.
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
# Start running operations on the Graph. allow_soft_placement must be set to
|
||||
# True to build towers on GPU, as some of the ops do not have GPU
|
||||
# implementations.
|
||||
sess = tf.Session(config=tf.ConfigProto(
|
||||
allow_soft_placement=True,
|
||||
log_device_placement=FLAGS.log_device_placement))
|
||||
sess.run(init)
|
||||
|
||||
# Start the queue runners.
|
||||
tf.train.start_queue_runners(sess=sess)
|
||||
|
||||
summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph)
|
||||
|
||||
for step in xrange(FLAGS.max_steps):
|
||||
start_time = time.time()
|
||||
_, loss_value = sess.run([train_op, loss])
|
||||
duration = time.time() - start_time
|
||||
|
||||
assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
|
||||
|
||||
if step % 10 == 0:
|
||||
num_examples_per_step = FLAGS.batch_size * FLAGS.num_gpus
|
||||
examples_per_sec = num_examples_per_step / duration
|
||||
sec_per_batch = duration / FLAGS.num_gpus
|
||||
|
||||
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
|
||||
'sec/batch)')
|
||||
print (format_str % (datetime.now(), step, loss_value,
|
||||
examples_per_sec, sec_per_batch))
|
||||
|
||||
if step % 100 == 0:
|
||||
summary_str = sess.run(summary_op)
|
||||
summary_writer.add_summary(summary_str, step)
|
||||
|
||||
# Save the model checkpoint periodically.
|
||||
if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
|
||||
checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
|
||||
saver.save(sess, checkpoint_path, global_step=step)
|
||||
|
||||
|
||||
def main(argv=None): # pylint: disable=unused-argument
|
||||
cifar10.maybe_download_and_extract()
|
||||
if tf.gfile.Exists(FLAGS.train_dir):
|
||||
tf.gfile.DeleteRecursively(FLAGS.train_dir)
|
||||
tf.gfile.MakeDirs(FLAGS.train_dir)
|
||||
train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.app.run()
|
@ -1,121 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""A binary to train CIFAR-10 using a single GPU.
|
||||
|
||||
Accuracy:
|
||||
cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
|
||||
data) as judged by cifar10_eval.py.
|
||||
|
||||
Speed: With batch_size 128.
|
||||
|
||||
System | Step Time (sec/batch) | Accuracy
|
||||
------------------------------------------------------------------
|
||||
1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
|
||||
1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
|
||||
|
||||
Usage:
|
||||
Please see the tutorial and website for how to download the CIFAR-10
|
||||
data set, compile the program and train the model.
|
||||
|
||||
http://tensorflow.org/tutorials/deep_cnn/
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.image.cifar10 import cifar10
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
|
||||
"""Directory where to write event logs """
|
||||
"""and checkpoint.""")
|
||||
tf.app.flags.DEFINE_integer('max_steps', 1000000,
|
||||
"""Number of batches to run.""")
|
||||
tf.app.flags.DEFINE_boolean('log_device_placement', False,
|
||||
"""Whether to log device placement.""")
|
||||
|
||||
|
||||
def train():
|
||||
"""Train CIFAR-10 for a number of steps."""
|
||||
with tf.Graph().as_default():
|
||||
global_step = tf.contrib.framework.get_or_create_global_step()
|
||||
|
||||
# Get images and labels for CIFAR-10.
|
||||
images, labels = cifar10.distorted_inputs()
|
||||
|
||||
# Build a Graph that computes the logits predictions from the
|
||||
# inference model.
|
||||
logits = cifar10.inference(images)
|
||||
|
||||
# Calculate loss.
|
||||
loss = cifar10.loss(logits, labels)
|
||||
|
||||
# Build a Graph that trains the model with one batch of examples and
|
||||
# updates the model parameters.
|
||||
train_op = cifar10.train(loss, global_step)
|
||||
|
||||
class _LoggerHook(tf.train.SessionRunHook):
|
||||
"""Logs loss and runtime."""
|
||||
|
||||
def begin(self):
|
||||
self._step = -1
|
||||
|
||||
def before_run(self, run_context):
|
||||
self._step += 1
|
||||
self._start_time = time.time()
|
||||
return tf.train.SessionRunArgs(loss) # Asks for loss value.
|
||||
|
||||
def after_run(self, run_context, run_values):
|
||||
duration = time.time() - self._start_time
|
||||
loss_value = run_values.results
|
||||
if self._step % 10 == 0:
|
||||
num_examples_per_step = FLAGS.batch_size
|
||||
examples_per_sec = num_examples_per_step / duration
|
||||
sec_per_batch = float(duration)
|
||||
|
||||
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
|
||||
'sec/batch)')
|
||||
print (format_str % (datetime.now(), self._step, loss_value,
|
||||
examples_per_sec, sec_per_batch))
|
||||
|
||||
with tf.train.MonitoredTrainingSession(
|
||||
checkpoint_dir=FLAGS.train_dir,
|
||||
hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
|
||||
tf.train.NanTensorHook(loss),
|
||||
_LoggerHook()],
|
||||
config=tf.ConfigProto(
|
||||
log_device_placement=FLAGS.log_device_placement)) as mon_sess:
|
||||
while not mon_sess.should_stop():
|
||||
mon_sess.run(train_op)
|
||||
|
||||
|
||||
def main(argv=None): # pylint: disable=unused-argument
|
||||
cifar10.maybe_download_and_extract()
|
||||
if tf.gfile.Exists(FLAGS.train_dir):
|
||||
tf.gfile.DeleteRecursively(FLAGS.train_dir)
|
||||
tf.gfile.MakeDirs(FLAGS.train_dir)
|
||||
train()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tf.app.run()
|
||||
|
@ -1,30 +0,0 @@
|
||||
# Description:
|
||||
# Example TensorFlow models for ImageNet.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_binary(
|
||||
name = "classify_image",
|
||||
srcs = [
|
||||
"classify_image.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,227 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Simple image classification with Inception.
|
||||
|
||||
Run image classification with Inception trained on ImageNet 2012 Challenge data
|
||||
set.
|
||||
|
||||
This program creates a graph from a saved GraphDef protocol buffer,
|
||||
and runs inference on an input JPEG image. It outputs human readable
|
||||
strings of the top 5 predictions along with their probabilities.
|
||||
|
||||
Change the --image_file argument to any jpg image to compute a
|
||||
classification of that image.
|
||||
|
||||
Please see the tutorial and website for a detailed description of how
|
||||
to use this script to perform image recognition.
|
||||
|
||||
https://tensorflow.org/tutorials/image_recognition/
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os.path
|
||||
import re
|
||||
import sys
|
||||
import tarfile
|
||||
|
||||
import numpy as np
|
||||
from six.moves import urllib
|
||||
import tensorflow as tf
|
||||
|
||||
FLAGS = None
|
||||
|
||||
# pylint: disable=line-too-long
|
||||
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
|
||||
# pylint: enable=line-too-long
|
||||
|
||||
|
||||
class NodeLookup(object):
|
||||
"""Converts integer node ID's to human readable labels."""
|
||||
|
||||
def __init__(self,
|
||||
label_lookup_path=None,
|
||||
uid_lookup_path=None):
|
||||
if not label_lookup_path:
|
||||
label_lookup_path = os.path.join(
|
||||
FLAGS.model_dir, 'imagenet_2012_challenge_label_map_proto.pbtxt')
|
||||
if not uid_lookup_path:
|
||||
uid_lookup_path = os.path.join(
|
||||
FLAGS.model_dir, 'imagenet_synset_to_human_label_map.txt')
|
||||
self.node_lookup = self.load(label_lookup_path, uid_lookup_path)
|
||||
|
||||
def load(self, label_lookup_path, uid_lookup_path):
|
||||
"""Loads a human readable English name for each softmax node.
|
||||
|
||||
Args:
|
||||
label_lookup_path: string UID to integer node ID.
|
||||
uid_lookup_path: string UID to human-readable string.
|
||||
|
||||
Returns:
|
||||
dict from integer node ID to human-readable string.
|
||||
"""
|
||||
if not tf.gfile.Exists(uid_lookup_path):
|
||||
tf.logging.fatal('File does not exist %s', uid_lookup_path)
|
||||
if not tf.gfile.Exists(label_lookup_path):
|
||||
tf.logging.fatal('File does not exist %s', label_lookup_path)
|
||||
|
||||
# Loads mapping from string UID to human-readable string
|
||||
proto_as_ascii_lines = tf.gfile.GFile(uid_lookup_path).readlines()
|
||||
uid_to_human = {}
|
||||
p = re.compile(r'[n\d]*[ \S,]*')
|
||||
for line in proto_as_ascii_lines:
|
||||
parsed_items = p.findall(line)
|
||||
uid = parsed_items[0]
|
||||
human_string = parsed_items[2]
|
||||
uid_to_human[uid] = human_string
|
||||
|
||||
# Loads mapping from string UID to integer node ID.
|
||||
node_id_to_uid = {}
|
||||
proto_as_ascii = tf.gfile.GFile(label_lookup_path).readlines()
|
||||
for line in proto_as_ascii:
|
||||
if line.startswith(' target_class:'):
|
||||
target_class = int(line.split(': ')[1])
|
||||
if line.startswith(' target_class_string:'):
|
||||
target_class_string = line.split(': ')[1]
|
||||
node_id_to_uid[target_class] = target_class_string[1:-2]
|
||||
|
||||
# Loads the final mapping of integer node ID to human-readable string
|
||||
node_id_to_name = {}
|
||||
for key, val in node_id_to_uid.items():
|
||||
if val not in uid_to_human:
|
||||
tf.logging.fatal('Failed to locate: %s', val)
|
||||
name = uid_to_human[val]
|
||||
node_id_to_name[key] = name
|
||||
|
||||
return node_id_to_name
|
||||
|
||||
def id_to_string(self, node_id):
|
||||
if node_id not in self.node_lookup:
|
||||
return ''
|
||||
return self.node_lookup[node_id]
|
||||
|
||||
|
||||
def create_graph():
|
||||
"""Creates a graph from saved GraphDef file and returns a saver."""
|
||||
# Creates graph from saved graph_def.pb.
|
||||
with tf.gfile.FastGFile(os.path.join(
|
||||
FLAGS.model_dir, 'classify_image_graph_def.pb'), 'rb') as f:
|
||||
graph_def = tf.GraphDef()
|
||||
graph_def.ParseFromString(f.read())
|
||||
_ = tf.import_graph_def(graph_def, name='')
|
||||
|
||||
|
||||
def run_inference_on_image(image):
|
||||
"""Runs inference on an image.
|
||||
|
||||
Args:
|
||||
image: Image file name.
|
||||
|
||||
Returns:
|
||||
Nothing
|
||||
"""
|
||||
if not tf.gfile.Exists(image):
|
||||
tf.logging.fatal('File does not exist %s', image)
|
||||
image_data = tf.gfile.FastGFile(image, 'rb').read()
|
||||
|
||||
# Creates graph from saved GraphDef.
|
||||
create_graph()
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Some useful tensors:
|
||||
# 'softmax:0': A tensor containing the normalized prediction across
|
||||
# 1000 labels.
|
||||
# 'pool_3:0': A tensor containing the next-to-last layer containing 2048
|
||||
# float description of the image.
|
||||
# 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
|
||||
# encoding of the image.
|
||||
# Runs the softmax tensor by feeding the image_data as input to the graph.
|
||||
softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
|
||||
predictions = sess.run(softmax_tensor,
|
||||
{'DecodeJpeg/contents:0': image_data})
|
||||
predictions = np.squeeze(predictions)
|
||||
|
||||
# Creates node ID --> English string lookup.
|
||||
node_lookup = NodeLookup()
|
||||
|
||||
top_k = predictions.argsort()[-FLAGS.num_top_predictions:][::-1]
|
||||
for node_id in top_k:
|
||||
human_string = node_lookup.id_to_string(node_id)
|
||||
score = predictions[node_id]
|
||||
print('%s (score = %.5f)' % (human_string, score))
|
||||
|
||||
|
||||
def maybe_download_and_extract():
|
||||
"""Download and extract model tar file."""
|
||||
dest_directory = FLAGS.model_dir
|
||||
if not os.path.exists(dest_directory):
|
||||
os.makedirs(dest_directory)
|
||||
filename = DATA_URL.split('/')[-1]
|
||||
filepath = os.path.join(dest_directory, filename)
|
||||
if not os.path.exists(filepath):
|
||||
def _progress(count, block_size, total_size):
|
||||
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
|
||||
filename, float(count * block_size) / float(total_size) * 100.0))
|
||||
sys.stdout.flush()
|
||||
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
|
||||
print()
|
||||
statinfo = os.stat(filepath)
|
||||
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
|
||||
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
|
||||
|
||||
|
||||
def main(_):
|
||||
maybe_download_and_extract()
|
||||
image = (FLAGS.image_file if FLAGS.image_file else
|
||||
os.path.join(FLAGS.model_dir, 'cropped_panda.jpg'))
|
||||
run_inference_on_image(image)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
# classify_image_graph_def.pb:
|
||||
# Binary representation of the GraphDef protocol buffer.
|
||||
# imagenet_synset_to_human_label_map.txt:
|
||||
# Map from synset ID to a human readable string.
|
||||
# imagenet_2012_challenge_label_map_proto.pbtxt:
|
||||
# Text representation of a protocol buffer mapping a label to synset ID.
|
||||
parser.add_argument(
|
||||
'--model_dir',
|
||||
type=str,
|
||||
default='/tmp/imagenet',
|
||||
help="""\
|
||||
Path to classify_image_graph_def.pb,
|
||||
imagenet_synset_to_human_label_map.txt, and
|
||||
imagenet_2012_challenge_label_map_proto.pbtxt.\
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--image_file',
|
||||
type=str,
|
||||
default='',
|
||||
help='Absolute path to image file.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_top_predictions',
|
||||
type=int,
|
||||
default=5,
|
||||
help='Display this many predictions.'
|
||||
)
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1,42 +0,0 @@
|
||||
# Description:
|
||||
# Example TensorFlow models for MNIST that achieves high accuracy
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_binary(
|
||||
name = "convolutional",
|
||||
srcs = [
|
||||
"convolutional.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "convolutional_test",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"convolutional.py",
|
||||
],
|
||||
args = [
|
||||
"--self_test",
|
||||
],
|
||||
main = "convolutional.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,339 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Simple, end-to-end, LeNet-5-like convolutional MNIST model example.
|
||||
|
||||
This should achieve a test error of 0.7%. Please keep this model as simple and
|
||||
linear as possible, it is meant as a tutorial for simple convolutional models.
|
||||
Run with --self_test on the command line to execute a short self-test.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import gzip
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy
|
||||
from six.moves import urllib
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
|
||||
WORK_DIRECTORY = 'data'
|
||||
IMAGE_SIZE = 28
|
||||
NUM_CHANNELS = 1
|
||||
PIXEL_DEPTH = 255
|
||||
NUM_LABELS = 10
|
||||
VALIDATION_SIZE = 5000 # Size of the validation set.
|
||||
SEED = 66478 # Set to None for random seed.
|
||||
BATCH_SIZE = 64
|
||||
NUM_EPOCHS = 10
|
||||
EVAL_BATCH_SIZE = 64
|
||||
EVAL_FREQUENCY = 100 # Number of steps between evaluations.
|
||||
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def data_type():
|
||||
"""Return the type of the activations, weights, and placeholder variables."""
|
||||
if FLAGS.use_fp16:
|
||||
return tf.float16
|
||||
else:
|
||||
return tf.float32
|
||||
|
||||
|
||||
def maybe_download(filename):
|
||||
"""Download the data from Yann's website, unless it's already here."""
|
||||
if not tf.gfile.Exists(WORK_DIRECTORY):
|
||||
tf.gfile.MakeDirs(WORK_DIRECTORY)
|
||||
filepath = os.path.join(WORK_DIRECTORY, filename)
|
||||
if not tf.gfile.Exists(filepath):
|
||||
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
|
||||
with tf.gfile.GFile(filepath) as f:
|
||||
size = f.size()
|
||||
print('Successfully downloaded', filename, size, 'bytes.')
|
||||
return filepath
|
||||
|
||||
|
||||
def extract_data(filename, num_images):
|
||||
"""Extract the images into a 4D tensor [image index, y, x, channels].
|
||||
|
||||
Values are rescaled from [0, 255] down to [-0.5, 0.5].
|
||||
"""
|
||||
print('Extracting', filename)
|
||||
with gzip.open(filename) as bytestream:
|
||||
bytestream.read(16)
|
||||
buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS)
|
||||
data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32)
|
||||
data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH
|
||||
data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)
|
||||
return data
|
||||
|
||||
|
||||
def extract_labels(filename, num_images):
|
||||
"""Extract the labels into a vector of int64 label IDs."""
|
||||
print('Extracting', filename)
|
||||
with gzip.open(filename) as bytestream:
|
||||
bytestream.read(8)
|
||||
buf = bytestream.read(1 * num_images)
|
||||
labels = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.int64)
|
||||
return labels
|
||||
|
||||
|
||||
def fake_data(num_images):
|
||||
"""Generate a fake dataset that matches the dimensions of MNIST."""
|
||||
data = numpy.ndarray(
|
||||
shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS),
|
||||
dtype=numpy.float32)
|
||||
labels = numpy.zeros(shape=(num_images,), dtype=numpy.int64)
|
||||
for image in xrange(num_images):
|
||||
label = image % 2
|
||||
data[image, :, :, 0] = label - 0.5
|
||||
labels[image] = label
|
||||
return data, labels
|
||||
|
||||
|
||||
def error_rate(predictions, labels):
|
||||
"""Return the error rate based on dense predictions and sparse labels."""
|
||||
return 100.0 - (
|
||||
100.0 *
|
||||
numpy.sum(numpy.argmax(predictions, 1) == labels) /
|
||||
predictions.shape[0])
|
||||
|
||||
|
||||
def main(_):
|
||||
if FLAGS.self_test:
|
||||
print('Running self-test.')
|
||||
train_data, train_labels = fake_data(256)
|
||||
validation_data, validation_labels = fake_data(EVAL_BATCH_SIZE)
|
||||
test_data, test_labels = fake_data(EVAL_BATCH_SIZE)
|
||||
num_epochs = 1
|
||||
else:
|
||||
# Get the data.
|
||||
train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
|
||||
train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
|
||||
test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
|
||||
test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')
|
||||
|
||||
# Extract it into numpy arrays.
|
||||
train_data = extract_data(train_data_filename, 60000)
|
||||
train_labels = extract_labels(train_labels_filename, 60000)
|
||||
test_data = extract_data(test_data_filename, 10000)
|
||||
test_labels = extract_labels(test_labels_filename, 10000)
|
||||
|
||||
# Generate a validation set.
|
||||
validation_data = train_data[:VALIDATION_SIZE, ...]
|
||||
validation_labels = train_labels[:VALIDATION_SIZE]
|
||||
train_data = train_data[VALIDATION_SIZE:, ...]
|
||||
train_labels = train_labels[VALIDATION_SIZE:]
|
||||
num_epochs = NUM_EPOCHS
|
||||
train_size = train_labels.shape[0]
|
||||
|
||||
# This is where training samples and labels are fed to the graph.
|
||||
# These placeholder nodes will be fed a batch of training data at each
|
||||
# training step using the {feed_dict} argument to the Run() call below.
|
||||
train_data_node = tf.placeholder(
|
||||
data_type(),
|
||||
shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
|
||||
train_labels_node = tf.placeholder(tf.int64, shape=(BATCH_SIZE,))
|
||||
eval_data = tf.placeholder(
|
||||
data_type(),
|
||||
shape=(EVAL_BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
|
||||
|
||||
# The variables below hold all the trainable weights. They are passed an
|
||||
# initial value which will be assigned when we call:
|
||||
# {tf.global_variables_initializer().run()}
|
||||
conv1_weights = tf.Variable(
|
||||
tf.truncated_normal([5, 5, NUM_CHANNELS, 32], # 5x5 filter, depth 32.
|
||||
stddev=0.1,
|
||||
seed=SEED, dtype=data_type()))
|
||||
conv1_biases = tf.Variable(tf.zeros([32], dtype=data_type()))
|
||||
conv2_weights = tf.Variable(tf.truncated_normal(
|
||||
[5, 5, 32, 64], stddev=0.1,
|
||||
seed=SEED, dtype=data_type()))
|
||||
conv2_biases = tf.Variable(tf.constant(0.1, shape=[64], dtype=data_type()))
|
||||
fc1_weights = tf.Variable( # fully connected, depth 512.
|
||||
tf.truncated_normal([IMAGE_SIZE // 4 * IMAGE_SIZE // 4 * 64, 512],
|
||||
stddev=0.1,
|
||||
seed=SEED,
|
||||
dtype=data_type()))
|
||||
fc1_biases = tf.Variable(tf.constant(0.1, shape=[512], dtype=data_type()))
|
||||
fc2_weights = tf.Variable(tf.truncated_normal([512, NUM_LABELS],
|
||||
stddev=0.1,
|
||||
seed=SEED,
|
||||
dtype=data_type()))
|
||||
fc2_biases = tf.Variable(tf.constant(
|
||||
0.1, shape=[NUM_LABELS], dtype=data_type()))
|
||||
|
||||
# We will replicate the model structure for the training subgraph, as well
|
||||
# as the evaluation subgraphs, while sharing the trainable parameters.
|
||||
def model(data, train=False):
|
||||
"""The Model definition."""
|
||||
# 2D convolution, with 'SAME' padding (i.e. the output feature map has
|
||||
# the same size as the input). Note that {strides} is a 4D array whose
|
||||
# shape matches the data layout: [image index, y, x, depth].
|
||||
conv = tf.nn.conv2d(data,
|
||||
conv1_weights,
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME')
|
||||
# Bias and rectified linear non-linearity.
|
||||
relu = tf.nn.relu(tf.nn.bias_add(conv, conv1_biases))
|
||||
# Max pooling. The kernel size spec {ksize} also follows the layout of
|
||||
# the data. Here we have a pooling window of 2, and a stride of 2.
|
||||
pool = tf.nn.max_pool(relu,
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding='SAME')
|
||||
conv = tf.nn.conv2d(pool,
|
||||
conv2_weights,
|
||||
strides=[1, 1, 1, 1],
|
||||
padding='SAME')
|
||||
relu = tf.nn.relu(tf.nn.bias_add(conv, conv2_biases))
|
||||
pool = tf.nn.max_pool(relu,
|
||||
ksize=[1, 2, 2, 1],
|
||||
strides=[1, 2, 2, 1],
|
||||
padding='SAME')
|
||||
# Reshape the feature map cuboid into a 2D matrix to feed it to the
|
||||
# fully connected layers.
|
||||
pool_shape = pool.get_shape().as_list()
|
||||
reshape = tf.reshape(
|
||||
pool,
|
||||
[pool_shape[0], pool_shape[1] * pool_shape[2] * pool_shape[3]])
|
||||
# Fully connected layer. Note that the '+' operation automatically
|
||||
# broadcasts the biases.
|
||||
hidden = tf.nn.relu(tf.matmul(reshape, fc1_weights) + fc1_biases)
|
||||
# Add a 50% dropout during training only. Dropout also scales
|
||||
# activations such that no rescaling is needed at evaluation time.
|
||||
if train:
|
||||
hidden = tf.nn.dropout(hidden, 0.5, seed=SEED)
|
||||
return tf.matmul(hidden, fc2_weights) + fc2_biases
|
||||
|
||||
# Training computation: logits + cross-entropy loss.
|
||||
logits = model(train_data_node, True)
|
||||
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
|
||||
logits, train_labels_node))
|
||||
|
||||
# L2 regularization for the fully connected parameters.
|
||||
regularizers = (tf.nn.l2_loss(fc1_weights) + tf.nn.l2_loss(fc1_biases) +
|
||||
tf.nn.l2_loss(fc2_weights) + tf.nn.l2_loss(fc2_biases))
|
||||
# Add the regularization term to the loss.
|
||||
loss += 5e-4 * regularizers
|
||||
|
||||
# Optimizer: set up a variable that's incremented once per batch and
|
||||
# controls the learning rate decay.
|
||||
batch = tf.Variable(0, dtype=data_type())
|
||||
# Decay once per epoch, using an exponential schedule starting at 0.01.
|
||||
learning_rate = tf.train.exponential_decay(
|
||||
0.01, # Base learning rate.
|
||||
batch * BATCH_SIZE, # Current index into the dataset.
|
||||
train_size, # Decay step.
|
||||
0.95, # Decay rate.
|
||||
staircase=True)
|
||||
# Use simple momentum for the optimization.
|
||||
optimizer = tf.train.MomentumOptimizer(learning_rate,
|
||||
0.9).minimize(loss,
|
||||
global_step=batch)
|
||||
|
||||
# Predictions for the current training minibatch.
|
||||
train_prediction = tf.nn.softmax(logits)
|
||||
|
||||
# Predictions for the test and validation, which we'll compute less often.
|
||||
eval_prediction = tf.nn.softmax(model(eval_data))
|
||||
|
||||
# Small utility function to evaluate a dataset by feeding batches of data to
|
||||
# {eval_data} and pulling the results from {eval_predictions}.
|
||||
# Saves memory and enables this to run on smaller GPUs.
|
||||
def eval_in_batches(data, sess):
|
||||
"""Get all predictions for a dataset by running it in small batches."""
|
||||
size = data.shape[0]
|
||||
if size < EVAL_BATCH_SIZE:
|
||||
raise ValueError("batch size for evals larger than dataset: %d" % size)
|
||||
predictions = numpy.ndarray(shape=(size, NUM_LABELS), dtype=numpy.float32)
|
||||
for begin in xrange(0, size, EVAL_BATCH_SIZE):
|
||||
end = begin + EVAL_BATCH_SIZE
|
||||
if end <= size:
|
||||
predictions[begin:end, :] = sess.run(
|
||||
eval_prediction,
|
||||
feed_dict={eval_data: data[begin:end, ...]})
|
||||
else:
|
||||
batch_predictions = sess.run(
|
||||
eval_prediction,
|
||||
feed_dict={eval_data: data[-EVAL_BATCH_SIZE:, ...]})
|
||||
predictions[begin:, :] = batch_predictions[begin - size:, :]
|
||||
return predictions
|
||||
|
||||
# Create a local session to run the training.
|
||||
start_time = time.time()
|
||||
with tf.Session() as sess:
|
||||
# Run all the initializers to prepare the trainable parameters.
|
||||
tf.global_variables_initializer().run()
|
||||
print('Initialized!')
|
||||
# Loop through training steps.
|
||||
for step in xrange(int(num_epochs * train_size) // BATCH_SIZE):
|
||||
# Compute the offset of the current minibatch in the data.
|
||||
# Note that we could use better randomization across epochs.
|
||||
offset = (step * BATCH_SIZE) % (train_size - BATCH_SIZE)
|
||||
batch_data = train_data[offset:(offset + BATCH_SIZE), ...]
|
||||
batch_labels = train_labels[offset:(offset + BATCH_SIZE)]
|
||||
# This dictionary maps the batch data (as a numpy array) to the
|
||||
# node in the graph it should be fed to.
|
||||
feed_dict = {train_data_node: batch_data,
|
||||
train_labels_node: batch_labels}
|
||||
# Run the optimizer to update weights.
|
||||
sess.run(optimizer, feed_dict=feed_dict)
|
||||
# print some extra information once reach the evaluation frequency
|
||||
if step % EVAL_FREQUENCY == 0:
|
||||
# fetch some extra nodes' data
|
||||
l, lr, predictions = sess.run([loss, learning_rate, train_prediction],
|
||||
feed_dict=feed_dict)
|
||||
elapsed_time = time.time() - start_time
|
||||
start_time = time.time()
|
||||
print('Step %d (epoch %.2f), %.1f ms' %
|
||||
(step, float(step) * BATCH_SIZE / train_size,
|
||||
1000 * elapsed_time / EVAL_FREQUENCY))
|
||||
print('Minibatch loss: %.3f, learning rate: %.6f' % (l, lr))
|
||||
print('Minibatch error: %.1f%%' % error_rate(predictions, batch_labels))
|
||||
print('Validation error: %.1f%%' % error_rate(
|
||||
eval_in_batches(validation_data, sess), validation_labels))
|
||||
sys.stdout.flush()
|
||||
# Finally print the result!
|
||||
test_error = error_rate(eval_in_batches(test_data, sess), test_labels)
|
||||
print('Test error: %.1f%%' % test_error)
|
||||
if FLAGS.self_test:
|
||||
print('test_error', test_error)
|
||||
assert test_error == 0.0, 'expected 0.0 test_error, got %.2f' % (
|
||||
test_error,)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--use_fp16',
|
||||
default=False,
|
||||
help='Use half floats instead of full floats if True.',
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
'--self_test',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='True if running a self test.')
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
@ -1,80 +0,0 @@
|
||||
# Description:
|
||||
# Example RNN models, including language models and sequence-to-sequence models.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "linear",
|
||||
srcs = [
|
||||
"linear.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "rnn_cell",
|
||||
srcs = [
|
||||
"rnn_cell.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":linear",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "package",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":rnn",
|
||||
":rnn_cell",
|
||||
":seq2seq",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "rnn",
|
||||
srcs = [
|
||||
"rnn.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":rnn_cell",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "seq2seq",
|
||||
srcs = [
|
||||
"seq2seq.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":rnn",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,13 +0,0 @@
|
||||
This directory contains functions for creating recurrent neural networks
|
||||
and sequence-to-sequence models. Detailed instructions on how to get started
|
||||
and use them are available in the tutorials.
|
||||
|
||||
* [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/index.md)
|
||||
* [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/index.md)
|
||||
|
||||
Here is a short overview of what is in this directory.
|
||||
|
||||
File | What's in it?
|
||||
--- | ---
|
||||
`ptb/` | PTB language model, see the [RNN Tutorial](http://tensorflow.org/tutorials/recurrent/)
|
||||
`translate/` | Translation model, see the [Sequence-to-Sequence Tutorial](http://tensorflow.org/tutorials/seq2seq/)
|
@ -1,19 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Libraries to build Recurrent Neural Networks."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
@ -1,20 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Import linear python op for backward compatibility."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
raise ImportError("This module is deprecated. Use tf.contrib.layers.linear.")
|
@ -1,61 +0,0 @@
|
||||
# Description:
|
||||
# Python support for TensorFlow.
|
||||
|
||||
package(default_visibility = ["//tensorflow:internal"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "package",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":reader",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "reader",
|
||||
srcs = ["reader.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "reader_test",
|
||||
size = "small",
|
||||
srcs = ["reader_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":reader",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "ptb_word_lm",
|
||||
srcs = [
|
||||
"ptb_word_lm.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":reader",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,21 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Makes helper libraries available in the ptb package."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.models.rnn.ptb import reader
|
@ -1,374 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Example / benchmark for building a PTB LSTM model.
|
||||
|
||||
Trains the model described in:
|
||||
(Zaremba, et. al.) Recurrent Neural Network Regularization
|
||||
http://arxiv.org/abs/1409.2329
|
||||
|
||||
There are 3 supported model configurations:
|
||||
===========================================
|
||||
| config | epochs | train | valid | test
|
||||
===========================================
|
||||
| small | 13 | 37.99 | 121.39 | 115.91
|
||||
| medium | 39 | 48.45 | 86.16 | 82.07
|
||||
| large | 55 | 37.87 | 82.62 | 78.29
|
||||
The exact results may vary depending on the random initialization.
|
||||
|
||||
The hyperparameters used in the model:
|
||||
- init_scale - the initial scale of the weights
|
||||
- learning_rate - the initial value of the learning rate
|
||||
- max_grad_norm - the maximum permissible norm of the gradient
|
||||
- num_layers - the number of LSTM layers
|
||||
- num_steps - the number of unrolled steps of LSTM
|
||||
- hidden_size - the number of LSTM units
|
||||
- max_epoch - the number of epochs trained with the initial learning rate
|
||||
- max_max_epoch - the total number of epochs for training
|
||||
- keep_prob - the probability of keeping weights in the dropout layer
|
||||
- lr_decay - the decay of the learning rate for each epoch after "max_epoch"
|
||||
- batch_size - the batch size
|
||||
|
||||
The data required for this example is in the data/ dir of the
|
||||
PTB dataset from Tomas Mikolov's webpage:
|
||||
|
||||
$ wget http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
|
||||
$ tar xvf simple-examples.tgz
|
||||
|
||||
To run:
|
||||
|
||||
$ python ptb_word_lm.py --data_path=simple-examples/data/
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.rnn.ptb import reader
|
||||
|
||||
flags = tf.flags
|
||||
logging = tf.logging
|
||||
|
||||
flags.DEFINE_string(
|
||||
"model", "small",
|
||||
"A type of model. Possible options are: small, medium, large.")
|
||||
flags.DEFINE_string("data_path", None,
|
||||
"Where the training/test data is stored.")
|
||||
flags.DEFINE_string("save_path", None,
|
||||
"Model output directory.")
|
||||
flags.DEFINE_bool("use_fp16", False,
|
||||
"Train using 16-bit floats instead of 32bit floats")
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
def data_type():
|
||||
return tf.float16 if FLAGS.use_fp16 else tf.float32
|
||||
|
||||
|
||||
class PTBInput(object):
|
||||
"""The input data."""
|
||||
|
||||
def __init__(self, config, data, name=None):
|
||||
self.batch_size = batch_size = config.batch_size
|
||||
self.num_steps = num_steps = config.num_steps
|
||||
self.epoch_size = ((len(data) // batch_size) - 1) // num_steps
|
||||
self.input_data, self.targets = reader.ptb_producer(
|
||||
data, batch_size, num_steps, name=name)
|
||||
|
||||
|
||||
class PTBModel(object):
|
||||
"""The PTB model."""
|
||||
|
||||
def __init__(self, is_training, config, input_):
|
||||
self._input = input_
|
||||
|
||||
batch_size = input_.batch_size
|
||||
num_steps = input_.num_steps
|
||||
size = config.hidden_size
|
||||
vocab_size = config.vocab_size
|
||||
|
||||
# Slightly better results can be obtained with forget gate biases
|
||||
# initialized to 1 but the hyperparameters of the model would need to be
|
||||
# different than reported in the paper.
|
||||
lstm_cell = tf.contrib.rnn.BasicLSTMCell(
|
||||
size, forget_bias=0.0, state_is_tuple=True)
|
||||
if is_training and config.keep_prob < 1:
|
||||
lstm_cell = tf.contrib.rnn.DropoutWrapper(
|
||||
lstm_cell, output_keep_prob=config.keep_prob)
|
||||
cell = tf.contrib.rnn.MultiRNNCell(
|
||||
[lstm_cell] * config.num_layers, state_is_tuple=True)
|
||||
|
||||
self._initial_state = cell.zero_state(batch_size, data_type())
|
||||
|
||||
with tf.device("/cpu:0"):
|
||||
embedding = tf.get_variable(
|
||||
"embedding", [vocab_size, size], dtype=data_type())
|
||||
inputs = tf.nn.embedding_lookup(embedding, input_.input_data)
|
||||
|
||||
if is_training and config.keep_prob < 1:
|
||||
inputs = tf.nn.dropout(inputs, config.keep_prob)
|
||||
|
||||
# Simplified version of tensorflow.models.rnn.rnn.py's rnn().
|
||||
# This builds an unrolled LSTM for tutorial purposes only.
|
||||
# In general, use the rnn() or state_saving_rnn() from rnn.py.
|
||||
#
|
||||
# The alternative version of the code below is:
|
||||
#
|
||||
# inputs = [tf.squeeze(input_step, [1])
|
||||
# for input_step in tf.split(value=inputs,
|
||||
# num_or_size_splits=num_steps,
|
||||
# axis=1)]
|
||||
# outputs, state = tf.nn.rnn(cell, inputs,
|
||||
# initial_state=self._initial_state)
|
||||
outputs = []
|
||||
state = self._initial_state
|
||||
with tf.variable_scope("RNN"):
|
||||
for time_step in range(num_steps):
|
||||
if time_step > 0: tf.get_variable_scope().reuse_variables()
|
||||
(cell_output, state) = cell(inputs[:, time_step, :], state)
|
||||
outputs.append(cell_output)
|
||||
|
||||
output = tf.reshape(tf.concat_v2(outputs, 1), [-1, size])
|
||||
softmax_w = tf.get_variable(
|
||||
"softmax_w", [size, vocab_size], dtype=data_type())
|
||||
softmax_b = tf.get_variable("softmax_b", [vocab_size], dtype=data_type())
|
||||
logits = tf.matmul(output, softmax_w) + softmax_b
|
||||
loss = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
|
||||
[logits],
|
||||
[tf.reshape(input_.targets, [-1])],
|
||||
[tf.ones([batch_size * num_steps], dtype=data_type())])
|
||||
self._cost = cost = tf.reduce_sum(loss) / batch_size
|
||||
self._final_state = state
|
||||
|
||||
if not is_training:
|
||||
return
|
||||
|
||||
self._lr = tf.Variable(0.0, trainable=False)
|
||||
tvars = tf.trainable_variables()
|
||||
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
|
||||
config.max_grad_norm)
|
||||
optimizer = tf.train.GradientDescentOptimizer(self._lr)
|
||||
self._train_op = optimizer.apply_gradients(
|
||||
zip(grads, tvars),
|
||||
global_step=tf.contrib.framework.get_or_create_global_step())
|
||||
|
||||
self._new_lr = tf.placeholder(
|
||||
tf.float32, shape=[], name="new_learning_rate")
|
||||
self._lr_update = tf.assign(self._lr, self._new_lr)
|
||||
|
||||
def assign_lr(self, session, lr_value):
|
||||
session.run(self._lr_update, feed_dict={self._new_lr: lr_value})
|
||||
|
||||
@property
|
||||
def input(self):
|
||||
return self._input
|
||||
|
||||
@property
|
||||
def initial_state(self):
|
||||
return self._initial_state
|
||||
|
||||
@property
|
||||
def cost(self):
|
||||
return self._cost
|
||||
|
||||
@property
|
||||
def final_state(self):
|
||||
return self._final_state
|
||||
|
||||
@property
|
||||
def lr(self):
|
||||
return self._lr
|
||||
|
||||
@property
|
||||
def train_op(self):
|
||||
return self._train_op
|
||||
|
||||
|
||||
class SmallConfig(object):
|
||||
"""Small config."""
|
||||
init_scale = 0.1
|
||||
learning_rate = 1.0
|
||||
max_grad_norm = 5
|
||||
num_layers = 2
|
||||
num_steps = 20
|
||||
hidden_size = 200
|
||||
max_epoch = 4
|
||||
max_max_epoch = 13
|
||||
keep_prob = 1.0
|
||||
lr_decay = 0.5
|
||||
batch_size = 20
|
||||
vocab_size = 10000
|
||||
|
||||
|
||||
class MediumConfig(object):
|
||||
"""Medium config."""
|
||||
init_scale = 0.05
|
||||
learning_rate = 1.0
|
||||
max_grad_norm = 5
|
||||
num_layers = 2
|
||||
num_steps = 35
|
||||
hidden_size = 650
|
||||
max_epoch = 6
|
||||
max_max_epoch = 39
|
||||
keep_prob = 0.5
|
||||
lr_decay = 0.8
|
||||
batch_size = 20
|
||||
vocab_size = 10000
|
||||
|
||||
|
||||
class LargeConfig(object):
|
||||
"""Large config."""
|
||||
init_scale = 0.04
|
||||
learning_rate = 1.0
|
||||
max_grad_norm = 10
|
||||
num_layers = 2
|
||||
num_steps = 35
|
||||
hidden_size = 1500
|
||||
max_epoch = 14
|
||||
max_max_epoch = 55
|
||||
keep_prob = 0.35
|
||||
lr_decay = 1 / 1.15
|
||||
batch_size = 20
|
||||
vocab_size = 10000
|
||||
|
||||
|
||||
class TestConfig(object):
|
||||
"""Tiny config, for testing."""
|
||||
init_scale = 0.1
|
||||
learning_rate = 1.0
|
||||
max_grad_norm = 1
|
||||
num_layers = 1
|
||||
num_steps = 2
|
||||
hidden_size = 2
|
||||
max_epoch = 1
|
||||
max_max_epoch = 1
|
||||
keep_prob = 1.0
|
||||
lr_decay = 0.5
|
||||
batch_size = 20
|
||||
vocab_size = 10000
|
||||
|
||||
|
||||
def run_epoch(session, model, eval_op=None, verbose=False):
|
||||
"""Runs the model on the given data."""
|
||||
start_time = time.time()
|
||||
costs = 0.0
|
||||
iters = 0
|
||||
state = session.run(model.initial_state)
|
||||
|
||||
fetches = {
|
||||
"cost": model.cost,
|
||||
"final_state": model.final_state,
|
||||
}
|
||||
if eval_op is not None:
|
||||
fetches["eval_op"] = eval_op
|
||||
|
||||
for step in range(model.input.epoch_size):
|
||||
feed_dict = {}
|
||||
for i, (c, h) in enumerate(model.initial_state):
|
||||
feed_dict[c] = state[i].c
|
||||
feed_dict[h] = state[i].h
|
||||
|
||||
vals = session.run(fetches, feed_dict)
|
||||
cost = vals["cost"]
|
||||
state = vals["final_state"]
|
||||
|
||||
costs += cost
|
||||
iters += model.input.num_steps
|
||||
|
||||
if verbose and step % (model.input.epoch_size // 10) == 10:
|
||||
print("%.3f perplexity: %.3f speed: %.0f wps" %
|
||||
(step * 1.0 / model.input.epoch_size, np.exp(costs / iters),
|
||||
iters * model.input.batch_size / (time.time() - start_time)))
|
||||
|
||||
return np.exp(costs / iters)
|
||||
|
||||
|
||||
def get_config():
|
||||
if FLAGS.model == "small":
|
||||
return SmallConfig()
|
||||
elif FLAGS.model == "medium":
|
||||
return MediumConfig()
|
||||
elif FLAGS.model == "large":
|
||||
return LargeConfig()
|
||||
elif FLAGS.model == "test":
|
||||
return TestConfig()
|
||||
else:
|
||||
raise ValueError("Invalid model: %s", FLAGS.model)
|
||||
|
||||
|
||||
def main(_):
|
||||
if not FLAGS.data_path:
|
||||
raise ValueError("Must set --data_path to PTB data directory")
|
||||
|
||||
raw_data = reader.ptb_raw_data(FLAGS.data_path)
|
||||
train_data, valid_data, test_data, _ = raw_data
|
||||
|
||||
config = get_config()
|
||||
eval_config = get_config()
|
||||
eval_config.batch_size = 1
|
||||
eval_config.num_steps = 1
|
||||
|
||||
with tf.Graph().as_default():
|
||||
initializer = tf.random_uniform_initializer(-config.init_scale,
|
||||
config.init_scale)
|
||||
|
||||
with tf.name_scope("Train"):
|
||||
train_input = PTBInput(config=config, data=train_data, name="TrainInput")
|
||||
with tf.variable_scope("Model", reuse=None, initializer=initializer):
|
||||
m = PTBModel(is_training=True, config=config, input_=train_input)
|
||||
tf.contrib.deprecated.scalar_summary("Training Loss", m.cost)
|
||||
tf.contrib.deprecated.scalar_summary("Learning Rate", m.lr)
|
||||
|
||||
with tf.name_scope("Valid"):
|
||||
valid_input = PTBInput(config=config, data=valid_data, name="ValidInput")
|
||||
with tf.variable_scope("Model", reuse=True, initializer=initializer):
|
||||
mvalid = PTBModel(is_training=False, config=config, input_=valid_input)
|
||||
tf.contrib.deprecated.scalar_summary("Validation Loss", mvalid.cost)
|
||||
|
||||
with tf.name_scope("Test"):
|
||||
test_input = PTBInput(config=eval_config, data=test_data, name="TestInput")
|
||||
with tf.variable_scope("Model", reuse=True, initializer=initializer):
|
||||
mtest = PTBModel(is_training=False, config=eval_config,
|
||||
input_=test_input)
|
||||
|
||||
sv = tf.train.Supervisor(logdir=FLAGS.save_path)
|
||||
with sv.managed_session() as session:
|
||||
for i in range(config.max_max_epoch):
|
||||
lr_decay = config.lr_decay ** max(i + 1 - config.max_epoch, 0.0)
|
||||
m.assign_lr(session, config.learning_rate * lr_decay)
|
||||
|
||||
print("Epoch: %d Learning rate: %.3f" % (i + 1, session.run(m.lr)))
|
||||
train_perplexity = run_epoch(session, m, eval_op=m.train_op,
|
||||
verbose=True)
|
||||
print("Epoch: %d Train Perplexity: %.3f" % (i + 1, train_perplexity))
|
||||
valid_perplexity = run_epoch(session, mvalid)
|
||||
print("Epoch: %d Valid Perplexity: %.3f" % (i + 1, valid_perplexity))
|
||||
|
||||
test_perplexity = run_epoch(session, mtest)
|
||||
print("Test Perplexity: %.3f" % test_perplexity)
|
||||
|
||||
if FLAGS.save_path:
|
||||
print("Saving model to %s." % FLAGS.save_path)
|
||||
sv.saver.save(session, FLAGS.save_path, global_step=sv.global_step)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
@ -1,120 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
|
||||
"""Utilities for parsing PTB text files."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import os
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def _read_words(filename):
|
||||
with tf.gfile.GFile(filename, "r") as f:
|
||||
return f.read().decode("utf-8").replace("\n", "<eos>").split()
|
||||
|
||||
|
||||
def _build_vocab(filename):
|
||||
data = _read_words(filename)
|
||||
|
||||
counter = collections.Counter(data)
|
||||
count_pairs = sorted(counter.items(), key=lambda x: (-x[1], x[0]))
|
||||
|
||||
words, _ = list(zip(*count_pairs))
|
||||
word_to_id = dict(zip(words, range(len(words))))
|
||||
|
||||
return word_to_id
|
||||
|
||||
|
||||
def _file_to_word_ids(filename, word_to_id):
|
||||
data = _read_words(filename)
|
||||
return [word_to_id[word] for word in data if word in word_to_id]
|
||||
|
||||
|
||||
def ptb_raw_data(data_path=None):
|
||||
"""Load PTB raw data from data directory "data_path".
|
||||
|
||||
Reads PTB text files, converts strings to integer ids,
|
||||
and performs mini-batching of the inputs.
|
||||
|
||||
The PTB dataset comes from Tomas Mikolov's webpage:
|
||||
|
||||
http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
|
||||
|
||||
Args:
|
||||
data_path: string path to the directory where simple-examples.tgz has
|
||||
been extracted.
|
||||
|
||||
Returns:
|
||||
tuple (train_data, valid_data, test_data, vocabulary)
|
||||
where each of the data objects can be passed to PTBIterator.
|
||||
"""
|
||||
|
||||
train_path = os.path.join(data_path, "ptb.train.txt")
|
||||
valid_path = os.path.join(data_path, "ptb.valid.txt")
|
||||
test_path = os.path.join(data_path, "ptb.test.txt")
|
||||
|
||||
word_to_id = _build_vocab(train_path)
|
||||
train_data = _file_to_word_ids(train_path, word_to_id)
|
||||
valid_data = _file_to_word_ids(valid_path, word_to_id)
|
||||
test_data = _file_to_word_ids(test_path, word_to_id)
|
||||
vocabulary = len(word_to_id)
|
||||
return train_data, valid_data, test_data, vocabulary
|
||||
|
||||
|
||||
def ptb_producer(raw_data, batch_size, num_steps, name=None):
|
||||
"""Iterate on the raw PTB data.
|
||||
|
||||
This chunks up raw_data into batches of examples and returns Tensors that
|
||||
are drawn from these batches.
|
||||
|
||||
Args:
|
||||
raw_data: one of the raw data outputs from ptb_raw_data.
|
||||
batch_size: int, the batch size.
|
||||
num_steps: int, the number of unrolls.
|
||||
name: the name of this operation (optional).
|
||||
|
||||
Returns:
|
||||
A pair of Tensors, each shaped [batch_size, num_steps]. The second element
|
||||
of the tuple is the same data time-shifted to the right by one.
|
||||
|
||||
Raises:
|
||||
tf.errors.InvalidArgumentError: if batch_size or num_steps are too high.
|
||||
"""
|
||||
with tf.name_scope(name, "PTBProducer", [raw_data, batch_size, num_steps]):
|
||||
raw_data = tf.convert_to_tensor(raw_data, name="raw_data", dtype=tf.int32)
|
||||
|
||||
data_len = tf.size(raw_data)
|
||||
batch_len = data_len // batch_size
|
||||
data = tf.reshape(raw_data[0 : batch_size * batch_len],
|
||||
[batch_size, batch_len])
|
||||
|
||||
epoch_size = (batch_len - 1) // num_steps
|
||||
assertion = tf.assert_positive(
|
||||
epoch_size,
|
||||
message="epoch_size == 0, decrease batch_size or num_steps")
|
||||
with tf.control_dependencies([assertion]):
|
||||
epoch_size = tf.identity(epoch_size, name="epoch_size")
|
||||
|
||||
i = tf.train.range_input_producer(epoch_size, shuffle=False).dequeue()
|
||||
x = tf.strided_slice(data, [0, i * num_steps],
|
||||
[batch_size, (i + 1) * num_steps])
|
||||
y = tf.strided_slice(data, [0, i * num_steps + 1],
|
||||
[batch_size, (i + 1) * num_steps + 1])
|
||||
return x, y
|
@ -1,68 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for tensorflow.models.ptb_lstm.ptb_reader."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.rnn.ptb import reader
|
||||
|
||||
|
||||
class PtbReaderTest(tf.test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._string_data = "\n".join(
|
||||
[" hello there i am",
|
||||
" rain as day",
|
||||
" want some cheesy puffs ?"])
|
||||
|
||||
def testPtbRawData(self):
|
||||
tmpdir = tf.test.get_temp_dir()
|
||||
for suffix in "train", "valid", "test":
|
||||
filename = os.path.join(tmpdir, "ptb.%s.txt" % suffix)
|
||||
with tf.gfile.GFile(filename, "w") as fh:
|
||||
fh.write(self._string_data)
|
||||
# Smoke test
|
||||
output = reader.ptb_raw_data(tmpdir)
|
||||
self.assertEqual(len(output), 4)
|
||||
|
||||
def testPtbProducer(self):
|
||||
raw_data = [4, 3, 2, 1, 0, 5, 6, 1, 1, 1, 1, 0, 3, 4, 1]
|
||||
batch_size = 3
|
||||
num_steps = 2
|
||||
x, y = reader.ptb_producer(raw_data, batch_size, num_steps)
|
||||
with self.test_session() as session:
|
||||
coord = tf.train.Coordinator()
|
||||
tf.train.start_queue_runners(session, coord=coord)
|
||||
try:
|
||||
xval, yval = session.run([x, y])
|
||||
self.assertAllEqual(xval, [[4, 3], [5, 6], [1, 0]])
|
||||
self.assertAllEqual(yval, [[3, 2], [6, 1], [0, 3]])
|
||||
xval, yval = session.run([x, y])
|
||||
self.assertAllEqual(xval, [[2, 1], [1, 1], [3, 4]])
|
||||
self.assertAllEqual(yval, [[1, 0], [1, 1], [4, 1]])
|
||||
finally:
|
||||
coord.request_stop()
|
||||
coord.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -1,21 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Import rnn python ops for backward compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
raise ImportError("This module is deprecated. Use tf.nn.rnn_* instead.")
|
@ -1,21 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Import rnn_cell python ops for backward compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
raise ImportError("This module is deprecated. Use tf.contrib.rnn instead.")
|
@ -1,22 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Import seq2seq python ops for backward compatibility."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
raise ImportError(
|
||||
"This module is deprecated. Use tf.contrib.legacy_seq2seq instead.")
|
@ -1,84 +0,0 @@
|
||||
# Description:
|
||||
# Example neural translation models.
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
py_library(
|
||||
name = "package",
|
||||
srcs = [
|
||||
"__init__.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":data_utils",
|
||||
":seq2seq_model",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "data_utils",
|
||||
srcs = [
|
||||
"data_utils.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = ["//tensorflow:tensorflow_py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "seq2seq_model",
|
||||
srcs = [
|
||||
"seq2seq_model.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":data_utils",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "translate",
|
||||
srcs = [
|
||||
"translate.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":data_utils",
|
||||
":seq2seq_model",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "translate_test",
|
||||
size = "medium",
|
||||
srcs = [
|
||||
"translate.py",
|
||||
],
|
||||
args = [
|
||||
"--self_test=True",
|
||||
],
|
||||
main = "translate.py",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":data_utils",
|
||||
":seq2seq_model",
|
||||
"//tensorflow:tensorflow_py",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
@ -1,22 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Makes helper libraries available in the translate package."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.models.rnn.translate import data_utils
|
||||
from tensorflow.models.rnn.translate import seq2seq_model
|
@ -1,290 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Utilities for downloading data from WMT, tokenizing, vocabularies."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import gzip
|
||||
import os
|
||||
import re
|
||||
import tarfile
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
from tensorflow.python.platform import gfile
|
||||
import tensorflow as tf
|
||||
|
||||
# Special vocabulary symbols - we always put them at the start.
|
||||
_PAD = b"_PAD"
|
||||
_GO = b"_GO"
|
||||
_EOS = b"_EOS"
|
||||
_UNK = b"_UNK"
|
||||
_START_VOCAB = [_PAD, _GO, _EOS, _UNK]
|
||||
|
||||
PAD_ID = 0
|
||||
GO_ID = 1
|
||||
EOS_ID = 2
|
||||
UNK_ID = 3
|
||||
|
||||
# Regular expressions used to tokenize.
|
||||
_WORD_SPLIT = re.compile(b"([.,!?\"':;)(])")
|
||||
_DIGIT_RE = re.compile(br"\d")
|
||||
|
||||
# URLs for WMT data.
|
||||
_WMT_ENFR_TRAIN_URL = "http://www.statmt.org/wmt10/training-giga-fren.tar"
|
||||
_WMT_ENFR_DEV_URL = "http://www.statmt.org/wmt15/dev-v2.tgz"
|
||||
|
||||
|
||||
def maybe_download(directory, filename, url):
|
||||
"""Download filename from url unless it's already in directory."""
|
||||
if not os.path.exists(directory):
|
||||
print("Creating directory %s" % directory)
|
||||
os.mkdir(directory)
|
||||
filepath = os.path.join(directory, filename)
|
||||
if not os.path.exists(filepath):
|
||||
print("Downloading %s to %s" % (url, filepath))
|
||||
filepath, _ = urllib.request.urlretrieve(url, filepath)
|
||||
statinfo = os.stat(filepath)
|
||||
print("Succesfully downloaded", filename, statinfo.st_size, "bytes")
|
||||
return filepath
|
||||
|
||||
|
||||
def gunzip_file(gz_path, new_path):
|
||||
"""Unzips from gz_path into new_path."""
|
||||
print("Unpacking %s to %s" % (gz_path, new_path))
|
||||
with gzip.open(gz_path, "rb") as gz_file:
|
||||
with open(new_path, "wb") as new_file:
|
||||
for line in gz_file:
|
||||
new_file.write(line)
|
||||
|
||||
|
||||
def get_wmt_enfr_train_set(directory):
|
||||
"""Download the WMT en-fr training corpus to directory unless it's there."""
|
||||
train_path = os.path.join(directory, "giga-fren.release2.fixed")
|
||||
if not (gfile.Exists(train_path +".fr") and gfile.Exists(train_path +".en")):
|
||||
corpus_file = maybe_download(directory, "training-giga-fren.tar",
|
||||
_WMT_ENFR_TRAIN_URL)
|
||||
print("Extracting tar file %s" % corpus_file)
|
||||
with tarfile.open(corpus_file, "r") as corpus_tar:
|
||||
corpus_tar.extractall(directory)
|
||||
gunzip_file(train_path + ".fr.gz", train_path + ".fr")
|
||||
gunzip_file(train_path + ".en.gz", train_path + ".en")
|
||||
return train_path
|
||||
|
||||
|
||||
def get_wmt_enfr_dev_set(directory):
|
||||
"""Download the WMT en-fr training corpus to directory unless it's there."""
|
||||
dev_name = "newstest2013"
|
||||
dev_path = os.path.join(directory, dev_name)
|
||||
if not (gfile.Exists(dev_path + ".fr") and gfile.Exists(dev_path + ".en")):
|
||||
dev_file = maybe_download(directory, "dev-v2.tgz", _WMT_ENFR_DEV_URL)
|
||||
print("Extracting tgz file %s" % dev_file)
|
||||
with tarfile.open(dev_file, "r:gz") as dev_tar:
|
||||
fr_dev_file = dev_tar.getmember("dev/" + dev_name + ".fr")
|
||||
en_dev_file = dev_tar.getmember("dev/" + dev_name + ".en")
|
||||
fr_dev_file.name = dev_name + ".fr" # Extract without "dev/" prefix.
|
||||
en_dev_file.name = dev_name + ".en"
|
||||
dev_tar.extract(fr_dev_file, directory)
|
||||
dev_tar.extract(en_dev_file, directory)
|
||||
return dev_path
|
||||
|
||||
|
||||
def basic_tokenizer(sentence):
|
||||
"""Very basic tokenizer: split the sentence into a list of tokens."""
|
||||
words = []
|
||||
for space_separated_fragment in sentence.strip().split():
|
||||
words.extend(_WORD_SPLIT.split(space_separated_fragment))
|
||||
return [w for w in words if w]
|
||||
|
||||
|
||||
def create_vocabulary(vocabulary_path, data_path, max_vocabulary_size,
|
||||
tokenizer=None, normalize_digits=True):
|
||||
"""Create vocabulary file (if it does not exist yet) from data file.
|
||||
|
||||
Data file is assumed to contain one sentence per line. Each sentence is
|
||||
tokenized and digits are normalized (if normalize_digits is set).
|
||||
Vocabulary contains the most-frequent tokens up to max_vocabulary_size.
|
||||
We write it to vocabulary_path in a one-token-per-line format, so that later
|
||||
token in the first line gets id=0, second line gets id=1, and so on.
|
||||
|
||||
Args:
|
||||
vocabulary_path: path where the vocabulary will be created.
|
||||
data_path: data file that will be used to create vocabulary.
|
||||
max_vocabulary_size: limit on the size of the created vocabulary.
|
||||
tokenizer: a function to use to tokenize each data sentence;
|
||||
if None, basic_tokenizer will be used.
|
||||
normalize_digits: Boolean; if true, all digits are replaced by 0s.
|
||||
"""
|
||||
if not gfile.Exists(vocabulary_path):
|
||||
print("Creating vocabulary %s from data %s" % (vocabulary_path, data_path))
|
||||
vocab = {}
|
||||
with gfile.GFile(data_path, mode="rb") as f:
|
||||
counter = 0
|
||||
for line in f:
|
||||
counter += 1
|
||||
if counter % 100000 == 0:
|
||||
print(" processing line %d" % counter)
|
||||
line = tf.compat.as_bytes(line)
|
||||
tokens = tokenizer(line) if tokenizer else basic_tokenizer(line)
|
||||
for w in tokens:
|
||||
word = _DIGIT_RE.sub(b"0", w) if normalize_digits else w
|
||||
if word in vocab:
|
||||
vocab[word] += 1
|
||||
else:
|
||||
vocab[word] = 1
|
||||
vocab_list = _START_VOCAB + sorted(vocab, key=vocab.get, reverse=True)
|
||||
if len(vocab_list) > max_vocabulary_size:
|
||||
vocab_list = vocab_list[:max_vocabulary_size]
|
||||
with gfile.GFile(vocabulary_path, mode="wb") as vocab_file:
|
||||
for w in vocab_list:
|
||||
vocab_file.write(w + b"\n")
|
||||
|
||||
|
||||
def initialize_vocabulary(vocabulary_path):
|
||||
"""Initialize vocabulary from file.
|
||||
|
||||
We assume the vocabulary is stored one-item-per-line, so a file:
|
||||
dog
|
||||
cat
|
||||
will result in a vocabulary {"dog": 0, "cat": 1}, and this function will
|
||||
also return the reversed-vocabulary ["dog", "cat"].
|
||||
|
||||
Args:
|
||||
vocabulary_path: path to the file containing the vocabulary.
|
||||
|
||||
Returns:
|
||||
a pair: the vocabulary (a dictionary mapping string to integers), and
|
||||
the reversed vocabulary (a list, which reverses the vocabulary mapping).
|
||||
|
||||
Raises:
|
||||
ValueError: if the provided vocabulary_path does not exist.
|
||||
"""
|
||||
if gfile.Exists(vocabulary_path):
|
||||
rev_vocab = []
|
||||
with gfile.GFile(vocabulary_path, mode="rb") as f:
|
||||
rev_vocab.extend(f.readlines())
|
||||
rev_vocab = [line.strip() for line in rev_vocab]
|
||||
vocab = dict([(x, y) for (y, x) in enumerate(rev_vocab)])
|
||||
return vocab, rev_vocab
|
||||
else:
|
||||
raise ValueError("Vocabulary file %s not found.", vocabulary_path)
|
||||
|
||||
|
||||
def sentence_to_token_ids(sentence, vocabulary,
|
||||
tokenizer=None, normalize_digits=True):
|
||||
"""Convert a string to list of integers representing token-ids.
|
||||
|
||||
For example, a sentence "I have a dog" may become tokenized into
|
||||
["I", "have", "a", "dog"] and with vocabulary {"I": 1, "have": 2,
|
||||
"a": 4, "dog": 7"} this function will return [1, 2, 4, 7].
|
||||
|
||||
Args:
|
||||
sentence: the sentence in bytes format to convert to token-ids.
|
||||
vocabulary: a dictionary mapping tokens to integers.
|
||||
tokenizer: a function to use to tokenize each sentence;
|
||||
if None, basic_tokenizer will be used.
|
||||
normalize_digits: Boolean; if true, all digits are replaced by 0s.
|
||||
|
||||
Returns:
|
||||
a list of integers, the token-ids for the sentence.
|
||||
"""
|
||||
|
||||
if tokenizer:
|
||||
words = tokenizer(sentence)
|
||||
else:
|
||||
words = basic_tokenizer(sentence)
|
||||
if not normalize_digits:
|
||||
return [vocabulary.get(w, UNK_ID) for w in words]
|
||||
# Normalize digits by 0 before looking words up in the vocabulary.
|
||||
return [vocabulary.get(_DIGIT_RE.sub(b"0", w), UNK_ID) for w in words]
|
||||
|
||||
|
||||
def data_to_token_ids(data_path, target_path, vocabulary_path,
|
||||
tokenizer=None, normalize_digits=True):
|
||||
"""Tokenize data file and turn into token-ids using given vocabulary file.
|
||||
|
||||
This function loads data line-by-line from data_path, calls the above
|
||||
sentence_to_token_ids, and saves the result to target_path. See comment
|
||||
for sentence_to_token_ids on the details of token-ids format.
|
||||
|
||||
Args:
|
||||
data_path: path to the data file in one-sentence-per-line format.
|
||||
target_path: path where the file with token-ids will be created.
|
||||
vocabulary_path: path to the vocabulary file.
|
||||
tokenizer: a function to use to tokenize each sentence;
|
||||
if None, basic_tokenizer will be used.
|
||||
normalize_digits: Boolean; if true, all digits are replaced by 0s.
|
||||
"""
|
||||
if not gfile.Exists(target_path):
|
||||
print("Tokenizing data in %s" % data_path)
|
||||
vocab, _ = initialize_vocabulary(vocabulary_path)
|
||||
with gfile.GFile(data_path, mode="rb") as data_file:
|
||||
with gfile.GFile(target_path, mode="w") as tokens_file:
|
||||
counter = 0
|
||||
for line in data_file:
|
||||
counter += 1
|
||||
if counter % 100000 == 0:
|
||||
print(" tokenizing line %d" % counter)
|
||||
token_ids = sentence_to_token_ids(line, vocab, tokenizer,
|
||||
normalize_digits)
|
||||
tokens_file.write(" ".join([str(tok) for tok in token_ids]) + "\n")
|
||||
|
||||
|
||||
def prepare_wmt_data(data_dir, en_vocabulary_size, fr_vocabulary_size, tokenizer=None):
|
||||
"""Get WMT data into data_dir, create vocabularies and tokenize data.
|
||||
|
||||
Args:
|
||||
data_dir: directory in which the data sets will be stored.
|
||||
en_vocabulary_size: size of the English vocabulary to create and use.
|
||||
fr_vocabulary_size: size of the French vocabulary to create and use.
|
||||
tokenizer: a function to use to tokenize each data sentence;
|
||||
if None, basic_tokenizer will be used.
|
||||
|
||||
Returns:
|
||||
A tuple of 6 elements:
|
||||
(1) path to the token-ids for English training data-set,
|
||||
(2) path to the token-ids for French training data-set,
|
||||
(3) path to the token-ids for English development data-set,
|
||||
(4) path to the token-ids for French development data-set,
|
||||
(5) path to the English vocabulary file,
|
||||
(6) path to the French vocabulary file.
|
||||
"""
|
||||
# Get wmt data to the specified directory.
|
||||
train_path = get_wmt_enfr_train_set(data_dir)
|
||||
dev_path = get_wmt_enfr_dev_set(data_dir)
|
||||
|
||||
# Create vocabularies of the appropriate sizes.
|
||||
fr_vocab_path = os.path.join(data_dir, "vocab%d.fr" % fr_vocabulary_size)
|
||||
en_vocab_path = os.path.join(data_dir, "vocab%d.en" % en_vocabulary_size)
|
||||
create_vocabulary(fr_vocab_path, train_path + ".fr", fr_vocabulary_size, tokenizer)
|
||||
create_vocabulary(en_vocab_path, train_path + ".en", en_vocabulary_size, tokenizer)
|
||||
|
||||
# Create token ids for the training data.
|
||||
fr_train_ids_path = train_path + (".ids%d.fr" % fr_vocabulary_size)
|
||||
en_train_ids_path = train_path + (".ids%d.en" % en_vocabulary_size)
|
||||
data_to_token_ids(train_path + ".fr", fr_train_ids_path, fr_vocab_path, tokenizer)
|
||||
data_to_token_ids(train_path + ".en", en_train_ids_path, en_vocab_path, tokenizer)
|
||||
|
||||
# Create token ids for the development data.
|
||||
fr_dev_ids_path = dev_path + (".ids%d.fr" % fr_vocabulary_size)
|
||||
en_dev_ids_path = dev_path + (".ids%d.en" % en_vocabulary_size)
|
||||
data_to_token_ids(dev_path + ".fr", fr_dev_ids_path, fr_vocab_path, tokenizer)
|
||||
data_to_token_ids(dev_path + ".en", en_dev_ids_path, en_vocab_path, tokenizer)
|
||||
|
||||
return (en_train_ids_path, fr_train_ids_path,
|
||||
en_dev_ids_path, fr_dev_ids_path,
|
||||
en_vocab_path, fr_vocab_path)
|
@ -1,313 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Sequence-to-sequence model with an attention mechanism."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.rnn.translate import data_utils
|
||||
|
||||
|
||||
class Seq2SeqModel(object):
|
||||
"""Sequence-to-sequence model with attention and for multiple buckets.
|
||||
|
||||
This class implements a multi-layer recurrent neural network as encoder,
|
||||
and an attention-based decoder. This is the same as the model described in
|
||||
this paper: http://arxiv.org/abs/1412.7449 - please look there for details,
|
||||
or into the seq2seq library for complete model implementation.
|
||||
This class also allows to use GRU cells in addition to LSTM cells, and
|
||||
sampled softmax to handle large output vocabulary size. A single-layer
|
||||
version of this model, but with bi-directional encoder, was presented in
|
||||
http://arxiv.org/abs/1409.0473
|
||||
and sampled softmax is described in Section 3 of the following paper.
|
||||
http://arxiv.org/abs/1412.2007
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
source_vocab_size,
|
||||
target_vocab_size,
|
||||
buckets,
|
||||
size,
|
||||
num_layers,
|
||||
max_gradient_norm,
|
||||
batch_size,
|
||||
learning_rate,
|
||||
learning_rate_decay_factor,
|
||||
use_lstm=False,
|
||||
num_samples=512,
|
||||
forward_only=False,
|
||||
dtype=tf.float32):
|
||||
"""Create the model.
|
||||
|
||||
Args:
|
||||
source_vocab_size: size of the source vocabulary.
|
||||
target_vocab_size: size of the target vocabulary.
|
||||
buckets: a list of pairs (I, O), where I specifies maximum input length
|
||||
that will be processed in that bucket, and O specifies maximum output
|
||||
length. Training instances that have inputs longer than I or outputs
|
||||
longer than O will be pushed to the next bucket and padded accordingly.
|
||||
We assume that the list is sorted, e.g., [(2, 4), (8, 16)].
|
||||
size: number of units in each layer of the model.
|
||||
num_layers: number of layers in the model.
|
||||
max_gradient_norm: gradients will be clipped to maximally this norm.
|
||||
batch_size: the size of the batches used during training;
|
||||
the model construction is independent of batch_size, so it can be
|
||||
changed after initialization if this is convenient, e.g., for decoding.
|
||||
learning_rate: learning rate to start with.
|
||||
learning_rate_decay_factor: decay learning rate by this much when needed.
|
||||
use_lstm: if true, we use LSTM cells instead of GRU cells.
|
||||
num_samples: number of samples for sampled softmax.
|
||||
forward_only: if set, we do not construct the backward pass in the model.
|
||||
dtype: the data type to use to store internal variables.
|
||||
"""
|
||||
self.source_vocab_size = source_vocab_size
|
||||
self.target_vocab_size = target_vocab_size
|
||||
self.buckets = buckets
|
||||
self.batch_size = batch_size
|
||||
self.learning_rate = tf.Variable(
|
||||
float(learning_rate), trainable=False, dtype=dtype)
|
||||
self.learning_rate_decay_op = self.learning_rate.assign(
|
||||
self.learning_rate * learning_rate_decay_factor)
|
||||
self.global_step = tf.Variable(0, trainable=False)
|
||||
|
||||
# If we use sampled softmax, we need an output projection.
|
||||
output_projection = None
|
||||
softmax_loss_function = None
|
||||
# Sampled softmax only makes sense if we sample less than vocabulary size.
|
||||
if num_samples > 0 and num_samples < self.target_vocab_size:
|
||||
w_t = tf.get_variable("proj_w", [self.target_vocab_size, size], dtype=dtype)
|
||||
w = tf.transpose(w_t)
|
||||
b = tf.get_variable("proj_b", [self.target_vocab_size], dtype=dtype)
|
||||
output_projection = (w, b)
|
||||
|
||||
def sampled_loss(labels, inputs):
|
||||
labels = tf.reshape(labels, [-1, 1])
|
||||
# We need to compute the sampled_softmax_loss using 32bit floats to
|
||||
# avoid numerical instabilities.
|
||||
local_w_t = tf.cast(w_t, tf.float32)
|
||||
local_b = tf.cast(b, tf.float32)
|
||||
local_inputs = tf.cast(inputs, tf.float32)
|
||||
return tf.cast(
|
||||
tf.nn.sampled_softmax_loss(
|
||||
weights=local_w_t,
|
||||
biases=local_b,
|
||||
labels=labels,
|
||||
inputs=local_inputs,
|
||||
num_sampled=num_samples,
|
||||
num_classes=self.target_vocab_size),
|
||||
dtype)
|
||||
softmax_loss_function = sampled_loss
|
||||
|
||||
# Create the internal multi-layer cell for our RNN.
|
||||
single_cell = tf.contrib.rnn.GRUCell(size)
|
||||
if use_lstm:
|
||||
single_cell = tf.contrib.rnn.BasicLSTMCell(size)
|
||||
cell = single_cell
|
||||
if num_layers > 1:
|
||||
cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)
|
||||
|
||||
# The seq2seq function: we use embedding for the input and attention.
|
||||
def seq2seq_f(encoder_inputs, decoder_inputs, do_decode):
|
||||
return tf.contrib.legacy_seq2seq.embedding_attention_seq2seq(
|
||||
encoder_inputs,
|
||||
decoder_inputs,
|
||||
cell,
|
||||
num_encoder_symbols=source_vocab_size,
|
||||
num_decoder_symbols=target_vocab_size,
|
||||
embedding_size=size,
|
||||
output_projection=output_projection,
|
||||
feed_previous=do_decode,
|
||||
dtype=dtype)
|
||||
|
||||
# Feeds for inputs.
|
||||
self.encoder_inputs = []
|
||||
self.decoder_inputs = []
|
||||
self.target_weights = []
|
||||
for i in xrange(buckets[-1][0]): # Last bucket is the biggest one.
|
||||
self.encoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
|
||||
name="encoder{0}".format(i)))
|
||||
for i in xrange(buckets[-1][1] + 1):
|
||||
self.decoder_inputs.append(tf.placeholder(tf.int32, shape=[None],
|
||||
name="decoder{0}".format(i)))
|
||||
self.target_weights.append(tf.placeholder(dtype, shape=[None],
|
||||
name="weight{0}".format(i)))
|
||||
|
||||
# Our targets are decoder inputs shifted by one.
|
||||
targets = [self.decoder_inputs[i + 1]
|
||||
for i in xrange(len(self.decoder_inputs) - 1)]
|
||||
|
||||
# Training outputs and losses.
|
||||
if forward_only:
|
||||
self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
|
||||
self.encoder_inputs, self.decoder_inputs, targets,
|
||||
self.target_weights, buckets, lambda x, y: seq2seq_f(x, y, True),
|
||||
softmax_loss_function=softmax_loss_function)
|
||||
# If we use output projection, we need to project outputs for decoding.
|
||||
if output_projection is not None:
|
||||
for b in xrange(len(buckets)):
|
||||
self.outputs[b] = [
|
||||
tf.matmul(output, output_projection[0]) + output_projection[1]
|
||||
for output in self.outputs[b]
|
||||
]
|
||||
else:
|
||||
self.outputs, self.losses = tf.contrib.legacy_seq2seq.model_with_buckets(
|
||||
self.encoder_inputs, self.decoder_inputs, targets,
|
||||
self.target_weights, buckets,
|
||||
lambda x, y: seq2seq_f(x, y, False),
|
||||
softmax_loss_function=softmax_loss_function)
|
||||
|
||||
# Gradients and SGD update operation for training the model.
|
||||
params = tf.trainable_variables()
|
||||
if not forward_only:
|
||||
self.gradient_norms = []
|
||||
self.updates = []
|
||||
opt = tf.train.GradientDescentOptimizer(self.learning_rate)
|
||||
for b in xrange(len(buckets)):
|
||||
gradients = tf.gradients(self.losses[b], params)
|
||||
clipped_gradients, norm = tf.clip_by_global_norm(gradients,
|
||||
max_gradient_norm)
|
||||
self.gradient_norms.append(norm)
|
||||
self.updates.append(opt.apply_gradients(
|
||||
zip(clipped_gradients, params), global_step=self.global_step))
|
||||
|
||||
self.saver = tf.train.Saver(tf.all_variables())
|
||||
|
||||
def step(self, session, encoder_inputs, decoder_inputs, target_weights,
|
||||
bucket_id, forward_only):
|
||||
"""Run a step of the model feeding the given inputs.
|
||||
|
||||
Args:
|
||||
session: tensorflow session to use.
|
||||
encoder_inputs: list of numpy int vectors to feed as encoder inputs.
|
||||
decoder_inputs: list of numpy int vectors to feed as decoder inputs.
|
||||
target_weights: list of numpy float vectors to feed as target weights.
|
||||
bucket_id: which bucket of the model to use.
|
||||
forward_only: whether to do the backward step or only forward.
|
||||
|
||||
Returns:
|
||||
A triple consisting of gradient norm (or None if we did not do backward),
|
||||
average perplexity, and the outputs.
|
||||
|
||||
Raises:
|
||||
ValueError: if length of encoder_inputs, decoder_inputs, or
|
||||
target_weights disagrees with bucket size for the specified bucket_id.
|
||||
"""
|
||||
# Check if the sizes match.
|
||||
encoder_size, decoder_size = self.buckets[bucket_id]
|
||||
if len(encoder_inputs) != encoder_size:
|
||||
raise ValueError("Encoder length must be equal to the one in bucket,"
|
||||
" %d != %d." % (len(encoder_inputs), encoder_size))
|
||||
if len(decoder_inputs) != decoder_size:
|
||||
raise ValueError("Decoder length must be equal to the one in bucket,"
|
||||
" %d != %d." % (len(decoder_inputs), decoder_size))
|
||||
if len(target_weights) != decoder_size:
|
||||
raise ValueError("Weights length must be equal to the one in bucket,"
|
||||
" %d != %d." % (len(target_weights), decoder_size))
|
||||
|
||||
# Input feed: encoder inputs, decoder inputs, target_weights, as provided.
|
||||
input_feed = {}
|
||||
for l in xrange(encoder_size):
|
||||
input_feed[self.encoder_inputs[l].name] = encoder_inputs[l]
|
||||
for l in xrange(decoder_size):
|
||||
input_feed[self.decoder_inputs[l].name] = decoder_inputs[l]
|
||||
input_feed[self.target_weights[l].name] = target_weights[l]
|
||||
|
||||
# Since our targets are decoder inputs shifted by one, we need one more.
|
||||
last_target = self.decoder_inputs[decoder_size].name
|
||||
input_feed[last_target] = np.zeros([self.batch_size], dtype=np.int32)
|
||||
|
||||
# Output feed: depends on whether we do a backward step or not.
|
||||
if not forward_only:
|
||||
output_feed = [self.updates[bucket_id], # Update Op that does SGD.
|
||||
self.gradient_norms[bucket_id], # Gradient norm.
|
||||
self.losses[bucket_id]] # Loss for this batch.
|
||||
else:
|
||||
output_feed = [self.losses[bucket_id]] # Loss for this batch.
|
||||
for l in xrange(decoder_size): # Output logits.
|
||||
output_feed.append(self.outputs[bucket_id][l])
|
||||
|
||||
outputs = session.run(output_feed, input_feed)
|
||||
if not forward_only:
|
||||
return outputs[1], outputs[2], None # Gradient norm, loss, no outputs.
|
||||
else:
|
||||
return None, outputs[0], outputs[1:] # No gradient norm, loss, outputs.
|
||||
|
||||
def get_batch(self, data, bucket_id):
|
||||
"""Get a random batch of data from the specified bucket, prepare for step.
|
||||
|
||||
To feed data in step(..) it must be a list of batch-major vectors, while
|
||||
data here contains single length-major cases. So the main logic of this
|
||||
function is to re-index data cases to be in the proper format for feeding.
|
||||
|
||||
Args:
|
||||
data: a tuple of size len(self.buckets) in which each element contains
|
||||
lists of pairs of input and output data that we use to create a batch.
|
||||
bucket_id: integer, which bucket to get the batch for.
|
||||
|
||||
Returns:
|
||||
The triple (encoder_inputs, decoder_inputs, target_weights) for
|
||||
the constructed batch that has the proper format to call step(...) later.
|
||||
"""
|
||||
encoder_size, decoder_size = self.buckets[bucket_id]
|
||||
encoder_inputs, decoder_inputs = [], []
|
||||
|
||||
# Get a random batch of encoder and decoder inputs from data,
|
||||
# pad them if needed, reverse encoder inputs and add GO to decoder.
|
||||
for _ in xrange(self.batch_size):
|
||||
encoder_input, decoder_input = random.choice(data[bucket_id])
|
||||
|
||||
# Encoder inputs are padded and then reversed.
|
||||
encoder_pad = [data_utils.PAD_ID] * (encoder_size - len(encoder_input))
|
||||
encoder_inputs.append(list(reversed(encoder_input + encoder_pad)))
|
||||
|
||||
# Decoder inputs get an extra "GO" symbol, and are padded then.
|
||||
decoder_pad_size = decoder_size - len(decoder_input) - 1
|
||||
decoder_inputs.append([data_utils.GO_ID] + decoder_input +
|
||||
[data_utils.PAD_ID] * decoder_pad_size)
|
||||
|
||||
# Now we create batch-major vectors from the data selected above.
|
||||
batch_encoder_inputs, batch_decoder_inputs, batch_weights = [], [], []
|
||||
|
||||
# Batch encoder inputs are just re-indexed encoder_inputs.
|
||||
for length_idx in xrange(encoder_size):
|
||||
batch_encoder_inputs.append(
|
||||
np.array([encoder_inputs[batch_idx][length_idx]
|
||||
for batch_idx in xrange(self.batch_size)], dtype=np.int32))
|
||||
|
||||
# Batch decoder inputs are re-indexed decoder_inputs, we create weights.
|
||||
for length_idx in xrange(decoder_size):
|
||||
batch_decoder_inputs.append(
|
||||
np.array([decoder_inputs[batch_idx][length_idx]
|
||||
for batch_idx in xrange(self.batch_size)], dtype=np.int32))
|
||||
|
||||
# Create target_weights to be 0 for targets that are padding.
|
||||
batch_weight = np.ones(self.batch_size, dtype=np.float32)
|
||||
for batch_idx in xrange(self.batch_size):
|
||||
# We set weight to 0 if the corresponding target is a PAD symbol.
|
||||
# The corresponding target is decoder_input shifted by 1 forward.
|
||||
if length_idx < decoder_size - 1:
|
||||
target = decoder_inputs[batch_idx][length_idx + 1]
|
||||
if length_idx == decoder_size - 1 or target == data_utils.PAD_ID:
|
||||
batch_weight[batch_idx] = 0.0
|
||||
batch_weights.append(batch_weight)
|
||||
return batch_encoder_inputs, batch_decoder_inputs, batch_weights
|
@ -1,297 +0,0 @@
|
||||
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Binary for training translation models and decoding from them.
|
||||
|
||||
Running this program without --decode will download the WMT corpus into
|
||||
the directory specified as --data_dir and tokenize it in a very basic way,
|
||||
and then start training a model saving checkpoints to --train_dir.
|
||||
|
||||
Running with --decode starts an interactive loop so you can see how
|
||||
the current checkpoint translates English sentences into French.
|
||||
|
||||
See the following papers for more information on neural translation models.
|
||||
* http://arxiv.org/abs/1409.3215
|
||||
* http://arxiv.org/abs/1409.0473
|
||||
* http://arxiv.org/abs/1412.2007
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.models.rnn.translate import data_utils
|
||||
from tensorflow.models.rnn.translate import seq2seq_model
|
||||
|
||||
|
||||
tf.app.flags.DEFINE_float("learning_rate", 0.5, "Learning rate.")
|
||||
tf.app.flags.DEFINE_float("learning_rate_decay_factor", 0.99,
|
||||
"Learning rate decays by this much.")
|
||||
tf.app.flags.DEFINE_float("max_gradient_norm", 5.0,
|
||||
"Clip gradients to this norm.")
|
||||
tf.app.flags.DEFINE_integer("batch_size", 64,
|
||||
"Batch size to use during training.")
|
||||
tf.app.flags.DEFINE_integer("size", 1024, "Size of each model layer.")
|
||||
tf.app.flags.DEFINE_integer("num_layers", 3, "Number of layers in the model.")
|
||||
tf.app.flags.DEFINE_integer("en_vocab_size", 40000, "English vocabulary size.")
|
||||
tf.app.flags.DEFINE_integer("fr_vocab_size", 40000, "French vocabulary size.")
|
||||
tf.app.flags.DEFINE_string("data_dir", "/tmp", "Data directory")
|
||||
tf.app.flags.DEFINE_string("train_dir", "/tmp", "Training directory.")
|
||||
tf.app.flags.DEFINE_integer("max_train_data_size", 0,
|
||||
"Limit on the size of training data (0: no limit).")
|
||||
tf.app.flags.DEFINE_integer("steps_per_checkpoint", 200,
|
||||
"How many training steps to do per checkpoint.")
|
||||
tf.app.flags.DEFINE_boolean("decode", False,
|
||||
"Set to True for interactive decoding.")
|
||||
tf.app.flags.DEFINE_boolean("self_test", False,
|
||||
"Run a self-test if this is set to True.")
|
||||
tf.app.flags.DEFINE_boolean("use_fp16", False,
|
||||
"Train using fp16 instead of fp32.")
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
|
||||
# We use a number of buckets and pad to the closest one for efficiency.
|
||||
# See seq2seq_model.Seq2SeqModel for details of how they work.
|
||||
_buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
|
||||
|
||||
|
||||
def read_data(source_path, target_path, max_size=None):
|
||||
"""Read data from source and target files and put into buckets.
|
||||
|
||||
Args:
|
||||
source_path: path to the files with token-ids for the source language.
|
||||
target_path: path to the file with token-ids for the target language;
|
||||
it must be aligned with the source file: n-th line contains the desired
|
||||
output for n-th line from the source_path.
|
||||
max_size: maximum number of lines to read, all other will be ignored;
|
||||
if 0 or None, data files will be read completely (no limit).
|
||||
|
||||
Returns:
|
||||
data_set: a list of length len(_buckets); data_set[n] contains a list of
|
||||
(source, target) pairs read from the provided data files that fit
|
||||
into the n-th bucket, i.e., such that len(source) < _buckets[n][0] and
|
||||
len(target) < _buckets[n][1]; source and target are lists of token-ids.
|
||||
"""
|
||||
data_set = [[] for _ in _buckets]
|
||||
with tf.gfile.GFile(source_path, mode="r") as source_file:
|
||||
with tf.gfile.GFile(target_path, mode="r") as target_file:
|
||||
source, target = source_file.readline(), target_file.readline()
|
||||
counter = 0
|
||||
while source and target and (not max_size or counter < max_size):
|
||||
counter += 1
|
||||
if counter % 100000 == 0:
|
||||
print(" reading data line %d" % counter)
|
||||
sys.stdout.flush()
|
||||
source_ids = [int(x) for x in source.split()]
|
||||
target_ids = [int(x) for x in target.split()]
|
||||
target_ids.append(data_utils.EOS_ID)
|
||||
for bucket_id, (source_size, target_size) in enumerate(_buckets):
|
||||
if len(source_ids) < source_size and len(target_ids) < target_size:
|
||||
data_set[bucket_id].append([source_ids, target_ids])
|
||||
break
|
||||
source, target = source_file.readline(), target_file.readline()
|
||||
return data_set
|
||||
|
||||
|
||||
def create_model(session, forward_only):
|
||||
"""Create translation model and initialize or load parameters in session."""
|
||||
dtype = tf.float16 if FLAGS.use_fp16 else tf.float32
|
||||
model = seq2seq_model.Seq2SeqModel(
|
||||
FLAGS.en_vocab_size,
|
||||
FLAGS.fr_vocab_size,
|
||||
_buckets,
|
||||
FLAGS.size,
|
||||
FLAGS.num_layers,
|
||||
FLAGS.max_gradient_norm,
|
||||
FLAGS.batch_size,
|
||||
FLAGS.learning_rate,
|
||||
FLAGS.learning_rate_decay_factor,
|
||||
forward_only=forward_only,
|
||||
dtype=dtype)
|
||||
ckpt = tf.train.get_checkpoint_state(FLAGS.train_dir)
|
||||
if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
|
||||
print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
|
||||
model.saver.restore(session, ckpt.model_checkpoint_path)
|
||||
else:
|
||||
print("Created model with fresh parameters.")
|
||||
session.run(tf.global_variables_initializer())
|
||||
return model
|
||||
|
||||
|
||||
def train():
|
||||
"""Train a en->fr translation model using WMT data."""
|
||||
# Prepare WMT data.
|
||||
print("Preparing WMT data in %s" % FLAGS.data_dir)
|
||||
en_train, fr_train, en_dev, fr_dev, _, _ = data_utils.prepare_wmt_data(
|
||||
FLAGS.data_dir, FLAGS.en_vocab_size, FLAGS.fr_vocab_size)
|
||||
|
||||
with tf.Session() as sess:
|
||||
# Create model.
|
||||
print("Creating %d layers of %d units." % (FLAGS.num_layers, FLAGS.size))
|
||||
model = create_model(sess, False)
|
||||
|
||||
# Read data into buckets and compute their sizes.
|
||||
print ("Reading development and training data (limit: %d)."
|
||||
% FLAGS.max_train_data_size)
|
||||
dev_set = read_data(en_dev, fr_dev)
|
||||
train_set = read_data(en_train, fr_train, FLAGS.max_train_data_size)
|
||||
train_bucket_sizes = [len(train_set[b]) for b in xrange(len(_buckets))]
|
||||
train_total_size = float(sum(train_bucket_sizes))
|
||||
|
||||
# A bucket scale is a list of increasing numbers from 0 to 1 that we'll use
|
||||
# to select a bucket. Length of [scale[i], scale[i+1]] is proportional to
|
||||
# the size if i-th training bucket, as used later.
|
||||
train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size
|
||||
for i in xrange(len(train_bucket_sizes))]
|
||||
|
||||
# This is the training loop.
|
||||
step_time, loss = 0.0, 0.0
|
||||
current_step = 0
|
||||
previous_losses = []
|
||||
while True:
|
||||
# Choose a bucket according to data distribution. We pick a random number
|
||||
# in [0, 1] and use the corresponding interval in train_buckets_scale.
|
||||
random_number_01 = np.random.random_sample()
|
||||
bucket_id = min([i for i in xrange(len(train_buckets_scale))
|
||||
if train_buckets_scale[i] > random_number_01])
|
||||
|
||||
# Get a batch and make a step.
|
||||
start_time = time.time()
|
||||
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
||||
train_set, bucket_id)
|
||||
_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
|
||||
target_weights, bucket_id, False)
|
||||
step_time += (time.time() - start_time) / FLAGS.steps_per_checkpoint
|
||||
loss += step_loss / FLAGS.steps_per_checkpoint
|
||||
current_step += 1
|
||||
|
||||
# Once in a while, we save checkpoint, print statistics, and run evals.
|
||||
if current_step % FLAGS.steps_per_checkpoint == 0:
|
||||
# Print statistics for the previous epoch.
|
||||
perplexity = math.exp(float(loss)) if loss < 300 else float("inf")
|
||||
print ("global step %d learning rate %.4f step-time %.2f perplexity "
|
||||
"%.2f" % (model.global_step.eval(), model.learning_rate.eval(),
|
||||
step_time, perplexity))
|
||||
# Decrease learning rate if no improvement was seen over last 3 times.
|
||||
if len(previous_losses) > 2 and loss > max(previous_losses[-3:]):
|
||||
sess.run(model.learning_rate_decay_op)
|
||||
previous_losses.append(loss)
|
||||
# Save checkpoint and zero timer and loss.
|
||||
checkpoint_path = os.path.join(FLAGS.train_dir, "translate.ckpt")
|
||||
model.saver.save(sess, checkpoint_path, global_step=model.global_step)
|
||||
step_time, loss = 0.0, 0.0
|
||||
# Run evals on development set and print their perplexity.
|
||||
for bucket_id in xrange(len(_buckets)):
|
||||
if len(dev_set[bucket_id]) == 0:
|
||||
print(" eval: empty bucket %d" % (bucket_id))
|
||||
continue
|
||||
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
||||
dev_set, bucket_id)
|
||||
_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs,
|
||||
target_weights, bucket_id, True)
|
||||
eval_ppx = math.exp(float(eval_loss)) if eval_loss < 300 else float(
|
||||
"inf")
|
||||
print(" eval: bucket %d perplexity %.2f" % (bucket_id, eval_ppx))
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def decode():
|
||||
with tf.Session() as sess:
|
||||
# Create model and load parameters.
|
||||
model = create_model(sess, True)
|
||||
model.batch_size = 1 # We decode one sentence at a time.
|
||||
|
||||
# Load vocabularies.
|
||||
en_vocab_path = os.path.join(FLAGS.data_dir,
|
||||
"vocab%d.en" % FLAGS.en_vocab_size)
|
||||
fr_vocab_path = os.path.join(FLAGS.data_dir,
|
||||
"vocab%d.fr" % FLAGS.fr_vocab_size)
|
||||
en_vocab, _ = data_utils.initialize_vocabulary(en_vocab_path)
|
||||
_, rev_fr_vocab = data_utils.initialize_vocabulary(fr_vocab_path)
|
||||
|
||||
# Decode from standard input.
|
||||
sys.stdout.write("> ")
|
||||
sys.stdout.flush()
|
||||
sentence = sys.stdin.readline()
|
||||
while sentence:
|
||||
# Get token-ids for the input sentence.
|
||||
token_ids = data_utils.sentence_to_token_ids(tf.compat.as_bytes(sentence), en_vocab)
|
||||
# Which bucket does it belong to?
|
||||
bucket_id = len(_buckets) - 1
|
||||
for i, bucket in enumerate(_buckets):
|
||||
if bucket[0] >= len(token_ids):
|
||||
bucket_id = i
|
||||
break
|
||||
else:
|
||||
logging.warning("Sentence truncated: %s", sentence)
|
||||
|
||||
# Get a 1-element batch to feed the sentence to the model.
|
||||
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
||||
{bucket_id: [(token_ids, [])]}, bucket_id)
|
||||
# Get output logits for the sentence.
|
||||
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
|
||||
target_weights, bucket_id, True)
|
||||
# This is a greedy decoder - outputs are just argmaxes of output_logits.
|
||||
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
|
||||
# If there is an EOS symbol in outputs, cut them at that point.
|
||||
if data_utils.EOS_ID in outputs:
|
||||
outputs = outputs[:outputs.index(data_utils.EOS_ID)]
|
||||
# Print out French sentence corresponding to outputs.
|
||||
print(" ".join([tf.compat.as_str(rev_fr_vocab[output]) for output in outputs]))
|
||||
print("> ", end="")
|
||||
sys.stdout.flush()
|
||||
sentence = sys.stdin.readline()
|
||||
|
||||
|
||||
def self_test():
|
||||
"""Test the translation model."""
|
||||
with tf.Session() as sess:
|
||||
print("Self-test for neural translation model.")
|
||||
# Create model with vocabularies of 10, 2 small buckets, 2 layers of 32.
|
||||
model = seq2seq_model.Seq2SeqModel(10, 10, [(3, 3), (6, 6)], 32, 2,
|
||||
5.0, 32, 0.3, 0.99, num_samples=8)
|
||||
sess.run(tf.global_variables_initializer())
|
||||
|
||||
# Fake data set for both the (3, 3) and (6, 6) bucket.
|
||||
data_set = ([([1, 1], [2, 2]), ([3, 3], [4]), ([5], [6])],
|
||||
[([1, 1, 1, 1, 1], [2, 2, 2, 2, 2]), ([3, 3, 3], [5, 6])])
|
||||
for _ in xrange(5): # Train the fake model for 5 steps.
|
||||
bucket_id = random.choice([0, 1])
|
||||
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
|
||||
data_set, bucket_id)
|
||||
model.step(sess, encoder_inputs, decoder_inputs, target_weights,
|
||||
bucket_id, False)
|
||||
|
||||
|
||||
def main(_):
|
||||
if FLAGS.self_test:
|
||||
self_test()
|
||||
elif FLAGS.decode:
|
||||
decode()
|
||||
else:
|
||||
train()
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
@ -290,3 +290,7 @@ Fact
|
||||
|
||||
# training_ops
|
||||
# (None)
|
||||
|
||||
# word2vec deprecated ops
|
||||
NegTrain
|
||||
Skipgram
|
||||
|
@ -180,7 +180,7 @@ test_cifar10_train() {
|
||||
fi
|
||||
|
||||
run_in_directory "${TEST_DIR}" "${LOG_FILE}" \
|
||||
tensorflow/models/image/cifar10/cifar10_train.py \
|
||||
tensorflow_models/tutorials/image/cifar10/cifar10_train.py \
|
||||
--data_dir="${TUT_TEST_DATA_DIR}/cifar10" --max_steps=50 \
|
||||
--train_dir="${TUT_TEST_ROOT}/cifar10_train"
|
||||
|
||||
@ -208,7 +208,7 @@ test_word2vec_test() {
|
||||
LOG_FILE=$1
|
||||
|
||||
run_in_directory "${TEST_DIR}" "${LOG_FILE}" \
|
||||
tensorflow/models/embedding/word2vec_test.py
|
||||
tensorflow_models/tutorials/embedding/word2vec_test.py
|
||||
}
|
||||
|
||||
|
||||
@ -218,7 +218,7 @@ test_word2vec_optimized_test() {
|
||||
LOG_FILE=$1
|
||||
|
||||
run_in_directory "${TEST_DIR}" "${LOG_FILE}" \
|
||||
tensorflow/models/embedding/word2vec_optimized_test.py
|
||||
tensorflow_models/tutorials/embedding/word2vec_optimized_test.py
|
||||
}
|
||||
|
||||
|
||||
@ -251,7 +251,7 @@ test_ptb_word_lm() {
|
||||
fi
|
||||
|
||||
run_in_directory "${TEST_DIR}" "${LOG_FILE}" \
|
||||
tensorflow/models/rnn/ptb/ptb_word_lm.py \
|
||||
tensorflow_models/tutorials/rnn/ptb/ptb_word_lm.py \
|
||||
--data_path="${DATA_DIR}/simple-examples/data" --model test
|
||||
|
||||
if [[ $? != 0 ]]; then
|
||||
@ -282,7 +282,7 @@ test_translate_test() {
|
||||
LOG_FILE=$1
|
||||
|
||||
run_in_directory "${TEST_DIR}" "${LOG_FILE}" \
|
||||
tensorflow/models/rnn/translate/translate.py --self_test=True
|
||||
tensorflow_models/tutorials/rnn/translate/translate.py --self_test=True
|
||||
}
|
||||
|
||||
|
||||
|
@ -58,20 +58,12 @@ py_binary(
|
||||
"//tensorflow/contrib/tensor_forest:all_files",
|
||||
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
|
||||
"//tensorflow/examples/tutorials/mnist:package",
|
||||
"//tensorflow/models/embedding:package",
|
||||
"//tensorflow/models/image/alexnet:all_files",
|
||||
"//tensorflow/models/image/cifar10:all_files",
|
||||
"//tensorflow/models/image/imagenet:all_files",
|
||||
"//tensorflow/models/rnn:package",
|
||||
"//tensorflow/models/rnn/ptb:package",
|
||||
"//tensorflow/models/rnn/translate:package",
|
||||
"//tensorflow/python:util_example_parser_configuration",
|
||||
"//tensorflow/python/debug:all_files",
|
||||
"//tensorflow/python/saved_model:all_files",
|
||||
"//tensorflow/python/tools:all_files",
|
||||
# The following two targets have an issue when archiving them into
|
||||
# the python zip, exclude them for now.
|
||||
# "//tensorflow/models/image/mnist:convolutional",
|
||||
# "//tensorflow/tensorboard",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
@ -125,14 +117,6 @@ sh_binary(
|
||||
"//tensorflow/contrib/tensor_forest:all_files",
|
||||
"//tensorflow/contrib/tensor_forest/hybrid:all_files",
|
||||
"//tensorflow/examples/tutorials/mnist:package",
|
||||
"//tensorflow/models/embedding:package",
|
||||
"//tensorflow/models/image/alexnet:all_files",
|
||||
"//tensorflow/models/image/cifar10:all_files",
|
||||
"//tensorflow/models/image/imagenet:all_files",
|
||||
"//tensorflow/models/image/mnist:convolutional",
|
||||
"//tensorflow/models/rnn:package",
|
||||
"//tensorflow/models/rnn/ptb:package",
|
||||
"//tensorflow/models/rnn/translate:package",
|
||||
"//tensorflow/python:util_example_parser_configuration",
|
||||
"//tensorflow/python/debug:all_files",
|
||||
"//tensorflow/python/saved_model:all_files",
|
||||
|
Loading…
Reference in New Issue
Block a user