Make freeze_graph work with empty tag-set i.e. --saved_model_tags=""

PiperOrigin-RevId: 304346115
Change-Id: I3adff9b5da8e6c56f6c7032a5f2f380886847456
This commit is contained in:
A. Unique TensorFlower 2020-04-02 00:53:18 -07:00 committed by TensorFlower Gardener
parent 6e38f2672b
commit 3d602790f8
3 changed files with 15 additions and 8 deletions

View File

@ -135,6 +135,7 @@ py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:training",
"//tensorflow/python:variables",
"@absl_py//absl/testing:parameterized",
],
)

View File

@ -357,7 +357,7 @@ def freeze_graph(input_graph,
variable_names_blacklist,
input_meta_graph_def,
input_saved_model_dir,
saved_model_tags.replace(" ", "").split(","),
[tag for tag in saved_model_tags.replace(" ", "").split(",") if tag],
checkpoint_version=checkpoint_version)

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import os
import re
from absl.testing import parameterized
from tensorflow.core.example import example_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2
@ -46,7 +48,7 @@ from tensorflow.python.tools import freeze_graph
from tensorflow.python.training import saver as saver_lib
class FreezeGraphTest(test_util.TensorFlowTestCase):
class FreezeGraphTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _testFreezeGraph(self, saver_write_version):
@ -124,7 +126,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
feature_value])
return example.SerializeToString()
def _writeDummySavedModel(self, path, feature_name):
def _writeDummySavedModel(self, path, feature_name, tags):
"""Writes a classifier with two input features to the given path."""
with ops.Graph().as_default():
examples = array_ops.placeholder(dtypes.string, name="input_node")
@ -151,11 +153,12 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
builder = saved_model_builder.SavedModelBuilder(path)
builder.add_meta_graph_and_variables(
sess,
[tag_constants.SERVING],
tags,
signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature,
},)
},
)
builder.save(as_text=True)
@test_util.run_v1_only("b/120545219")
@ -218,11 +221,14 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001)
def testFreezeSavedModel(self):
@parameterized.named_parameters(
("empty_tags_set", "", []),
("default_tags_set", tag_constants.SERVING, [tag_constants.SERVING]))
def testFreezeSavedModel(self, tags_string, tags_list):
tmp_dir = self.get_temp_dir()
saved_model_dir = os.path.join(tmp_dir, "saved_model_dir")
feature_name = "feature"
self._writeDummySavedModel(saved_model_dir, feature_name)
self._writeDummySavedModel(saved_model_dir, feature_name, tags_list)
output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
input_saved_model_dir = saved_model_dir
@ -235,7 +241,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
input_meta_graph = False
checkpoint_path = None
input_graph_filename = None
saved_model_tags = tag_constants.SERVING
saved_model_tags = tags_string
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
input_binary, checkpoint_path, output_node_names,