More conversions of flags library to argparse.
Add argv to benchmark/main function so they can handle passing command line arguments. Change: 144254260
This commit is contained in:
parent
37af1b8790
commit
963674de71
@ -18,6 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -27,22 +30,18 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags as flags_lib
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
flags = flags_lib
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string('out_dir', '',
|
||||
'Output directory for graphs, checkpoints and savers.')
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def tfadd():
|
||||
def tfadd(_):
|
||||
x = constant_op.constant([1], name='x_const')
|
||||
y = constant_op.constant([2], name='y_const')
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
|
||||
|
||||
def tfadd_with_ckpt():
|
||||
def tfadd_with_ckpt(out_dir):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x_hold')
|
||||
y = variables.Variable(constant_op.constant([0]), name='y_saved')
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
@ -53,11 +52,11 @@ def tfadd_with_ckpt():
|
||||
sess.run(init_op)
|
||||
sess.run(y.assign(y + 42))
|
||||
# Without the checkpoint, the variable won't be set to 42.
|
||||
ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % FLAGS.out_dir
|
||||
ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % out_dir
|
||||
saver.save(sess, ckpt)
|
||||
|
||||
|
||||
def tfadd_with_ckpt_saver():
|
||||
def tfadd_with_ckpt_saver(out_dir):
|
||||
x = array_ops.placeholder(dtypes.int32, name='x_hold')
|
||||
y = variables.Variable(constant_op.constant([0]), name='y_saved')
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
@ -68,27 +67,27 @@ def tfadd_with_ckpt_saver():
|
||||
sess.run(init_op)
|
||||
sess.run(y.assign(y + 42))
|
||||
# Without the checkpoint, the variable won't be set to 42.
|
||||
ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % FLAGS.out_dir
|
||||
ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % out_dir
|
||||
saver.save(sess, ckpt_file)
|
||||
# Without the SaverDef, the restore op won't be named correctly.
|
||||
saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % FLAGS.out_dir
|
||||
saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % out_dir
|
||||
with open(saver_file, 'w') as f:
|
||||
f.write(saver.as_saver_def().SerializeToString())
|
||||
|
||||
|
||||
def tfgather():
|
||||
def tfgather(_):
|
||||
params = array_ops.placeholder(dtypes.float32, name='params')
|
||||
indices = array_ops.placeholder(dtypes.int32, name='indices')
|
||||
array_ops.gather(params, indices, name='gather_output')
|
||||
|
||||
|
||||
def tfmatmul():
|
||||
def tfmatmul(_):
|
||||
x = array_ops.placeholder(dtypes.float32, name='x_hold')
|
||||
y = array_ops.placeholder(dtypes.float32, name='y_hold')
|
||||
math_ops.matmul(x, y, name='x_y_prod')
|
||||
|
||||
|
||||
def tfmatmulandadd():
|
||||
def tfmatmulandadd(_):
|
||||
# This tests multiple outputs.
|
||||
x = array_ops.placeholder(dtypes.float32, name='x_hold')
|
||||
y = array_ops.placeholder(dtypes.float32, name='y_hold')
|
||||
@ -96,24 +95,33 @@ def tfmatmulandadd():
|
||||
math_ops.add(x, y, name='x_y_sum')
|
||||
|
||||
|
||||
def write_graph(build_graph):
|
||||
def write_graph(build_graph, out_dir):
|
||||
"""Build a graph using build_graph and write it out."""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
build_graph()
|
||||
filename = '%s/test_graph_%s.pb' % (FLAGS.out_dir, build_graph.__name__)
|
||||
build_graph(out_dir)
|
||||
filename = '%s/test_graph_%s.pb' % (out_dir, build_graph.__name__)
|
||||
with open(filename, 'w') as f:
|
||||
f.write(g.as_graph_def().SerializeToString())
|
||||
|
||||
|
||||
def main(_):
|
||||
write_graph(tfadd)
|
||||
write_graph(tfadd_with_ckpt)
|
||||
write_graph(tfadd_with_ckpt_saver)
|
||||
write_graph(tfgather)
|
||||
write_graph(tfmatmul)
|
||||
write_graph(tfmatmulandadd)
|
||||
write_graph(tfadd, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt, FLAGS.out_dir)
|
||||
write_graph(tfadd_with_ckpt_saver, FLAGS.out_dir)
|
||||
write_graph(tfgather, FLAGS.out_dir)
|
||||
write_graph(tfmatmul, FLAGS.out_dir)
|
||||
write_graph(tfmatmulandadd, FLAGS.out_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
||||
parser.add_argument(
|
||||
'--out_dir',
|
||||
type=str,
|
||||
default='',
|
||||
help='Output directory for graphs, checkpoints and savers.'
|
||||
)
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -18,7 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -32,29 +34,8 @@ from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import flags as flags_lib
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
flags = flags_lib
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_integer('batch_size', 128,
|
||||
'Inputs are fed in batches of this size, for both '
|
||||
'inference and training. Larger values cause the matmul '
|
||||
'in each LSTM cell to have higher dimensionality.')
|
||||
flags.DEFINE_integer('seq_length', 60,
|
||||
'Length of the unrolled sequence of LSTM cells in a layer.'
|
||||
'Larger values cause more LSTM matmuls to be run.')
|
||||
flags.DEFINE_integer('num_inputs', 1024,
|
||||
'Dimension of inputs that are fed into each LSTM cell.')
|
||||
flags.DEFINE_integer('num_nodes', 1024, 'Number of nodes in each LSTM cell.')
|
||||
flags.DEFINE_string('device', 'gpu',
|
||||
'TensorFlow device to assign ops to, e.g. "gpu", "cpu". '
|
||||
'For details see documentation for tf.Graph.device.')
|
||||
|
||||
flags.DEFINE_string('dump_graph_dir', '', 'If non-empty, dump graphs in '
|
||||
'*.pbtxt format to this directory.')
|
||||
|
||||
|
||||
def _DumpGraph(graph, basename):
|
||||
if FLAGS.dump_graph_dir:
|
||||
@ -290,4 +271,54 @@ class LSTMBenchmark(test.Benchmark):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
||||
parser.add_argument(
|
||||
'--batch_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help="""\
|
||||
Inputs are fed in batches of this size, for both inference and training.
|
||||
Larger values cause the matmul in each LSTM cell to have higher
|
||||
dimensionality.\
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--seq_length',
|
||||
type=int,
|
||||
default=60,
|
||||
help="""\
|
||||
Length of the unrolled sequence of LSTM cells in a layer.Larger values
|
||||
cause more LSTM matmuls to be run.\
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_inputs',
|
||||
type=int,
|
||||
default=1024,
|
||||
help='Dimension of inputs that are fed into each LSTM cell.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--num_nodes',
|
||||
type=int,
|
||||
default=1024,
|
||||
help='Number of nodes in each LSTM cell.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--device',
|
||||
type=str,
|
||||
default='gpu',
|
||||
help="""\
|
||||
TensorFlow device to assign ops to, e.g. "gpu", "cpu". For details see
|
||||
documentation for tf.Graph.device.\
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
'--dump_graph_dir',
|
||||
type=str,
|
||||
default='',
|
||||
help='If non-empty, dump graphs in *.pbtxt format to this directory.'
|
||||
)
|
||||
global FLAGS # pylint:disable=global-at-module-level
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
test.main(argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn
|
||||
from tensorflow.contrib.rnn.python.ops import core_rnn_cell_impl
|
||||
@ -31,12 +32,8 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
flags.DEFINE_integer("batch_size", 64, "batch size.")
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
class CudnnRNNBenchmark(test.Benchmark):
|
||||
"""Benchmarks Cudnn LSTM and other related models.
|
||||
|
@ -17,27 +17,15 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
import pandas as pd
|
||||
import tensorflow as tf
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("model_dir", "", "Base directory for output models.")
|
||||
flags.DEFINE_string("model_type", "wide_n_deep",
|
||||
"Valid model types: {'wide', 'deep', 'wide_n_deep'}.")
|
||||
flags.DEFINE_integer("train_steps", 200, "Number of training steps.")
|
||||
flags.DEFINE_string(
|
||||
"train_data",
|
||||
"",
|
||||
"Path to the training data.")
|
||||
flags.DEFINE_string(
|
||||
"test_data",
|
||||
"",
|
||||
"Path to the test data.")
|
||||
|
||||
COLUMNS = ["age", "workclass", "fnlwgt", "education", "education_num",
|
||||
"marital_status", "occupation", "relationship", "race", "gender",
|
||||
@ -50,10 +38,10 @@ CONTINUOUS_COLUMNS = ["age", "education_num", "capital_gain", "capital_loss",
|
||||
"hours_per_week"]
|
||||
|
||||
|
||||
def maybe_download():
|
||||
def maybe_download(train_data, test_data):
|
||||
"""Maybe downloads training data and returns train and test file names."""
|
||||
if FLAGS.train_data:
|
||||
train_file_name = FLAGS.train_data
|
||||
if train_data:
|
||||
train_file_name = train_data
|
||||
else:
|
||||
train_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data", train_file.name) # pylint: disable=line-too-long
|
||||
@ -61,8 +49,8 @@ def maybe_download():
|
||||
train_file.close()
|
||||
print("Training data is downloaded to %s" % train_file_name)
|
||||
|
||||
if FLAGS.test_data:
|
||||
test_file_name = FLAGS.test_data
|
||||
if test_data:
|
||||
test_file_name = test_data
|
||||
else:
|
||||
test_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
urllib.request.urlretrieve("http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test", test_file.name) # pylint: disable=line-too-long
|
||||
@ -73,7 +61,7 @@ def maybe_download():
|
||||
return train_file_name, test_file_name
|
||||
|
||||
|
||||
def build_estimator(model_dir):
|
||||
def build_estimator(model_dir, model_type):
|
||||
"""Build an estimator."""
|
||||
# Sparse base columns.
|
||||
gender = tf.contrib.layers.sparse_column_with_keys(column_name="gender",
|
||||
@ -128,10 +116,10 @@ def build_estimator(model_dir):
|
||||
hours_per_week,
|
||||
]
|
||||
|
||||
if FLAGS.model_type == "wide":
|
||||
if model_type == "wide":
|
||||
m = tf.contrib.learn.LinearClassifier(model_dir=model_dir,
|
||||
feature_columns=wide_columns)
|
||||
elif FLAGS.model_type == "deep":
|
||||
elif model_type == "deep":
|
||||
m = tf.contrib.learn.DNNClassifier(model_dir=model_dir,
|
||||
feature_columns=deep_columns,
|
||||
hidden_units=[100, 50])
|
||||
@ -166,9 +154,9 @@ def input_fn(df):
|
||||
return feature_cols, label
|
||||
|
||||
|
||||
def train_and_eval():
|
||||
def train_and_eval(model_dir, model_type, train_steps, train_data, test_data):
|
||||
"""Train and evaluate the model."""
|
||||
train_file_name, test_file_name = maybe_download()
|
||||
train_file_name, test_file_name = maybe_download(train_data, test_data)
|
||||
df_train = pd.read_csv(
|
||||
tf.gfile.Open(train_file_name),
|
||||
names=COLUMNS,
|
||||
@ -190,19 +178,56 @@ def train_and_eval():
|
||||
df_test[LABEL_COLUMN] = (
|
||||
df_test["income_bracket"].apply(lambda x: ">50K" in x)).astype(int)
|
||||
|
||||
model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
|
||||
model_dir = tempfile.mkdtemp() if not model_dir else model_dir
|
||||
print("model directory = %s" % model_dir)
|
||||
|
||||
m = build_estimator(model_dir)
|
||||
m.fit(input_fn=lambda: input_fn(df_train), steps=FLAGS.train_steps)
|
||||
m = build_estimator(model_dir, model_type)
|
||||
m.fit(input_fn=lambda: input_fn(df_train), steps=train_steps)
|
||||
results = m.evaluate(input_fn=lambda: input_fn(df_test), steps=1)
|
||||
for key in sorted(results):
|
||||
print("%s: %s" % (key, results[key]))
|
||||
|
||||
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def main(_):
|
||||
train_and_eval()
|
||||
train_and_eval(FLAGS.model_dir, FLAGS.model_type, FLAGS.train_steps,
|
||||
FLAGS.train_data, FLAGS.test_data)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Base directory for output models."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
type=str,
|
||||
default="wide_n_deep",
|
||||
help="Valid model types: {'wide', 'deep', 'wide_n_deep'}."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_steps",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Number of training steps."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_data",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to the test data."
|
||||
)
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -17,27 +17,17 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from six.moves import urllib
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string(
|
||||
"train_data",
|
||||
"",
|
||||
"Path to the training data.")
|
||||
flags.DEFINE_string(
|
||||
"test_data",
|
||||
"",
|
||||
"Path to the test data.")
|
||||
flags.DEFINE_string(
|
||||
"predict_data",
|
||||
"",
|
||||
"Path to the prediction data.")
|
||||
FLAGS = None
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
@ -45,31 +35,36 @@ tf.logging.set_verbosity(tf.logging.INFO)
|
||||
LEARNING_RATE = 0.001
|
||||
|
||||
|
||||
def maybe_download():
|
||||
def maybe_download(train_data, test_data, predict_data):
|
||||
"""Maybe downloads training data and returns train and test file names."""
|
||||
if FLAGS.train_data:
|
||||
train_file_name = FLAGS.train_data
|
||||
if train_data:
|
||||
train_file_name = train_data
|
||||
else:
|
||||
train_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_train.csv", train_file.name) # pylint: disable=line-too-long
|
||||
urllib.request.urlretrieve(
|
||||
"http://download.tensorflow.org/data/abalone_train.csv",
|
||||
train_file.name) # pylint: disable=line-too-long
|
||||
train_file_name = train_file.name
|
||||
train_file.close()
|
||||
print("Training data is downloaded to %s" % train_file_name)
|
||||
|
||||
if FLAGS.test_data:
|
||||
test_file_name = FLAGS.test_data
|
||||
if test_data:
|
||||
test_file_name = test_data
|
||||
else:
|
||||
test_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name) # pylint: disable=line-too-long
|
||||
urllib.request.urlretrieve(
|
||||
"http://download.tensorflow.org/data/abalone_test.csv", test_file.name) # pylint: disable=line-too-long
|
||||
test_file_name = test_file.name
|
||||
test_file.close()
|
||||
print("Test data is downloaded to %s" % test_file_name)
|
||||
|
||||
if FLAGS.predict_data:
|
||||
predict_file_name = FLAGS.predict_data
|
||||
if predict_data:
|
||||
predict_file_name = predict_data
|
||||
else:
|
||||
predict_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
urllib.request.urlretrieve("http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name) # pylint: disable=line-too-long
|
||||
urllib.request.urlretrieve(
|
||||
"http://download.tensorflow.org/data/abalone_predict.csv",
|
||||
predict_file.name) # pylint: disable=line-too-long
|
||||
predict_file_name = predict_file.name
|
||||
predict_file.close()
|
||||
print("Prediction data is downloaded to %s" % predict_file_name)
|
||||
@ -109,32 +104,26 @@ def model_fn(features, targets, mode, params):
|
||||
|
||||
def main(unused_argv):
|
||||
# Load datasets
|
||||
abalone_train, abalone_test, abalone_predict = maybe_download()
|
||||
abalone_train, abalone_test, abalone_predict = maybe_download(
|
||||
FLAGS.train_data, FLAGS.test_data, FLAGS.predict_data)
|
||||
|
||||
# Training examples
|
||||
training_set = tf.contrib.learn.datasets.base.load_csv_without_header(
|
||||
filename=abalone_train,
|
||||
target_dtype=np.int,
|
||||
features_dtype=np.float64)
|
||||
filename=abalone_train, target_dtype=np.int, features_dtype=np.float64)
|
||||
|
||||
# Test examples
|
||||
test_set = tf.contrib.learn.datasets.base.load_csv_without_header(
|
||||
filename=abalone_test,
|
||||
target_dtype=np.int,
|
||||
features_dtype=np.float64)
|
||||
filename=abalone_test, target_dtype=np.int, features_dtype=np.float64)
|
||||
|
||||
# Set of 7 examples for which to predict abalone ages
|
||||
prediction_set = tf.contrib.learn.datasets.base.load_csv_without_header(
|
||||
filename=abalone_predict,
|
||||
target_dtype=np.int,
|
||||
features_dtype=np.float64)
|
||||
filename=abalone_predict, target_dtype=np.int, features_dtype=np.float64)
|
||||
|
||||
# Set model params
|
||||
model_params = {"learning_rate": LEARNING_RATE}
|
||||
|
||||
# Build 2 layer fully connected DNN with 10, 10 units respectively.
|
||||
nn = tf.contrib.learn.Estimator(
|
||||
model_fn=model_fn, params=model_params)
|
||||
nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params)
|
||||
|
||||
# Fit
|
||||
nn.fit(x=training_set.data, y=training_set.target, steps=5000)
|
||||
@ -145,11 +134,22 @@ def main(unused_argv):
|
||||
print("Loss: %s" % loss_score)
|
||||
|
||||
# Print out predictions
|
||||
predictions = nn.predict(x=prediction_set.data,
|
||||
as_iterable=True)
|
||||
predictions = nn.predict(x=prediction_set.data, as_iterable=True)
|
||||
for i, p in enumerate(predictions):
|
||||
print("Prediction %s: %s" % (i + 1, p["ages"]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--train_data", type=str, default="", help="Path to the training data.")
|
||||
parser.add_argument(
|
||||
"--test_data", type=str, default="", help="Path to the test data.")
|
||||
parser.add_argument(
|
||||
"--predict_data",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to the prediction data.")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -182,19 +182,12 @@ implementing **between-graph replication** and **asynchronous training**. It
|
||||
includes the code for the parameter server and worker tasks.
|
||||
|
||||
```python
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# Flags for defining the tf.train.ClusterSpec
|
||||
tf.app.flags.DEFINE_string("ps_hosts", "",
|
||||
"Comma-separated list of hostname:port pairs")
|
||||
tf.app.flags.DEFINE_string("worker_hosts", "",
|
||||
"Comma-separated list of hostname:port pairs")
|
||||
|
||||
# Flags for defining the tf.train.Server
|
||||
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
|
||||
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def main(_):
|
||||
@ -253,7 +246,36 @@ def main(_):
|
||||
sv.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
# Flags for defining the tf.train.ClusterSpec
|
||||
parser.add_argument(
|
||||
"--ps_hosts",
|
||||
type=str,
|
||||
default="",
|
||||
help="Comma-separated list of hostname:port pairs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--worker_hosts",
|
||||
type=str,
|
||||
default="",
|
||||
help="Comma-separated list of hostname:port pairs"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job_name",
|
||||
type=str,
|
||||
default="",
|
||||
help="One of 'ps', 'worker'"
|
||||
)
|
||||
# Flags for defining the tf.train.Server
|
||||
parser.add_argument(
|
||||
"--task_index",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Index of task within the job"
|
||||
)
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
```
|
||||
|
||||
To start the trainer with two parameter servers and two workers, use the
|
||||
|
@ -101,35 +101,20 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import tempfile
|
||||
import urllib
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
FLAGS = None
|
||||
```
|
||||
|
||||
Then define flags to allow users to optionally specify CSV files for training,
|
||||
test, and prediction datasets via the command line (by default, files will be
|
||||
downloaded from [tensorflow.org](https://www.tensorflow.org/)), and enable
|
||||
logging:
|
||||
Enable logging:
|
||||
|
||||
```python
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string(
|
||||
"train_data",
|
||||
"",
|
||||
"Path to the training data.")
|
||||
flags.DEFINE_string(
|
||||
"test_data",
|
||||
"",
|
||||
"Path to the test data.")
|
||||
flags.DEFINE_string(
|
||||
"predict_data",
|
||||
"",
|
||||
"Path to the prediction data.")
|
||||
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
```
|
||||
|
||||
@ -138,10 +123,10 @@ command-line options, or downloaded from
|
||||
[tensorflow.org](https://www.tensorflow.org/)):
|
||||
|
||||
```python
|
||||
def maybe_download():
|
||||
def maybe_download(train_data, test_data, predict_data):
|
||||
"""Maybe downloads training data and returns train and test file names."""
|
||||
if FLAGS.train_data:
|
||||
train_file_name = FLAGS.train_data
|
||||
if train_data:
|
||||
train_file_name = train_data
|
||||
else:
|
||||
train_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
urllib.urlretrieve("http://download.tensorflow.org/data/abalone_train.csv", train_file.name)
|
||||
@ -149,8 +134,8 @@ def maybe_download():
|
||||
train_file.close()
|
||||
print("Training data is downloaded to %s" % train_file_name)
|
||||
|
||||
if FLAGS.test_data:
|
||||
test_file_name = FLAGS.test_data
|
||||
if test_data:
|
||||
test_file_name = test_data
|
||||
else:
|
||||
test_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
urllib.urlretrieve("http://download.tensorflow.org/data/abalone_test.csv", test_file.name)
|
||||
@ -158,8 +143,8 @@ def maybe_download():
|
||||
test_file.close()
|
||||
print("Test data is downloaded to %s" % test_file_name)
|
||||
|
||||
if FLAGS.predict_data:
|
||||
predict_file_name = FLAGS.predict_data
|
||||
if predict_data:
|
||||
predict_file_name = predict_data
|
||||
else:
|
||||
predict_file = tempfile.NamedTemporaryFile(delete=False)
|
||||
urllib.urlretrieve("http://download.tensorflow.org/data/abalone_predict.csv", predict_file.name)
|
||||
@ -170,12 +155,16 @@ def maybe_download():
|
||||
return train_file_name, test_file_name, predict_file_name
|
||||
```
|
||||
|
||||
Finally, create `main()` and load the abalone CSVs into `Datasets`:
|
||||
Finally, create `main()` and load the abalone CSVs into `Datasets`,
|
||||
defining flags to allow users to optionally specify CSV files for training,
|
||||
test, and prediction datasets via the command line (by default, files will be
|
||||
downloaded from [tensorflow.org](https://www.tensorflow.org/)):
|
||||
|
||||
```python
|
||||
def main(unused_argv):
|
||||
# Load datasets
|
||||
abalone_train, abalone_test, abalone_predict = maybe_download()
|
||||
abalone_train, abalone_test, abalone_predict = maybe_download(
|
||||
FLAGS.train_data, FLAGS.test_data, FLAGS.predict_data)
|
||||
|
||||
# Training examples
|
||||
training_set = tf.contrib.learn.datasets.base.load_csv_without_header(
|
||||
@ -196,7 +185,28 @@ def main(unused_argv):
|
||||
features_dtype=np.float64)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--train_data",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to the training data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test_data",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to the test data."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--predict_data",
|
||||
type=str,
|
||||
default="",
|
||||
help="Path to the prediction data."
|
||||
)
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
```
|
||||
|
||||
## Instantiating an Estimator
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
|
||||
from tensorflow.python.client import session as session_lib
|
||||
@ -30,12 +32,8 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_impl
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_boolean("use_gpu", True, """Run GPU benchmarks.""")
|
||||
|
||||
|
||||
def batch_norm_op(tensor, mean, variance, beta, gamma, scale):
|
||||
"""Fused kernel for batch normalization."""
|
||||
@ -245,4 +243,16 @@ class BatchNormBenchmark(test.Benchmark):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--use_gpu",
|
||||
type="bool",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=True,
|
||||
help="Run GPU benchmarks."
|
||||
)
|
||||
global FLAGS # pylint:disable=global-at-module-level
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
test.main(argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -29,12 +29,8 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_boolean("use_gpu", True, """Run GPU benchmarks.""")
|
||||
|
||||
|
||||
def build_graph(device, input_shape, variable, num_inputs, axis, grad):
|
||||
"""Build a graph containing a sequence of concat operations.
|
||||
|
@ -312,13 +312,15 @@ def _run_benchmarks(regex):
|
||||
instance_benchmark_fn()
|
||||
|
||||
|
||||
def benchmarks_main(true_main):
|
||||
"""Run benchmarks as declared in args.
|
||||
def benchmarks_main(true_main, argv=None):
|
||||
"""Run benchmarks as declared in argv.
|
||||
|
||||
Args:
|
||||
true_main: True main function to run if benchmarks are not requested.
|
||||
argv: the command line arguments (if None, uses sys.argv).
|
||||
"""
|
||||
argv = sys.argv
|
||||
if argv is None:
|
||||
argv = sys.argv
|
||||
found_arg = [arg for arg in argv
|
||||
if arg.startswith("--benchmarks=")
|
||||
or arg.startswith("-benchmarks=")]
|
||||
@ -327,6 +329,6 @@ def benchmarks_main(true_main):
|
||||
argv.remove(found_arg[0])
|
||||
|
||||
regex = found_arg[0].split("=")[1]
|
||||
app.run(lambda _: _run_benchmarks(regex))
|
||||
app.run(lambda _: _run_benchmarks(regex), argv=argv)
|
||||
else:
|
||||
true_main()
|
||||
|
@ -80,9 +80,12 @@ def g_main(argv):
|
||||
|
||||
|
||||
# Redefine main to allow running benchmarks
|
||||
def main(): # pylint: disable=function-redefined
|
||||
def main(argv=None): # pylint: disable=function-redefined
|
||||
def main_wrapper():
|
||||
return app.run(main=g_main, argv=sys.argv)
|
||||
args = argv
|
||||
if args is None:
|
||||
args = sys.argv
|
||||
return app.run(main=g_main, argv=args)
|
||||
benchmark.benchmarks_main(true_main=main_wrapper)
|
||||
|
||||
|
||||
|
@ -88,9 +88,9 @@ else:
|
||||
Benchmark = _googletest.Benchmark # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def main():
|
||||
def main(argv=None):
|
||||
"""Runs all unit tests."""
|
||||
return _googletest.main()
|
||||
return _googletest.main(argv)
|
||||
|
||||
|
||||
def get_temp_dir():
|
||||
|
Loading…
Reference in New Issue
Block a user