More conversions of flags library to argparse.

Add argv to benchmark/main function so they can handle passing
command line arguments.
Change: 144254260
This commit is contained in:
Vijay Vasudevan 2017-01-11 15:01:00 -08:00 committed by TensorFlower Gardener
parent 37af1b8790
commit 963674de71
12 changed files with 282 additions and 178 deletions

View File

@ -18,6 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
@ -27,22 +30,18 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import app
from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.training import saver as saver_lib
flags = flags_lib
FLAGS = flags.FLAGS
flags.DEFINE_string('out_dir', '',
'Output directory for graphs, checkpoints and savers.')
FLAGS = None
def tfadd():
def tfadd(_):
x = constant_op.constant([1], name='x_const')
y = constant_op.constant([2], name='y_const')
math_ops.add(x, y, name='x_y_sum')
def tfadd_with_ckpt():
def tfadd_with_ckpt(out_dir):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = variables.Variable(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum')
@ -53,11 +52,11 @@ def tfadd_with_ckpt():
sess.run(init_op)
sess.run(y.assign(y + 42))
# Without the checkpoint, the variable won't be set to 42.
ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % FLAGS.out_dir
ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % out_dir
saver.save(sess, ckpt)
def tfadd_with_ckpt_saver():
def tfadd_with_ckpt_saver(out_dir):
x = array_ops.placeholder(dtypes.int32, name='x_hold')
y = variables.Variable(constant_op.constant([0]), name='y_saved')
math_ops.add(x, y, name='x_y_sum')
@ -68,27 +67,27 @@ def tfadd_with_ckpt_saver():
sess.run(init_op)
sess.run(y.assign(y + 42))
# Without the checkpoint, the variable won't be set to 42.
ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % FLAGS.out_dir
ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % out_dir
saver.save(sess, ckpt_file)
# Without the SaverDef, the restore op won't be named correctly.
saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % FLAGS.out_dir
saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % out_dir
with open(saver_file, 'w') as f:
f.write(saver.as_saver_def().SerializeToString())
def tfgather():
def tfgather(_):
params = array_ops.placeholder(dtypes.float32, name='params')
indices = array_ops.placeholder(dtypes.int32, name='indices')
array_ops.gather(params, indices, name='gather_output')
def tfmatmul():
def tfmatmul(_):
x = array_ops.placeholder(dtypes.float32, name='x_hold')
y = array_ops.placeholder(dtypes.float32, name='y_hold')
math_ops.matmul(x, y, name='x_y_prod')
def tfmatmulandadd():
def tfmatmulandadd(_):
# This tests multiple outputs.
x = array_ops.placeholder(dtypes.float32, name='x_hold')
y = array_ops.placeholder(dtypes.float32, name='y_hold')
@ -96,24 +95,33 @@ def tfmatmulandadd():
math_ops.add(x, y, name='x_y_sum')
def write_graph(build_graph):
def write_graph(build_graph, out_dir):
"""Build a graph using build_graph and write it out."""
g = ops.Graph()
with g.as_default():
build_graph()
filename = '%s/test_graph_%s.pb' % (FLAGS.out_dir, build_graph.__name__)
build_graph(out_dir)
filename = '%s/test_graph_%s.pb' % (out_dir, build_graph.__name__)
with open(filename, 'w') as f:
f.write(g.as_graph_def().SerializeToString())
def main(_):
write_graph(tfadd)
write_graph(tfadd_with_ckpt)
write_graph(tfadd_with_ckpt_saver)
write_graph(tfgather)
write_graph(tfmatmul)
write_graph(tfmatmulandadd)
write_graph(tfadd, FLAGS.out_dir)
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
write_graph(tfgather, FLAGS.out_dir)
write_graph(tfmatmul, FLAGS.out_dir)
write_graph(tfmatmulandadd, FLAGS.out_dir)
if __name__ == '__main__':
app.run()
parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument(
'--out_dir',
type=str,
default='',
help='Output directory for graphs, checkpoints and savers.'
)
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -18,7 +18,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import numpy as np
@ -32,29 +34,8 @@ from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.platform import test
flags = flags_lib
FLAGS = flags.FLAGS
flags.DEFINE_integer('batch_size', 128,
'Inputs are fed in batches of this size, for both '
'inference and training. Larger values cause the matmul '
'in each LSTM cell to have higher dimensionality.')
flags.DEFINE_integer('seq_length', 60,
'Length of the unrolled sequence of LSTM cells in a layer.'
'Larger values cause more LSTM matmuls to be run.')
flags.DEFINE_integer('num_inputs', 1024,
'Dimension of inputs that are fed into each LSTM cell.')
flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.')
flags.DEFINE_string('device', 'gpu',
'TensorFlow device to assign ops to, e.g. "gpu", "cpu". '
'For details see documentation for tf.Graph.device.')
flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in '
'*.pbtxt format to this directory.')
def _DumpGraph(graph, basename):
if FLAGS.dump_graph_dir:
@ -290,4 +271,54 @@ class LSTMBenchmark(test.Benchmark):
if __name__ == '__main__':
test.main()
parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument(
'--batch_size',
type=int,
default=128,
help="""\
Inputs are fed in batches of this size, for both inference and training.
Larger values cause the matmul in each LSTM cell to have higher
dimensionality.\
"""
)
parser.add_argument(
'--seq_length',
type=int,
default=60,
help="""\
Length of the unrolled sequence of LSTM cells in a layer.Larger values
cause more LSTM matmuls to be run.\
"""
)
parser.add_argument(
'--num_inputs',
type=int,
default=1024,
help='Dimension of inputs that are fed into each LSTM cell.'
)
parser.add_argument(
'--num_nodes',
type=int,
default=1024,
help='Number of nodes in each LSTM cell.'
)
parser.add_argument(
'--device',
type=str,
default='gpu',
help="""\
TensorFlow device to assign ops to, e.g. "gpu", "cpu". For details see
documentation for tf.Graph.device.\
"""
)
parser.add_argument(
'--dump_graph_dir',
type=str,
default='',
help='If non-empty, dump graphs in *.pbtxt format to this directory.'
)
global FLAGS # pylint:disable=global-at-module-level
FLAGS, unparsed = parser.parse_known_args()
test.main(argv=[sys.argv[0]] + unparsed)

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import time
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
from tensorflow.contrib.rnn.python.ops import core_rnn
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
@ -31,12 +32,8 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.platform import test
flags.DEFINE_integer("batch_size", 64, "batch size.")
FLAGS = flags.FLAGS
class CudnnRNNBenchmark(test.Benchmark):
"""Benchmarks Cudnn LSTM and other related models.

View File

@ -17,27 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tempfile
from six.moves import urllib
import pandas as pd
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", "", "Base directory for output models.")
flags.DEFINE_string("model_type", "wide_n_deep",
"Valid model types: {'wide', 'deep', 'wide_n_deep'}.")
flags.DEFINE_integer("train_steps", 200, "Number of training steps.")
flags.DEFINE_string(
"train_data",
"",
"Path to the training data.")
flags.DEFINE_string(
"test_data",
"",
"Path to the test data.")
COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num",
"marital_status", "occupation", "relationship", "race", "gender",
@ -50,10 +38,10 @@ CONTINUOUS_COLUMNS = ["age", "education_num", "capital_gain", "capital_loss",
"hours_per_week"]
def maybe_download():
def maybe_download(train_data, test_data):
"""Maybe downloads training data and returns train and test file names."""
if FLAGS.train_data:
train_file_name = FLAGS.train_data
if train_data:
train_file_name = train_data
else:
train_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long
@ -61,8 +49,8 @@ def maybe_download():
train_file.close()
print("Training data is downloaded to %s" % train_file_name)
if FLAGS.test_data:
test_file_name = FLAGS.test_data
if test_data:
test_file_name = test_data
else:
test_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long
@ -73,7 +61,7 @@ def maybe_download():
return train_file_name, test_file_name
def build_estimator(model_dir):
def build_estimator(model_dir, model_type):
"""Build an estimator."""
# Sparse base columns.
gender = tf.contrib.layers.sparse_column_with_keys(column_name="gender",
@ -128,10 +116,10 @@ def build_estimator(model_dir):
hours_per_week,
]
if FLAGS.model_type == "wide":
if model_type == "wide":
m = tf.contrib.learn.LinearClassifier(model_dir=model_dir,
feature_columns=wide_columns)
elif FLAGS.model_type == "deep":
elif model_type == "deep":
m = tf.contrib.learn.DNNClassifier(model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=[100, 50])
@ -166,9 +154,9 @@ def input_fn(df):
return feature_cols, label
def train_and_eval():
def train_and_eval(model_dir, model_type, train_steps, train_data, test_data):
"""Train and evaluate the model."""
train_file_name, test_file_name = maybe_download()
train_file_name, test_file_name = maybe_download(train_data, test_data)
df_train = pd.read_csv(
tf.gfile.Open(train_file_name),
names=COLUMNS,
@ -190,19 +178,56 @@ def train_and_eval():
df_test[LABEL_COLUMN] = (
df_test["income_bracket"].apply(lambda x: ">50K" in x)).astype(int)
model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
model_dir = tempfile.mkdtemp() if not model_dir else model_dir
print("model directory = %s" % model_dir)
m = build_estimator(model_dir)
m.fit(input_fn=lambda: input_fn(df_train), steps=FLAGS.train_steps)
m = build_estimator(model_dir, model_type)
m.fit(input_fn=lambda: input_fn(df_train), steps=train_steps)
results = m.evaluate(input_fn=lambda: input_fn(df_test), steps=1)
for key in sorted(results):
print("%s: %s" % (key, results[key]))
FLAGS = None
def main(_):
train_and_eval()
train_and_eval(FLAGS.model_dir, FLAGS.model_type, FLAGS.train_steps,
FLAGS.train_data, FLAGS.test_data)
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--model_dir",
type=str,
default="",
help="Base directory for output models."
)
parser.add_argument(
"--model_type",
type=str,
default="wide_n_deep",
help="Valid model types: {'wide', 'deep', 'wide_n_deep'}."
)
parser.add_argument(
"--train_steps",
type=int,
default=200,
help="Number of training steps."
)
parser.add_argument(
"--train_data",
type=str,
default="",
help="Path to the training data."
)
parser.add_argument(
"--test_data",
type=str,
default="",
help="Path to the test data."
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,27 +17,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tempfile
from six.moves import urllib
import numpy as np
import tensorflow as tf
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string(
"train_data",
"",
"Path to the training data.")
flags.DEFINE_string(
"test_data",
"",
"Path to the test data.")
flags.DEFINE_string(
"predict_data",
"",
"Path to the prediction data.")
FLAGS = None
tf.logging.set_verbosity(tf.logging.INFO)
@ -45,31 +35,36 @@ tf.logging.set_verbosity(tf.logging.INFO)
LEARNING_RATE = 0.001
def maybe_download():
def maybe_download(train_data, test_data, predict_data):
"""Maybe downloads training data and returns train and test file names."""
if FLAGS.train_data:
train_file_name = FLAGS.train_data
if train_data:
train_file_name = train_data
else:
train_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_train.csv", train_file.name) # pylint: disable=line-too-long
urllib.request.urlretrieve(
"http://download.tensorflow.org/data/abalone_train.csv",
train_file.name) # pylint: disable=line-too-long
train_file_name = train_file.name
train_file.close()
print("Training data is downloaded to %s" % train_file_name)
if FLAGS.test_data:
test_file_name = FLAGS.test_data
if test_data:
test_file_name = test_data
else:
test_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name) # pylint: disable=line-too-long
urllib.request.urlretrieve(
"http://download.tensorflow.org/data/abalone_test.csv", test_file.name) # pylint: disable=line-too-long
test_file_name = test_file.name
test_file.close()
print("Test data is downloaded to %s" % test_file_name)
if FLAGS.predict_data:
predict_file_name = FLAGS.predict_data
if predict_data:
predict_file_name = predict_data
else:
predict_file = tempfile.NamedTemporaryFile(delete=False)
urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name) # pylint: disable=line-too-long
urllib.request.urlretrieve(
"http://download.tensorflow.org/data/abalone_predict.csv",
predict_file.name) # pylint: disable=line-too-long
predict_file_name = predict_file.name
predict_file.close()
print("Prediction data is downloaded to %s" % predict_file_name)
@ -109,32 +104,26 @@ def model_fn(features, targets, mode, params):
def main(unused_argv):
# Load datasets
abalone_train, abalone_test, abalone_predict = maybe_download()
abalone_train, abalone_test, abalone_predict = maybe_download(
FLAGS.train_data, FLAGS.test_data, FLAGS.predict_data)
# Training examples
training_set = tf.contrib.learn.datasets.base.load_csv_without_header(
filename=abalone_train,
target_dtype=np.int,
features_dtype=np.float64)
filename=abalone_train, target_dtype=np.int, features_dtype=np.float64)
# Test examples
test_set = tf.contrib.learn.datasets.base.load_csv_without_header(
filename=abalone_test,
target_dtype=np.int,
features_dtype=np.float64)
filename=abalone_test, target_dtype=np.int, features_dtype=np.float64)
# Set of 7 examples for which to predict abalone ages
prediction_set = tf.contrib.learn.datasets.base.load_csv_without_header(
filename=abalone_predict,
target_dtype=np.int,
features_dtype=np.float64)
filename=abalone_predict, target_dtype=np.int, features_dtype=np.float64)
# Set model params
model_params = {"learning_rate": LEARNING_RATE}
# Build 2 layer fully connected DNN with 10, 10 units respectively.
nn = tf.contrib.learn.Estimator(
model_fn=model_fn, params=model_params)
nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params)
# Fit
nn.fit(x=training_set.data, y=training_set.target, steps=5000)
@ -145,11 +134,22 @@ def main(unused_argv):
print("Loss: %s" % loss_score)
# Print out predictions
predictions = nn.predict(x=prediction_set.data,
as_iterable=True)
predictions = nn.predict(x=prediction_set.data, as_iterable=True)
for i, p in enumerate(predictions):
print("Prediction %s: %s" % (i + 1, p["ages"]))
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--train_data", type=str, default="", help="Path to the training data.")
parser.add_argument(
"--test_data", type=str, default="", help="Path to the test data.")
parser.add_argument(
"--predict_data",
type=str,
default="",
help="Path to the prediction data.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -182,19 +182,12 @@ implementing **between-graph replication** and **asynchronous training**. It
includes the code for the parameter server and worker tasks.
```python
import argparse
import sys
import tensorflow as tf
# Flags for defining the tf.train.ClusterSpec
tf.app.flags.DEFINE_string("ps_hosts", "",
"Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "",
"Comma-separated list of hostname:port pairs")
# Flags for defining the tf.train.Server
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
FLAGS = tf.app.flags.FLAGS
FLAGS = None
def main(_):
@ -253,7 +246,36 @@ def main(_):
sv.stop()
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
# Flags for defining the tf.train.ClusterSpec
parser.add_argument(
"--ps_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--worker_hosts",
type=str,
default="",
help="Comma-separated list of hostname:port pairs"
)
parser.add_argument(
"--job_name",
type=str,
default="",
help="One of 'ps', 'worker'"
)
# Flags for defining the tf.train.Server
parser.add_argument(
"--task_index",
type=int,
default=0,
help="Index of task within the job"
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
```
To start the trainer with two parameter servers and two workers, use the

View File

@ -101,35 +101,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tempfile
import urllib
import numpy as np
import tensorflow as tf
FLAGS = None
```
Then define flags to allow users to optionally specify CSV files for training,
test, and prediction datasets via the command line (by default, files will be
downloaded from [tensorflow.org](https://www.tensorflow.org/)), and enable
logging:
Enable logging:
```python
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string(
"train_data",
"",
"Path to the training data.")
flags.DEFINE_string(
"test_data",
"",
"Path to the test data.")
flags.DEFINE_string(
"predict_data",
"",
"Path to the prediction data.")
tf.logging.set_verbosity(tf.logging.INFO)
```
@ -138,10 +123,10 @@ command-line options, or downloaded from
[tensorflow.org](https://www.tensorflow.org/)):
```python
def maybe_download():
def maybe_download(train_data, test_data, predict_data):
"""Maybe downloads training data and returns train and test file names."""
if FLAGS.train_data:
train_file_name = FLAGS.train_data
if train_data:
train_file_name = train_data
else:
train_file = tempfile.NamedTemporaryFile(delete=False)
urllib.urlretrieve("http://download.tensorflow.org/data/abalone_train.csv", train_file.name)
@ -149,8 +134,8 @@ def maybe_download():
train_file.close()
print("Training data is downloaded to %s" % train_file_name)
if FLAGS.test_data:
test_file_name = FLAGS.test_data
if test_data:
test_file_name = test_data
else:
test_file = tempfile.NamedTemporaryFile(delete=False)
urllib.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name)
@ -158,8 +143,8 @@ def maybe_download():
test_file.close()
print("Test data is downloaded to %s" % test_file_name)
if FLAGS.predict_data:
predict_file_name = FLAGS.predict_data
if predict_data:
predict_file_name = predict_data
else:
predict_file = tempfile.NamedTemporaryFile(delete=False)
urllib.urlretrieve("http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name)
@ -170,12 +155,16 @@ def maybe_download():
return train_file_name, test_file_name, predict_file_name
```
Finally, create `main()` and load the abalone CSVs into `Datasets`:
Finally, create `main()` and load the abalone CSVs into `Datasets`,
defining flags to allow users to optionally specify CSV files for training,
test, and prediction datasets via the command line (by default, files will be
downloaded from [tensorflow.org](https://www.tensorflow.org/)):
```python
def main(unused_argv):
# Load datasets
abalone_train, abalone_test, abalone_predict = maybe_download()
abalone_train, abalone_test, abalone_predict = maybe_download(
FLAGS.train_data, FLAGS.test_data, FLAGS.predict_data)
# Training examples
training_set = tf.contrib.learn.datasets.base.load_csv_without_header(
@ -196,7 +185,28 @@ def main(unused_argv):
features_dtype=np.float64)
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--train_data",
type=str,
default="",
help="Path to the training data."
)
parser.add_argument(
"--test_data",
type=str,
default="",
help="Path to the test data."
)
parser.add_argument(
"--predict_data",
type=str,
default="",
help="Path to the prediction data."
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
```
## Instantiating an Estimator

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import time
from tensorflow.python.client import session as session_lib
@ -30,12 +32,8 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.platform import test
FLAGS = flags.FLAGS
flags.DEFINE_boolean("use_gpu", True, """Run GPU benchmarks.""")
def batch_norm_op(tensor, mean, variance, beta, gamma, scale):
"""Fused kernel for batch normalization."""
@ -245,4 +243,16 @@ class BatchNormBenchmark(test.Benchmark):
if __name__ == "__main__":
test.main()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--use_gpu",
type="bool",
nargs="?",
const=True,
default=True,
help="Run GPU benchmarks."
)
global FLAGS # pylint:disable=global-at-module-level
FLAGS, unparsed = parser.parse_known_args()
test.main(argv=[sys.argv[0]] + unparsed)

View File

@ -29,12 +29,8 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.platform import test
FLAGS = flags.FLAGS
flags.DEFINE_boolean("use_gpu", True, """Run GPU benchmarks.""")
def build_graph(device, input_shape, variable, num_inputs, axis, grad):
"""Build a graph containing a sequence of concat operations.

View File

@ -312,13 +312,15 @@ def _run_benchmarks(regex):
instance_benchmark_fn()
def benchmarks_main(true_main):
"""Run benchmarks as declared in args.
def benchmarks_main(true_main, argv=None):
"""Run benchmarks as declared in argv.
Args:
true_main: True main function to run if benchmarks are not requested.
argv: the command line arguments (if None, uses sys.argv).
"""
argv = sys.argv
if argv is None:
argv = sys.argv
found_arg = [arg for arg in argv
if arg.startswith("--benchmarks=")
or arg.startswith("-benchmarks=")]
@ -327,6 +329,6 @@ def benchmarks_main(true_main):
argv.remove(found_arg[0])
regex = found_arg[0].split("=")[1]
app.run(lambda _: _run_benchmarks(regex))
app.run(lambda _: _run_benchmarks(regex), argv=argv)
else:
true_main()

View File

@ -80,9 +80,12 @@ def g_main(argv):
# Redefine main to allow running benchmarks
def main(): # pylint: disable=function-redefined
def main(argv=None): # pylint: disable=function-redefined
def main_wrapper():
return app.run(main=g_main, argv=sys.argv)
args = argv
if args is None:
args = sys.argv
return app.run(main=g_main, argv=args)
benchmark.benchmarks_main(true_main=main_wrapper)

View File

@ -88,9 +88,9 @@ else:
Benchmark = _googletest.Benchmark # pylint: disable=invalid-name
def main():
def main(argv=None):
"""Runs all unit tests."""
return _googletest.main()
return _googletest.main(argv)
def get_temp_dir():