diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index 261dfcbdf8c..2a2d13dc498 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys + from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op @@ -27,22 +30,18 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import app -from tensorflow.python.platform import flags as flags_lib from tensorflow.python.training import saver as saver_lib -flags = flags_lib -FLAGS = flags.FLAGS -flags.DEFINE_string('out_dir', '', - 'Output directory for graphs, checkpoints and savers.') +FLAGS = None -def tfadd(): +def tfadd(_): x = constant_op.constant([1], name='x_const') y = constant_op.constant([2], name='y_const') math_ops.add(x, y, name='x_y_sum') -def tfadd_with_ckpt(): +def tfadd_with_ckpt(out_dir): x = array_ops.placeholder(dtypes.int32, name='x_hold') y = variables.Variable(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') @@ -53,11 +52,11 @@ def tfadd_with_ckpt(): sess.run(init_op) sess.run(y.assign(y + 42)) # Without the checkpoint, the variable won't be set to 42. - ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % FLAGS.out_dir + ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % out_dir saver.save(sess, ckpt) -def tfadd_with_ckpt_saver(): +def tfadd_with_ckpt_saver(out_dir): x = array_ops.placeholder(dtypes.int32, name='x_hold') y = variables.Variable(constant_op.constant([0]), name='y_saved') math_ops.add(x, y, name='x_y_sum') @@ -68,27 +67,27 @@ def tfadd_with_ckpt_saver(): sess.run(init_op) sess.run(y.assign(y + 42)) # Without the checkpoint, the variable won't be set to 42. - ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % FLAGS.out_dir + ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % out_dir saver.save(sess, ckpt_file) # Without the SaverDef, the restore op won't be named correctly. - saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % FLAGS.out_dir + saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % out_dir with open(saver_file, 'w') as f: f.write(saver.as_saver_def().SerializeToString()) -def tfgather(): +def tfgather(_): params = array_ops.placeholder(dtypes.float32, name='params') indices = array_ops.placeholder(dtypes.int32, name='indices') array_ops.gather(params, indices, name='gather_output') -def tfmatmul(): +def tfmatmul(_): x = array_ops.placeholder(dtypes.float32, name='x_hold') y = array_ops.placeholder(dtypes.float32, name='y_hold') math_ops.matmul(x, y, name='x_y_prod') -def tfmatmulandadd(): +def tfmatmulandadd(_): # This tests multiple outputs. x = array_ops.placeholder(dtypes.float32, name='x_hold') y = array_ops.placeholder(dtypes.float32, name='y_hold') @@ -96,24 +95,33 @@ def tfmatmulandadd(): math_ops.add(x, y, name='x_y_sum') -def write_graph(build_graph): +def write_graph(build_graph, out_dir): """Build a graph using build_graph and write it out.""" g = ops.Graph() with g.as_default(): - build_graph() - filename = '%s/test_graph_%s.pb' % (FLAGS.out_dir, build_graph.__name__) + build_graph(out_dir) + filename = '%s/test_graph_%s.pb' % (out_dir, build_graph.__name__) with open(filename, 'w') as f: f.write(g.as_graph_def().SerializeToString()) def main(_): - write_graph(tfadd) - write_graph(tfadd_with_ckpt) - write_graph(tfadd_with_ckpt_saver) - write_graph(tfgather) - write_graph(tfmatmul) - write_graph(tfmatmulandadd) + write_graph(tfadd, FLAGS.out_dir) + write_graph(tfadd_with_ckpt, FLAGS.out_dir) + write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir) + write_graph(tfgather, FLAGS.out_dir) + write_graph(tfmatmul, FLAGS.out_dir) + write_graph(tfmatmulandadd, FLAGS.out_dir) if __name__ == '__main__': - app.run() + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--out_dir', + type=str, + default='', + help='Output directory for graphs, checkpoints and savers.' + ) + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/compiler/tests/lstm_test.py b/tensorflow/compiler/tests/lstm_test.py index 9ffeb6c2a2f..31093c65713 100644 --- a/tensorflow/compiler/tests/lstm_test.py +++ b/tensorflow/compiler/tests/lstm_test.py @@ -18,7 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import os +import sys import numpy as np @@ -32,29 +34,8 @@ from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables -from tensorflow.python.platform import flags as flags_lib from tensorflow.python.platform import test -flags = flags_lib -FLAGS = flags.FLAGS - -flags.DEFINE_integer('batch_size', 128, - 'Inputs are fed in batches of this size, for both ' - 'inference and training. Larger values cause the matmul ' - 'in each LSTM cell to have higher dimensionality.') -flags.DEFINE_integer('seq_length', 60, - 'Length of the unrolled sequence of LSTM cells in a layer.' - 'Larger values cause more LSTM matmuls to be run.') -flags.DEFINE_integer('num_inputs', 1024, - 'Dimension of inputs that are fed into each LSTM cell.') -flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.') -flags.DEFINE_string('device', 'gpu', - 'TensorFlow device to assign ops to, e.g. "gpu", "cpu". ' - 'For details see documentation for tf.Graph.device.') - -flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in ' - '*.pbtxt format to this directory.') - def _DumpGraph(graph, basename): if FLAGS.dump_graph_dir: @@ -290,4 +271,54 @@ class LSTMBenchmark(test.Benchmark): if __name__ == '__main__': - test.main() + parser = argparse.ArgumentParser() + parser.register('type', 'bool', lambda v: v.lower() == 'true') + parser.add_argument( + '--batch_size', + type=int, + default=128, + help="""\ + Inputs are fed in batches of this size, for both inference and training. + Larger values cause the matmul in each LSTM cell to have higher + dimensionality.\ + """ + ) + parser.add_argument( + '--seq_length', + type=int, + default=60, + help="""\ + Length of the unrolled sequence of LSTM cells in a layer.Larger values + cause more LSTM matmuls to be run.\ + """ + ) + parser.add_argument( + '--num_inputs', + type=int, + default=1024, + help='Dimension of inputs that are fed into each LSTM cell.' + ) + parser.add_argument( + '--num_nodes', + type=int, + default=1024, + help='Number of nodes in each LSTM cell.' + ) + parser.add_argument( + '--device', + type=str, + default='gpu', + help="""\ + TensorFlow device to assign ops to, e.g. "gpu", "cpu". For details see + documentation for tf.Graph.device.\ + """ + ) + parser.add_argument( + '--dump_graph_dir', + type=str, + default='', + help='If non-empty, dump graphs in *.pbtxt format to this directory.' + ) + global FLAGS # pylint:disable=global-at-module-level + FLAGS, unparsed = parser.parse_known_args() + test.main(argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py index 8d5ff341acd..24b726ac098 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import time + from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.contrib.rnn.python.ops import core_rnn from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl @@ -31,12 +32,8 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import variables -from tensorflow.python.platform import flags from tensorflow.python.platform import test -flags.DEFINE_integer("batch_size", 64, "batch size.") -FLAGS = flags.FLAGS - class CudnnRNNBenchmark(test.Benchmark): """Benchmarks Cudnn LSTM and other related models. diff --git a/tensorflow/examples/learn/wide_n_deep_tutorial.py b/tensorflow/examples/learn/wide_n_deep_tutorial.py index 760b26bd529..09af375314e 100644 --- a/tensorflow/examples/learn/wide_n_deep_tutorial.py +++ b/tensorflow/examples/learn/wide_n_deep_tutorial.py @@ -17,27 +17,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys import tempfile + from six.moves import urllib import pandas as pd import tensorflow as tf -flags = tf.app.flags -FLAGS = flags.FLAGS - -flags.DEFINE_string("model_dir", "", "Base directory for output models.") -flags.DEFINE_string("model_type", "wide_n_deep", - "Valid model types: {'wide', 'deep', 'wide_n_deep'}.") -flags.DEFINE_integer("train_steps", 200, "Number of training steps.") -flags.DEFINE_string( - "train_data", - "", - "Path to the training data.") -flags.DEFINE_string( - "test_data", - "", - "Path to the test data.") COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num", "marital_status", "occupation", "relationship", "race", "gender", @@ -50,10 +38,10 @@ CONTINUOUS_COLUMNS = ["age", "education_num", "capital_gain", "capital_loss", "hours_per_week"] -def maybe_download(): +def maybe_download(train_data, test_data): """Maybe downloads training data and returns train and test file names.""" - if FLAGS.train_data: - train_file_name = FLAGS.train_data + if train_data: + train_file_name = train_data else: train_file = tempfile.NamedTemporaryFile(delete=False) urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long @@ -61,8 +49,8 @@ def maybe_download(): train_file.close() print("Training data is downloaded to %s" % train_file_name) - if FLAGS.test_data: - test_file_name = FLAGS.test_data + if test_data: + test_file_name = test_data else: test_file = tempfile.NamedTemporaryFile(delete=False) urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long @@ -73,7 +61,7 @@ def maybe_download(): return train_file_name, test_file_name -def build_estimator(model_dir): +def build_estimator(model_dir, model_type): """Build an estimator.""" # Sparse base columns. gender = tf.contrib.layers.sparse_column_with_keys(column_name="gender", @@ -128,10 +116,10 @@ def build_estimator(model_dir): hours_per_week, ] - if FLAGS.model_type == "wide": + if model_type == "wide": m = tf.contrib.learn.LinearClassifier(model_dir=model_dir, feature_columns=wide_columns) - elif FLAGS.model_type == "deep": + elif model_type == "deep": m = tf.contrib.learn.DNNClassifier(model_dir=model_dir, feature_columns=deep_columns, hidden_units=[100, 50]) @@ -166,9 +154,9 @@ def input_fn(df): return feature_cols, label -def train_and_eval(): +def train_and_eval(model_dir, model_type, train_steps, train_data, test_data): """Train and evaluate the model.""" - train_file_name, test_file_name = maybe_download() + train_file_name, test_file_name = maybe_download(train_data, test_data) df_train = pd.read_csv( tf.gfile.Open(train_file_name), names=COLUMNS, @@ -190,19 +178,56 @@ def train_and_eval(): df_test[LABEL_COLUMN] = ( df_test["income_bracket"].apply(lambda x: ">50K" in x)).astype(int) - model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir + model_dir = tempfile.mkdtemp() if not model_dir else model_dir print("model directory = %s" % model_dir) - m = build_estimator(model_dir) - m.fit(input_fn=lambda: input_fn(df_train), steps=FLAGS.train_steps) + m = build_estimator(model_dir, model_type) + m.fit(input_fn=lambda: input_fn(df_train), steps=train_steps) results = m.evaluate(input_fn=lambda: input_fn(df_test), steps=1) for key in sorted(results): print("%s: %s" % (key, results[key])) +FLAGS = None + + def main(_): - train_and_eval() + train_and_eval(FLAGS.model_dir, FLAGS.model_type, FLAGS.train_steps, + FLAGS.train_data, FLAGS.test_data) if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--model_dir", + type=str, + default="", + help="Base directory for output models." + ) + parser.add_argument( + "--model_type", + type=str, + default="wide_n_deep", + help="Valid model types: {'wide', 'deep', 'wide_n_deep'}." + ) + parser.add_argument( + "--train_steps", + type=int, + default=200, + help="Number of training steps." + ) + parser.add_argument( + "--train_data", + type=str, + default="", + help="Path to the training data." + ) + parser.add_argument( + "--test_data", + type=str, + default="", + help="Path to the test data." + ) + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/tutorials/estimators/abalone.py b/tensorflow/examples/tutorials/estimators/abalone.py index 6d8ce2cbc7e..9faaa3e6305 100644 --- a/tensorflow/examples/tutorials/estimators/abalone.py +++ b/tensorflow/examples/tutorials/estimators/abalone.py @@ -17,27 +17,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys import tempfile + from six.moves import urllib import numpy as np import tensorflow as tf -flags = tf.app.flags -FLAGS = flags.FLAGS -flags.DEFINE_string( - "train_data", - "", - "Path to the training data.") -flags.DEFINE_string( - "test_data", - "", - "Path to the test data.") -flags.DEFINE_string( - "predict_data", - "", - "Path to the prediction data.") +FLAGS = None tf.logging.set_verbosity(tf.logging.INFO) @@ -45,31 +35,36 @@ tf.logging.set_verbosity(tf.logging.INFO) LEARNING_RATE = 0.001 -def maybe_download(): +def maybe_download(train_data, test_data, predict_data): """Maybe downloads training data and returns train and test file names.""" - if FLAGS.train_data: - train_file_name = FLAGS.train_data + if train_data: + train_file_name = train_data else: train_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_train.csv", train_file.name) # pylint: disable=line-too-long + urllib.request.urlretrieve( + "http://download.tensorflow.org/data/abalone_train.csv", + train_file.name) # pylint: disable=line-too-long train_file_name = train_file.name train_file.close() print("Training data is downloaded to %s" % train_file_name) - if FLAGS.test_data: - test_file_name = FLAGS.test_data + if test_data: + test_file_name = test_data else: test_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name) # pylint: disable=line-too-long + urllib.request.urlretrieve( + "http://download.tensorflow.org/data/abalone_test.csv", test_file.name) # pylint: disable=line-too-long test_file_name = test_file.name test_file.close() print("Test data is downloaded to %s" % test_file_name) - if FLAGS.predict_data: - predict_file_name = FLAGS.predict_data + if predict_data: + predict_file_name = predict_data else: predict_file = tempfile.NamedTemporaryFile(delete=False) - urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name) # pylint: disable=line-too-long + urllib.request.urlretrieve( + "http://download.tensorflow.org/data/abalone_predict.csv", + predict_file.name) # pylint: disable=line-too-long predict_file_name = predict_file.name predict_file.close() print("Prediction data is downloaded to %s" % predict_file_name) @@ -109,32 +104,26 @@ def model_fn(features, targets, mode, params): def main(unused_argv): # Load datasets - abalone_train, abalone_test, abalone_predict = maybe_download() + abalone_train, abalone_test, abalone_predict = maybe_download( + FLAGS.train_data, FLAGS.test_data, FLAGS.predict_data) # Training examples training_set = tf.contrib.learn.datasets.base.load_csv_without_header( - filename=abalone_train, - target_dtype=np.int, - features_dtype=np.float64) + filename=abalone_train, target_dtype=np.int, features_dtype=np.float64) # Test examples test_set = tf.contrib.learn.datasets.base.load_csv_without_header( - filename=abalone_test, - target_dtype=np.int, - features_dtype=np.float64) + filename=abalone_test, target_dtype=np.int, features_dtype=np.float64) # Set of 7 examples for which to predict abalone ages prediction_set = tf.contrib.learn.datasets.base.load_csv_without_header( - filename=abalone_predict, - target_dtype=np.int, - features_dtype=np.float64) + filename=abalone_predict, target_dtype=np.int, features_dtype=np.float64) # Set model params model_params = {"learning_rate": LEARNING_RATE} # Build 2 layer fully connected DNN with 10, 10 units respectively. - nn = tf.contrib.learn.Estimator( - model_fn=model_fn, params=model_params) + nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params) # Fit nn.fit(x=training_set.data, y=training_set.target, steps=5000) @@ -145,11 +134,22 @@ def main(unused_argv): print("Loss: %s" % loss_score) # Print out predictions - predictions = nn.predict(x=prediction_set.data, - as_iterable=True) + predictions = nn.predict(x=prediction_set.data, as_iterable=True) for i, p in enumerate(predictions): print("Prediction %s: %s" % (i + 1, p["ages"])) if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--train_data", type=str, default="", help="Path to the training data.") + parser.add_argument( + "--test_data", type=str, default="", help="Path to the test data.") + parser.add_argument( + "--predict_data", + type=str, + default="", + help="Path to the prediction data.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/g3doc/how_tos/distributed/index.md b/tensorflow/g3doc/how_tos/distributed/index.md index 859bd3b7aa5..bce6af6f803 100644 --- a/tensorflow/g3doc/how_tos/distributed/index.md +++ b/tensorflow/g3doc/how_tos/distributed/index.md @@ -182,19 +182,12 @@ implementing **between-graph replication** and **asynchronous training**. It includes the code for the parameter server and worker tasks. ```python +import argparse +import sys + import tensorflow as tf -# Flags for defining the tf.train.ClusterSpec -tf.app.flags.DEFINE_string("ps_hosts", "", - "Comma-separated list of hostname:port pairs") -tf.app.flags.DEFINE_string("worker_hosts", "", - "Comma-separated list of hostname:port pairs") - -# Flags for defining the tf.train.Server -tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'") -tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job") - -FLAGS = tf.app.flags.FLAGS +FLAGS = None def main(_): @@ -253,7 +246,36 @@ def main(_): sv.stop() if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + # Flags for defining the tf.train.ClusterSpec + parser.add_argument( + "--ps_hosts", + type=str, + default="", + help="Comma-separated list of hostname:port pairs" + ) + parser.add_argument( + "--worker_hosts", + type=str, + default="", + help="Comma-separated list of hostname:port pairs" + ) + parser.add_argument( + "--job_name", + type=str, + default="", + help="One of 'ps', 'worker'" + ) + # Flags for defining the tf.train.Server + parser.add_argument( + "--task_index", + type=int, + default=0, + help="Index of task within the job" + ) + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) ``` To start the trainer with two parameter servers and two workers, use the diff --git a/tensorflow/g3doc/tutorials/estimators/index.md b/tensorflow/g3doc/tutorials/estimators/index.md index 0c7e12f51d7..6dede9eccc9 100644 --- a/tensorflow/g3doc/tutorials/estimators/index.md +++ b/tensorflow/g3doc/tutorials/estimators/index.md @@ -101,35 +101,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys import tempfile import urllib import numpy as np import tensorflow as tf + +FLAGS = None ``` -Then define flags to allow users to optionally specify CSV files for training, -test, and prediction datasets via the command line (by default, files will be -downloaded from [tensorflow.org](https://www.tensorflow.org/)), and enable -logging: +Enable logging: ```python -flags = tf.app.flags -FLAGS = flags.FLAGS - -flags.DEFINE_string( - "train_data", - "", - "Path to the training data.") -flags.DEFINE_string( - "test_data", - "", - "Path to the test data.") -flags.DEFINE_string( - "predict_data", - "", - "Path to the prediction data.") - tf.logging.set_verbosity(tf.logging.INFO) ``` @@ -138,10 +123,10 @@ command-line options, or downloaded from [tensorflow.org](https://www.tensorflow.org/)): ```python -def maybe_download(): +def maybe_download(train_data, test_data, predict_data): """Maybe downloads training data and returns train and test file names.""" - if FLAGS.train_data: - train_file_name = FLAGS.train_data + if train_data: + train_file_name = train_data else: train_file = tempfile.NamedTemporaryFile(delete=False) urllib.urlretrieve("http://download.tensorflow.org/data/abalone_train.csv", train_file.name) @@ -149,8 +134,8 @@ def maybe_download(): train_file.close() print("Training data is downloaded to %s" % train_file_name) - if FLAGS.test_data: - test_file_name = FLAGS.test_data + if test_data: + test_file_name = test_data else: test_file = tempfile.NamedTemporaryFile(delete=False) urllib.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name) @@ -158,8 +143,8 @@ def maybe_download(): test_file.close() print("Test data is downloaded to %s" % test_file_name) - if FLAGS.predict_data: - predict_file_name = FLAGS.predict_data + if predict_data: + predict_file_name = predict_data else: predict_file = tempfile.NamedTemporaryFile(delete=False) urllib.urlretrieve("http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name) @@ -170,12 +155,16 @@ def maybe_download(): return train_file_name, test_file_name, predict_file_name ``` -Finally, create `main()` and load the abalone CSVs into `Datasets`: +Finally, create `main()` and load the abalone CSVs into `Datasets`, +defining flags to allow users to optionally specify CSV files for training, +test, and prediction datasets via the command line (by default, files will be +downloaded from [tensorflow.org](https://www.tensorflow.org/)): ```python def main(unused_argv): # Load datasets - abalone_train, abalone_test, abalone_predict = maybe_download() + abalone_train, abalone_test, abalone_predict = maybe_download( + FLAGS.train_data, FLAGS.test_data, FLAGS.predict_data) # Training examples training_set = tf.contrib.learn.datasets.base.load_csv_without_header( @@ -196,7 +185,28 @@ def main(unused_argv): features_dtype=np.float64) if __name__ == "__main__": - tf.app.run() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--train_data", + type=str, + default="", + help="Path to the training data." + ) + parser.add_argument( + "--test_data", + type=str, + default="", + help="Path to the test data." + ) + parser.add_argument( + "--predict_data", + type=str, + default="", + help="Path to the prediction data." + ) + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) ``` ## Instantiating an Estimator diff --git a/tensorflow/python/ops/batch_norm_benchmark.py b/tensorflow/python/ops/batch_norm_benchmark.py index a4940e09ccb..93a96c90c38 100644 --- a/tensorflow/python/ops/batch_norm_benchmark.py +++ b/tensorflow/python/ops/batch_norm_benchmark.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse +import sys import time from tensorflow.python.client import session as session_lib @@ -30,12 +32,8 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables -from tensorflow.python.platform import flags from tensorflow.python.platform import test -FLAGS = flags.FLAGS -flags.DEFINE_boolean("use_gpu", True, """Run GPU benchmarks.""") - def batch_norm_op(tensor, mean, variance, beta, gamma, scale): """Fused kernel for batch normalization.""" @@ -245,4 +243,16 @@ class BatchNormBenchmark(test.Benchmark): if __name__ == "__main__": - test.main() + parser = argparse.ArgumentParser() + parser.register("type", "bool", lambda v: v.lower() == "true") + parser.add_argument( + "--use_gpu", + type="bool", + nargs="?", + const=True, + default=True, + help="Run GPU benchmarks." + ) + global FLAGS # pylint:disable=global-at-module-level + FLAGS, unparsed = parser.parse_known_args() + test.main(argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/ops/concat_benchmark.py b/tensorflow/python/ops/concat_benchmark.py index 094f8bb2dc8..1ce48b511fc 100644 --- a/tensorflow/python/ops/concat_benchmark.py +++ b/tensorflow/python/ops/concat_benchmark.py @@ -29,12 +29,8 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import variables -from tensorflow.python.platform import flags from tensorflow.python.platform import test -FLAGS = flags.FLAGS -flags.DEFINE_boolean("use_gpu", True, """Run GPU benchmarks.""") - def build_graph(device, input_shape, variable, num_inputs, axis, grad): """Build a graph containing a sequence of concat operations. diff --git a/tensorflow/python/platform/benchmark.py b/tensorflow/python/platform/benchmark.py index 1fa0165d87d..d91c19eeb46 100644 --- a/tensorflow/python/platform/benchmark.py +++ b/tensorflow/python/platform/benchmark.py @@ -312,13 +312,15 @@ def _run_benchmarks(regex): instance_benchmark_fn() -def benchmarks_main(true_main): - """Run benchmarks as declared in args. +def benchmarks_main(true_main, argv=None): + """Run benchmarks as declared in argv. Args: true_main: True main function to run if benchmarks are not requested. + argv: the command line arguments (if None, uses sys.argv). """ - argv = sys.argv + if argv is None: + argv = sys.argv found_arg = [arg for arg in argv if arg.startswith("--benchmarks=") or arg.startswith("-benchmarks=")] @@ -327,6 +329,6 @@ def benchmarks_main(true_main): argv.remove(found_arg[0]) regex = found_arg[0].split("=")[1] - app.run(lambda _: _run_benchmarks(regex)) + app.run(lambda _: _run_benchmarks(regex), argv=argv) else: true_main() diff --git a/tensorflow/python/platform/googletest.py b/tensorflow/python/platform/googletest.py index d88a056e786..be5c833d6da 100644 --- a/tensorflow/python/platform/googletest.py +++ b/tensorflow/python/platform/googletest.py @@ -80,9 +80,12 @@ def g_main(argv): # Redefine main to allow running benchmarks -def main(): # pylint: disable=function-redefined +def main(argv=None): # pylint: disable=function-redefined def main_wrapper(): - return app.run(main=g_main, argv=sys.argv) + args = argv + if args is None: + args = sys.argv + return app.run(main=g_main, argv=args) benchmark.benchmarks_main(true_main=main_wrapper) diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py index 0563b370ea0..3f2c1d97b48 100644 --- a/tensorflow/python/platform/test.py +++ b/tensorflow/python/platform/test.py @@ -88,9 +88,9 @@ else: Benchmark = _googletest.Benchmark # pylint: disable=invalid-name -def main(): +def main(argv=None): """Runs all unit tests.""" - return _googletest.main() + return _googletest.main(argv) def get_temp_dir():