Merge pull request #33954 from ltn100:feature/LN/fix_import_pb_to_tensorboard

PiperOrigin-RevId: 281037512
Change-Id: I72428d1c5d80029cc9f2bd8810ea5c55c6251546
This commit is contained in:
TensorFlower Gardener 2019-11-18 03:58:11 -08:00
commit 89dffab453

View File

@ -21,13 +21,12 @@ from __future__ import print_function
import argparse import argparse
import sys import sys
from tensorflow.core.framework import graph_pb2
from tensorflow.python.client import session from tensorflow.python.client import session
from tensorflow.python.framework import importer from tensorflow.python.framework import importer
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.platform import app from tensorflow.python.platform import app
from tensorflow.python.platform import gfile
from tensorflow.python.summary import summary from tensorflow.python.summary import summary
from tensorflow.python.tools import saved_model_utils
# Try importing TensorRT ops if available # Try importing TensorRT ops if available
# TODO(aaroey): ideally we should import everything from contrib, but currently # TODO(aaroey): ideally we should import everything from contrib, but currently
@ -40,12 +39,16 @@ except ImportError:
pass pass
# pylint: enable=unused-import,g-import-not-at-top,wildcard-import # pylint: enable=unused-import,g-import-not-at-top,wildcard-import
def import_to_tensorboard(model_dir, log_dir):
def import_to_tensorboard(model_dir, log_dir, tag_set):
"""View an imported protobuf model (`.pb` file) as a graph in Tensorboard. """View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
Args: Args:
model_dir: The location of the protobuf (`pb`) model to visualize model_dir: The location of the protobuf (`pb`) model to visualize
log_dir: The location for the Tensorboard log to begin visualization from. log_dir: The location for the Tensorboard log to begin visualization from.
tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
separated by ','. For tag-set contains multiple tags, all tags must be
passed in.
Usage: Usage:
Call this function with your model location and desired log directory. Call this function with your model location and desired log directory.
@ -53,10 +56,9 @@ def import_to_tensorboard(model_dir, log_dir):
View your imported `.pb` model as a graph. View your imported `.pb` model as a graph.
""" """
with session.Session(graph=ops.Graph()) as sess: with session.Session(graph=ops.Graph()) as sess:
with gfile.GFile(model_dir, "rb") as f: input_graph_def = saved_model_utils.get_meta_graph_def(model_dir,
graph_def = graph_pb2.GraphDef() tag_set).graph_def
graph_def.ParseFromString(f.read()) importer.import_graph_def(input_graph_def)
importer.import_graph_def(graph_def)
pb_visual_writer = summary.FileWriter(log_dir) pb_visual_writer = summary.FileWriter(log_dir)
pb_visual_writer.add_graph(sess.graph) pb_visual_writer.add_graph(sess.graph)
@ -64,8 +66,9 @@ def import_to_tensorboard(model_dir, log_dir):
"tensorboard --logdir={}".format(log_dir)) "tensorboard --logdir={}".format(log_dir))
def main(unused_args): def main(_):
import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir) import_to_tensorboard(FLAGS.model_dir, FLAGS.log_dir, FLAGS.tag_set)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -75,12 +78,18 @@ if __name__ == "__main__":
type=str, type=str,
default="", default="",
required=True, required=True,
help="The location of the protobuf (\'pb\') model to visualize.") help="The directory containing the SavedModel to import.")
parser.add_argument( parser.add_argument(
"--log_dir", "--log_dir",
type=str, type=str,
default="", default="",
required=True, required=True,
help="The location for the Tensorboard log to begin visualization from.") help="The location for the Tensorboard log to begin visualization from.")
parser.add_argument(
"--tag_set",
type=str,
default="serve",
required=False,
help='tag-set of graph in SavedModel to load, separated by \',\'')
FLAGS, unparsed = parser.parse_known_args() FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed) app.run(main=main, argv=[sys.argv[0]] + unparsed)