Make freeze_graph work with empty tag-set i.e. --saved_model_tags=""

PiperOrigin-RevId: 304346115
Change-Id: I3adff9b5da8e6c56f6c7032a5f2f380886847456
This commit is contained in:
A. Unique TensorFlower 2020-04-02 00:53:18 -07:00 committed by TensorFlower Gardener
parent 6e38f2672b
commit 3d602790f8
3 changed files with 15 additions and 8 deletions

View File

@ -135,6 +135,7 @@ py_test(
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:variables", "//tensorflow/python:variables",
"@absl_py//absl/testing:parameterized",
], ],
) )

View File

@ -357,7 +357,7 @@ def freeze_graph(input_graph,
variable_names_blacklist, variable_names_blacklist,
input_meta_graph_def, input_meta_graph_def,
input_saved_model_dir, input_saved_model_dir,
saved_model_tags.replace(" ", "").split(","), [tag for tag in saved_model_tags.replace(" ", "").split(",") if tag],
checkpoint_version=checkpoint_version) checkpoint_version=checkpoint_version)

View File

@ -21,6 +21,8 @@ from __future__ import print_function
import os import os
import re import re
from absl.testing import parameterized
from tensorflow.core.example import example_pb2 from tensorflow.core.example import example_pb2
from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import saver_pb2 from tensorflow.core.protobuf import saver_pb2
@ -46,7 +48,7 @@ from tensorflow.python.tools import freeze_graph
from tensorflow.python.training import saver as saver_lib from tensorflow.python.training import saver as saver_lib
class FreezeGraphTest(test_util.TensorFlowTestCase): class FreezeGraphTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def _testFreezeGraph(self, saver_write_version): def _testFreezeGraph(self, saver_write_version):
@ -124,7 +126,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
feature_value]) feature_value])
return example.SerializeToString() return example.SerializeToString()
def _writeDummySavedModel(self, path, feature_name): def _writeDummySavedModel(self, path, feature_name, tags):
"""Writes a classifier with two input features to the given path.""" """Writes a classifier with two input features to the given path."""
with ops.Graph().as_default(): with ops.Graph().as_default():
examples = array_ops.placeholder(dtypes.string, name="input_node") examples = array_ops.placeholder(dtypes.string, name="input_node")
@ -151,11 +153,12 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
builder = saved_model_builder.SavedModelBuilder(path) builder = saved_model_builder.SavedModelBuilder(path)
builder.add_meta_graph_and_variables( builder.add_meta_graph_and_variables(
sess, sess,
[tag_constants.SERVING], tags,
signature_def_map={ signature_def_map={
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
signature, signature,
},) },
)
builder.save(as_text=True) builder.save(as_text=True)
@test_util.run_v1_only("b/120545219") @test_util.run_v1_only("b/120545219")
@ -218,11 +221,14 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output = sess.run(output_node) output = sess.run(output_node)
self.assertNear(2.0, output, 0.00001) self.assertNear(2.0, output, 0.00001)
def testFreezeSavedModel(self): @parameterized.named_parameters(
("empty_tags_set", "", []),
("default_tags_set", tag_constants.SERVING, [tag_constants.SERVING]))
def testFreezeSavedModel(self, tags_string, tags_list):
tmp_dir = self.get_temp_dir() tmp_dir = self.get_temp_dir()
saved_model_dir = os.path.join(tmp_dir, "saved_model_dir") saved_model_dir = os.path.join(tmp_dir, "saved_model_dir")
feature_name = "feature" feature_name = "feature"
self._writeDummySavedModel(saved_model_dir, feature_name) self._writeDummySavedModel(saved_model_dir, feature_name, tags_list)
output_graph_filename = os.path.join(tmp_dir, "output_graph.pb") output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
input_saved_model_dir = saved_model_dir input_saved_model_dir = saved_model_dir
@ -235,7 +241,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
input_meta_graph = False input_meta_graph = False
checkpoint_path = None checkpoint_path = None
input_graph_filename = None input_graph_filename = None
saved_model_tags = tag_constants.SERVING saved_model_tags = tags_string
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path, freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
input_binary, checkpoint_path, output_node_names, input_binary, checkpoint_path, output_node_names,