Add a --all_tensor_names option, which is useful if I only want to know all tensor names. It is especially useful in cases whether some of the tensors has huge size. Also update the usage description.
PiperOrigin-RevId: 175074541
This commit is contained in:
parent
12d6b450b2
commit
35febc0cc9
@ -29,7 +29,8 @@ from tensorflow.python.platform import flags
|
|||||||
FLAGS = None
|
FLAGS = None
|
||||||
|
|
||||||
|
|
||||||
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
|
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors,
|
||||||
|
all_tensor_names):
|
||||||
"""Prints tensors in a checkpoint file.
|
"""Prints tensors in a checkpoint file.
|
||||||
|
|
||||||
If no `tensor_name` is provided, prints the tensor names and shapes
|
If no `tensor_name` is provided, prints the tensor names and shapes
|
||||||
@ -41,13 +42,15 @@ def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
|
|||||||
file_name: Name of the checkpoint file.
|
file_name: Name of the checkpoint file.
|
||||||
tensor_name: Name of the tensor in the checkpoint file to print.
|
tensor_name: Name of the tensor in the checkpoint file to print.
|
||||||
all_tensors: Boolean indicating whether to print all tensors.
|
all_tensors: Boolean indicating whether to print all tensors.
|
||||||
|
all_tensor_names: Boolean indicating whether to print all tensor names.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
|
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
|
||||||
if all_tensors:
|
if all_tensors or all_tensor_names:
|
||||||
var_to_shape_map = reader.get_variable_to_shape_map()
|
var_to_shape_map = reader.get_variable_to_shape_map()
|
||||||
for key in sorted(var_to_shape_map):
|
for key in sorted(var_to_shape_map):
|
||||||
print("tensor_name: ", key)
|
print("tensor_name: ", key)
|
||||||
|
if all_tensors:
|
||||||
print(reader.get_tensor(key))
|
print(reader.get_tensor(key))
|
||||||
elif not tensor_name:
|
elif not tensor_name:
|
||||||
print(reader.debug_string().decode("utf-8"))
|
print(reader.debug_string().decode("utf-8"))
|
||||||
@ -104,11 +107,14 @@ def parse_numpy_printoption(kv_str):
|
|||||||
def main(unused_argv):
|
def main(unused_argv):
|
||||||
if not FLAGS.file_name:
|
if not FLAGS.file_name:
|
||||||
print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
|
print("Usage: inspect_checkpoint --file_name=checkpoint_file_name "
|
||||||
"[--tensor_name=tensor_to_print]")
|
"[--tensor_name=tensor_to_print] "
|
||||||
|
"[--all_tensors] "
|
||||||
|
"[--all_tensor_names] "
|
||||||
|
"[--printoptions]")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
else:
|
else:
|
||||||
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
|
print_tensors_in_checkpoint_file(FLAGS.file_name, FLAGS.tensor_name,
|
||||||
FLAGS.all_tensors)
|
FLAGS.all_tensors, FLAGS.all_tensor_names)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -130,6 +136,13 @@ if __name__ == "__main__":
|
|||||||
type="bool",
|
type="bool",
|
||||||
default=False,
|
default=False,
|
||||||
help="If True, print the values of all the tensors.")
|
help="If True, print the values of all the tensors.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--all_tensor_names",
|
||||||
|
nargs="?",
|
||||||
|
const=True,
|
||||||
|
type="bool",
|
||||||
|
default=False,
|
||||||
|
help="If True, print the names of all the tensors.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--printoptions",
|
"--printoptions",
|
||||||
nargs="*",
|
nargs="*",
|
||||||
|
Loading…
Reference in New Issue
Block a user