Convert tf.flags usage to argparse. Move use of FLAGS globals into main() only.

Change: 143799731
This commit is contained in:
Vijay Vasudevan 2017-01-06 12:07:23 -08:00 committed by TensorFlower Gardener
parent 4b3d59a771
commit 2b351f224d
12 changed files with 414 additions and 157 deletions

View File

@ -31,27 +31,21 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import socket import socket
import sys import sys
from tensorflow.python.platform import app
# pylint: disable=g-import-not-at-top # pylint: disable=g-import-not-at-top
# Official recommended way of turning on fast protocol buffers as of 10/21/14 # Official recommended way of turning on fast protocol buffers as of 10/21/14
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2" os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2"
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
FLAGS = flags.FLAGS FLAGS = None
flags.DEFINE_string(
"password", None,
"Password to require. If set, the server will allow public access."
" Only used if notebook config file does not exist.")
flags.DEFINE_string("notebook_dir", "experimental/brain/notebooks",
"root location where to store notebooks")
ORIG_ARGV = sys.argv ORIG_ARGV = sys.argv
# Main notebook process calls itself with argv[1]="kernel" to start kernel # Main notebook process calls itself with argv[1]="kernel" to start kernel
@ -108,6 +102,21 @@ def main(unused_argv):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--password",
type=str,
default=None,
help="""\
Password to require. If set, the server will allow public access. Only
used if notebook config file does not exist.\
""")
parser.add_argument(
"--notebook_dir",
type=str,
default="experimental/brain/notebooks",
help="root location where to store notebooks")
# When the user starts the main notebook process, we don't touch sys.argv. # When the user starts the main notebook process, we don't touch sys.argv.
# When the main process launches kernel subprocesses, it writes all flags # When the main process launches kernel subprocesses, it writes all flags
# to a tmpfile and sets --flagfile to that tmpfile, so for kernel # to a tmpfile and sets --flagfile to that tmpfile, so for kernel
@ -118,4 +127,6 @@ if __name__ == "__main__":
# Drop everything except --flagfile. # Drop everything except --flagfile.
sys.argv = ([sys.argv[0]] + sys.argv = ([sys.argv[0]] +
[x for x in sys.argv[1:] if x.startswith("--flagfile")]) [x for x in sys.argv[1:] if x.startswith("--flagfile")])
app.run()
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,23 +17,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys import sys
# Google-internal import(s). # Google-internal import(s).
from tensorflow.python.debug import debug_data from tensorflow.python.debug import debug_data
from tensorflow.python.debug.cli import analyzer_cli from tensorflow.python.debug.cli import analyzer_cli
from tensorflow.python.platform import app from tensorflow.python.platform import app
from tensorflow.python.platform import flags
FLAGS = flags.FLAGS
flags.DEFINE_string("dump_dir", "", "tfdbg dump directory to load")
flags.DEFINE_string("ui_type", "curses",
"Command-line user interface type (curses | readline)")
flags.DEFINE_boolean(
"log_usage", True, "Whether the usage of this tool is to be logged")
flags.DEFINE_boolean(
"validate_graph", True,
"Whether the dumped tensors will be validated against the GraphDefs")
def main(_): def main(_):
@ -58,4 +49,30 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--dump_dir", type=str, default="", help="tfdbg dump directory to load")
parser.add_argument(
"--log_usage",
type="bool",
nargs="?",
const=True,
default=True,
help="Whether the usage of this tool is to be logged")
parser.add_argument(
"--ui_type",
type=str,
default="curses"
help="Command-line user interface type (curses | readline)")
parser.add_argument(
"--validate_graph",
nargs="?",
const=True,
type="bool",
default=True,
help="""\
Whether the dumped tensors will be validated against the GraphDefs\
""")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,20 +17,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.python import debug as tf_debug from tensorflow.python import debug as tf_debug
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("error", "shape_mismatch", "Type of the error to generate "
"(shape_mismatch | uninitialized_variable | no_error).")
flags.DEFINE_string("ui_type", "curses",
"Command-line user interface type (curses | readline)")
flags.DEFINE_boolean("debug", False,
"Use debugger to track down bad values during training")
def main(_): def main(_):
sess = tf.Session() sess = tf.Session()
@ -60,4 +54,27 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--error",
type=str,
default="shape_mismatch",
help="""\
Type of the error to generate (shape_mismatch | uninitialized_variable |
no_error).\
""")
parser.add_argument(
"--ui_type",
type=str,
default="curses"
help="Command-line user interface type (curses | readline)")
parser.add_argument(
"--debug",
type="bool",
nargs="?",
const=True,
default=False,
help="Use debugger to track down bad values during training")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,19 +17,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf import tensorflow as tf
from tensorflow.python import debug as tf_debug from tensorflow.python import debug as tf_debug
flags = tf.app.flags FLAGS = None
FLAGS = flags.FLAGS
flags.DEFINE_integer("tensor_size", 30,
"Size of tensor. E.g., if the value is 30, the tensors "
"will have shape [30, 30].")
flags.DEFINE_integer("length", 20,
"Length of the fibonacci sequence to compute.")
def main(_): def main(_):
@ -54,4 +51,20 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--tensor_size",
type=int,
default=30,
help="""\
Size of tensor. E.g., if the value is 30, the tensors will have shape
[30, 30].\
""")
parser.add_argument(
"--length",
type=int,
default=20,
help="Length of the fibonacci sequence to compute.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -24,22 +24,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys
import tensorflow as tf import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python import debug as tf_debug from tensorflow.python import debug as tf_debug
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer("max_steps", 10, "Number of steps to run trainer.")
flags.DEFINE_integer("train_batch_size", 100,
"Batch size used during training.")
flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.")
flags.DEFINE_string("data_dir", "/tmp/mnist_data", "Directory for storing data")
flags.DEFINE_string("ui_type", "curses",
"Command-line user interface type (curses | readline)")
flags.DEFINE_boolean("debug", False,
"Use debugger to track down bad values during training")
IMAGE_SIZE = 28 IMAGE_SIZE = 28
HIDDEN_SIZE = 500 HIDDEN_SIZE = 500
@ -137,4 +129,39 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--max_steps",
type=int,
default=10,
help="Number of steps to run trainer.")
parser.add_argument(
"--train_batch_size",
type=int,
default=100,
help="Batch size used during training.")
parser.add_argument(
"--learning_rate",
type=float,
default=0.025,
help="Initial learning rate.")
parser.add_argument(
"--data_dir",
type=str,
default="/tmp/mnist_data",
help="Directory for storing data")
parser.add_argument(
"--ui_type",
type=str,
default="curses"
help="Command-line user interface type (curses | readline)")
parser.add_argument(
"--debug",
type="bool",
nargs="?",
const=True,
default=False,
help="Use debugger to track down bad values during training")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,7 +17,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import sys
import tempfile import tempfile
import numpy as np import numpy as np
@ -26,33 +28,26 @@ import tensorflow as tf
from tensorflow.python import debug as tf_debug from tensorflow.python import debug as tf_debug
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string("data_dir", "/tmp/iris_data",
"Directory to save the training and test data in.")
flags.DEFINE_string("model_dir", "", "Directory to save the trained model in.")
flags.DEFINE_integer("train_steps", 10, "Number of steps to run trainer.")
flags.DEFINE_string("ui_type", "curses",
"Command-line user interface type (curses | readline)")
flags.DEFINE_boolean("debug", False,
"Use debugger to track down bad values during training")
# URLs to download data sets from, if necessary. # URLs to download data sets from, if necessary.
IRIS_TRAINING_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_training.csv" IRIS_TRAINING_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_training.csv"
IRIS_TEST_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_test.csv" IRIS_TEST_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_test.csv"
def maybe_download_data(): def maybe_download_data(data_dir):
"""Download data sets if necessary. """Download data sets if necessary.
Args:
data_dir: Path to where data should be downloaded.
Returns: Returns:
Paths to the training and test data files. Paths to the training and test data files.
""" """
if not os.path.isdir(FLAGS.data_dir): if not os.path.isdir(data_dir):
os.makedirs(FLAGS.data_dir) os.makedirs(data_dir)
training_data_path = os.path.join(FLAGS.data_dir, training_data_path = os.path.join(data_dir,
os.path.basename(IRIS_TRAINING_DATA_URL)) os.path.basename(IRIS_TRAINING_DATA_URL))
if not os.path.isfile(training_data_path): if not os.path.isfile(training_data_path):
train_file = open(training_data_path, "wt") train_file = open(training_data_path, "wt")
@ -61,8 +56,7 @@ def maybe_download_data():
print("Training data are downloaded to %s" % train_file.name) print("Training data are downloaded to %s" % train_file.name)
test_data_path = os.path.join(FLAGS.data_dir, test_data_path = os.path.join(data_dir, os.path.basename(IRIS_TEST_DATA_URL))
os.path.basename(IRIS_TEST_DATA_URL))
if not os.path.isfile(test_data_path): if not os.path.isfile(test_data_path):
test_file = open(test_data_path, "wt") test_file = open(test_data_path, "wt")
urllib.request.urlretrieve(IRIS_TEST_DATA_URL, test_file.name) urllib.request.urlretrieve(IRIS_TEST_DATA_URL, test_file.name)
@ -74,7 +68,7 @@ def maybe_download_data():
def main(_): def main(_):
training_data_path, test_data_path = maybe_download_data() training_data_path, test_data_path = maybe_download_data(FLAGS.data_dir)
# Load datasets. # Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_with_header( training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
@ -115,4 +109,34 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--data_dir",
type=str,
default="/tmp/iris_data",
help="Directory to save the training and test data in.")
parser.add_argument(
"--model_dir",
type=str,
default="",
help="Directory to save the trained model in.")
parser.add_argument(
"--train_steps",
type=int,
default=10,
help="Number of steps to run trainer.")
parser.add_argument(
"--ui_type",
type=str,
default="curses"
help="Command-line user interface type (curses | readline)")
parser.add_argument(
"--debug",
type="bool",
nargs="?",
const=True,
default=False,
help="Use debugger to track down bad values during training")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import collections import collections
import os.path import os.path
import sys import sys
@ -31,12 +32,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import docs from tensorflow.python.framework import docs
from tensorflow.python.framework import framework_lib from tensorflow.python.framework import framework_lib
FLAGS = None
tf.flags.DEFINE_string("out_dir", None,
"Directory to which docs should be written.")
tf.flags.DEFINE_boolean("print_hidden_regex", False,
"Dump a regular expression matching any hidden symbol")
FLAGS = tf.flags.FLAGS
PREFIX_TEXT = """ PREFIX_TEXT = """
@ -309,4 +305,19 @@ def main(unused_argv):
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--out_dir",
type=str,
default=None,
help="Directory to which docs should be written.")
parser.add_argument(
"--print_hidden_regex",
type="bool",
nargs="?",
const=True,
default=False,
help="Dump a regular expression matching any hidden symbol")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -31,7 +31,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import sys
import tensorflow as tf import tensorflow as tf
from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2
@ -42,13 +45,7 @@ from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants from tensorflow.python.saved_model import tag_constants
from tensorflow.python.util import compat from tensorflow.python.util import compat
tf.app.flags.DEFINE_string("output_dir", "/tmp/saved_model_half_plus_two", FLAGS = None
"Directory where to ouput SavedModel.")
tf.app.flags.DEFINE_string("output_dir_pbtxt",
"/tmp/saved_model_half_plus_two_pbtxt",
"Directory where to ouput the text format of "
"SavedModel.")
FLAGS = tf.flags.FLAGS
def _write_assets(assets_directory, assets_filename): def _write_assets(assets_directory, assets_filename):
@ -172,4 +169,16 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
tf.app.run() parser = argparse.ArgumentParser()
parser.add_argument(
"--output_dir",
type=str,
default="/tmp/saved_model_half_plus_two",
help="Directory where to ouput SavedModel.")
parser.add_argument(
"--output_dir_pbtxt",
type=str,
default="/tmp/saved_model_half_plus_two_pbtxt",
help="Directory where to ouput the text format of SavedModel.")
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -37,6 +37,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys
from google.protobuf import text_format from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
@ -45,37 +48,23 @@ from tensorflow.python.client import session
from tensorflow.python.framework import graph_util from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer from tensorflow.python.framework import importer
from tensorflow.python.platform import app from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
FLAGS = flags.FLAGS FLAGS = None
flags.DEFINE_string("input_graph", "",
"""TensorFlow 'GraphDef' file to load.""")
flags.DEFINE_string("input_saver", "", """TensorFlow saver file to load.""")
flags.DEFINE_string("input_checkpoint", "",
"""TensorFlow variables file to load.""")
flags.DEFINE_string("output_graph", "", """Output 'GraphDef' file name.""")
flags.DEFINE_boolean("input_binary", False,
"""Whether the input files are in binary format.""")
flags.DEFINE_string("output_node_names", "",
"""The name of the output nodes, comma separated.""")
flags.DEFINE_string("restore_op_name", "save/restore_all",
"""The name of the master restore operator.""")
flags.DEFINE_string("filename_tensor_name", "save/Const:0",
"""The name of the tensor holding the save path.""")
flags.DEFINE_boolean("clear_devices", True,
"""Whether to remove device specifications.""")
flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
"initializer nodes to run before freezing.")
flags.DEFINE_string("variable_names_blacklist", "", "comma separated "
"list of variables to skip converting to constants ")
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint, def freeze_graph(input_graph,
output_node_names, restore_op_name, filename_tensor_name, input_saver,
output_graph, clear_devices, initializer_nodes): input_binary,
input_checkpoint,
output_node_names,
restore_op_name,
filename_tensor_name,
output_graph,
clear_devices,
initializer_nodes,
variable_names_blacklist=""):
"""Converts all variables in a graph and checkpoint into constants.""" """Converts all variables in a graph and checkpoint into constants."""
if not gfile.Exists(input_graph): if not gfile.Exists(input_graph):
@ -124,8 +113,8 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
if initializer_nodes: if initializer_nodes:
sess.run(initializer_nodes) sess.run(initializer_nodes)
variable_names_blacklist = (FLAGS.variable_names_blacklist.split(",") if variable_names_blacklist = (variable_names_blacklist.split(",") if
FLAGS.variable_names_blacklist else None) variable_names_blacklist else None)
output_graph_def = graph_util.convert_variables_to_constants( output_graph_def = graph_util.convert_variables_to_constants(
sess, sess,
input_graph_def, input_graph_def,
@ -141,8 +130,73 @@ def main(unused_args):
freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary, freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
FLAGS.input_checkpoint, FLAGS.output_node_names, FLAGS.input_checkpoint, FLAGS.output_node_names,
FLAGS.restore_op_name, FLAGS.filename_tensor_name, FLAGS.restore_op_name, FLAGS.filename_tensor_name,
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes) FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
FLAGS.variable_names_blacklist)
if __name__ == "__main__": if __name__ == "__main__":
app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--input_graph",
type=str,
default="",
help="TensorFlow \'GraphDef\' file to load.")
parser.add_argument(
"--input_saver",
type=str,
default="",
help="TensorFlow saver file to load.")
parser.add_argument(
"--input_checkpoint",
type=str,
default="",
help="TensorFlow variables file to load.")
parser.add_argument(
"--output_graph",
type=str,
default="",
help="Output \'GraphDef\' file name.")
parser.add_argument(
"--input_binary",
nargs="?",
const=True,
type="bool",
default=False,
help="Whether the input files are in binary format.")
parser.add_argument(
"--output_node_names",
type=str,
default="",
help="The name of the output nodes, comma separated.")
parser.add_argument(
"--restore_op_name",
type=str,
default="save/restore_all",
help="The name of the master restore operator.")
parser.add_argument(
"--filename_tensor_name",
type=str,
default="save/Const:0",
help="The name of the tensor holding the save path.")
parser.add_argument(
"--clear_devices",
nargs="?",
const=True,
type="bool",
default=True,
help="Whether to remove device specifications.")
parser.add_argument(
"--initializer_nodes",
type=str,
default="",
help="comma separated list of initializer nodes to run before freezing.")
parser.add_argument(
"--variable_names_blacklist",
type=str,
default="",
help="""\
comma separated list of variables to skip converting to constants\
""")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -17,20 +17,16 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys import sys
from tensorflow.python import pywrap_tensorflow from tensorflow.python import pywrap_tensorflow
from tensorflow.python.platform import app from tensorflow.python.platform import app
from tensorflow.python.platform import flags
FLAGS = flags.FLAGS FLAGS = None
flags.DEFINE_string("file_name", "", "Checkpoint filename")
flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect")
flags.DEFINE_bool("all_tensors", "False",
"If True, print the values of all the tensors.")
def print_tensors_in_checkpoint_file(file_name, tensor_name): def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
"""Prints tensors in a checkpoint file. """Prints tensors in a checkpoint file.
If no `tensor_name` is provided, prints the tensor names and shapes If no `tensor_name` is provided, prints the tensor names and shapes
@ -41,10 +37,11 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name):
Args: Args:
file_name: Name of the checkpoint file. file_name: Name of the checkpoint file.
tensor_name: Name of the tensor in the checkpoint file to print. tensor_name: Name of the tensor in the checkpoint file to print.
all_tensors: Boolean indicating whether to print all tensors.
""" """
try: try:
reader = pywrap_tensorflow.NewCheckpointReader(file_name) reader = pywrap_tensorflow.NewCheckpointReader(file_name)
if FLAGS.all_tensors: if all_tensors:
var_to_shape_map = reader.get_variable_to_shape_map() var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map: for key in var_to_shape_map:
print("tensor_name: ", key) print("tensor_name: ", key)
@ -67,8 +64,26 @@ def main(unused_argv):
"[--tensor_name=tensor_to_print]") "[--tensor_name=tensor_to_print]")
sys.exit(1) sys.exit(1)
else: else:
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name) print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
FLAGS.all_tensors)
if __name__ == "__main__": if __name__ == "__main__":
app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--file_name", type=str, default="", help="Checkpoint filename")
parser.add_argument(
"--tensor_name",
type=str,
default="",
help="Name of the tensor to inspect")
parser.add_argument(
"--all_tensors",
nargs="?",
const=True,
type="bool",
default=False,
help="If True, print the values of all the tensors.")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -55,7 +55,9 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import os import os
import sys
from google.protobuf import text_format from google.protobuf import text_format
@ -63,22 +65,10 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io from tensorflow.python.framework import graph_io
from tensorflow.python.platform import app from tensorflow.python.platform import app
from tensorflow.python.platform import flags as flags_lib
from tensorflow.python.platform import gfile from tensorflow.python.platform import gfile
from tensorflow.python.tools import optimize_for_inference_lib from tensorflow.python.tools import optimize_for_inference_lib
flags = flags_lib FLAGS = None
FLAGS = flags.FLAGS
flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
flags.DEFINE_string("output", "", """File to save the output graph to.""")
flags.DEFINE_string("input_names", "", """Input node names, comma separated.""")
flags.DEFINE_string("output_names", "",
"""Output node names, comma separated.""")
flags.DEFINE_boolean("frozen_graph", True,
"""If true, the input graph is a binary frozen GraphDef
file; if false, it is a text GraphDef proto file.""")
flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum,
"""The AttrValue enum to use for placeholders.""")
def main(unused_args): def main(unused_args):
@ -110,4 +100,42 @@ def main(unused_args):
if __name__ == "__main__": if __name__ == "__main__":
app.run() parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--input",
type=str,
default="",
help="TensorFlow \'GraphDef\' file to load.")
parser.add_argument(
"--output",
type=str,
default="",
help="File to save the output graph to.")
parser.add_argument(
"--input_names",
type=str,
default="",
help="Input node names, comma separated.")
parser.add_argument(
"--output_names",
type=str,
default="",
help="Output node names, comma separated.")
parser.add_argument(
"--frozen_graph",
nargs="?",
const=True,
type="bool",
default=True,
help="""\
If true, the input graph is a binary frozen GraphDef
file; if false, it is a text GraphDef proto file.\
""")
parser.add_argument(
"--placeholder_type_enum",
type=int,
default=dtypes.float32.as_datatype_enum,
help="The AttrValue enum to use for placeholders.")
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@ -41,25 +41,14 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import argparse
import sys
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.platform import app from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from tensorflow.python.tools import strip_unused_lib from tensorflow.python.tools import strip_unused_lib
FLAGS = flags.FLAGS FLAGS = None
flags.DEFINE_string("input_graph", "",
"""TensorFlow 'GraphDef' file to load.""")
flags.DEFINE_boolean("input_binary", False,
"""Whether the input files are in binary format.""")
flags.DEFINE_string("output_graph", "", """Output 'GraphDef' file name.""")
flags.DEFINE_boolean("output_binary", True,
"""Whether to write a binary format graph.""")
flags.DEFINE_string("input_node_names", "",
"""The name of the input nodes, comma separated.""")
flags.DEFINE_string("output_node_names", "",
"""The name of the output nodes, comma separated.""")
flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum,
"""The AttrValue enum to use for placeholders.""")
def main(unused_args): def main(unused_args):
@ -72,5 +61,47 @@ def main(unused_args):
FLAGS.placeholder_type_enum) FLAGS.placeholder_type_enum)
if __name__ == "__main__": if __name__ == '__main__':
app.run() parser = argparse.ArgumentParser()
parser.register('type', 'bool', lambda v: v.lower() == 'true')
parser.add_argument(
'--input_graph',
type=str,
default='',
help='TensorFlow \'GraphDef\' file to load.')
parser.add_argument(
'--input_binary',
nargs='?',
const=True,
type='bool',
default=False,
help='Whether the input files are in binary format.')
parser.add_argument(
'--output_graph',
type=str,
default='',
help='Output \'GraphDef\' file name.')
parser.add_argument(
'--output_binary',
nargs='?',
const=True,
type='bool',
default=True,
help='Whether to write a binary format graph.')
parser.add_argument(
'--input_node_names',
type=str,
default='',
help='The name of the input nodes, comma separated.')
parser.add_argument(
'--output_node_names',
type=str,
default='',
help='The name of the output nodes, comma separated.')
parser.add_argument(
'--placeholder_type_enum',
type=int,
default=dtypes.float32.as_datatype_enum,
help='The AttrValue enum to use for placeholders.')
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)