Make freeze_graph work with empty tag-set i.e. --saved_model_tags=""
PiperOrigin-RevId: 304346115 Change-Id: I3adff9b5da8e6c56f6c7032a5f2f380886847456
This commit is contained in:
parent
6e38f2672b
commit
3d602790f8
@ -135,6 +135,7 @@ py_test(
|
|||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
|
"@absl_py//absl/testing:parameterized",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -357,7 +357,7 @@ def freeze_graph(input_graph,
|
|||||||
variable_names_blacklist,
|
variable_names_blacklist,
|
||||||
input_meta_graph_def,
|
input_meta_graph_def,
|
||||||
input_saved_model_dir,
|
input_saved_model_dir,
|
||||||
saved_model_tags.replace(" ", "").split(","),
|
[tag for tag in saved_model_tags.replace(" ", "").split(",") if tag],
|
||||||
checkpoint_version=checkpoint_version)
|
checkpoint_version=checkpoint_version)
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,6 +21,8 @@ from __future__ import print_function
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.core.example import example_pb2
|
from tensorflow.core.example import example_pb2
|
||||||
from tensorflow.core.framework import graph_pb2
|
from tensorflow.core.framework import graph_pb2
|
||||||
from tensorflow.core.protobuf import saver_pb2
|
from tensorflow.core.protobuf import saver_pb2
|
||||||
@ -46,7 +48,7 @@ from tensorflow.python.tools import freeze_graph
|
|||||||
from tensorflow.python.training import saver as saver_lib
|
from tensorflow.python.training import saver as saver_lib
|
||||||
|
|
||||||
|
|
||||||
class FreezeGraphTest(test_util.TensorFlowTestCase):
|
class FreezeGraphTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def _testFreezeGraph(self, saver_write_version):
|
def _testFreezeGraph(self, saver_write_version):
|
||||||
|
|
||||||
@ -124,7 +126,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
|||||||
feature_value])
|
feature_value])
|
||||||
return example.SerializeToString()
|
return example.SerializeToString()
|
||||||
|
|
||||||
def _writeDummySavedModel(self, path, feature_name):
|
def _writeDummySavedModel(self, path, feature_name, tags):
|
||||||
"""Writes a classifier with two input features to the given path."""
|
"""Writes a classifier with two input features to the given path."""
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
examples = array_ops.placeholder(dtypes.string, name="input_node")
|
examples = array_ops.placeholder(dtypes.string, name="input_node")
|
||||||
@ -151,11 +153,12 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
|||||||
builder = saved_model_builder.SavedModelBuilder(path)
|
builder = saved_model_builder.SavedModelBuilder(path)
|
||||||
builder.add_meta_graph_and_variables(
|
builder.add_meta_graph_and_variables(
|
||||||
sess,
|
sess,
|
||||||
[tag_constants.SERVING],
|
tags,
|
||||||
signature_def_map={
|
signature_def_map={
|
||||||
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
|
||||||
signature,
|
signature,
|
||||||
},)
|
},
|
||||||
|
)
|
||||||
builder.save(as_text=True)
|
builder.save(as_text=True)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_v1_only("b/120545219")
|
||||||
@ -218,11 +221,14 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
|||||||
output = sess.run(output_node)
|
output = sess.run(output_node)
|
||||||
self.assertNear(2.0, output, 0.00001)
|
self.assertNear(2.0, output, 0.00001)
|
||||||
|
|
||||||
def testFreezeSavedModel(self):
|
@parameterized.named_parameters(
|
||||||
|
("empty_tags_set", "", []),
|
||||||
|
("default_tags_set", tag_constants.SERVING, [tag_constants.SERVING]))
|
||||||
|
def testFreezeSavedModel(self, tags_string, tags_list):
|
||||||
tmp_dir = self.get_temp_dir()
|
tmp_dir = self.get_temp_dir()
|
||||||
saved_model_dir = os.path.join(tmp_dir, "saved_model_dir")
|
saved_model_dir = os.path.join(tmp_dir, "saved_model_dir")
|
||||||
feature_name = "feature"
|
feature_name = "feature"
|
||||||
self._writeDummySavedModel(saved_model_dir, feature_name)
|
self._writeDummySavedModel(saved_model_dir, feature_name, tags_list)
|
||||||
output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
|
output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
|
||||||
|
|
||||||
input_saved_model_dir = saved_model_dir
|
input_saved_model_dir = saved_model_dir
|
||||||
@ -235,7 +241,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
|
|||||||
input_meta_graph = False
|
input_meta_graph = False
|
||||||
checkpoint_path = None
|
checkpoint_path = None
|
||||||
input_graph_filename = None
|
input_graph_filename = None
|
||||||
saved_model_tags = tag_constants.SERVING
|
saved_model_tags = tags_string
|
||||||
|
|
||||||
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
|
freeze_graph.freeze_graph(input_graph_filename, input_saver_def_path,
|
||||||
input_binary, checkpoint_path, output_node_names,
|
input_binary, checkpoint_path, output_node_names,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user