Convert tf.flags usage to argparse. Move use of FLAGS globals into main() only.

Change: 143799731
This commit is contained in:
Vijay Vasudevan 2017-01-06 12:07:23 -08:00 committed by TensorFlower Gardener
parent 4b3d59a771
commit 2b351f224d
12 changed files with 414 additions and 157 deletions

View File

@ -31,27 +31,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import socket
import sys
from tensorflow.python.platform import app
# pylint: disable=g-import-not-at-top
# Official recommended way of turning on fast protocol buffers as of 10/21/14
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2"
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
FLAGS = flags.FLAGS
FLAGS = None
flags.DEFINE_string(
"password", None,
"Password to require. If set, the server will allow public access."
" Only used if notebook config file does not exist.")
flags.DEFINE_string("notebook_dir", "experimental/brain/notebooks",
"root location where to store notebooks")
ORIG_ARGV = sys.argv
# Main notebook process calls itself with argv[1]="kernel" to start kernel
@ -108,6 +102,21 @@ def main(unused_argv):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--password",
type=str,
default=None,
help="""\
Password to require. If set, the server will allow public access. Only
used if notebook config file does not exist.\
""")
parser.add_argument(
"--notebook_dir",
type=str,
default="experimental/brain/notebooks",
help="root location where to store notebooks")
# When the user starts the main notebook process, we don't touch sys.argv.
# When the main process launches kernel subprocesses, it writes all flags
# to a tmpfile and sets --flagfile to that tmpfile, so for kernel
@ -118,4 +127,6 @@ if __name__ == "__main__":
# Drop everything except --flagfile.
sys.argv = ([sys.argv[0]] +
[x for x in sys.argv[1:] if x.startswith("--flagfile")])
app.run()
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,23 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
# Google-internal import(s).
from tensorflow.python.debug import debug_data
from tensorflow.python.debug.cli import analyzer_cli
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
FLAGS = flags.FLAGS
flags.DEFINE_string("dump_dir", "", "tfdbg dump directory to load")
flags.DEFINE_string("ui_type", "curses",
"Command-line user interface type (curses | readline)")
flags.DEFINE_boolean(
"log_usage", True, "Whether the usage of this tool is to be logged")
flags.DEFINE_boolean(
"validate_graph", True,
"Whether the dumped tensors will be validated against the GraphDefs")
def main(_):
@ -58,4 +49,30 @@ def main(_):
if __name__ == "__main__":
app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--dump_dir", type=str, default="", help="tfdbg dump directory to load")
parser.add_argument(
"--log_usage",
type="bool",
nargs="?",
const=True,
default=True,
help="Whether the usage of this tool is to be logged")
parser.add_argument(
"--ui_type",
type=str,
default="curses"
help="Command-line user interface type (curses | readline)")
parser.add_argument(
"--validate_graph",
nargs="?",
const=True,
type="bool",
default=True,
help="""\
Whether the dumped tensors will be validated against the GraphDefs\
""")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,20 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import numpy as np
import tensorflow as tf
from tensorflow.python import debug as tf_debug
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("error", "shape_mismatch", "Type of the error to generate "
"(shape_mismatch | uninitialized_variable | no_error).")
flags.DEFINE_string("ui_type", "curses",
"Command-line user interface type (curses | readline)")
flags.DEFINE_boolean("debug", False,
"Use debugger to track down bad values during training")
def main(_):
sess = tf.Session()
@ -60,4 +54,27 @@ def main(_):
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--error",
type=str,
default="shape_mismatch",
help="""\
Type of the error to generate (shape_mismatch | uninitialized_variable |
no_error).\
""")
parser.add_argument(
"--ui_type",
type=str,
default="curses"
help="Command-line user interface type (curses | readline)")
parser.add_argument(
"--debug",
type="bool",
nargs="?",
const=True,
default=False,
help="Use debugger to track down bad values during training")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,19 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.python import debug as tf_debug
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer("tensor_size", 30,
"Size of tensor. E.g., if the value is 30, the tensors "
"will have shape [30, 30].")
flags.DEFINE_integer("length", 20,
"Length of the fibonacci sequence to compute.")
FLAGS = None
def main(_):
@ -54,4 +51,20 @@ def main(_):
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--tensor_size",
type=int,
default=30,
help="""\
Size of tensor. E.g., if the value is 30, the tensors will have shape
[30, 30].\
""")
parser.add_argument(
"--length",
type=int,
default=20,
help="Length of the fibonacci sequence to compute.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -24,22 +24,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python import debug as tf_debug
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer("max_steps", 10, "Number of steps to run trainer.")
flags.DEFINE_integer("train_batch_size", 100,
"Batch size used during training.")
flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.")
flags.DEFINE_string("data_dir", "/tmp/mnist_data", "Directory for storing data")
flags.DEFINE_string("ui_type", "curses",
"Command-line user interface type (curses | readline)")
flags.DEFINE_boolean("debug", False,
"Use debugger to track down bad values during training")
IMAGE_SIZE = 28
HIDDEN_SIZE = 500
@ -137,4 +129,39 @@ def main(_):
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--max_steps",
type=int,
default=10,
help="Number of steps to run trainer.")
parser.add_argument(
"--train_batch_size",
type=int,
default=100,
help="Batch size used during training.")
parser.add_argument(
"--learning_rate",
type=float,
default=0.025,
help="Initial learning rate.")
parser.add_argument(
"--data_dir",
type=str,
default="/tmp/mnist_data",
help="Directory for storing data")
parser.add_argument(
"--ui_type",
type=str,
default="curses"
help="Command-line user interface type (curses | readline)")
parser.add_argument(
"--debug",
type="bool",
nargs="?",
const=True,
default=False,
help="Use debugger to track down bad values during training")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,7 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tempfile
import numpy as np
@ -26,33 +28,26 @@ import tensorflow as tf
from tensorflow.python import debug as tf_debug
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("data_dir", "/tmp/iris_data",
"Directory to save the training and test data in.")
flags.DEFINE_string("model_dir", "", "Directory to save the trained model in.")
flags.DEFINE_integer("train_steps", 10, "Number of steps to run trainer.")
flags.DEFINE_string("ui_type", "curses",
"Command-line user interface type (curses | readline)")
flags.DEFINE_boolean("debug", False,
"Use debugger to track down bad values during training")
# URLs to download data sets from, if necessary.
IRIS_TRAINING_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_training.csv"
IRIS_TEST_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_test.csv"
def maybe_download_data():
def maybe_download_data(data_dir):
"""Download data sets if necessary.
Args:
data_dir: Path to where data should be downloaded.
Returns:
Paths to the training and test data files.
"""
if not os.path.isdir(FLAGS.data_dir):
os.makedirs(FLAGS.data_dir)
if not os.path.isdir(data_dir):
os.makedirs(data_dir)
training_data_path = os.path.join(FLAGS.data_dir,
training_data_path = os.path.join(data_dir,
os.path.basename(IRIS_TRAINING_DATA_URL))
if not os.path.isfile(training_data_path):
train_file = open(training_data_path, "wt")
@ -61,8 +56,7 @@ def maybe_download_data():
print("Training data are downloaded to %s" % train_file.name)
test_data_path = os.path.join(FLAGS.data_dir,
os.path.basename(IRIS_TEST_DATA_URL))
test_data_path = os.path.join(data_dir, os.path.basename(IRIS_TEST_DATA_URL))
if not os.path.isfile(test_data_path):
test_file = open(test_data_path, "wt")
urllib.request.urlretrieve(IRIS_TEST_DATA_URL, test_file.name)
@ -74,7 +68,7 @@ def maybe_download_data():
def main(_):
training_data_path, test_data_path = maybe_download_data()
training_data_path, test_data_path = maybe_download_data(FLAGS.data_dir)
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
@ -115,4 +109,34 @@ def main(_):
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--data_dir",
type=str,
default="/tmp/iris_data",
help="Directory to save the training and test data in.")
parser.add_argument(
"--model_dir",
type=str,
default="",
help="Directory to save the trained model in.")
parser.add_argument(
"--train_steps",
type=int,
default=10,
help="Number of steps to run trainer.")
parser.add_argument(
"--ui_type",
type=str,
default="curses"
help="Command-line user interface type (curses | readline)")
parser.add_argument(
"--debug",
type="bool",
nargs="?",
const=True,
default=False,
help="Use debugger to track down bad values during training")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import collections
import os.path
import sys
@ -31,12 +32,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import docs
from tensorflow.python.framework import framework_lib
tf.flags.DEFINE_string("out_dir", None,
"Directory to which docs should be written.")
tf.flags.DEFINE_boolean("print_hidden_regex", False,
"Dump a regular expression matching any hidden symbol")
FLAGS = tf.flags.FLAGS
FLAGS = None
PREFIX_TEXT = """
@ -309,4 +305,19 @@ def main(unused_argv):
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--out_dir",
type=str,
default=None,
help="Directory to which docs should be written.")
parser.add_argument(
"--print_hidden_regex",
type="bool",
nargs="?",
const=True,
default=False,
help="Dump a regular expression matching any hidden symbol")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -31,7 +31,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2
@ -42,13 +45,7 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.util import compat
tf.app.flags.DEFINE_string("output_dir", "/tmp/saved_model_half_plus_two",
"Directory where to ouput SavedModel.")
tf.app.flags.DEFINE_string("output_dir_pbtxt",
"/tmp/saved_model_half_plus_two_pbtxt",
"Directory where to ouput the text format of "
"SavedModel.")
FLAGS = tf.flags.FLAGS
FLAGS = None
def _write_assets(assets_directory, assets_filename):
@ -172,4 +169,16 @@ def main(_):
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.add_argument(
"--output_dir",
type=str,
default="/tmp/saved_model_half_plus_two",
help="Directory where to ouput SavedModel.")
parser.add_argument(
"--output_dir_pbtxt",
type=str,
default="/tmp/saved_model_half_plus_two_pbtxt",
help="Directory where to ouput the text format of SavedModel.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -37,6 +37,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
@ -45,37 +48,23 @@ from tensorflow.python.client import session
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
from tensorflow.python.training import saver as saver_lib
FLAGS = flags.FLAGS
flags.DEFINE_string("input_graph", "",
"""TensorFlow 'GraphDef' file to load.""")
flags.DEFINE_string("input_saver", "", """TensorFlow saver file to load.""")
flags.DEFINE_string("input_checkpoint", "",
"""TensorFlow variables file to load.""")
flags.DEFINE_string("output_graph", "", """Output 'GraphDef' file name.""")
flags.DEFINE_boolean("input_binary", False,
"""Whether the input files are in binary format.""")
flags.DEFINE_string("output_node_names", "",
"""The name of the output nodes, comma separated.""")
flags.DEFINE_string("restore_op_name", "save/restore_all",
"""The name of the master restore operator.""")
flags.DEFINE_string("filename_tensor_name", "save/Const:0",
"""The name of the tensor holding the save path.""")
flags.DEFINE_boolean("clear_devices", True,
"""Whether to remove device specifications.""")
flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
"initializer nodes to run before freezing.")
flags.DEFINE_string("variable_names_blacklist", "", "comma separated "
"list of variables to skip converting to constants ")
FLAGS = None
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
output_node_names, restore_op_name, filename_tensor_name,
output_graph, clear_devices, initializer_nodes):
def freeze_graph(input_graph,
input_saver,
input_binary,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_blacklist=""):
"""Converts all variables in a graph and checkpoint into constants."""
if not gfile.Exists(input_graph):
@ -124,8 +113,8 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
if initializer_nodes:
sess.run(initializer_nodes)
variable_names_blacklist = (FLAGS.variable_names_blacklist.split(",") if
FLAGS.variable_names_blacklist else None)
variable_names_blacklist = (variable_names_blacklist.split(",") if
variable_names_blacklist else None)
output_graph_def = graph_util.convert_variables_to_constants(
sess,
input_graph_def,
@ -141,8 +130,73 @@ def main(unused_args):
freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
FLAGS.input_checkpoint, FLAGS.output_node_names,
FLAGS.restore_op_name, FLAGS.filename_tensor_name,
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes)
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
FLAGS.variable_names_blacklist)
if __name__ == "__main__":
app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--input_graph",
type=str,
default="",
help="TensorFlow \'GraphDef\' file to load.")
parser.add_argument(
"--input_saver",
type=str,
default="",
help="TensorFlow saver file to load.")
parser.add_argument(
"--input_checkpoint",
type=str,
default="",
help="TensorFlow variables file to load.")
parser.add_argument(
"--output_graph",
type=str,
default="",
help="Output \'GraphDef\' file name.")
parser.add_argument(
"--input_binary",
nargs="?",
const=True,
type="bool",
default=False,
help="Whether the input files are in binary format.")
parser.add_argument(
"--output_node_names",
type=str,
default="",
help="The name of the output nodes, comma separated.")
parser.add_argument(
"--restore_op_name",
type=str,
default="save/restore_all",
help="The name of the master restore operator.")
parser.add_argument(
"--filename_tensor_name",
type=str,
default="save/Const:0",
help="The name of the tensor holding the save path.")
parser.add_argument(
"--clear_devices",
nargs="?",
const=True,
type="bool",
default=True,
help="Whether to remove device specifications.")
parser.add_argument(
"--initializer_nodes",
type=str,
default="",
help="comma separated list of initializer nodes to run before freezing.")
parser.add_argument(
"--variable_names_blacklist",
type=str,
default="",
help="""\
comma separated list of variables to skip converting to constants\
""")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,20 +17,16 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
FLAGS = flags.FLAGS
flags.DEFINE_string("file_name", "", "Checkpoint filename")
flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect")
flags.DEFINE_bool("all_tensors", "False",
"If True, print the values of all the tensors.")
FLAGS = None
def print_tensors_in_checkpoint_file(file_name, tensor_name):
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
"""Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes
@ -41,10 +37,11 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name):
Args:
file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print.
all_tensors: Boolean indicating whether to print all tensors.
"""
try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
if FLAGS.all_tensors:
if all_tensors:
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
print("tensor_name: ", key)
@ -67,8 +64,26 @@ def main(unused_argv):
"[--tensor_name=tensor_to_print]")
sys.exit(1)
else:
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name)
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
FLAGS.all_tensors)
if __name__ == "__main__":
app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--file_name", type=str, default="", help="Checkpoint filename")
parser.add_argument(
"--tensor_name",
type=str,
default="",
help="Name of the tensor to inspect")
parser.add_argument(
"--all_tensors",
nargs="?",
const=True,
type="bool",
default=False,
help="If True, print the values of all the tensors.")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -55,7 +55,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
from google.protobuf import text_format
@ -63,22 +65,10 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.platform import app
from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib
flags = flags_lib
FLAGS = flags.FLAGS
flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
flags.DEFINE_string("output", "", """File to save the output graph to.""")
flags.DEFINE_string("input_names", "", """Input node names, comma separated.""")
flags.DEFINE_string("output_names", "",
"""Output node names, comma separated.""")
flags.DEFINE_boolean("frozen_graph", True,
"""If true, the input graph is a binary frozen GraphDef
file; if false, it is a text GraphDef proto file.""")
flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum,
"""The AttrValue enum to use for placeholders.""")
FLAGS = None
def main(unused_args):
@ -110,4 +100,42 @@ def main(unused_args):
if __name__ == "__main__":
app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--input",
type=str,
default="",
help="TensorFlow \'GraphDef\' file to load.")
parser.add_argument(
"--output",
type=str,
default="",
help="File to save the output graph to.")
parser.add_argument(
"--input_names",
type=str,
default="",
help="Input node names, comma separated.")
parser.add_argument(
"--output_names",
type=str,
default="",
help="Output node names, comma separated.")
parser.add_argument(
"--frozen_graph",
nargs="?",
const=True,
type="bool",
default=True,
help="""\
If true, the input graph is a binary frozen GraphDef
file; if false, it is a text GraphDef proto file.\
""")
parser.add_argument(
"--placeholder_type_enum",
type=int,
default=dtypes.float32.as_datatype_enum,
help="The AttrValue enum to use for placeholders.")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -41,25 +41,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from tensorflow.python.tools import strip_unused_lib
FLAGS = flags.FLAGS
flags.DEFINE_string("input_graph", "",
"""TensorFlow 'GraphDef' file to load.""")
flags.DEFINE_boolean("input_binary", False,
"""Whether the input files are in binary format.""")
flags.DEFINE_string("output_graph", "", """Output 'GraphDef' file name.""")
flags.DEFINE_boolean("output_binary", True,
"""Whether to write a binary format graph.""")
flags.DEFINE_string("input_node_names", "",
"""The name of the input nodes, comma separated.""")
flags.DEFINE_string("output_node_names", "",
"""The name of the output nodes, comma separated.""")
flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum,
"""The AttrValue enum to use for placeholders.""")
FLAGS = None
def main(unused_args):
@ -72,5 +61,47 @@ def main(unused_args):
FLAGS.placeholder_type_enum)
if __name__ == "__main__":
app.run()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument(
'--input_graph',
type=str,
default='',
help='TensorFlow \'GraphDef\' file to load.')
parser.add_argument(
'--input_binary',
nargs='?',
const=True,
type='bool',
default=False,
help='Whether the input files are in binary format.')
parser.add_argument(
'--output_graph',
type=str,
default='',
help='Output \'GraphDef\' file name.')
parser.add_argument(
'--output_binary',
nargs='?',
const=True,
type='bool',
default=True,
help='Whether to write a binary format graph.')
parser.add_argument(
'--input_node_names',
type=str,
default='',
help='The name of the input nodes, comma separated.')
parser.add_argument(
'--output_node_names',
type=str,
default='',
help='The name of the output nodes, comma separated.')
parser.add_argument(
'--placeholder_type_enum',
type=int,
default=dtypes.float32.as_datatype_enum,
help='The AttrValue enum to use for placeholders.')
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)