Tweaks to saved_model_cli to support empty tag-sets.

Handling of tag sets via string join and split leads to convert an empty list
into a list with empty string. Instead of fixing the get_meta_graph_def to receive
a list of tags instead of a string and move the issue to the caller of it,
get_meta_graph now ignores tags that are empty.

Additionally change print of the tag sets so use '%r' so that it is clear there is
a tag set (i.e. '') instead of printing an empty line.

PiperOrigin-RevId: 286375150
Change-Id: I8b6b17c8c1a30d76bbebfbfffdf064e367e526f9
This commit is contained in:
Andr? Susano Pinto 2019-12-19 05:27:50 -08:00 committed by TensorFlower Gardener
parent c19bfd4219
commit 6ce4169f68
3 changed files with 12 additions and 9 deletions

View File

@ -62,7 +62,7 @@ def _show_tag_sets(saved_model_dir):
tag_sets = saved_model_utils.get_saved_model_tag_sets(saved_model_dir)
print('The given SavedModel contains the following tag-sets:')
for tag_set in sorted(tag_sets):
print(', '.join(sorted(tag_set)))
print('%r' % ', '.join(sorted(tag_set)))
def _show_signature_def_map_keys(saved_model_dir, tag_set):

View File

@ -244,7 +244,7 @@ Defined Functions:
with captured_output() as (out, err):
saved_model_cli.show(args)
output = out.getvalue().strip()
exp_out = 'The given SavedModel contains the following tag-sets:\nserve'
exp_out = 'The given SavedModel contains the following tag-sets:\n\'serve\''
self.assertMultiLineEqual(output, exp_out)
self.assertEqual(err.getvalue().strip(), '')

View File

@ -83,7 +83,8 @@ def get_saved_model_tag_sets(saved_model_dir):
saved_model_dir: Directory containing the SavedModel.
Returns:
String representation of all tag-sets in the SavedModel.
List of all tag-sets in the SavedModel, where a tag-set is represented as a
list of strings.
"""
saved_model = read_saved_model(saved_model_dir)
all_tags = []
@ -98,10 +99,11 @@ def get_meta_graph_def(saved_model_dir, tag_set):
Returns the MetaGraphDef for the given tag-set and SavedModel directory.
Args:
saved_model_dir: Directory containing the SavedModel to inspect or execute.
saved_model_dir: Directory containing the SavedModel to inspect.
tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
separated by ','. For tag-set contains multiple tags, all tags must be
passed in.
separated by ','. The empty string tag is ignored so that passing ''
means the empty tag set. For tag-set contains multiple tags, all tags
must be passed in.
Raises:
RuntimeError: An error when the given tag-set does not exist in the
@ -111,10 +113,11 @@ def get_meta_graph_def(saved_model_dir, tag_set):
A MetaGraphDef corresponding to the tag-set.
"""
saved_model = read_saved_model(saved_model_dir)
set_of_tags = set(tag_set.split(','))
# Note: Discard empty tags so that "" can mean the empty tag set.
set_of_tags = set([tag for tag in tag_set.split(",") if tag])
for meta_graph_def in saved_model.meta_graphs:
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
return meta_graph_def
raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set +
' could not be found in SavedModel')
raise RuntimeError("MetaGraphDef associated with tag-set %r could not be"
" found in SavedModel" % tag_set)