Make freeze_graph work with empty tag-set i.e. --saved_model_tags=""
PiperOrigin-RevId: 304346115 Change-Id: I3adff9b5da8e6c56f6c7032a5f2f380886847456
This commit is contained in:
parent
6e38f2672b
commit
3d602790f8
tensorflow/python/tools
@ -135,6 +135,7 @@ py_test(
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variables",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -357,7 +357,7 @@ def freeze_graph(input_graph,
|
||||
variable_names_blacklist,
|
||||
input_meta_graph_def,
|
||||
input_saved_model_dir,
|
||||
saved_model_tags.replace(" ", "").split(","),
|
||||
[tag for tag in saved_model_tags.replace(" ", "").split(",") if tag],
|
||||
checkpoint_version=checkpoint_version)
|
||||
|
||||
|
||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
||||
import os
|
||||
import re
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import saver_pb2
|
||||
@ -46,7 +48,7 @@ from tensorflow.python.tools import freeze_graph
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
class FreezeGraphTest(test_util.TensorFlowTestCase):
|
||||
class FreezeGraphTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def _testFreezeGraph(self, saver_write_version):
|
||||
|
||||
@ -124,7 +126,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
||||
feature_value])
|
||||
return example.SerializeToString()
|
||||
|
||||
def _writeDummySavedModel(self, path, feature_name):
|
||||
def _writeDummySavedModel(self, path, feature_name, tags):
|
||||
"""Writes a classifier with two input features to the given path."""
|
||||
with ops.Graph().as_default():
|
||||
examples = array_ops.placeholder(dtypes.string, name="input_node")
|
||||
@ -151,11 +153,12 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
||||
builder = saved_model_builder.SavedModelBuilder(path)
|
||||
builder.add_meta_graph_and_variables(
|
||||
sess,
|
||||
[tag_constants.SERVING],
|
||||
tags,
|
||||
signature_def_map={
|
||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
||||
signature,
|
||||
},)
|
||||
},
|
||||
)
|
||||
builder.save(as_text=True)
|
||||
|
||||
@test_util.run_v1_only("b/120545219")
|
||||
@ -218,11 +221,14 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
||||
output = sess.run(output_node)
|
||||
self.assertNear(2.0, output, 0.00001)
|
||||
|
||||
def testFreezeSavedModel(self):
|
||||
@parameterized.named_parameters(
|
||||
("empty_tags_set", "", []),
|
||||
("default_tags_set", tag_constants.SERVING, [tag_constants.SERVING]))
|
||||
def testFreezeSavedModel(self, tags_string, tags_list):
|
||||
tmp_dir = self.get_temp_dir()
|
||||
saved_model_dir = os.path.join(tmp_dir, "saved_model_dir")
|
||||
feature_name = "feature"
|
||||
self._writeDummySavedModel(saved_model_dir, feature_name)
|
||||
self._writeDummySavedModel(saved_model_dir, feature_name, tags_list)
|
||||
output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
|
||||
|
||||
input_saved_model_dir = saved_model_dir
|
||||
@ -235,7 +241,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
||||
input_meta_graph = False
|
||||
checkpoint_path = None
|
||||
input_graph_filename = None
|
||||
saved_model_tags = tag_constants.SERVING
|
||||
saved_model_tags = tags_string
|
||||
|
||||
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
|
||||
input_binary, checkpoint_path, output_node_names,
|
||||
|
Loading…
Reference in New Issue
Block a user