More conversions of flags library to argparse.

Add argv to benchmark/main function so they can handle passing
command line arguments.
Change: 144254260
This commit is contained in:
Vijay Vasudevan 2017-01-11 15:01:00 -08:00 committed by TensorFlower Gardener
parent 37af1b8790
commit 963674de71
12 changed files with 282 additions and 178 deletions
tensorflow
compiler
contrib/cudnn_rnn/python/kernel_tests
examples
learn
tutorials/estimators
g3doc
how_tos/distributed
tutorials/estimators
python

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys
from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
@ -27,22 +30,18 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import app from tensorflow.python.platform import app
from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
flags = flags_lib FLAGS = None
FLAGS = flags.FLAGS
flags.DEFINE_string('out_dir', '',
'Output directory for graphs, checkpoints and savers.')
def tfadd(): def tfadd(_):
x = constant_op.constant([1], name='x_const') x = constant_op.constant([1], name='x_const')
y = constant_op.constant([2], name='y_const') y = constant_op.constant([2], name='y_const')
math_ops.add(x, y, name='x_y_sum') math_ops.add(x, y, name='x_y_sum')
def tfadd_with_ckpt(): def tfadd_with_ckpt(out_dir):
x = array_ops.placeholder(dtypes.int32, name='x_hold') x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = variables.Variable(constant_op.constant([0]), name='y_saved') y = variables.Variable(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum') math_ops.add(x, y, name='x_y_sum')
@ -53,11 +52,11 @@ def tfadd_with_ckpt():
sess.run(init_op) sess.run(init_op)
sess.run(y.assign(y + 42)) sess.run(y.assign(y + 42))
# Without the checkpoint, the variable won't be set to 42. # Without the checkpoint, the variable won't be set to 42.
ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % FLAGS.out_dir ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % out_dir
saver.save(sess, ckpt) saver.save(sess, ckpt)
def tfadd_with_ckpt_saver(): def tfadd_with_ckpt_saver(out_dir):
x = array_ops.placeholder(dtypes.int32, name='x_hold') x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = variables.Variable(constant_op.constant([0]), name='y_saved') y = variables.Variable(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum') math_ops.add(x, y, name='x_y_sum')
@ -68,27 +67,27 @@ def tfadd_with_ckpt_saver():
sess.run(init_op) sess.run(init_op)
sess.run(y.assign(y + 42)) sess.run(y.assign(y + 42))
# Without the checkpoint, the variable won't be set to 42. # Without the checkpoint, the variable won't be set to 42.
ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % FLAGS.out_dir ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % out_dir
saver.save(sess, ckpt_file) saver.save(sess, ckpt_file)
# Without the SaverDef, the restore op won't be named correctly. # Without the SaverDef, the restore op won't be named correctly.
saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % FLAGS.out_dir saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % out_dir
with open(saver_file, 'w') as f: with open(saver_file, 'w') as f:
f.write(saver.as_saver_def().SerializeToString()) f.write(saver.as_saver_def().SerializeToString())
def tfgather(): def tfgather(_):
params = array_ops.placeholder(dtypes.float32, name='params') params = array_ops.placeholder(dtypes.float32, name='params')
indices = array_ops.placeholder(dtypes.int32, name='indices') indices = array_ops.placeholder(dtypes.int32, name='indices')
array_ops.gather(params, indices, name='gather_output') array_ops.gather(params, indices, name='gather_output')
def tfmatmul(): def tfmatmul(_):
x = array_ops.placeholder(dtypes.float32, name='x_hold') x = array_ops.placeholder(dtypes.float32, name='x_hold')
y = array_ops.placeholder(dtypes.float32, name='y_hold') y = array_ops.placeholder(dtypes.float32, name='y_hold')
math_ops.matmul(x, y, name='x_y_prod') math_ops.matmul(x, y, name='x_y_prod')
def tfmatmulandadd(): def tfmatmulandadd(_):
# This tests multiple outputs. # This tests multiple outputs.
x = array_ops.placeholder(dtypes.float32, name='x_hold') x = array_ops.placeholder(dtypes.float32, name='x_hold')
y = array_ops.placeholder(dtypes.float32, name='y_hold') y = array_ops.placeholder(dtypes.float32, name='y_hold')
@ -96,24 +95,33 @@ def tfmatmulandadd():
math_ops.add(x, y, name='x_y_sum') math_ops.add(x, y, name='x_y_sum')
def write_graph(build_graph): def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out.""" """Build a graph using build_graph and write it out."""
g = ops.Graph() g = ops.Graph()
with g.as_default(): with g.as_default():
build_graph() build_graph(out_dir)
filename = '%s/test_graph_%s.pb' % (FLAGS.out_dir, build_graph.__name__) filename = '%s/test_graph_%s.pb' % (out_dir, build_graph.__name__)
with open(filename, 'w') as f: with open(filename, 'w') as f:
f.write(g.as_graph_def().SerializeToString()) f.write(g.as_graph_def().SerializeToString())
def main(_): def main(_):
write_graph(tfadd) write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt) write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver) write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
write_graph(tfgather) write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul) write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd) write_graph(tfmatmulandadd, FLAGS.out_dir)
if __name__ == '__main__': if __name__ == '__main__':
app.run() parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument(
'--out_dir',
type=str,
default='',
help='Output directory for graphs, checkpoints and savers.'
)
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -18,7 +18,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import sys
import numpy as np import numpy as np
@ -32,29 +34,8 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.platform import test from tensorflow.python.platform import test
flags = flags_lib
FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', 128,
'Inputs are fed in batches of this size, for both '
'inference and training. Larger values cause the matmul '
'in each LSTM cell to have higher dimensionality.')
flags.DEFINE_integer('seq_length', 60,
'Length of the unrolled sequence of LSTM cells in a layer.'
'Larger values cause more LSTM matmuls to be run.')
flags.DEFINE_integer('num_inputs', 1024,
'Dimension of inputs that are fed into each LSTM cell.')
flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.')
flags.DEFINE_string('device', 'gpu',
'TensorFlow device to assign ops to, e.g. "gpu", "cpu". '
'For details see documentation for tf.Graph.device.')
flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in '
'*.pbtxt format to this directory.')
def _DumpGraph(graph, basename): def _DumpGraph(graph, basename):
if FLAGS.dump_graph_dir: if FLAGS.dump_graph_dir:
@ -290,4 +271,54 @@ class LSTMBenchmark(test.Benchmark):
if __name__ == '__main__': if __name__ == '__main__':
test.main() parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument(
'--batch_size',
type=int,
default=128,
help="""\
Inputs are fed in batches of this size, for both inference and training.
Larger values cause the matmul in each LSTM cell to have higher
dimensionality.\
"""
)
parser.add_argument(
'--seq_length',
type=int,
default=60,
help="""\
Length of the unrolled sequence of LSTM cells in a layer.Larger values
cause more LSTM matmuls to be run.\
"""
)
parser.add_argument(
'--num_inputs',
type=int,
default=1024,
help='Dimension of inputs that are fed into each LSTM cell.'
)
parser.add_argument(
'--num_nodes',
type=int,
default=1024,
help='Number of nodes in each LSTM cell.'
)
parser.add_argument(
'--device',
type=str,
default='gpu',
help="""\
TensorFlow device to assign ops to, e.g. "gpu", "cpu". For details see
documentation for tf.Graph.device.\
"""
)
parser.add_argument(
'--dump_graph_dir',
type=str,
default='',
help='If non-empty, dump graphs in *.pbtxt format to this directory.'
)
global FLAGS # pylint:disable=global-at-module-level
FLAGS, unparsed = parser.parse_known_args()
test.main(argv=[sys.argv[0]] + unparsed)

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import time import time
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
from tensorflow.contrib.rnn.python.ops import core_rnn from tensorflow.contrib.rnn.python.ops import core_rnn
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
@ -31,12 +32,8 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.platform import test from tensorflow.python.platform import test
flags.DEFINE_integer("batch_size", 64, "batch size.")
FLAGS = flags.FLAGS
class CudnnRNNBenchmark(test.Benchmark): class CudnnRNNBenchmark(test.Benchmark):
"""Benchmarks Cudnn LSTM and other related models. """Benchmarks Cudnn LSTM and other related models.

View File

@ -17,27 +17,15 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys
import tempfile import tempfile
from six.moves import urllib from six.moves import urllib
import pandas as pd import pandas as pd
import tensorflow as tf import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", "", "Base directory for output models.")
flags.DEFINE_string("model_type", "wide_n_deep",
"Valid model types: {'wide', 'deep', 'wide_n_deep'}.")
flags.DEFINE_integer("train_steps", 200, "Number of training steps.")
flags.DEFINE_string(
"train_data",
"",
"Path to the training data.")
flags.DEFINE_string(
"test_data",
"",
"Path to the test data.")
COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num", COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "gender", "marital_status", "occupation", "relationship", "race", "gender",
@ -50,10 +38,10 @@ CONTINUOUS_COLUMNS = ["age", "education_num", "capital_gain", "capital_loss",
"hours_per_week"] "hours_per_week"]
def maybe_download(): def maybe_download(train_data, test_data):
"""Maybe downloads training data and returns train and test file names.""" """Maybe downloads training data and returns train and test file names."""
if FLAGS.train_data: if train_data:
train_file_name = FLAGS.train_data train_file_name = train_data
else: else:
train_file = tempfile.NamedTemporaryFile(delete=False) train_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long
@ -61,8 +49,8 @@ def maybe_download():
train_file.close() train_file.close()
print("Training data is downloaded to %s" % train_file_name) print("Training data is downloaded to %s" % train_file_name)
if FLAGS.test_data: if test_data:
test_file_name = FLAGS.test_data test_file_name = test_data
else: else:
test_file = tempfile.NamedTemporaryFile(delete=False) test_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long
@ -73,7 +61,7 @@ def maybe_download():
return train_file_name, test_file_name return train_file_name, test_file_name
def build_estimator(model_dir): def build_estimator(model_dir, model_type):
"""Build an estimator.""" """Build an estimator."""
# Sparse base columns. # Sparse base columns.
gender = tf.contrib.layers.sparse_column_with_keys(column_name="gender", gender = tf.contrib.layers.sparse_column_with_keys(column_name="gender",
@ -128,10 +116,10 @@ def build_estimator(model_dir):
hours_per_week, hours_per_week,
] ]
if FLAGS.model_type == "wide": if model_type == "wide":
m = tf.contrib.learn.LinearClassifier(model_dir=model_dir, m = tf.contrib.learn.LinearClassifier(model_dir=model_dir,
feature_columns=wide_columns) feature_columns=wide_columns)
elif FLAGS.model_type == "deep": elif model_type == "deep":
m = tf.contrib.learn.DNNClassifier(model_dir=model_dir, m = tf.contrib.learn.DNNClassifier(model_dir=model_dir,
feature_columns=deep_columns, feature_columns=deep_columns,
hidden_units=[100, 50]) hidden_units=[100, 50])
@ -166,9 +154,9 @@ def input_fn(df):
return feature_cols, label return feature_cols, label
def train_and_eval(): def train_and_eval(model_dir, model_type, train_steps, train_data, test_data):
"""Train and evaluate the model.""" """Train and evaluate the model."""
train_file_name, test_file_name = maybe_download() train_file_name, test_file_name = maybe_download(train_data, test_data)
df_train = pd.read_csv( df_train = pd.read_csv(
tf.gfile.Open(train_file_name), tf.gfile.Open(train_file_name),
names=COLUMNS, names=COLUMNS,
@ -190,19 +178,56 @@ def train_and_eval():
df_test[LABEL_COLUMN] = ( df_test[LABEL_COLUMN] = (
df_test["income_bracket"].apply(lambda x: ">50K" in x)).astype(int) df_test["income_bracket"].apply(lambda x: ">50K" in x)).astype(int)
model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir model_dir = tempfile.mkdtemp() if not model_dir else model_dir
print("model directory = %s" % model_dir) print("model directory = %s" % model_dir)
m = build_estimator(model_dir) m = build_estimator(model_dir, model_type)
m.fit(input_fn=lambda: input_fn(df_train), steps=FLAGS.train_steps) m.fit(input_fn=lambda: input_fn(df_train), steps=train_steps)
results = m.evaluate(input_fn=lambda: input_fn(df_test), steps=1) results = m.evaluate(input_fn=lambda: input_fn(df_test), steps=1)
for key in sorted(results): for key in sorted(results):
print("%s: %s" % (key, results[key])) print("%s: %s" % (key, results[key]))
FLAGS = None
def main(_): def main(_):
train_and_eval() train_and_eval(FLAGS.model_dir, FLAGS.model_type, FLAGS.train_steps,
FLAGS.train_data, FLAGS.test_data)
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--model_dir",
type=str,
default="",
help="Base directory for output models."
)
parser.add_argument(
"--model_type",
type=str,
default="wide_n_deep",
help="Valid model types: {'wide', 'deep', 'wide_n_deep'}."
)
parser.add_argument(
"--train_steps",
type=int,
default=200,
help="Number of training steps."
)
parser.add_argument(
"--train_data",
type=str,
default="",
help="Path to the training data."
)
parser.add_argument(
"--test_data",
type=str,
default="",
help="Path to the test data."
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,27 +17,17 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys
import tempfile import tempfile
from six.moves import urllib from six.moves import urllib
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string( FLAGS = None
"train_data",
"",
"Path to the training data.")
flags.DEFINE_string(
"test_data",
"",
"Path to the test data.")
flags.DEFINE_string(
"predict_data",
"",
"Path to the prediction data.")
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
@ -45,31 +35,36 @@ tf.logging.set_verbosity(tf.logging.INFO)
LEARNING_RATE = 0.001 LEARNING_RATE = 0.001
def maybe_download(): def maybe_download(train_data, test_data, predict_data):
"""Maybe downloads training data and returns train and test file names.""" """Maybe downloads training data and returns train and test file names."""
if FLAGS.train_data: if train_data:
train_file_name = FLAGS.train_data train_file_name = train_data
else: else:
train_file = tempfile.NamedTemporaryFile(delete=False) train_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_train.csv", train_file.name) # pylint: disable=line-too-long urllib.request.urlretrieve(
"http://download.tensorflow.org/data/abalone_train.csv",
train_file.name) # pylint: disable=line-too-long
train_file_name = train_file.name train_file_name = train_file.name
train_file.close() train_file.close()
print("Training data is downloaded to %s" % train_file_name) print("Training data is downloaded to %s" % train_file_name)
if FLAGS.test_data: if test_data:
test_file_name = FLAGS.test_data test_file_name = test_data
else: else:
test_file = tempfile.NamedTemporaryFile(delete=False) test_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name) # pylint: disable=line-too-long urllib.request.urlretrieve(
"http://download.tensorflow.org/data/abalone_test.csv", test_file.name) # pylint: disable=line-too-long
test_file_name = test_file.name test_file_name = test_file.name
test_file.close() test_file.close()
print("Test data is downloaded to %s" % test_file_name) print("Test data is downloaded to %s" % test_file_name)
if FLAGS.predict_data: if predict_data:
predict_file_name = FLAGS.predict_data predict_file_name = predict_data
else: else:
predict_file = tempfile.NamedTemporaryFile(delete=False) predict_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name) # pylint: disable=line-too-long urllib.request.urlretrieve(
"http://download.tensorflow.org/data/abalone_predict.csv",
predict_file.name) # pylint: disable=line-too-long
predict_file_name = predict_file.name predict_file_name = predict_file.name
predict_file.close() predict_file.close()
print("Prediction data is downloaded to %s" % predict_file_name) print("Prediction data is downloaded to %s" % predict_file_name)
@ -109,32 +104,26 @@ def model_fn(features, targets, mode, params):
def main(unused_argv): def main(unused_argv):
# Load datasets # Load datasets
abalone_train, abalone_test, abalone_predict = maybe_download() abalone_train, abalone_test, abalone_predict = maybe_download(
FLAGS.train_data, FLAGS.test_data, FLAGS.predict_data)
# Training examples # Training examples
training_set = tf.contrib.learn.datasets.base.load_csv_without_header( training_set = tf.contrib.learn.datasets.base.load_csv_without_header(
filename=abalone_train, filename=abalone_train, target_dtype=np.int, features_dtype=np.float64)
target_dtype=np.int,
features_dtype=np.float64)
# Test examples # Test examples
test_set = tf.contrib.learn.datasets.base.load_csv_without_header( test_set = tf.contrib.learn.datasets.base.load_csv_without_header(
filename=abalone_test, filename=abalone_test, target_dtype=np.int, features_dtype=np.float64)
target_dtype=np.int,
features_dtype=np.float64)
# Set of 7 examples for which to predict abalone ages # Set of 7 examples for which to predict abalone ages
prediction_set = tf.contrib.learn.datasets.base.load_csv_without_header( prediction_set = tf.contrib.learn.datasets.base.load_csv_without_header(
filename=abalone_predict, filename=abalone_predict, target_dtype=np.int, features_dtype=np.float64)
target_dtype=np.int,
features_dtype=np.float64)
# Set model params # Set model params
model_params = {"learning_rate": LEARNING_RATE} model_params = {"learning_rate": LEARNING_RATE}
# Build 2 layer fully connected DNN with 10, 10 units respectively. # Build 2 layer fully connected DNN with 10, 10 units respectively.
nn = tf.contrib.learn.Estimator( nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params)
model_fn=model_fn, params=model_params)
# Fit # Fit
nn.fit(x=training_set.data, y=training_set.target, steps=5000) nn.fit(x=training_set.data, y=training_set.target, steps=5000)
@ -145,11 +134,22 @@ def main(unused_argv):
print("Loss: %s" % loss_score) print("Loss: %s" % loss_score)
# Print out predictions # Print out predictions
predictions = nn.predict(x=prediction_set.data, predictions = nn.predict(x=prediction_set.data, as_iterable=True)
as_iterable=True)
for i, p in enumerate(predictions): for i, p in enumerate(predictions):
print("Prediction %s: %s" % (i + 1, p["ages"])) print("Prediction %s: %s" % (i + 1, p["ages"]))
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--train_data", type=str, default="", help="Path to the training data.")
parser.add_argument(
"--test_data", type=str, default="", help="Path to the test data.")
parser.add_argument(
"--predict_data",
type=str,
default="",
help="Path to the prediction data.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -182,19 +182,12 @@ implementing **between-graph replication** and **asynchronous training**. It
includes the code for the parameter server and worker tasks. includes the code for the parameter server and worker tasks.
```python ```python
import argparse
import sys
import tensorflow as tf import tensorflow as tf
# Flags for defining the tf.train.ClusterSpec FLAGS = None
tf.app.flags.DEFINE_string("ps_hosts", "",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "",
"Comma-separated list of hostname:port pairs")
# Flags for defining the tf.train.Server
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
FLAGS = tf.app.flags.FLAGS
def main(_): def main(_):
@ -253,7 +246,36 @@ def main(_):
sv.stop() sv.stop()
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
# Flags for defining the tf.train.ClusterSpec
parser.add_argument(
"--ps_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--worker_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--job_name",
type=str,
default="",
help="One of 'ps', 'worker'"
)
# Flags for defining the tf.train.Server
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
``` ```
To start the trainer with two parameter servers and two workers, use the To start the trainer with two parameter servers and two workers, use the

View File

@ -101,35 +101,20 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys
import tempfile import tempfile
import urllib import urllib
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
FLAGS = None
``` ```
Then define flags to allow users to optionally specify CSV files for training, Enable logging:
test, and prediction datasets via the command line (by default, files will be
downloaded from [tensorflow.org](https://www.tensorflow.org/)), and enable
logging:
```python ```python
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string(
"train_data",
"",
"Path to the training data.")
flags.DEFINE_string(
"test_data",
"",
"Path to the test data.")
flags.DEFINE_string(
"predict_data",
"",
"Path to the prediction data.")
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
``` ```
@ -138,10 +123,10 @@ command-line options, or downloaded from
[tensorflow.org](https://www.tensorflow.org/)): [tensorflow.org](https://www.tensorflow.org/)):
```python ```python
def maybe_download(): def maybe_download(train_data, test_data, predict_data):
"""Maybe downloads training data and returns train and test file names.""" """Maybe downloads training data and returns train and test file names."""
if FLAGS.train_data: if train_data:
train_file_name = FLAGS.train_data train_file_name = train_data
else: else:
train_file = tempfile.NamedTemporaryFile(delete=False) train_file = tempfile.NamedTemporaryFile(delete=False)
urllib.urlretrieve("http://download.tensorflow.org/data/abalone_train.csv", train_file.name) urllib.urlretrieve("http://download.tensorflow.org/data/abalone_train.csv", train_file.name)
@ -149,8 +134,8 @@ def maybe_download():
train_file.close() train_file.close()
print("Training data is downloaded to %s" % train_file_name) print("Training data is downloaded to %s" % train_file_name)
if FLAGS.test_data: if test_data:
test_file_name = FLAGS.test_data test_file_name = test_data
else: else:
test_file = tempfile.NamedTemporaryFile(delete=False) test_file = tempfile.NamedTemporaryFile(delete=False)
urllib.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name) urllib.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name)
@ -158,8 +143,8 @@ def maybe_download():
test_file.close() test_file.close()
print("Test data is downloaded to %s" % test_file_name) print("Test data is downloaded to %s" % test_file_name)
if FLAGS.predict_data: if predict_data:
predict_file_name = FLAGS.predict_data predict_file_name = predict_data
else: else:
predict_file = tempfile.NamedTemporaryFile(delete=False) predict_file = tempfile.NamedTemporaryFile(delete=False)
urllib.urlretrieve("http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name) urllib.urlretrieve("http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name)
@ -170,12 +155,16 @@ def maybe_download():
return train_file_name, test_file_name, predict_file_name return train_file_name, test_file_name, predict_file_name
``` ```
Finally, create `main()` and load the abalone CSVs into `Datasets`: Finally, create `main()` and load the abalone CSVs into `Datasets`,
defining flags to allow users to optionally specify CSV files for training,
test, and prediction datasets via the command line (by default, files will be
downloaded from [tensorflow.org](https://www.tensorflow.org/)):
```python ```python
def main(unused_argv): def main(unused_argv):
# Load datasets # Load datasets
abalone_train, abalone_test, abalone_predict = maybe_download() abalone_train, abalone_test, abalone_predict = maybe_download(
FLAGS.train_data, FLAGS.test_data, FLAGS.predict_data)
# Training examples # Training examples
training_set = tf.contrib.learn.datasets.base.load_csv_without_header( training_set = tf.contrib.learn.datasets.base.load_csv_without_header(
@ -196,7 +185,28 @@ def main(unused_argv):
features_dtype=np.float64) features_dtype=np.float64)
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--train_data",
type=str,
default="",
help="Path to the training data."
)
parser.add_argument(
"--test_data",
type=str,
default="",
help="Path to the test data."
)
parser.add_argument(
"--predict_data",
type=str,
default="",
help="Path to the prediction data."
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
``` ```
## Instantiating an Estimator ## Instantiating an Estimator

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys
import time import time
from tensorflow.python.client import session as session_lib from tensorflow.python.client import session as session_lib
@ -30,12 +32,8 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.platform import test from tensorflow.python.platform import test
FLAGS = flags.FLAGS
flags.DEFINE_boolean("use_gpu", True, """Run GPU benchmarks.""")
def batch_norm_op(tensor, mean, variance, beta, gamma, scale): def batch_norm_op(tensor, mean, variance, beta, gamma, scale):
"""Fused kernel for batch normalization.""" """Fused kernel for batch normalization."""
@ -245,4 +243,16 @@ class BatchNormBenchmark(test.Benchmark):
if __name__ == "__main__": if __name__ == "__main__":
test.main() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--use_gpu",
type="bool",
nargs="?",
const=True,
default=True,
help="Run GPU benchmarks."
)
global FLAGS # pylint:disable=global-at-module-level
FLAGS, unparsed = parser.parse_known_args()
test.main(argv=[sys.argv[0]] + unparsed)

View File

@ -29,12 +29,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.platform import test from tensorflow.python.platform import test
FLAGS = flags.FLAGS
flags.DEFINE_boolean("use_gpu", True, """Run GPU benchmarks.""")
def build_graph(device, input_shape, variable, num_inputs, axis, grad): def build_graph(device, input_shape, variable, num_inputs, axis, grad):
"""Build a graph containing a sequence of concat operations. """Build a graph containing a sequence of concat operations.

View File

@ -312,13 +312,15 @@ def _run_benchmarks(regex):
instance_benchmark_fn() instance_benchmark_fn()
def benchmarks_main(true_main): def benchmarks_main(true_main, argv=None):
"""Run benchmarks as declared in args. """Run benchmarks as declared in argv.
Args: Args:
true_main: True main function to run if benchmarks are not requested. true_main: True main function to run if benchmarks are not requested.
argv: the command line arguments (if None, uses sys.argv).
""" """
argv = sys.argv if argv is None:
argv = sys.argv
found_arg = [arg for arg in argv found_arg = [arg for arg in argv
if arg.startswith("--benchmarks=") if arg.startswith("--benchmarks=")
or arg.startswith("-benchmarks=")] or arg.startswith("-benchmarks=")]
@ -327,6 +329,6 @@ def benchmarks_main(true_main):
argv.remove(found_arg[0]) argv.remove(found_arg[0])
regex = found_arg[0].split("=")[1] regex = found_arg[0].split("=")[1]
app.run(lambda _: _run_benchmarks(regex)) app.run(lambda _: _run_benchmarks(regex), argv=argv)
else: else:
true_main() true_main()

View File

@ -80,9 +80,12 @@ def g_main(argv):
# Redefine main to allow running benchmarks # Redefine main to allow running benchmarks
def main(): # pylint: disable=function-redefined def main(argv=None): # pylint: disable=function-redefined
def main_wrapper(): def main_wrapper():
return app.run(main=g_main, argv=sys.argv) args = argv
if args is None:
args = sys.argv
return app.run(main=g_main, argv=args)
benchmark.benchmarks_main(true_main=main_wrapper) benchmark.benchmarks_main(true_main=main_wrapper)

View File

@ -88,9 +88,9 @@ else:
Benchmark = _googletest.Benchmark # pylint: disable=invalid-name Benchmark = _googletest.Benchmark # pylint: disable=invalid-name
def main(): def main(argv=None):
"""Runs all unit tests.""" """Runs all unit tests."""
return _googletest.main() return _googletest.main(argv)
def get_temp_dir(): def get_temp_dir():