Provide MetaGraphDef transformer, and call it from SavedModel export.

A MetaGraphDef transformation consists of a GraphDef transformation provided by the Graph Transform Tool, followed by some surgery at the MetaGraphDef level to remove references to any nodes that were removed.

This allows users to request Graph Transform Tool rewrites integrated with Estimator.export_savedmodel().  It also integrates graph freezing interleaved with those rewrites (even though that is not provided by GTT).

A limitation, for now, is that all Variables and many of their associated Save and Restore Ops are retained even if they are unused and strip_unused_nodes is requested (pending further clarity on which ones may be safe to remove).

PiperOrigin-RevId: 163475341
This commit is contained in:
David Soergel 2017-07-28 09:00:28 -07:00 committed by TensorFlower Gardener
parent 123f84b8e0
commit 34c49b133d
12 changed files with 967 additions and 37 deletions

View File

@ -280,6 +280,7 @@ filegroup(
"//tensorflow/contrib/linear_optimizer:all_files",
"//tensorflow/contrib/lookup:all_files",
"//tensorflow/contrib/losses:all_files",
"//tensorflow/contrib/meta_graph_transform:all_files",
"//tensorflow/contrib/metrics:all_files",
"//tensorflow/contrib/ndlstm:all_files",
"//tensorflow/contrib/nn:all_files",

View File

@ -47,6 +47,7 @@ py_library(
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/memory_stats:memory_stats_py",
"//tensorflow/contrib/meta_graph_transform",
"//tensorflow/contrib/metrics:metrics_py",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/contrib/ndlstm",

View File

@ -240,6 +240,8 @@ add_python_module("tensorflow/python/training")
add_python_module("tensorflow/python/user_ops")
add_python_module("tensorflow/python/util")
add_python_module("tensorflow/python/util/protobuf")
add_python_module("tensorflow/tools")
add_python_module("tensorflow/tools/graph_transforms")
add_python_module("tensorflow/contrib")
add_python_module("tensorflow/contrib/android")
add_python_module("tensorflow/contrib/android/java")
@ -440,6 +442,7 @@ add_python_module("tensorflow/contrib/memory_stats/ops")
add_python_module("tensorflow/contrib/memory_stats/python")
add_python_module("tensorflow/contrib/memory_stats/python/kernel_tests")
add_python_module("tensorflow/contrib/memory_stats/python/ops")
add_python_module("tensorflow/contrib/meta_graph_transform")
add_python_module("tensorflow/contrib/metrics")
add_python_module("tensorflow/contrib/metrics/kernels")
add_python_module("tensorflow/contrib/metrics/ops")

View File

@ -141,6 +141,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/debug/lib/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py"
"${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/profiler/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/profiler/internal/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py"

View File

@ -31,6 +31,7 @@ py_library(
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/meta_graph_transform",
"//tensorflow/contrib/metrics:metrics_py",
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/contrib/session_bundle:exporter",

View File

@ -303,6 +303,7 @@ from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import
from tensorflow.contrib.learn.python.learn.estimators.dynamic_rnn_estimator import DynamicRnnEstimator
from tensorflow.contrib.learn.python.learn.estimators.estimator import BaseEstimator
from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
from tensorflow.contrib.learn.python.learn.estimators.estimator import GraphRewriteSpec
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import abc
import collections
import copy
import os
import tempfile
@ -34,7 +35,6 @@ from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_args
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
@ -49,6 +49,7 @@ from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedE
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
from tensorflow.contrib.meta_graph_transform import meta_graph_transform
from tensorflow.contrib.training.python.training import evaluation
from tensorflow.core.framework import summary_pb2
from tensorflow.core.protobuf import config_pb2
@ -69,6 +70,7 @@ from tensorflow.python.training import device_setter
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
from tensorflow.python.training import summary_io
from tensorflow.python.training import training_util
from tensorflow.python.util import compat
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@ -346,12 +348,17 @@ def _write_dict_to_summary(output_dir,
value.simple_value = int(dictionary[key])
else:
logging.warn(
'Skipping summary for %s, must be a float, np.float32, np.int64, np.int32 or int.',
'Skipping summary for %s, must be a float, np.float32, '
'np.int64, np.int32 or int.',
key)
summary_writer.add_summary(summary_proto, current_global_step)
summary_writer.flush()
GraphRewriteSpec = collections.namedtuple('GraphRewriteSpec',
['tags', 'transforms'])
class BaseEstimator(
sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable):
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
@ -1229,7 +1236,8 @@ class Estimator(BaseEstimator):
default_output_alternative_key=None,
assets_extra=None,
as_text=False,
checkpoint_path=None):
checkpoint_path=None,
graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),)):
"""Exports inference graph as a SavedModel into given dir.
Args:
@ -1249,6 +1257,10 @@ class Estimator(BaseEstimator):
as_text: whether to write the SavedModel proto in text format.
checkpoint_path: The checkpoint path to export. If None (the default),
the most recent checkpoint found within the model directory is chosen.
graph_rewrite_specs: an iterable of `GraphRewriteSpec`. Each element will
produce a separate MetaGraphDef within the exported SavedModel, tagged
and rewritten as specified. Defaults to a single entry using the
default serving tag ("serve") and no rewriting.
Returns:
The string path to the exported directory.
@ -1259,8 +1271,20 @@ class Estimator(BaseEstimator):
if serving_input_fn is None:
raise ValueError('serving_input_fn must be defined.')
if not checkpoint_path:
# Locate the latest checkpoint
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% self._model_dir)
export_dir = saved_model_export_utils.get_timestamped_export_dir(
export_dir_base)
builder = saved_model_builder.SavedModelBuilder(export_dir)
# Build the base graph
with ops.Graph().as_default() as g:
contrib_variables.create_global_step(g)
training_util.create_global_step(g)
# Call the serving_input_fn and collect the input alternatives.
input_ops = serving_input_fn()
@ -1281,55 +1305,87 @@ class Estimator(BaseEstimator):
saved_model_export_utils.get_output_alternatives(
model_fn_ops, default_output_alternative_key))
init_op = control_flow_ops.group(
variables.local_variables_initializer(),
resources.initialize_resources(resources.shared_resources()),
lookup_ops.tables_initializer())
# Build the SignatureDefs from all pairs of input and output alternatives
signature_def_map = saved_model_export_utils.build_all_signature_defs(
input_alternatives, output_alternatives,
actual_default_output_alternative_key)
if not checkpoint_path:
# Locate the latest checkpoint
checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise NotFittedError("Couldn't find trained model at %s."
% self._model_dir)
export_dir = saved_model_export_utils.get_timestamped_export_dir(
export_dir_base)
if (model_fn_ops.scaffold is not None and
model_fn_ops.scaffold.saver is not None):
saver_for_restore = model_fn_ops.scaffold.saver
else:
saver_for_restore = saver.Saver(sharded=True)
# Export the first MetaGraphDef with variables, assets etc.
with tf_session.Session('') as session:
# pylint: disable=protected-access
saveables = variables._all_saveable_objects()
# pylint: enable=protected-access
if (model_fn_ops.scaffold is not None and
model_fn_ops.scaffold.saver is not None):
saver_for_restore = model_fn_ops.scaffold.saver
elif saveables:
saver_for_restore = saver.Saver(saveables, sharded=True)
saver_for_restore.restore(session, checkpoint_path)
init_op = control_flow_ops.group(
variables.local_variables_initializer(),
resources.initialize_resources(resources.shared_resources()),
lookup_ops.tables_initializer())
# Perform the export
builder = saved_model_builder.SavedModelBuilder(export_dir)
if not graph_rewrite_specs or graph_rewrite_specs[0].transforms:
raise ValueError('The first element of graph_rewrite_specs '
'must specify no transforms.')
untransformed_tags = graph_rewrite_specs[0].tags
# TODO(soergel): switch to main_op or otherwise update when dust settles
builder.add_meta_graph_and_variables(
session, [tag_constants.SERVING],
session, untransformed_tags,
signature_def_map=signature_def_map,
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS),
legacy_init_op=init_op)
builder.save(as_text)
# Add the extra assets
if assets_extra:
assets_extra_path = os.path.join(compat.as_bytes(export_dir),
compat.as_bytes('assets.extra'))
for dest_relative, source in assets_extra.items():
dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
compat.as_bytes(dest_relative))
dest_path = os.path.dirname(dest_absolute)
gfile.MakeDirs(dest_path)
gfile.Copy(source, dest_absolute)
# pylint: disable=protected-access
base_meta_graph_def = builder._saved_model.meta_graphs[0]
# pylint: enable=protected-access
return export_dir
if graph_rewrite_specs[1:]:
# Prepare the input_names and output_names needed for the
# meta_graph_transform call below.
input_names = [tensor.name
for input_dict in input_alternatives.values()
for tensor in input_dict.values()]
output_names = [tensor.name
for output_alternative in output_alternatives.values()
for tensor in output_alternative[1].values()]
# Write the additional MetaGraphDefs
for graph_rewrite_spec in graph_rewrite_specs[1:]:
# TODO(soergel) consider moving most of this to saved_model.builder_impl
# as e.g. builder.add_rewritten_meta_graph(rewritten_graph_def, tags)
transformed_meta_graph_def = meta_graph_transform.meta_graph_transform(
base_meta_graph_def, input_names, output_names,
graph_rewrite_spec.transforms, graph_rewrite_spec.tags)
# pylint: disable=protected-access
meta_graph_def = builder._saved_model.meta_graphs.add()
# pylint: enable=protected-access
meta_graph_def.CopyFrom(transformed_meta_graph_def)
# Add the extra assets
if assets_extra:
assets_extra_path = os.path.join(compat.as_bytes(export_dir),
compat.as_bytes('assets.extra'))
for dest_relative, source in assets_extra.items():
dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
compat.as_bytes(dest_relative))
dest_path = os.path.dirname(dest_absolute)
gfile.MakeDirs(dest_path)
gfile.Copy(source, dest_absolute)
builder.save(as_text)
return export_dir
# For time of deprecation x,y from Estimator allow direct access.

View File

@ -1006,6 +1006,130 @@ class EstimatorTest(test.TestCase):
# cleanup
gfile.DeleteRecursively(tmpdir)
def test_export_savedmodel_with_graph_transforms(self):
tmpdir = tempfile.mkdtemp()
est, serving_input_fn = _build_estimator_for_export_tests(tmpdir)
extra_file_name = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('my_extra_file'))
extra_file = gfile.GFile(extra_file_name, mode='w')
extra_file.write(EXTRA_FILE_CONTENT)
extra_file.close()
assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}
export_dir_base = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('export'))
export_dir = est.export_savedmodel(
export_dir_base, serving_input_fn, assets_extra=assets_extra,
graph_rewrite_specs=[
estimator.GraphRewriteSpec(['tag_1'], []),
estimator.GraphRewriteSpec(['tag_2', 'tag_3'],
['strip_unused_nodes'])])
self.assertTrue(gfile.Exists(export_dir_base))
self.assertTrue(gfile.Exists(export_dir))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir), compat.as_bytes(
'saved_model.pb'))))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir), compat.as_bytes('variables'))))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('variables/variables.index'))))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('variables/variables.data-00000-of-00001'))))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir), compat.as_bytes('assets'))))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('assets/my_vocab_file'))))
self.assertEqual(
compat.as_bytes(VOCAB_FILE_CONTENT),
compat.as_bytes(
gfile.GFile(
os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('assets/my_vocab_file'))).read()))
expected_extra_path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
self.assertTrue(
gfile.Exists(
os.path.join(
compat.as_bytes(export_dir), compat.as_bytes('assets.extra'))))
self.assertTrue(gfile.Exists(expected_extra_path))
self.assertEqual(
compat.as_bytes(EXTRA_FILE_CONTENT),
compat.as_bytes(gfile.GFile(expected_extra_path).read()))
expected_vocab_file = os.path.join(
compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file'))
# Restore, to validate that the export was well-formed.
# tag_1 is untransformed.
tags = ['tag_1']
with ops.Graph().as_default() as graph:
with session_lib.Session(graph=graph) as sess:
loader.load(sess, tags, export_dir)
assets = [
x.eval()
for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
]
self.assertItemsEqual([expected_vocab_file], assets)
graph_ops = [x.name for x in graph.get_operations()]
self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops)
self.assertTrue('linear/linear/feature/matmul' in graph_ops)
# Since there were no transforms, both save ops are still present.
self.assertTrue('save/SaveV2/tensor_names' in graph_ops)
self.assertTrue('save_1/SaveV2/tensor_names' in graph_ops)
# Since there were no transforms, the hash table lookup is still there.
self.assertTrue('hash_table_Lookup' in graph_ops)
# Restore, to validate that the export was well-formed.
# tag_2, tag_3 was subjected to strip_unused_nodes.
tags = ['tag_2', 'tag_3']
with ops.Graph().as_default() as graph:
with session_lib.Session(graph=graph) as sess:
loader.load(sess, tags, export_dir)
assets = [
x.eval()
for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
]
self.assertItemsEqual([expected_vocab_file], assets)
graph_ops = [x.name for x in graph.get_operations()]
self.assertTrue('input_example_tensor' in graph_ops)
self.assertTrue('ParseExample/ParseExample' in graph_ops)
self.assertTrue('linear/linear/feature/matmul' in graph_ops)
# The Saver used to restore the checkpoint into the export Session
# was not added to the SAVERS collection, so strip_unused_nodes removes
# it. The one explicitly created in export_savedmodel is tracked in
# the MetaGraphDef saver_def field, so that one is retained.
# TODO(soergel): Make Savers sane again. I understand this is all a bit
# nuts but for now the test demonstrates what actually happens.
self.assertFalse('save/SaveV2/tensor_names' in graph_ops)
self.assertTrue('save_1/SaveV2/tensor_names' in graph_ops)
# The fake hash table lookup wasn't connected to anything; stripped.
self.assertFalse('hash_table_Lookup' in graph_ops)
# cleanup
gfile.DeleteRecursively(tmpdir)
class InferRealValuedColumnsTest(test.TestCase):

View File

@ -0,0 +1,59 @@
# Description:
# Utility for applying the Graph Transform tool to a MetaGraphDef.
package(default_visibility = ["//visibility:public"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load(
"//tensorflow:tensorflow.bzl",
"py_test",
)
py_library(
name = "meta_graph_transform",
srcs = [
"__init__.py",
"meta_graph_transform.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python:ops",
"//tensorflow/python/saved_model:constants",
"//tensorflow/tools/graph_transforms:transform_graph_py",
],
)
py_test(
name = "meta_graph_transform_test",
size = "small",
srcs = ["meta_graph_transform_test.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:private"],
deps = [
":meta_graph_transform",
"//tensorflow/python:client_testlib",
],
)
filegroup(
name = "py_srcs",
data = glob([
"**/*.py",
]),
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility for applying the Graph Transform tool to a MetaGraphDef."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.meta_graph_transform import meta_graph_transform
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['meta_graph_transform']
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -0,0 +1,453 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Apply graph_transforms tool to MetaGraphDefs."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _graph_util
from tensorflow.python.framework import importer as _importer
from tensorflow.python.framework import ops as _ops
from tensorflow.python.saved_model import constants as _saved_model_constants
from tensorflow.python.training import saver as _saver_lib
from tensorflow.python.util import compat
from tensorflow.tools import graph_transforms as _graph_transforms
def _op_name(tensor_name):
"""Get the op name from a tensor name."""
# control dependency inputs start with ^
if tensor_name[0] == '^':
tensor_name = tensor_name[1:]
if ':' in tensor_name:
op_name, _ = tensor_name.split(':')
return op_name
return tensor_name
def _do_transforms(graph_def, input_names, output_names, initializer_names,
transforms, saver_def=None, checkpoint_path=None):
"""Apply requested transforms to a GraphDef, including freezing.
This applies the Graph Transform Tool interleaved with graph freezing.
Args:
graph_def: A GraphDef proto to be transformed.
input_names: Names of input nodes.
output_names: Names of output nodes.
initializer_names: Names of "infrastructural" nodes (initializers, save and
restore ops, etc.) that should be retained even if they are not
transitively reachable from output nodes.
transforms: A list of strings naming the graph transforms to be applied in
order. These transform names are exactly those supported by the Graph
Transform Tool, with the addition of the 'freeze_graph' transform.
saver_def: A SaverDef proto used for restoring a checkpoint during freezing,
if needed (default None).
checkpoint_path: A path to a checkpoint to restore during freezing,
if needed (default None).
Returns:
The transformed GraphDef.
"""
if not transforms:
transformed_graph_def = _graph_pb2.GraphDef()
transformed_graph_def.CopyFrom(graph_def)
return transformed_graph_def
else:
try:
freeze_index = transforms.index('freeze_graph')
except ValueError:
# No freeze_graph requested, so do all transforms in one go.
all_output_names = output_names + initializer_names
return _graph_transforms.TransformGraph(
graph_def, input_names, all_output_names, transforms)
# freeze_graph requested, possibly with transforms before and after.
phase_1_transforms = transforms[:freeze_index]
phase_2_transforms = transforms[freeze_index+1:]
graph_def = _do_transforms(
graph_def, input_names, output_names, initializer_names,
phase_1_transforms, saver_def, checkpoint_path)
output_node_names = [_op_name(x) for x in output_names]
graph_def = _freeze_graph_with_def_protos(
graph_def, output_node_names, saver_def, checkpoint_path)
# No need for saver or checkpoint anymore
return _do_transforms(
graph_def, input_names, output_names, [], phase_2_transforms)
# forked and modified from freeze_graph.py
def _freeze_graph_with_def_protos(
input_graph_def,
output_node_names,
input_saver_def,
input_checkpoint):
"""Converts all variables in a graph and checkpoint into constants."""
with _ops.Graph().as_default():
_ = _importer.import_graph_def(input_graph_def, name='')
with _session.Session() as sess:
saver = _saver_lib.Saver(saver_def=input_saver_def)
saver.restore(sess, input_checkpoint)
output_graph_def = _graph_util.convert_variables_to_constants(
sess, input_graph_def, output_node_names)
return output_graph_def
def _find_all_mandatory_retain_ops(base_meta_graph_def):
"""Identify all infrastructural Ops, to ensure that they are retained.
We need to retain infrastructural Ops (init and saver stuff), in addition
to the desired outputs.
For now we retain *all* save and restore ops, variable initializers,
table initializers, and main init ops.
This means that strip_unused_nodes will not remove unused variables.
Args:
base_meta_graph_def: a GraphDef proto in which to identify nodes to retain.
Returns:
A list of node names to be retained.
"""
# TODO(b/63447631): implement variable stripping.
initializer_names = []
# Primary SaverDef and SAVERS collection
saver_defs = []
if base_meta_graph_def.HasField('saver_def'):
saver_defs.append(base_meta_graph_def.saver_def)
saver_defs.extend(_get_all_protos_from_collection(
base_meta_graph_def, _ops.GraphKeys.SAVERS))
for saver_def in saver_defs:
initializer_names.append(saver_def.filename_tensor_name)
initializer_names.append(saver_def.save_tensor_name)
initializer_names.append(saver_def.restore_op_name)
# Variable initializers
variable_collections = [
_ops.GraphKeys.GLOBAL_VARIABLES,
_ops.GraphKeys.TRAINABLE_VARIABLES,
_ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
_ops.GraphKeys.LOCAL_VARIABLES,
_ops.GraphKeys.MODEL_VARIABLES]
for var_coll in variable_collections:
variables = _get_all_protos_from_collection(base_meta_graph_def, var_coll)
var_init_names = [v.initializer_name for v in variables]
if var_init_names:
initializer_names.extend(var_init_names)
# Table initializers
op_names = _get_all_node_names_from_collection(
base_meta_graph_def, _ops.GraphKeys.TABLE_INITIALIZERS)
if op_names:
initializer_names.extend(op_names)
# Various init ops
various_init_op_collections = [_saved_model_constants.LEGACY_INIT_OP_KEY,
_saved_model_constants.MAIN_OP_KEY,
_ops.GraphKeys.INIT_OP,
_ops.GraphKeys.LOCAL_INIT_OP,
_ops.GraphKeys.READY_OP,
_ops.GraphKeys.READY_FOR_LOCAL_INIT_OP]
for op_coll in various_init_op_collections:
op_name = _get_single_node_name_from_collection(
base_meta_graph_def, op_coll)
if op_name:
initializer_names.append(op_name)
return initializer_names
def _add_pruned_collection(base_meta_graph_def, meta_graph_def,
collection_name, removed_op_names):
"""Copy collection to the transformed MetaGraphDef, omitting removed items."""
base_collection = base_meta_graph_def.collection_def[collection_name]
collection = meta_graph_def.collection_def[collection_name]
if base_collection.HasField('any_list'):
for any_value in base_collection.any_list.value:
# just search the serialized proto as a string
if not _is_removed_mentioned(any_value.value, removed_op_names):
copied_any = collection.any_list.value.add()
copied_any.CopyFrom(any_value)
elif base_collection.HasField('bytes_list'):
collection.bytes_list.value[:] = [
s for s in base_collection.bytes_list.value
if not _is_removed_mentioned(s, removed_op_names)]
elif base_collection.HasField('node_list'):
collection.node_list.value[:] = [
s for s in base_collection.node_list.value
if not _is_removed(s, removed_op_names)]
else:
collection.CopyFrom(base_collection)
def _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names):
"""Copy the Saver into the transformed MetaGraphDef, if valid.
Currently this copies the Saver as is, after verifying that none of the
referenced Save & Restore ops were removed. A future version will modify
the Save and Restore ops themselves as needed to account for removed
Variables.
Args:
base_meta_graph_def: The untransformed MetaGraphDef.
meta_graph_def: The transformed MetaGraphDef being built.
removed_op_names: An iterable of names of ops that were removed.
"""
# Note this does surgery on meta_graph_def.graph_def too, so that should have
# been copied already.
if base_meta_graph_def.HasField('saver_def'):
filename_tensor_name = base_meta_graph_def.saver_def.filename_tensor_name
save_tensor_name = base_meta_graph_def.saver_def.save_tensor_name
restore_op_name = base_meta_graph_def.saver_def.restore_op_name
_check_tensor_not_removed(filename_tensor_name, removed_op_names)
_check_tensor_not_removed(save_tensor_name, removed_op_names)
_check_tensor_not_removed(restore_op_name, removed_op_names)
# TODO(b/63447631): Once we strip unused variables, remove references to
# them from save and restore ops. Retain those ops only if they also refer
# to retained Variables.
# saver_name, restore_all = restore_op_name.rsplit('/', 1)
# if restore_all != 'restore_all':
# raise ValueError(
# 'SaverDef restore_op_name did not have expected form */restore_all')
# save_tensor_names_op_name = '{}/SaveV2/tensor_names'.format(saver_name)
# restore_tensor_names_op_name = (
# '{}/RestoreV2/tensor_names'.format(saver_name))
# save_tensor_names_op = _find_op(meta_graph_def.graph_def,
# save_tensor_names_op_name)
# save_tensor_names_value_tensor = save_tensor_names_op.attr['value'].tensor
# save_tensor_names_value_tensor.string_val[:] = [
# s for s in save_tensor_names_value_tensor.string_val
# if not _is_removed(s, removed_op_names)]
# restore_tensor_names_op = _find_op(
# meta_graph_def.graph_def, restore_tensor_names_op_name)
# restore_tensor_names_value_tensor = (
# restore_tensor_names_op.attr['value'].tensor)
# restore_tensor_names_value_tensor.string_val[:] = [
# s for s in restore_tensor_names_value_tensor.string_val
# if not _is_removed(s, removed_op_names)]
# if (save_tensor_names_value_tensor.string_val
# or restore_tensor_names_value_tensor.string_val):
meta_graph_def.saver_def.CopyFrom(base_meta_graph_def.saver_def)
def _find_op(graph_def, op_name):
"""Fetch a node from a GraphDef proto by name."""
for node_def in graph_def.node:
if node_def.name == op_name:
return node_def
return None
def _add_pruned_signature(base_meta_graph_def, meta_graph_def,
signature_name, removed_op_names):
"""Copy the named signature into the transformed MetaGraphDef, if valid.
If any input or output mentioned in the signature was removed by the graph
transform, the signature is silently omitted from the transformed
MetaGraphDef.
Args:
base_meta_graph_def: The untransformed MetaGraphDef.
meta_graph_def: The transformed MetaGraphDef being built.
signature_name: The name of the signature to copy.
removed_op_names: An iterable of names of ops that were removed.
"""
try:
base_signature = base_meta_graph_def.signature_def[signature_name]
for key in base_signature.inputs:
_check_tensor_not_removed(base_signature.inputs[key].name,
removed_op_names)
for key in base_signature.outputs:
_check_tensor_not_removed(base_signature.outputs[key].name,
removed_op_names)
meta_graph_def.signature_def[signature_name].CopyFrom(base_signature)
except ValueError:
# exclude any signature that mentions a removed node
pass
def _get_single_node_name_from_collection(meta_graph_def, collection_key):
"""Obtain a node name that is the single element of a collection."""
if collection_key not in meta_graph_def.collection_def:
return None
collection = meta_graph_def.collection_def[collection_key]
if not collection.node_list.value:
raise ValueError(
'Collection {} is present but type is not node_list.'.format(
collection_key))
if len(collection.node_list.value) != 1:
raise ValueError(
'Collection {} is has {} elements; expected exactly one.'.format(
collection_key, collection.bytes_list))
return collection.node_list.value[0]
def _get_all_node_names_from_collection(meta_graph_def, collection_key):
"""Obtain node names from a collection."""
if collection_key not in meta_graph_def.collection_def:
return None
collection = meta_graph_def.collection_def[collection_key]
if not collection.node_list.value:
raise ValueError(
'Collection {} is present but type is not node_list.'.format(
collection_key))
return collection.node_list.value
def _get_all_protos_from_collection(meta_graph_def, collection_key):
"""Obtain node names from a collection."""
if collection_key not in meta_graph_def.collection_def:
return []
collection = meta_graph_def.collection_def[collection_key]
if not collection.bytes_list.value:
raise ValueError(
'Collection {} is present but type is not bytes_list.'.format(
collection_key))
proto_type = _ops.get_collection_proto_type(collection_key)
result = []
for value in collection.bytes_list.value:
proto = proto_type()
proto.ParseFromString(value)
result.append(proto)
return result
def _is_removed(tensor_name, removed_op_names):
"""Determine whether the named tensor is an output of a removed op."""
for removed_op_name in removed_op_names:
if tensor_name.startswith(removed_op_name):
return True
return False
def _is_removed_mentioned(s, removed_op_names):
"""Determine whether any removed op is mentioned in the given object.
This relies on the string representation of the object. This is used for
proto messages that may mention ops by name in nested fields. The string
representation of the proto includes those field values, so this string
search approach is sufficient.
Args:
s: an object to search for removed op names.
removed_op_names: An iterable of names of ops that were removed.
Returns:
True if any removed op is mentioned in the given object, False otherwise.
"""
for removed_op_name in removed_op_names:
if removed_op_name in compat.as_str_any(s):
return True
return False
def _check_tensor_not_removed(tensor_name, removed_op_names):
"""Verify that the named tensor was not removed.
Args:
tensor_name: the name of a tensor to check.
removed_op_names: An iterable of names of ops that were removed.
Raises:
ValueError: if the tensor was removed.
"""
if not tensor_name:
raise ValueError('Tensor name should not be empty')
if _is_removed(tensor_name, removed_op_names):
raise ValueError(
'Expected Tensor, but it was removed: {}'.format(tensor_name))
def meta_graph_transform(
base_meta_graph_def, input_names, output_names, transforms, tags,
checkpoint_path=None):
"""Apply the Graph Transform tool to a MetaGraphDef.
Args:
base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
input_names: Names of input nodes.
output_names: Names of output nodes.
transforms: A list of strings naming the graph transforms to be applied in
order. These transform names are exactly those supported by the Graph
Transform Tool, with the addition of the 'freeze_graph' transform.
tags: A list of tags with which to annotate the transformed MetaGraphDef.
checkpoint_path: A path to a checkpoint to restore during freezing,
if needed (default None).
Returns:
A new transformed MetaGraphDef protocol buffer.
"""
meta_graph_def = _meta_graph_pb2.MetaGraphDef()
initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)
transformed_graph_def = _do_transforms(
base_meta_graph_def.graph_def,
input_names,
output_names,
initializer_names,
transforms,
base_meta_graph_def.saver_def,
checkpoint_path)
meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
meta_graph_def.meta_info_def.ClearField('tags')
for tag in tags:
meta_graph_def.meta_info_def.tags.append(tag)
base_op_names = [compat.as_str(node.name)
for node in base_meta_graph_def.graph_def.node]
retained_op_names = [compat.as_str(node.name)
for node in meta_graph_def.graph_def.node]
removed_op_names = set(base_op_names) - set(retained_op_names)
# Copy saver, excluding any pruned nodes
_add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)
# Copy collections, excluding any pruned nodes
for collection_name in base_meta_graph_def.collection_def:
_add_pruned_collection(
base_meta_graph_def, meta_graph_def, collection_name,
removed_op_names)
# Copy signature_defs, excluding any pruned nodes
for signature_name in base_meta_graph_def.signature_def:
_add_pruned_signature(
base_meta_graph_def, meta_graph_def, signature_name,
removed_op_names)
return meta_graph_def

View File

@ -0,0 +1,202 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MetaGraphDef Transform Tool."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from google.protobuf.any_pb2 import Any
from tensorflow.contrib.meta_graph_transform import meta_graph_transform
from tensorflow.core.framework import function_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.training import saver
from tensorflow.python.util import compat
def _make_asset_file_def_any(node_name):
asset_file_def = meta_graph_pb2.AssetFileDef()
asset_file_def.tensor_info.name = node_name
any_message = Any()
any_message.Pack(asset_file_def)
return any_message
class MetaGraphTransformTest(test.TestCase):
def test_meta_graph_transform(self):
with ops.Graph().as_default():
with tf_session.Session(''):
a = array_ops.placeholder(dtypes.int64, [1], name='a')
b = array_ops.placeholder(dtypes.int64, [1], name='b')
c = array_ops.placeholder(dtypes.int64, [1], name='c')
_ = a * b
_ = b * c
base_meta_graph_def = saver.export_meta_graph()
with ops.Graph().as_default():
with tf_session.Session(''):
a = array_ops.placeholder(dtypes.int64, [1], name='a')
b = array_ops.placeholder(dtypes.int64, [1], name='b')
_ = a * b
meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
meta_info_def.tags.append('tag_ab')
expected_meta_graph_def = saver.export_meta_graph(
meta_info_def=meta_info_def)
# Graph rewriter clears versions field, so we expect that.
expected_meta_graph_def.graph_def.ClearField('versions')
# Graph rewriter adds an empty library field, so we expect that.
expected_meta_graph_def.graph_def.library.CopyFrom(
function_pb2.FunctionDefLibrary())
input_names = ['a', 'b']
output_names = ['mul:0']
transforms = ['strip_unused_nodes']
tags = ['tag_ab']
print('AAAAAA: {}'.format(base_meta_graph_def))
transformed_meta_graph_def = meta_graph_transform.meta_graph_transform(
base_meta_graph_def, input_names, output_names, transforms, tags)
self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def)
def test_add_pruned_collection_node(self):
collection_name = 'node_collection'
base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
base_meta_graph_def.collection_def[collection_name].node_list.value.extend(
['node1', 'node2', 'node3', 'node4'])
meta_graph_def = meta_graph_pb2.MetaGraphDef()
removed_op_names = ['node2', 'node4', 'node5']
meta_graph_transform._add_pruned_collection(
base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
collection = meta_graph_def.collection_def[collection_name]
expected_nodes = ['node1', 'node3']
self.assertEqual(expected_nodes, collection.node_list.value)
def test_add_pruned_collection_int(self):
collection_name = 'int_collection'
base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
base_meta_graph_def.collection_def[collection_name].int64_list.value[:] = (
[10, 20, 30, 40])
meta_graph_def = meta_graph_pb2.MetaGraphDef()
removed_op_names = ['node2', 'node4', 'node5']
meta_graph_transform._add_pruned_collection(
base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
collection = meta_graph_def.collection_def[collection_name]
expected_ints = [10, 20, 30, 40]
self.assertEqual(expected_ints, collection.int64_list.value)
def test_add_pruned_collection_proto_in_any_list(self):
collection_name = 'proto_collection'
base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
base_meta_graph_def.collection_def[collection_name].any_list.value.extend(
[_make_asset_file_def_any('node1'),
_make_asset_file_def_any('node2'),
_make_asset_file_def_any('node3'),
_make_asset_file_def_any('node4')])
meta_graph_def = meta_graph_pb2.MetaGraphDef()
removed_op_names = ['node2', 'node4', 'node5']
meta_graph_transform._add_pruned_collection(
base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
collection = meta_graph_def.collection_def[collection_name]
expected_protos = [_make_asset_file_def_any('node1'),
_make_asset_file_def_any('node3')]
self.assertEqual(expected_protos, collection.any_list.value[:])
def test_add_pruned_collection_proto_in_bytes_list(self):
collection_name = 'proto_collection'
base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
base_meta_graph_def.collection_def[collection_name].bytes_list.value.extend(
[compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node2'))),
compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))),
compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4')))])
meta_graph_def = meta_graph_pb2.MetaGraphDef()
removed_op_names = ['node2', 'node4', 'node5']
meta_graph_transform._add_pruned_collection(
base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
collection = meta_graph_def.collection_def[collection_name]
expected_values = [
compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3')))]
self.assertEqual(expected_values, collection.bytes_list.value[:])
def test_add_pruned_saver(self):
base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
base_meta_graph_def.saver_def.filename_tensor_name = 'node1'
base_meta_graph_def.saver_def.save_tensor_name = 'node3'
base_meta_graph_def.saver_def.restore_op_name = 'node6'
meta_graph_def = meta_graph_pb2.MetaGraphDef()
removed_op_names = ['node2', 'node4', 'node5']
meta_graph_transform._add_pruned_saver(base_meta_graph_def,
meta_graph_def,
removed_op_names)
# TODO(b/63447631): For now the saver is just copied unchanged
self.assertEqual(base_meta_graph_def.saver_def, meta_graph_def.saver_def)
def test_add_pruned_signature(self):
base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
signature_name_keep = 'test_signature_keep'
base_sig_keep = base_meta_graph_def.signature_def[signature_name_keep]
base_sig_keep.inputs['input_1'].name = 'input_1'
base_sig_keep.outputs['output_1'].name = 'output_1'
signature_name_remove = 'test_signature_remove'
base_sig_remove = base_meta_graph_def.signature_def[signature_name_remove]
base_sig_remove.inputs['node2'].name = 'node2'
base_sig_remove.outputs['output_1'].name = 'output_1'
meta_graph_def = meta_graph_pb2.MetaGraphDef()
removed_op_names = ['node2', 'node4', 'node5']
meta_graph_transform._add_pruned_signature(base_meta_graph_def,
meta_graph_def,
signature_name_keep,
removed_op_names)
meta_graph_transform._add_pruned_signature(base_meta_graph_def,
meta_graph_def,
signature_name_remove,
removed_op_names)
self.assertTrue(signature_name_keep in meta_graph_def.signature_def)
sig_keep = meta_graph_def.signature_def[signature_name_keep]
self.assertEqual(base_sig_keep, sig_keep)
self.assertFalse(signature_name_remove in meta_graph_def.signature_def)
if __name__ == '__main__':
test.main()