Tweaks to saved_model_cli to support empty tag-sets.
Handling of tag sets via string join and split leads to convert an empty list into a list with empty string. Instead of fixing the get_meta_graph_def to receive a list of tags instead of a string and move the issue to the caller of it, get_meta_graph now ignores tags that are empty. Additionally change print of the tag sets so use '%r' so that it is clear there is a tag set (i.e. '') instead of printing an empty line. PiperOrigin-RevId: 286375150 Change-Id: I8b6b17c8c1a30d76bbebfbfffdf064e367e526f9
This commit is contained in:
parent
c19bfd4219
commit
6ce4169f68
@ -62,7 +62,7 @@ def _show_tag_sets(saved_model_dir):
|
||||
tag_sets = saved_model_utils.get_saved_model_tag_sets(saved_model_dir)
|
||||
print('The given SavedModel contains the following tag-sets:')
|
||||
for tag_set in sorted(tag_sets):
|
||||
print(', '.join(sorted(tag_set)))
|
||||
print('%r' % ', '.join(sorted(tag_set)))
|
||||
|
||||
|
||||
def _show_signature_def_map_keys(saved_model_dir, tag_set):
|
||||
|
@ -244,7 +244,7 @@ Defined Functions:
|
||||
with captured_output() as (out, err):
|
||||
saved_model_cli.show(args)
|
||||
output = out.getvalue().strip()
|
||||
exp_out = 'The given SavedModel contains the following tag-sets:\nserve'
|
||||
exp_out = 'The given SavedModel contains the following tag-sets:\n\'serve\''
|
||||
self.assertMultiLineEqual(output, exp_out)
|
||||
self.assertEqual(err.getvalue().strip(), '')
|
||||
|
||||
|
@ -83,7 +83,8 @@ def get_saved_model_tag_sets(saved_model_dir):
|
||||
saved_model_dir: Directory containing the SavedModel.
|
||||
|
||||
Returns:
|
||||
String representation of all tag-sets in the SavedModel.
|
||||
List of all tag-sets in the SavedModel, where a tag-set is represented as a
|
||||
list of strings.
|
||||
"""
|
||||
saved_model = read_saved_model(saved_model_dir)
|
||||
all_tags = []
|
||||
@ -98,10 +99,11 @@ def get_meta_graph_def(saved_model_dir, tag_set):
|
||||
Returns the MetaGraphDef for the given tag-set and SavedModel directory.
|
||||
|
||||
Args:
|
||||
saved_model_dir: Directory containing the SavedModel to inspect or execute.
|
||||
saved_model_dir: Directory containing the SavedModel to inspect.
|
||||
tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
|
||||
separated by ','. For tag-set contains multiple tags, all tags must be
|
||||
passed in.
|
||||
separated by ','. The empty string tag is ignored so that passing ''
|
||||
means the empty tag set. For tag-set contains multiple tags, all tags
|
||||
must be passed in.
|
||||
|
||||
Raises:
|
||||
RuntimeError: An error when the given tag-set does not exist in the
|
||||
@ -111,10 +113,11 @@ def get_meta_graph_def(saved_model_dir, tag_set):
|
||||
A MetaGraphDef corresponding to the tag-set.
|
||||
"""
|
||||
saved_model = read_saved_model(saved_model_dir)
|
||||
set_of_tags = set(tag_set.split(','))
|
||||
# Note: Discard empty tags so that "" can mean the empty tag set.
|
||||
set_of_tags = set([tag for tag in tag_set.split(",") if tag])
|
||||
for meta_graph_def in saved_model.meta_graphs:
|
||||
if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
|
||||
return meta_graph_def
|
||||
|
||||
raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set +
|
||||
' could not be found in SavedModel')
|
||||
raise RuntimeError("MetaGraphDef associated with tag-set %r could not be"
|
||||
" found in SavedModel" % tag_set)
|
||||
|
Loading…
x
Reference in New Issue
Block a user