diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 57ffc3f05c2..4cc9cabec4f 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -62,7 +62,7 @@ def _show_tag_sets(saved_model_dir): tag_sets = saved_model_utils.get_saved_model_tag_sets(saved_model_dir) print('The given SavedModel contains the following tag-sets:') for tag_set in sorted(tag_sets): - print(', '.join(sorted(tag_set))) + print('%r' % ', '.join(sorted(tag_set))) def _show_signature_def_map_keys(saved_model_dir, tag_set): diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index 671b73a1ccd..488bce93974 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -244,7 +244,7 @@ Defined Functions: with captured_output() as (out, err): saved_model_cli.show(args) output = out.getvalue().strip() - exp_out = 'The given SavedModel contains the following tag-sets:\nserve' + exp_out = 'The given SavedModel contains the following tag-sets:\n\'serve\'' self.assertMultiLineEqual(output, exp_out) self.assertEqual(err.getvalue().strip(), '') diff --git a/tensorflow/python/tools/saved_model_utils.py b/tensorflow/python/tools/saved_model_utils.py index 17c4b8cb831..c5c168c1d27 100644 --- a/tensorflow/python/tools/saved_model_utils.py +++ b/tensorflow/python/tools/saved_model_utils.py @@ -83,7 +83,8 @@ def get_saved_model_tag_sets(saved_model_dir): saved_model_dir: Directory containing the SavedModel. Returns: - String representation of all tag-sets in the SavedModel. + List of all tag-sets in the SavedModel, where a tag-set is represented as a + list of strings. """ saved_model = read_saved_model(saved_model_dir) all_tags = [] @@ -98,10 +99,11 @@ def get_meta_graph_def(saved_model_dir, tag_set): Returns the MetaGraphDef for the given tag-set and SavedModel directory. Args: - saved_model_dir: Directory containing the SavedModel to inspect or execute. + saved_model_dir: Directory containing the SavedModel to inspect. tag_set: Group of tag(s) of the MetaGraphDef to load, in string format, - separated by ','. For tag-set contains multiple tags, all tags must be - passed in. + separated by ','. The empty string tag is ignored so that passing '' + means the empty tag set. For tag-set contains multiple tags, all tags + must be passed in. Raises: RuntimeError: An error when the given tag-set does not exist in the @@ -111,10 +113,11 @@ def get_meta_graph_def(saved_model_dir, tag_set): A MetaGraphDef corresponding to the tag-set. """ saved_model = read_saved_model(saved_model_dir) - set_of_tags = set(tag_set.split(',')) + # Note: Discard empty tags so that "" can mean the empty tag set. + set_of_tags = set([tag for tag in tag_set.split(",") if tag]) for meta_graph_def in saved_model.meta_graphs: if set(meta_graph_def.meta_info_def.tags) == set_of_tags: return meta_graph_def - raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set + - ' could not be found in SavedModel') + raise RuntimeError("MetaGraphDef associated with tag-set %r could not be" + " found in SavedModel" % tag_set)