Convert tf.flags usage to argparse. Move use of FLAGS globals into main() only.
Change: 143799731
This commit is contained in:
parent
4b3d59a771
commit
2b351f224d
tensorflow/python
@ -31,27 +31,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
|
||||
from tensorflow.python.platform import app
|
||||
|
||||
# pylint: disable=g-import-not-at-top
|
||||
# Official recommended way of turning on fast protocol buffers as of 10/21/14
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "cpp"
|
||||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION"] = "2"
|
||||
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
FLAGS = None
|
||||
|
||||
flags.DEFINE_string(
|
||||
"password", None,
|
||||
"Password to require. If set, the server will allow public access."
|
||||
" Only used if notebook config file does not exist.")
|
||||
|
||||
flags.DEFINE_string("notebook_dir", "experimental/brain/notebooks",
|
||||
"root location where to store notebooks")
|
||||
|
||||
ORIG_ARGV = sys.argv
|
||||
# Main notebook process calls itself with argv[1]="kernel" to start kernel
|
||||
@ -108,6 +102,21 @@ def main(unused_argv):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--password",
|
||||
type=str,
|
||||
default=None,
|
||||
help="""\
|
||||
Password to require. If set, the server will allow public access. Only
|
||||
used if notebook config file does not exist.\
|
||||
""")
|
||||
parser.add_argument(
|
||||
"--notebook_dir",
|
||||
type=str,
|
||||
default="experimental/brain/notebooks",
|
||||
help="root location where to store notebooks")
|
||||
|
||||
# When the user starts the main notebook process, we don't touch sys.argv.
|
||||
# When the main process launches kernel subprocesses, it writes all flags
|
||||
# to a tmpfile and sets --flagfile to that tmpfile, so for kernel
|
||||
@ -118,4 +127,6 @@ if __name__ == "__main__":
|
||||
# Drop everything except --flagfile.
|
||||
sys.argv = ([sys.argv[0]] +
|
||||
[x for x in sys.argv[1:] if x.startswith("--flagfile")])
|
||||
app.run()
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -17,23 +17,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
# Google-internal import(s).
|
||||
|
||||
from tensorflow.python.debug import debug_data
|
||||
from tensorflow.python.debug.cli import analyzer_cli
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("dump_dir", "", "tfdbg dump directory to load")
|
||||
flags.DEFINE_string("ui_type", "curses",
|
||||
"Command-line user interface type (curses | readline)")
|
||||
flags.DEFINE_boolean(
|
||||
"log_usage", True, "Whether the usage of this tool is to be logged")
|
||||
flags.DEFINE_boolean(
|
||||
"validate_graph", True,
|
||||
"Whether the dumped tensors will be validated against the GraphDefs")
|
||||
|
||||
|
||||
def main(_):
|
||||
@ -58,4 +49,30 @@ def main(_):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--dump_dir", type=str, default="", help="tfdbg dump directory to load")
|
||||
parser.add_argument(
|
||||
"--log_usage",
|
||||
type="bool",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=True,
|
||||
help="Whether the usage of this tool is to be logged")
|
||||
parser.add_argument(
|
||||
"--ui_type",
|
||||
type=str,
|
||||
default="curses"
|
||||
help="Command-line user interface type (curses | readline)")
|
||||
parser.add_argument(
|
||||
"--validate_graph",
|
||||
nargs="?",
|
||||
const=True,
|
||||
type="bool",
|
||||
default=True,
|
||||
help="""\
|
||||
Whether the dumped tensors will be validated against the GraphDefs\
|
||||
""")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -17,20 +17,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("error", "shape_mismatch", "Type of the error to generate "
|
||||
"(shape_mismatch | uninitialized_variable | no_error).")
|
||||
flags.DEFINE_string("ui_type", "curses",
|
||||
"Command-line user interface type (curses | readline)")
|
||||
flags.DEFINE_boolean("debug", False,
|
||||
"Use debugger to track down bad values during training")
|
||||
|
||||
|
||||
def main(_):
|
||||
sess = tf.Session()
|
||||
@ -60,4 +54,27 @@ def main(_):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--error",
|
||||
type=str,
|
||||
default="shape_mismatch",
|
||||
help="""\
|
||||
Type of the error to generate (shape_mismatch | uninitialized_variable |
|
||||
no_error).\
|
||||
""")
|
||||
parser.add_argument(
|
||||
"--ui_type",
|
||||
type=str,
|
||||
default="curses"
|
||||
help="Command-line user interface type (curses | readline)")
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type="bool",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="Use debugger to track down bad values during training")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -17,19 +17,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("tensor_size", 30,
|
||||
"Size of tensor. E.g., if the value is 30, the tensors "
|
||||
"will have shape [30, 30].")
|
||||
flags.DEFINE_integer("length", 20,
|
||||
"Length of the fibonacci sequence to compute.")
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def main(_):
|
||||
@ -54,4 +51,20 @@ def main(_):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--tensor_size",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""\
|
||||
Size of tensor. E.g., if the value is 30, the tensors will have shape
|
||||
[30, 30].\
|
||||
""")
|
||||
parser.add_argument(
|
||||
"--length",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Length of the fibonacci sequence to compute.")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -24,22 +24,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_integer("max_steps", 10, "Number of steps to run trainer.")
|
||||
flags.DEFINE_integer("train_batch_size", 100,
|
||||
"Batch size used during training.")
|
||||
flags.DEFINE_float("learning_rate", 0.025, "Initial learning rate.")
|
||||
flags.DEFINE_string("data_dir", "/tmp/mnist_data", "Directory for storing data")
|
||||
flags.DEFINE_string("ui_type", "curses",
|
||||
"Command-line user interface type (curses | readline)")
|
||||
flags.DEFINE_boolean("debug", False,
|
||||
"Use debugger to track down bad values during training")
|
||||
|
||||
IMAGE_SIZE = 28
|
||||
HIDDEN_SIZE = 500
|
||||
@ -137,4 +129,39 @@ def main(_):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of steps to run trainer.")
|
||||
parser.add_argument(
|
||||
"--train_batch_size",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Batch size used during training.")
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=0.025,
|
||||
help="Initial learning rate.")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default="/tmp/mnist_data",
|
||||
help="Directory for storing data")
|
||||
parser.add_argument(
|
||||
"--ui_type",
|
||||
type=str,
|
||||
default="curses"
|
||||
help="Command-line user interface type (curses | readline)")
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type="bool",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="Use debugger to track down bad values during training")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -17,7 +17,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
@ -26,33 +28,26 @@ import tensorflow as tf
|
||||
|
||||
from tensorflow.python import debug as tf_debug
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("data_dir", "/tmp/iris_data",
|
||||
"Directory to save the training and test data in.")
|
||||
flags.DEFINE_string("model_dir", "", "Directory to save the trained model in.")
|
||||
flags.DEFINE_integer("train_steps", 10, "Number of steps to run trainer.")
|
||||
flags.DEFINE_string("ui_type", "curses",
|
||||
"Command-line user interface type (curses | readline)")
|
||||
flags.DEFINE_boolean("debug", False,
|
||||
"Use debugger to track down bad values during training")
|
||||
|
||||
# URLs to download data sets from, if necessary.
|
||||
IRIS_TRAINING_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_training.csv"
|
||||
IRIS_TEST_DATA_URL = "https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/tutorials/monitors/iris_test.csv"
|
||||
|
||||
|
||||
def maybe_download_data():
|
||||
def maybe_download_data(data_dir):
|
||||
"""Download data sets if necessary.
|
||||
|
||||
Args:
|
||||
data_dir: Path to where data should be downloaded.
|
||||
|
||||
Returns:
|
||||
Paths to the training and test data files.
|
||||
"""
|
||||
|
||||
if not os.path.isdir(FLAGS.data_dir):
|
||||
os.makedirs(FLAGS.data_dir)
|
||||
if not os.path.isdir(data_dir):
|
||||
os.makedirs(data_dir)
|
||||
|
||||
training_data_path = os.path.join(FLAGS.data_dir,
|
||||
training_data_path = os.path.join(data_dir,
|
||||
os.path.basename(IRIS_TRAINING_DATA_URL))
|
||||
if not os.path.isfile(training_data_path):
|
||||
train_file = open(training_data_path, "wt")
|
||||
@ -61,8 +56,7 @@ def maybe_download_data():
|
||||
|
||||
print("Training data are downloaded to %s" % train_file.name)
|
||||
|
||||
test_data_path = os.path.join(FLAGS.data_dir,
|
||||
os.path.basename(IRIS_TEST_DATA_URL))
|
||||
test_data_path = os.path.join(data_dir, os.path.basename(IRIS_TEST_DATA_URL))
|
||||
if not os.path.isfile(test_data_path):
|
||||
test_file = open(test_data_path, "wt")
|
||||
urllib.request.urlretrieve(IRIS_TEST_DATA_URL, test_file.name)
|
||||
@ -74,7 +68,7 @@ def maybe_download_data():
|
||||
|
||||
|
||||
def main(_):
|
||||
training_data_path, test_data_path = maybe_download_data()
|
||||
training_data_path, test_data_path = maybe_download_data(FLAGS.data_dir)
|
||||
|
||||
# Load datasets.
|
||||
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
|
||||
@ -115,4 +109,34 @@ def main(_):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--data_dir",
|
||||
type=str,
|
||||
default="/tmp/iris_data",
|
||||
help="Directory to save the training and test data in.")
|
||||
parser.add_argument(
|
||||
"--model_dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="Directory to save the trained model in.")
|
||||
parser.add_argument(
|
||||
"--train_steps",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of steps to run trainer.")
|
||||
parser.add_argument(
|
||||
"--ui_type",
|
||||
type=str,
|
||||
default="curses"
|
||||
help="Command-line user interface type (curses | readline)")
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type="bool",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="Use debugger to track down bad values during training")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import os.path
|
||||
import sys
|
||||
@ -31,12 +32,7 @@ from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import docs
|
||||
from tensorflow.python.framework import framework_lib
|
||||
|
||||
|
||||
tf.flags.DEFINE_string("out_dir", None,
|
||||
"Directory to which docs should be written.")
|
||||
tf.flags.DEFINE_boolean("print_hidden_regex", False,
|
||||
"Dump a regular expression matching any hidden symbol")
|
||||
FLAGS = tf.flags.FLAGS
|
||||
FLAGS = None
|
||||
|
||||
|
||||
PREFIX_TEXT = """
|
||||
@ -309,4 +305,19 @@ def main(unused_argv):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--out_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Directory to which docs should be written.")
|
||||
parser.add_argument(
|
||||
"--print_hidden_regex",
|
||||
type="bool",
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=False,
|
||||
help="Dump a regular expression matching any hidden symbol")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -31,7 +31,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
@ -42,13 +45,7 @@ from tensorflow.python.saved_model import signature_def_utils
|
||||
from tensorflow.python.saved_model import tag_constants
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
tf.app.flags.DEFINE_string("output_dir", "/tmp/saved_model_half_plus_two",
|
||||
"Directory where to ouput SavedModel.")
|
||||
tf.app.flags.DEFINE_string("output_dir_pbtxt",
|
||||
"/tmp/saved_model_half_plus_two_pbtxt",
|
||||
"Directory where to ouput the text format of "
|
||||
"SavedModel.")
|
||||
FLAGS = tf.flags.FLAGS
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def _write_assets(assets_directory, assets_filename):
|
||||
@ -172,4 +169,16 @@ def main(_):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="/tmp/saved_model_half_plus_two",
|
||||
help="Directory where to ouput SavedModel.")
|
||||
parser.add_argument(
|
||||
"--output_dir_pbtxt",
|
||||
type=str,
|
||||
default="/tmp/saved_model_half_plus_two_pbtxt",
|
||||
help="Directory where to ouput the text format of SavedModel.")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -37,6 +37,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
@ -45,37 +48,23 @@ from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string("input_graph", "",
|
||||
"""TensorFlow 'GraphDef' file to load.""")
|
||||
flags.DEFINE_string("input_saver", "", """TensorFlow saver file to load.""")
|
||||
flags.DEFINE_string("input_checkpoint", "",
|
||||
"""TensorFlow variables file to load.""")
|
||||
flags.DEFINE_string("output_graph", "", """Output 'GraphDef' file name.""")
|
||||
flags.DEFINE_boolean("input_binary", False,
|
||||
"""Whether the input files are in binary format.""")
|
||||
flags.DEFINE_string("output_node_names", "",
|
||||
"""The name of the output nodes, comma separated.""")
|
||||
flags.DEFINE_string("restore_op_name", "save/restore_all",
|
||||
"""The name of the master restore operator.""")
|
||||
flags.DEFINE_string("filename_tensor_name", "save/Const:0",
|
||||
"""The name of the tensor holding the save path.""")
|
||||
flags.DEFINE_boolean("clear_devices", True,
|
||||
"""Whether to remove device specifications.""")
|
||||
flags.DEFINE_string("initializer_nodes", "", "comma separated list of "
|
||||
"initializer nodes to run before freezing.")
|
||||
flags.DEFINE_string("variable_names_blacklist", "", "comma separated "
|
||||
"list of variables to skip converting to constants ")
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
|
||||
output_node_names, restore_op_name, filename_tensor_name,
|
||||
output_graph, clear_devices, initializer_nodes):
|
||||
def freeze_graph(input_graph,
|
||||
input_saver,
|
||||
input_binary,
|
||||
input_checkpoint,
|
||||
output_node_names,
|
||||
restore_op_name,
|
||||
filename_tensor_name,
|
||||
output_graph,
|
||||
clear_devices,
|
||||
initializer_nodes,
|
||||
variable_names_blacklist=""):
|
||||
"""Converts all variables in a graph and checkpoint into constants."""
|
||||
|
||||
if not gfile.Exists(input_graph):
|
||||
@ -124,8 +113,8 @@ def freeze_graph(input_graph, input_saver, input_binary, input_checkpoint,
|
||||
if initializer_nodes:
|
||||
sess.run(initializer_nodes)
|
||||
|
||||
variable_names_blacklist = (FLAGS.variable_names_blacklist.split(",") if
|
||||
FLAGS.variable_names_blacklist else None)
|
||||
variable_names_blacklist = (variable_names_blacklist.split(",") if
|
||||
variable_names_blacklist else None)
|
||||
output_graph_def = graph_util.convert_variables_to_constants(
|
||||
sess,
|
||||
input_graph_def,
|
||||
@ -141,8 +130,73 @@ def main(unused_args):
|
||||
freeze_graph(FLAGS.input_graph, FLAGS.input_saver, FLAGS.input_binary,
|
||||
FLAGS.input_checkpoint, FLAGS.output_node_names,
|
||||
FLAGS.restore_op_name, FLAGS.filename_tensor_name,
|
||||
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes)
|
||||
FLAGS.output_graph, FLAGS.clear_devices, FLAGS.initializer_nodes,
|
||||
FLAGS.variable_names_blacklist)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--input_graph",
|
||||
type=str,
|
||||
default="",
|
||||
help="TensorFlow \'GraphDef\' file to load.")
|
||||
parser.add_argument(
|
||||
"--input_saver",
|
||||
type=str,
|
||||
default="",
|
||||
help="TensorFlow saver file to load.")
|
||||
parser.add_argument(
|
||||
"--input_checkpoint",
|
||||
type=str,
|
||||
default="",
|
||||
help="TensorFlow variables file to load.")
|
||||
parser.add_argument(
|
||||
"--output_graph",
|
||||
type=str,
|
||||
default="",
|
||||
help="Output \'GraphDef\' file name.")
|
||||
parser.add_argument(
|
||||
"--input_binary",
|
||||
nargs="?",
|
||||
const=True,
|
||||
type="bool",
|
||||
default=False,
|
||||
help="Whether the input files are in binary format.")
|
||||
parser.add_argument(
|
||||
"--output_node_names",
|
||||
type=str,
|
||||
default="",
|
||||
help="The name of the output nodes, comma separated.")
|
||||
parser.add_argument(
|
||||
"--restore_op_name",
|
||||
type=str,
|
||||
default="save/restore_all",
|
||||
help="The name of the master restore operator.")
|
||||
parser.add_argument(
|
||||
"--filename_tensor_name",
|
||||
type=str,
|
||||
default="save/Const:0",
|
||||
help="The name of the tensor holding the save path.")
|
||||
parser.add_argument(
|
||||
"--clear_devices",
|
||||
nargs="?",
|
||||
const=True,
|
||||
type="bool",
|
||||
default=True,
|
||||
help="Whether to remove device specifications.")
|
||||
parser.add_argument(
|
||||
"--initializer_nodes",
|
||||
type=str,
|
||||
default="",
|
||||
help="comma separated list of initializer nodes to run before freezing.")
|
||||
parser.add_argument(
|
||||
"--variable_names_blacklist",
|
||||
type=str,
|
||||
default="",
|
||||
help="""\
|
||||
comma separated list of variables to skip converting to constants\
|
||||
""")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -17,20 +17,16 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("file_name", "", "Checkpoint filename")
|
||||
flags.DEFINE_string("tensor_name", "", "Name of the tensor to inspect")
|
||||
flags.DEFINE_bool("all_tensors", "False",
|
||||
"If True, print the values of all the tensors.")
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def print_tensors_in_checkpoint_file(file_name, tensor_name):
|
||||
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
|
||||
"""Prints tensors in a checkpoint file.
|
||||
|
||||
If no `tensor_name` is provided, prints the tensor names and shapes
|
||||
@ -41,10 +37,11 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name):
|
||||
Args:
|
||||
file_name: Name of the checkpoint file.
|
||||
tensor_name: Name of the tensor in the checkpoint file to print.
|
||||
all_tensors: Boolean indicating whether to print all tensors.
|
||||
"""
|
||||
try:
|
||||
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
|
||||
if FLAGS.all_tensors:
|
||||
if all_tensors:
|
||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||
for key in var_to_shape_map:
|
||||
print("tensor_name: ", key)
|
||||
@ -67,8 +64,26 @@ def main(unused_argv):
|
||||
"[--tensor_name=tensor_to_print]")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name)
|
||||
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
|
||||
FLAGS.all_tensors)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--file_name", type=str, default="", help="Checkpoint filename")
|
||||
parser.add_argument(
|
||||
"--tensor_name",
|
||||
type=str,
|
||||
default="",
|
||||
help="Name of the tensor to inspect")
|
||||
parser.add_argument(
|
||||
"--all_tensors",
|
||||
nargs="?",
|
||||
const=True,
|
||||
type="bool",
|
||||
default=False,
|
||||
help="If True, print the values of all the tensors.")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -55,7 +55,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
from google.protobuf import text_format
|
||||
|
||||
@ -63,22 +65,10 @@ from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import graph_io
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags as flags_lib
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.tools import optimize_for_inference_lib
|
||||
|
||||
flags = flags_lib
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("input", "", """TensorFlow 'GraphDef' file to load.""")
|
||||
flags.DEFINE_string("output", "", """File to save the output graph to.""")
|
||||
flags.DEFINE_string("input_names", "", """Input node names, comma separated.""")
|
||||
flags.DEFINE_string("output_names", "",
|
||||
"""Output node names, comma separated.""")
|
||||
flags.DEFINE_boolean("frozen_graph", True,
|
||||
"""If true, the input graph is a binary frozen GraphDef
|
||||
file; if false, it is a text GraphDef proto file.""")
|
||||
flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum,
|
||||
"""The AttrValue enum to use for placeholders.""")
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def main(unused_args):
|
||||
@ -110,4 +100,42 @@ def main(unused_args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register("type", "bool", lambda v: v.lower() == "true")
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
default="",
|
||||
help="TensorFlow \'GraphDef\' file to load.")
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="",
|
||||
help="File to save the output graph to.")
|
||||
parser.add_argument(
|
||||
"--input_names",
|
||||
type=str,
|
||||
default="",
|
||||
help="Input node names, comma separated.")
|
||||
parser.add_argument(
|
||||
"--output_names",
|
||||
type=str,
|
||||
default="",
|
||||
help="Output node names, comma separated.")
|
||||
parser.add_argument(
|
||||
"--frozen_graph",
|
||||
nargs="?",
|
||||
const=True,
|
||||
type="bool",
|
||||
default=True,
|
||||
help="""\
|
||||
If true, the input graph is a binary frozen GraphDef
|
||||
file; if false, it is a text GraphDef proto file.\
|
||||
""")
|
||||
parser.add_argument(
|
||||
"--placeholder_type_enum",
|
||||
type=int,
|
||||
default=dtypes.float32.as_datatype_enum,
|
||||
help="The AttrValue enum to use for placeholders.")
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
@ -41,25 +41,14 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.platform import app
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.tools import strip_unused_lib
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("input_graph", "",
|
||||
"""TensorFlow 'GraphDef' file to load.""")
|
||||
flags.DEFINE_boolean("input_binary", False,
|
||||
"""Whether the input files are in binary format.""")
|
||||
flags.DEFINE_string("output_graph", "", """Output 'GraphDef' file name.""")
|
||||
flags.DEFINE_boolean("output_binary", True,
|
||||
"""Whether to write a binary format graph.""")
|
||||
flags.DEFINE_string("input_node_names", "",
|
||||
"""The name of the input nodes, comma separated.""")
|
||||
flags.DEFINE_string("output_node_names", "",
|
||||
"""The name of the output nodes, comma separated.""")
|
||||
flags.DEFINE_integer("placeholder_type_enum", dtypes.float32.as_datatype_enum,
|
||||
"""The AttrValue enum to use for placeholders.""")
|
||||
FLAGS = None
|
||||
|
||||
|
||||
def main(unused_args):
|
||||
@ -72,5 +61,47 @@ def main(unused_args):
|
||||
FLAGS.placeholder_type_enum)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run()
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.register('type', 'bool', lambda v: v.lower() == 'true')
|
||||
parser.add_argument(
|
||||
'--input_graph',
|
||||
type=str,
|
||||
default='',
|
||||
help='TensorFlow \'GraphDef\' file to load.')
|
||||
parser.add_argument(
|
||||
'--input_binary',
|
||||
nargs='?',
|
||||
const=True,
|
||||
type='bool',
|
||||
default=False,
|
||||
help='Whether the input files are in binary format.')
|
||||
parser.add_argument(
|
||||
'--output_graph',
|
||||
type=str,
|
||||
default='',
|
||||
help='Output \'GraphDef\' file name.')
|
||||
parser.add_argument(
|
||||
'--output_binary',
|
||||
nargs='?',
|
||||
const=True,
|
||||
type='bool',
|
||||
default=True,
|
||||
help='Whether to write a binary format graph.')
|
||||
parser.add_argument(
|
||||
'--input_node_names',
|
||||
type=str,
|
||||
default='',
|
||||
help='The name of the input nodes, comma separated.')
|
||||
parser.add_argument(
|
||||
'--output_node_names',
|
||||
type=str,
|
||||
default='',
|
||||
help='The name of the output nodes, comma separated.')
|
||||
parser.add_argument(
|
||||
'--placeholder_type_enum',
|
||||
type=int,
|
||||
default=dtypes.float32.as_datatype_enum,
|
||||
help='The AttrValue enum to use for placeholders.')
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
|
Loading…
Reference in New Issue
Block a user