Some feature additions to saved_model_cli AOT compile & bugfixes in freeze graph
1. saved_model_cli now properly identifies read-only and non-readonly vars in a graph and only freezes the readonly variables (and marks the others as not readonly). 2. fixed bugs in convert_variables_to_constants where blacklist/whitelist was not properly respected for chains of operations. while the VarHandleOp node was properly blacklisted, downstream ops that used the resources would be converted as if the variable had been frozen. so for example the following graph would break: VarHandleOp -> Identity -> [first arg in] ResourceAssign The attrs of the Identity op would be changed to DT_FLOAT though it should stay as DT_RESOURCE. 3. Added support for freezing of *Nd ops (ResourceGatherNd, ResourceScatterNd). PiperOrigin-RevId: 293239272 Change-Id: I06de3d139c5585a93ba585f076edf92137e4c48a
This commit is contained in:
parent
3262f347ae
commit
96c84224c8
@ -514,7 +514,7 @@ def _convert_variables_to_constants_v2_impl(func,
|
||||
# Get dtype and data for non-variable Placeholders (ex. values for 1.X
|
||||
# Const ops that are loaded as Placeholders in 2.0)
|
||||
_save_placeholder(node.name, node.attr["dtype"])
|
||||
elif node.op in ["ReadVariableOp", "ResourceGather"]:
|
||||
elif node.op in ["ReadVariableOp", "ResourceGather", "ResourceGatherNd"]:
|
||||
# Get dtype and data for Placeholder ops associated with ReadVariableOp
|
||||
# and ResourceGather ops. There can be an Identity in between the
|
||||
# resource op and Placeholder. Store the dtype for the Identity ops.
|
||||
@ -570,6 +570,15 @@ def _convert_variables_to_constants_v2_impl(func,
|
||||
output_node.attr["Taxis"].CopyFrom(axis_dtype)
|
||||
if "_class" in input_node.attr:
|
||||
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
|
||||
elif input_node.op == "ResourceGatherNd":
|
||||
output_node.op = "GatherNd"
|
||||
output_node.name = input_node.name
|
||||
output_node.input.extend(
|
||||
[input_node.input[0], input_node.input[1]])
|
||||
output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
|
||||
output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
|
||||
if "_class" in input_node.attr:
|
||||
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
|
||||
# Update the function names and argument types for the conditional ops.
|
||||
elif input_node.op in _CONDITIONAL_OPS:
|
||||
_populate_if_op(output_node, input_node, function_data)
|
||||
|
@ -237,7 +237,9 @@ def tensor_shape_from_node_def_name(graph, input_name):
|
||||
return shape
|
||||
|
||||
|
||||
def _update_resource_identities(resource_identities, output_graph_def):
|
||||
def _update_resource_identities(resource_identities, output_graph_def,
|
||||
variable_names_whitelist,
|
||||
variable_names_blacklist):
|
||||
"""Updates the type of DT_RESOURCE Identity ops.
|
||||
|
||||
Updates the type of the `resource_identities` to the type of the node that
|
||||
@ -248,6 +250,10 @@ def _update_resource_identities(resource_identities, output_graph_def):
|
||||
resource_identities: List of NodeDef protos that are Identity ops with the
|
||||
type DT_RESOURCE.
|
||||
output_graph_def: GraphDef proto.
|
||||
variable_names_whitelist: The set of variable names to convert (by default,
|
||||
all variables are converted).
|
||||
variable_names_blacklist: The set of variable names to omit converting
|
||||
to constants.
|
||||
"""
|
||||
# Identify the nodes in the graph and the nodes consuming each node.
|
||||
map_name_to_node = {}
|
||||
@ -280,14 +286,27 @@ def _update_resource_identities(resource_identities, output_graph_def):
|
||||
"This Identity's type was changed from DT_RESOURCE during graph "
|
||||
"freezing.")
|
||||
if input_node.attr["T"].type != dtypes.resource:
|
||||
if input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY:
|
||||
if (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY
|
||||
and _should_convert(
|
||||
input_node.input[0],
|
||||
variable_names_whitelist,
|
||||
variable_names_blacklist)):
|
||||
node.attr["T"].CopyFrom(input_node.attr["T"])
|
||||
node.attr["_debugging"].s = debugging_message
|
||||
elif input_node.op == "VarHandleOp":
|
||||
elif (input_node.op == "VarHandleOp"
|
||||
and _should_convert(
|
||||
input_node.name,
|
||||
variable_names_whitelist,
|
||||
variable_names_blacklist)):
|
||||
node.attr["T"].CopyFrom(input_node.attr["dtype"])
|
||||
node.attr["_debugging"].s = debugging_message
|
||||
|
||||
|
||||
def _should_convert(name, whitelist, blacklist):
|
||||
return ((whitelist is None or name in whitelist)
|
||||
and (blacklist is None or name not in blacklist))
|
||||
|
||||
|
||||
@deprecation.deprecated(
|
||||
date=None,
|
||||
instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`")
|
||||
@ -315,6 +334,10 @@ def convert_variables_to_constants(sess,
|
||||
|
||||
Returns:
|
||||
GraphDef containing a simplified version of the original.
|
||||
|
||||
Raises:
|
||||
RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both
|
||||
blacklisted AND whitelisted for freezing.
|
||||
"""
|
||||
|
||||
get_input_name = lambda node, index=0: node.input[index].split(":")[0]
|
||||
@ -344,44 +367,60 @@ def convert_variables_to_constants(sess,
|
||||
variable_names = []
|
||||
variable_dict_names = []
|
||||
resource_op_types = {}
|
||||
|
||||
for node in inference_graph.node:
|
||||
if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
|
||||
variable_name = node.name
|
||||
if ((variable_names_whitelist is not None and
|
||||
variable_name not in variable_names_whitelist) or
|
||||
(variable_names_blacklist is not None and
|
||||
variable_name in variable_names_blacklist)):
|
||||
if not _should_convert(
|
||||
variable_name, variable_names_whitelist, variable_names_blacklist):
|
||||
continue
|
||||
variable_dict_names.append(variable_name)
|
||||
if node.op == "VarHandleOp":
|
||||
variable_names.append(variable_name + "/Read/ReadVariableOp:0")
|
||||
else:
|
||||
variable_names.append(variable_name + ":0")
|
||||
elif node.op in ["ReadVariableOp", "ResourceGather"]:
|
||||
elif node.op in ["ReadVariableOp", "ResourceGather", "ResourceGatherNd"]:
|
||||
# There can be one or more Identity or control flow ops in between the
|
||||
# ReadVariableOp and VarHandleOp. Store the ops with the associated
|
||||
# dtypes.
|
||||
source_op_names = [get_input_name(node)]
|
||||
candidate_resource_op_types = {}
|
||||
while (source_op_names and map_name_to_node[source_op_names[0]].op in
|
||||
_CONTROL_FLOW_OP_NAMES_OR_IDENTITY):
|
||||
source_op_name = source_op_names.pop()
|
||||
current_node = map_name_to_node[source_op_name]
|
||||
|
||||
if source_op_name not in resource_op_types:
|
||||
resource_op_types[source_op_name] = node.attr["dtype"]
|
||||
if (source_op_name not in resource_op_types and
|
||||
source_op_name not in candidate_resource_op_types):
|
||||
candidate_resource_op_types[source_op_name] = node.attr["dtype"]
|
||||
source_op_names.append(get_input_name(current_node))
|
||||
|
||||
if current_node == "Merge":
|
||||
merge_resource_name = get_input_name(current_node, index=1)
|
||||
if merge_resource_name not in resource_op_types:
|
||||
resource_op_types[merge_resource_name] = node.attr["dtype"]
|
||||
if (merge_resource_name not in resource_op_types
|
||||
and merge_resource_name not in candidate_resource_op_types):
|
||||
candidate_resource_op_types[merge_resource_name] = (
|
||||
node.attr["dtype"])
|
||||
source_op_names.append(
|
||||
get_input_name(map_name_to_node[merge_resource_name]))
|
||||
|
||||
should_convert_all = None
|
||||
for source_node in source_op_names:
|
||||
if map_name_to_node[source_node].op != "VarHandleOp":
|
||||
raise ValueError("Cannot find the variable that is an input "
|
||||
"to the ReadVariableOp.")
|
||||
should_convert_node = _should_convert(
|
||||
source_node, variable_names_whitelist, variable_names_blacklist)
|
||||
if should_convert_all is None:
|
||||
should_convert_all = should_convert_node
|
||||
elif should_convert_all != should_convert_node:
|
||||
raise RuntimeError(
|
||||
"Found DT_RESOURCE node whose ancestor Variables are both "
|
||||
"blacklisted AND whitelisted for freezing. Originating "
|
||||
"descendant node: {}. Ancestor variables: {}.".format(
|
||||
node.name, source_op_names))
|
||||
if should_convert_all in (None, True):
|
||||
resource_op_types.update(candidate_resource_op_types)
|
||||
|
||||
# Gets map of variables and the associated data.
|
||||
if variable_names:
|
||||
@ -391,6 +430,15 @@ def convert_variables_to_constants(sess,
|
||||
variables_data_map = dict(zip(variable_dict_names, returned_variables))
|
||||
logging.info("Froze %d variables.", len(returned_variables))
|
||||
|
||||
def _should_convert_ancestor(node):
|
||||
input_node = map_name_to_node[_node_name(node.input[0])]
|
||||
while (input_node.op in _CONTROL_FLOW_OP_NAMES_OR_IDENTITY and
|
||||
input_node.attr["T"].type == dtypes.resource):
|
||||
input_node = map_name_to_node[_node_name(input_node.input[0])]
|
||||
return _should_convert(input_node.name,
|
||||
variable_names_whitelist,
|
||||
variable_names_blacklist)
|
||||
|
||||
# Reconstruct the graph with constants in place of variables.
|
||||
output_graph_def = graph_pb2.GraphDef()
|
||||
how_many_converted = 0
|
||||
@ -413,7 +461,8 @@ def convert_variables_to_constants(sess,
|
||||
if str(attr_name) != "_output_shapes":
|
||||
output_node.attr[attr_name].CopyFrom(input_node.attr[attr_name])
|
||||
output_node.attr["T"].CopyFrom(resource_op_types[input_node.name])
|
||||
elif input_node.op == "ReadVariableOp":
|
||||
elif (input_node.op == "ReadVariableOp"
|
||||
and _should_convert_ancestor(input_node)):
|
||||
# The first branch converts all VarHandleOps of ResourceVariables to
|
||||
# constants, so we need to convert the associated ReadVariableOps to
|
||||
# Identity ops.
|
||||
@ -423,7 +472,8 @@ def convert_variables_to_constants(sess,
|
||||
output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
|
||||
if "_class" in input_node.attr:
|
||||
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
|
||||
elif input_node.op == "ResourceGather":
|
||||
elif (input_node.op == "ResourceGather"
|
||||
and _should_convert_ancestor(input_node)):
|
||||
# The first branch converts all VarHandleOps of ResourceGather to
|
||||
# constants, so we need to convert the associated ResourceGather to Gather
|
||||
# ops with a Const axis feeding into it.
|
||||
@ -444,9 +494,19 @@ def convert_variables_to_constants(sess,
|
||||
output_node.attr["Taxis"].CopyFrom(axis_dtype)
|
||||
if "_class" in input_node.attr:
|
||||
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
|
||||
elif (input_node.op == "ResourceGatherNd"
|
||||
and _should_convert_ancestor(input_node)):
|
||||
output_node.op = "GatherNd"
|
||||
output_node.name = input_node.name
|
||||
output_node.input.extend(
|
||||
[input_node.input[0], input_node.input[1]])
|
||||
output_node.attr["Tparams"].CopyFrom(input_node.attr["dtype"])
|
||||
output_node.attr["Tindices"].CopyFrom(input_node.attr["Tindices"])
|
||||
if "_class" in input_node.attr:
|
||||
output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
|
||||
else:
|
||||
output_node.CopyFrom(input_node)
|
||||
output_graph_def.node.extend([output_node])
|
||||
output_graph_def.node.append(output_node)
|
||||
|
||||
# Update the types of the DT_RESOURCE Identity nodes that do not have an
|
||||
# associated ReadVariableOp.
|
||||
@ -455,7 +515,10 @@ def convert_variables_to_constants(sess,
|
||||
if node.op == "Identity" and node.attr["T"].type == dtypes.resource:
|
||||
resource_identities.append(node)
|
||||
if resource_identities:
|
||||
_update_resource_identities(resource_identities, output_graph_def)
|
||||
_update_resource_identities(resource_identities,
|
||||
output_graph_def,
|
||||
variable_names_whitelist,
|
||||
variable_names_blacklist)
|
||||
|
||||
output_graph_def.library.CopyFrom(inference_graph.library)
|
||||
logging.info("Converted %d variables to const ops.", how_many_converted)
|
||||
|
@ -36,6 +36,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_state_ops
|
||||
@ -418,6 +419,53 @@ class ConvertVariablesToConstantsTest(test.TestCase):
|
||||
output = self.evaluate(output_node)
|
||||
self.assertNear(2.0, output, 0.00001)
|
||||
|
||||
def test_resource_variable_can_be_written_after_blacklisting(self):
|
||||
with ops.Graph().as_default():
|
||||
with variable_scope.variable_scope("", use_resource=True):
|
||||
variable_node = variable_scope.get_variable(
|
||||
"variable_node", initializer=1.0)
|
||||
another_variable = variable_scope.get_variable(
|
||||
"unused_variable_node", initializer=2.0)
|
||||
with ops.control_dependencies([
|
||||
variable_node.assign(another_variable + variable_node)]):
|
||||
output_node = array_ops.identity(variable_node, name="output_node")
|
||||
initializer_name = variable_node.initializer.name
|
||||
with session.Session() as sess:
|
||||
self.evaluate(variable_node.initializer)
|
||||
self.evaluate(another_variable.initializer)
|
||||
output = self.evaluate(output_node)
|
||||
self.assertNear(3.0, output, 0.00001)
|
||||
variable_graph_def = sess.graph.as_graph_def()
|
||||
|
||||
# Test variable name black list. This should result in the variable
|
||||
# not being a const. Furthermore, the paths that read from and assign
|
||||
# to the blacklisted variable should continue to be valid.
|
||||
constant_graph_def_with_blacklist = (
|
||||
graph_util.convert_variables_to_constants(
|
||||
sess,
|
||||
variable_graph_def, ["output_node", initializer_name],
|
||||
variable_names_blacklist=set(["variable_node"])))
|
||||
|
||||
variable_node = None
|
||||
for node in constant_graph_def_with_blacklist.node:
|
||||
if node.name == "variable_node":
|
||||
variable_node = node
|
||||
self.assertIsNotNone(variable_node)
|
||||
self.assertEqual(variable_node.op, "VarHandleOp")
|
||||
|
||||
# Now we make sure another_variable is now a constant, but the original
|
||||
# variable is not, and that the graph can be executed and update the
|
||||
# variable can be updated with each execution.
|
||||
with ops.Graph().as_default():
|
||||
_ = importer.import_graph_def(constant_graph_def_with_blacklist, name="")
|
||||
with session.Session() as sess:
|
||||
output_node = sess.graph.get_tensor_by_name("output_node:0")
|
||||
self.evaluate(sess.graph.get_operation_by_name(initializer_name))
|
||||
output = self.evaluate(output_node)
|
||||
self.assertNear(3.0, output, 0.00001)
|
||||
output = self.evaluate(output_node)
|
||||
self.assertNear(5.0, output, 0.00001)
|
||||
|
||||
def _inline_functions(self, graph_def, arrays):
|
||||
meta_graph = export_meta_graph(graph_def=graph_def)
|
||||
fetch_collection = meta_graph_pb2.CollectionDef()
|
||||
|
@ -322,9 +322,19 @@ py_library(
|
||||
srcs = ["saved_model_cli.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":saved_model_aot_compile",
|
||||
":saved_model_utils",
|
||||
"//tensorflow/python",
|
||||
"//tensorflow/python/debug:local_cli_wrapper",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "saved_model_aot_compile",
|
||||
srcs = ["saved_model_aot_compile.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python",
|
||||
"//tensorflow/python:tf_optimizer",
|
||||
] + if_xla_available(
|
||||
["//tensorflow/compiler/tf2xla:tf2xla_proto_py"],
|
||||
|
472
tensorflow/python/tools/saved_model_aot_compile.py
Normal file
472
tensorflow/python/tools/saved_model_aot_compile.py
Normal file
@ -0,0 +1,472 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Helper utilities for AOT compilation."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
import pipes
|
||||
import shlex
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import sysconfig as sysconfig_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
try:
|
||||
from tensorflow.python import _pywrap_tfcompile # pylint: disable=g-import-not-at-top
|
||||
except ImportError as e:
|
||||
_pywrap_tfcompile_import_error = ImportError(
|
||||
'Unable to import _pywrap_tfcompile; you must build TensorFlow '
|
||||
'with XLA. You may need to build tensorflow with flag '
|
||||
'--define=with_xla_support=true. Original error: {}'.format(str(e)))
|
||||
else:
|
||||
_pywrap_tfcompile_import_error = None
|
||||
|
||||
|
||||
_READ_ONLY_VARIABLE_OPS = (
|
||||
'ReadVariableOp',
|
||||
'IsVariableInitializedOp',
|
||||
'ResourceGather',
|
||||
'ResourceGatherNd',
|
||||
'VariableShape',
|
||||
)
|
||||
|
||||
_PASS_THROUGH_VARIABLE_OPS = ('Identity', 'IdentityN')
|
||||
|
||||
|
||||
def _shlex_quote(s):
|
||||
if six.PY2:
|
||||
return pipes.quote(s)
|
||||
else:
|
||||
return shlex.quote(s)
|
||||
|
||||
|
||||
def _sysconfig_module():
|
||||
"""Load tf.sysconfig if available and working (i.e., inside a pip package)."""
|
||||
try:
|
||||
_ = sysconfig_lib.get_include()
|
||||
except ImportError:
|
||||
return None
|
||||
return sysconfig_lib
|
||||
|
||||
|
||||
def _parse_tensor_name(name):
|
||||
"""Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
|
||||
if ':' in name and not name.endswith(':'):
|
||||
node_name = name[:name.rfind(':')]
|
||||
output_slot = int(name[name.rfind(':') + 1:])
|
||||
return node_name, output_slot
|
||||
else:
|
||||
return name, None
|
||||
|
||||
|
||||
_XLA_MAKEFILE_TEMPLATE = """
|
||||
INC = -I{tensorflow_includes}
|
||||
LIB = -L{compiled_dir}
|
||||
CXXFLAGS = {cxx_flags}
|
||||
"""
|
||||
|
||||
|
||||
def _xla_makefile_string(output_prefix):
|
||||
"""Returns a Makefile string with variables for using XLA binary object files.
|
||||
|
||||
Attempts to identify the right include header paths when run from either
|
||||
an installed TensorFlow pip package, or from bazel run.
|
||||
|
||||
Args:
|
||||
output_prefix: A string containing the output prefix for the XLA AOT
|
||||
compiled header + object files.
|
||||
|
||||
Returns:
|
||||
A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
|
||||
"""
|
||||
sysconfig = _sysconfig_module()
|
||||
output_dir, _ = os.path.split(output_prefix)
|
||||
if sysconfig:
|
||||
tensorflow_includes = _shlex_quote(sysconfig.get_include())
|
||||
else:
|
||||
# Try hard to find the real source directory if this is a local bazel run.
|
||||
if os.path.islink(__file__):
|
||||
this_file = __file__
|
||||
while os.path.islink(this_file):
|
||||
this_file = os.readlink(this_file)
|
||||
base = os.path.realpath(
|
||||
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
|
||||
else:
|
||||
try:
|
||||
base = test.test_src_dir_path('')
|
||||
except KeyError: # Can't find TEST_SRCDIR in environment path.
|
||||
base = os.path.realpath(
|
||||
os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
|
||||
expected_header = os.path.join(
|
||||
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
|
||||
if not os.path.exists(expected_header):
|
||||
logging.error(
|
||||
'Could not find includes path. Missing file: {}'
|
||||
.format(expected_header))
|
||||
tensorflow_includes = base
|
||||
|
||||
return _XLA_MAKEFILE_TEMPLATE.format(
|
||||
tensorflow_includes=tensorflow_includes,
|
||||
compiled_dir=_shlex_quote(output_dir),
|
||||
cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
|
||||
versions.CXX11_ABI_FLAG))
|
||||
|
||||
|
||||
def _get_variable_nodes_from_graph_def(graph_def):
|
||||
"""Get the list of Variable nodes from `graph_def`.
|
||||
|
||||
Args:
|
||||
graph_def: An instance of `GraphDef`. This GraphDef *must*
|
||||
have already been optimized by Grappler. In particular, function
|
||||
inlining must have already happened.
|
||||
|
||||
Returns:
|
||||
A dict mapping string names of variables to tuples `(node_def, modified)`,
|
||||
where `node_def` is the `NodeDef` corresponding to variable, and `modified`
|
||||
is a python bool describing whether the variable is modified during runtime.
|
||||
"""
|
||||
variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
|
||||
variable_name_map = dict((n.name, n) for n in variables)
|
||||
child_map = collections.defaultdict(lambda: [])
|
||||
for n in graph_def.node:
|
||||
for inp in n.input:
|
||||
if not inp.startswith('^'):
|
||||
child_map[inp].append(n)
|
||||
variables = {}
|
||||
for (v_name, v_node) in variable_name_map.items():
|
||||
queue = list(child_map[v_name])
|
||||
processed = set([])
|
||||
while queue:
|
||||
n_current = queue.pop()
|
||||
if n_current.name in processed:
|
||||
continue
|
||||
processed.add(n_current.name)
|
||||
if n_current.op in _PASS_THROUGH_VARIABLE_OPS:
|
||||
children = child_map.get(n_current.name, [])
|
||||
queue.extend(children)
|
||||
elif n_current.op not in _READ_ONLY_VARIABLE_OPS:
|
||||
variables[v_name] = (v_node, True)
|
||||
queue = []
|
||||
if v_name not in variables:
|
||||
variables[v_name] = (v_node, False)
|
||||
|
||||
return variables
|
||||
|
||||
|
||||
def _prune_removed_feed_nodes(signature_def, graph_def):
|
||||
"""Identify the inputs in the signature no longer in graph_def, prune them.
|
||||
|
||||
Args:
|
||||
signature_def: A `SignatureDef` instance.
|
||||
graph_def: A `GraphDef` instance.
|
||||
|
||||
Returns:
|
||||
A new pruned `SignatureDef`.
|
||||
"""
|
||||
node_names = set([n.name for n in graph_def.node])
|
||||
new_signature_def = meta_graph_pb2.SignatureDef()
|
||||
new_signature_def.CopyFrom(signature_def)
|
||||
for (k, v) in signature_def.inputs.items():
|
||||
tensor_name, _ = _parse_tensor_name(v.name)
|
||||
if tensor_name not in node_names:
|
||||
logging.warn(
|
||||
'Signature input key \'{}\', tensor name \'{}\', has been pruned '
|
||||
'while freezing the graph. Removing it from the compiled signatures.'
|
||||
.format(k, tensor_name))
|
||||
del new_signature_def.inputs[k]
|
||||
return new_signature_def
|
||||
|
||||
|
||||
def aot_compile_cpu_meta_graph_def(checkpoint_path,
|
||||
meta_graph_def,
|
||||
output_prefix,
|
||||
signature_def_key,
|
||||
cpp_class,
|
||||
target_triple,
|
||||
variables_to_feed=()):
|
||||
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
|
||||
|
||||
Use XLA AOT (`tfcompile`) to convert the given meta graph and
|
||||
signature into a header + object files. Also create an include makefile
|
||||
that helps identify the appropriate necessary include and library paths
|
||||
to incorporate these files into your C++ program.
|
||||
|
||||
The graph is always optimized with grappler, and optionally (by default)
|
||||
variables are frozen as constants, before compilation happens.
|
||||
|
||||
If the `freeze_graph` is `True`, all variables are embedded as constants
|
||||
into the graph and binary objects. If it is `False`, then the variable
|
||||
values become inputs and outputs of the compiled class and the C++
|
||||
caller must set these values manually.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Python string. Path to checkpoints/variables.
|
||||
meta_graph_def: Instance of `MetaGraphDef`.
|
||||
output_prefix: Python string. Path prefix for outputs.
|
||||
signature_def_key: String, the signature_def to use in the SavedModel.
|
||||
cpp_class: String, Name of output C++ class.
|
||||
target_triple: String, LLVM target triple.
|
||||
variables_to_feed: A list of strings, the variables that will be fed by the
|
||||
user; these won't be frozen. If `None`, then we will extract all the
|
||||
variables in the graph and mark them as to-feed. The default behavior is
|
||||
an empty tuple: all variables must be frozen.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If tensorflow was not built with XLA.
|
||||
ImportError: If tensorflow was built with XLA but there was another
|
||||
issue importing the tfcompile python wrapper.
|
||||
ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
|
||||
missing or has empty outputs.
|
||||
"""
|
||||
if _pywrap_tfcompile_import_error:
|
||||
raise _pywrap_tfcompile_import_error
|
||||
|
||||
signature_def_map = meta_graph_def.signature_def
|
||||
if signature_def_key not in signature_def_map:
|
||||
raise ValueError(
|
||||
'Unable to find signature_def key \'{}\' in signature def map. '
|
||||
'Available keys: {}'.format(
|
||||
signature_def_key,
|
||||
list(signature_def_map.keys())))
|
||||
signature_def = signature_def_map[signature_def_key]
|
||||
if not signature_def.outputs:
|
||||
raise ValueError(
|
||||
'Signature key {} must have outputs, but saw none:\n{}'.format(
|
||||
signature_def_key, str(signature_def)))
|
||||
|
||||
temp_dir = test.get_temp_dir()
|
||||
file_io.recursive_create_dir(temp_dir)
|
||||
if logging.get_verbosity() >= logging.INFO:
|
||||
original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb')
|
||||
with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
|
||||
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
|
||||
|
||||
# This updates graph_def in place.
|
||||
_replace_input_placeholders_with_default_values(
|
||||
meta_graph_def.graph_def, signature_def)
|
||||
|
||||
graph_def = _optimize_graph(meta_graph_def, signature_def)
|
||||
|
||||
all_variables = _get_variable_nodes_from_graph_def(graph_def)
|
||||
if variables_to_feed is None:
|
||||
variable_nodes_to_feed = list(all_variables.values())
|
||||
else:
|
||||
not_in_graph = set(variables_to_feed).difference(list(all_variables))
|
||||
if not_in_graph:
|
||||
raise ValueError(
|
||||
'Asked to feed variables that were not found in graph: {}. '
|
||||
'Variables contained in the graph: {}'.format(
|
||||
not_in_graph, list(all_variables)))
|
||||
variable_nodes_to_feed = [
|
||||
all_variables[name] for name in variables_to_feed
|
||||
]
|
||||
|
||||
if logging.get_verbosity() >= logging.INFO:
|
||||
prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb')
|
||||
with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
|
||||
graph_writer.write(graph_def.SerializeToString())
|
||||
|
||||
# Load the Variables so that we can freeze the graph.
|
||||
with session.Session(graph=ops_lib.Graph()) as sess:
|
||||
restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
|
||||
restorer.restore(sess, checkpoint_path)
|
||||
graph_def.CopyFrom(
|
||||
graph_util.convert_variables_to_constants(
|
||||
sess,
|
||||
graph_def,
|
||||
output_node_names=[
|
||||
_parse_tensor_name(n.name)[0]
|
||||
for n in signature_def.outputs.values()
|
||||
],
|
||||
variable_names_blacklist=[
|
||||
name for (name, node_modified) in all_variables.items()
|
||||
if node_modified[1]
|
||||
],
|
||||
))
|
||||
|
||||
signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
|
||||
|
||||
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
|
||||
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
|
||||
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
|
||||
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
|
||||
graph_writer.write(graph_def.SerializeToString())
|
||||
config = _signature_to_tf2xla_config(
|
||||
signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
|
||||
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
|
||||
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
|
||||
config_writer.write(str(config))
|
||||
|
||||
output_dir = os.path.dirname(output_prefix)
|
||||
file_io.recursive_create_dir(output_dir)
|
||||
|
||||
entry_digest = hashlib.md5()
|
||||
entry_digest.update(str(config).encode())
|
||||
entry_digest.update(str(graph_def).encode())
|
||||
entry_digest = entry_digest.hexdigest()
|
||||
|
||||
logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
|
||||
|
||||
makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
|
||||
with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
|
||||
makefile_writer.write(_xla_makefile_string(output_prefix))
|
||||
|
||||
output_prefix = _shlex_quote(output_prefix)
|
||||
|
||||
_pywrap_tfcompile.Compile(
|
||||
graph=frozen_graph_def_location,
|
||||
config=config_pbtxt_location,
|
||||
cpp_class=cpp_class,
|
||||
target_triple=target_triple,
|
||||
entry_point='entry_{}'.format(entry_digest),
|
||||
out_function_object='{}.o'.format(output_prefix),
|
||||
out_header='{}.h'.format(output_prefix),
|
||||
out_metadata_object='{}_metadata.o'.format(output_prefix),
|
||||
gen_name_to_index=True,
|
||||
# ProgramShape isn't uniquefied by entry_point.
|
||||
gen_program_shape=False)
|
||||
|
||||
|
||||
def _optimize_graph(meta_graph_def, signature_def):
|
||||
"""Optimize `meta_graph_def` using grappler. Returns a `GraphDef`."""
|
||||
# We need to add a collection called 'train_op' so that grappler
|
||||
# knows what the outputs are.
|
||||
new_meta_graph_def = copy.deepcopy(meta_graph_def)
|
||||
fetch_collection = meta_graph_pb2.CollectionDef()
|
||||
for tensor_info in (
|
||||
list(signature_def.inputs.values()) +
|
||||
list(signature_def.outputs.values())):
|
||||
fetch_collection.node_list.value.append(tensor_info.name)
|
||||
|
||||
new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
|
||||
|
||||
config = config_pb2.ConfigProto()
|
||||
return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
|
||||
|
||||
|
||||
def _replace_input_placeholders_with_default_values(graph_def, signature_def):
|
||||
"""Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
|
||||
name_to_node_map = dict((n.name, n) for n in graph_def.node)
|
||||
processed_nodes = set([])
|
||||
for name, input_ in signature_def.inputs.items():
|
||||
tensor_name, _ = _parse_tensor_name(input_.name)
|
||||
if tensor_name in processed_nodes:
|
||||
continue
|
||||
processed_nodes.add(tensor_name)
|
||||
if tensor_name not in name_to_node_map:
|
||||
raise RuntimeError(
|
||||
'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
|
||||
'Graph nodes are: {}'.format(tensor_name,
|
||||
list(name_to_node_map.keys())))
|
||||
node = name_to_node_map[tensor_name]
|
||||
if node.op not in ('Placeholder', 'PlaceholderV2'):
|
||||
logging.info(
|
||||
'Tried to convert SavedModel input node \'{}\' from a placeholder, '
|
||||
'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
|
||||
node))
|
||||
continue
|
||||
shape = tensor_shape.TensorShape(input_.tensor_shape)
|
||||
if not shape.is_fully_defined():
|
||||
raise ValueError(
|
||||
'Expected fully defined input shape for signature_def \'{}\', '
|
||||
'tensor name: \'{}\'; but shape is: {}.'
|
||||
.format(name, tensor_name, shape))
|
||||
temp_graph = ops_lib.Graph()
|
||||
with temp_graph.as_default():
|
||||
const = array_ops.zeros(
|
||||
shape, dtype=input_.dtype, name=tensor_name)
|
||||
node.CopyFrom(const.op.node_def)
|
||||
# Sometimes zeros() also creates additional nodes
|
||||
for op in temp_graph.get_operations():
|
||||
if op.name == const.op.name: # We just inserted this one.
|
||||
continue
|
||||
graph_def.node.append(op.node_def)
|
||||
name_to_node_map[op.node_def.name] = op.node_def
|
||||
|
||||
|
||||
def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
|
||||
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
|
||||
|
||||
Args:
|
||||
signature_def: Instance of `SignatureDef`.
|
||||
variable_nodes_to_feed: List of tuples of form `(node_def, modified)`
|
||||
corresponding to VarHandleOp, and a boolean `modified` that describes
|
||||
whether the variable was modified during execution.
|
||||
|
||||
Returns:
|
||||
An instance of `tf2xla.Config` proto.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If TensorFlow was not compiled with XLA.
|
||||
"""
|
||||
from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top
|
||||
|
||||
config = tf2xla_pb2.Config()
|
||||
tensor_id = tf2xla_pb2.TensorId
|
||||
|
||||
for name, input_ in signature_def.inputs.items():
|
||||
name = name.replace('/', '_')
|
||||
name = 'feed_{}'.format(name)
|
||||
(node_name, output_index) = _parse_tensor_name(input_.name)
|
||||
output_index = int(output_index)
|
||||
config.feed.append(
|
||||
tf2xla_pb2.Feed(
|
||||
id=tensor_id(node_name=node_name, output_index=output_index),
|
||||
name=name,
|
||||
type=input_.dtype,
|
||||
shape=input_.tensor_shape))
|
||||
for name, output_ in signature_def.outputs.items():
|
||||
name = name.replace('/', '_')
|
||||
name = 'fetch_{}'.format(name)
|
||||
(node_name, output_index) = _parse_tensor_name(output_.name)
|
||||
output_index = int(output_index)
|
||||
config.fetch.append(
|
||||
tf2xla_pb2.Fetch(
|
||||
id=tensor_id(node_name=node_name, output_index=output_index),
|
||||
name=name,
|
||||
type=output_.dtype,
|
||||
shape=output_.tensor_shape))
|
||||
for (node, modified) in variable_nodes_to_feed:
|
||||
name = node.name.replace('/', '_')
|
||||
name = 'param_{}'.format(name)
|
||||
config.variable.append(
|
||||
tf2xla_pb2.Variable(
|
||||
node_name=node.name,
|
||||
name=name,
|
||||
type=node.attr['dtype'].type,
|
||||
shape=node.attr['shape'].shape,
|
||||
readonly=not modified))
|
||||
|
||||
return config
|
@ -25,12 +25,8 @@ from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import copy
|
||||
import hashlib
|
||||
import os
|
||||
import pipes
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
@ -38,31 +34,22 @@ import six
|
||||
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import meta_graph_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.debug.wrappers import local_cli_wrapper
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.eager import function as defun
|
||||
from tensorflow.python.framework import graph_util
|
||||
from tensorflow.python.framework import meta_graph as meta_graph_lib
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.grappler import tf_optimizer
|
||||
from tensorflow.python.lib.io import file_io
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import app # pylint: disable=unused-import
|
||||
from tensorflow.python.platform import sysconfig as sysconfig_lib
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.saved_model import load
|
||||
from tensorflow.python.saved_model import loader
|
||||
from tensorflow.python.saved_model import save
|
||||
from tensorflow.python.saved_model import signature_constants
|
||||
from tensorflow.python.tools import saved_model_aot_compile
|
||||
from tensorflow.python.tools import saved_model_utils
|
||||
from tensorflow.python.training import saver as saver_lib
|
||||
|
||||
|
||||
_XLA_DEBUG_OPTIONS_URL = (
|
||||
@ -70,100 +57,10 @@ _XLA_DEBUG_OPTIONS_URL = (
|
||||
'tensorflow/compiler/xla/debug_options_flags.cc')
|
||||
|
||||
|
||||
try:
|
||||
from tensorflow.python import _pywrap_tfcompile # pylint: disable=g-import-not-at-top
|
||||
except ImportError as e:
|
||||
_pywrap_tfcompile_import_error = ImportError(
|
||||
'Unable to import _pywrap_tfcompile; you must build TensorFlow '
|
||||
'with XLA. You may need to build tensorflow with flag '
|
||||
'--define=with_xla_support=true. Original error: {}'.format(str(e)))
|
||||
else:
|
||||
_pywrap_tfcompile_import_error = None
|
||||
|
||||
|
||||
# Set of ops to blacklist.
|
||||
_OP_BLACKLIST = set(['WriteFile', 'ReadFile', 'PrintV2'])
|
||||
|
||||
|
||||
def _shlex_quote(s):
|
||||
if six.PY2:
|
||||
return pipes.quote(s)
|
||||
else:
|
||||
return shlex.quote(s)
|
||||
|
||||
|
||||
def _sysconfig_module():
|
||||
"""Load tf.sysconfig if available and working (i.e., inside a pip package)."""
|
||||
try:
|
||||
_ = sysconfig_lib.get_include()
|
||||
except ImportError:
|
||||
return None
|
||||
return sysconfig_lib
|
||||
|
||||
|
||||
def _parse_tensor_name(name):
|
||||
"""Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
|
||||
if ':' in name and not name.endswith(':'):
|
||||
node_name = name[:name.rfind(':')]
|
||||
output_slot = int(name[name.rfind(':') + 1:])
|
||||
return node_name, output_slot
|
||||
else:
|
||||
return name, None
|
||||
|
||||
|
||||
_XLA_MAKEFILE_TEMPLATE = """
|
||||
INC = -I{tensorflow_includes}
|
||||
LIB = -L{compiled_dir}
|
||||
CXXFLAGS = {cxx_flags}
|
||||
"""
|
||||
|
||||
|
||||
def _xla_makefile_string(output_prefix):
|
||||
"""Returns a Makefile string with variables for using XLA binary object files.
|
||||
|
||||
Attempts to identify the right include header paths when run from either
|
||||
an installed TensorFlow pip package, or from bazel run.
|
||||
|
||||
Args:
|
||||
output_prefix: A string containing the output prefix for the XLA AOT
|
||||
compiled header + object files.
|
||||
|
||||
Returns:
|
||||
A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
|
||||
"""
|
||||
sysconfig = _sysconfig_module()
|
||||
output_dir, _ = os.path.split(output_prefix)
|
||||
if sysconfig:
|
||||
tensorflow_includes = _shlex_quote(sysconfig.get_include())
|
||||
else:
|
||||
# Try hard to find the real source directory if this is a local bazel run.
|
||||
if os.path.islink(__file__):
|
||||
this_file = __file__
|
||||
while os.path.islink(this_file):
|
||||
this_file = os.readlink(this_file)
|
||||
base = os.path.realpath(
|
||||
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
|
||||
else:
|
||||
try:
|
||||
base = test.test_src_dir_path('')
|
||||
except KeyError: # Can't find TEST_SRCDIR in environment path.
|
||||
base = os.path.realpath(
|
||||
os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
|
||||
expected_header = os.path.join(
|
||||
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
|
||||
if not os.path.exists(expected_header):
|
||||
logging.error(
|
||||
'Could not find includes path. Missing file: {}'
|
||||
.format(expected_header))
|
||||
tensorflow_includes = base
|
||||
|
||||
return _XLA_MAKEFILE_TEMPLATE.format(
|
||||
tensorflow_includes=tensorflow_includes,
|
||||
compiled_dir=_shlex_quote(output_dir),
|
||||
cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
|
||||
versions.CXX11_ABI_FLAG))
|
||||
|
||||
|
||||
def _show_tag_sets(saved_model_dir):
|
||||
"""Prints the tag-sets stored in SavedModel directory.
|
||||
|
||||
@ -178,47 +75,6 @@ def _show_tag_sets(saved_model_dir):
|
||||
print('%r' % ', '.join(sorted(tag_set)))
|
||||
|
||||
|
||||
def _get_variable_nodes_from_graph_def(graph_def):
|
||||
"""Get the list of Variable nodes from `graph_def`.
|
||||
|
||||
Args:
|
||||
graph_def: An instance of `GraphDef`.
|
||||
|
||||
Returns:
|
||||
A list of `NodeDef` corresponding to variables in the graph.
|
||||
"""
|
||||
variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
|
||||
|
||||
for f in graph_def.library.function:
|
||||
variables += [n for n in f.node_def if n.op == 'VarHandleOp']
|
||||
|
||||
return variables
|
||||
|
||||
|
||||
def _prune_removed_feed_nodes(signature_def, graph_def):
|
||||
"""Identify the inputs in the signature no longer in graph_def, prune them.
|
||||
|
||||
Args:
|
||||
signature_def: A `SignatureDef` instance.
|
||||
graph_def: A `GraphDef` instance.
|
||||
|
||||
Returns:
|
||||
A new pruned `SignatureDef`.
|
||||
"""
|
||||
node_names = set([n.name for n in graph_def.node])
|
||||
new_signature_def = meta_graph_pb2.SignatureDef()
|
||||
new_signature_def.CopyFrom(signature_def)
|
||||
for (k, v) in signature_def.inputs.items():
|
||||
tensor_name, _ = _parse_tensor_name(v.name)
|
||||
if tensor_name not in node_names:
|
||||
logging.warn(
|
||||
'Signature input key \'{}\', tensor name \'{}\', has been pruned '
|
||||
'while freezing the graph. Removing it from the compiled signatures.'
|
||||
.format(k, tensor_name))
|
||||
del new_signature_def.inputs[k]
|
||||
return new_signature_def
|
||||
|
||||
|
||||
def _show_signature_def_map_keys(saved_model_dir, tag_set):
|
||||
"""Prints the keys for each SignatureDef in the SignatureDef map.
|
||||
|
||||
@ -943,7 +799,7 @@ def aot_compile_cpu(args):
|
||||
variables_to_feed = None # We will identify them after.
|
||||
else:
|
||||
variables_to_feed = args.variables_to_feed.split(',')
|
||||
aot_compile_cpu_meta_graph_def(
|
||||
saved_model_aot_compile.aot_compile_cpu_meta_graph_def(
|
||||
checkpoint_path=checkpoint_path,
|
||||
meta_graph_def=saved_model_utils.get_meta_graph_def(
|
||||
args.dir, args.tag_set),
|
||||
@ -954,210 +810,6 @@ def aot_compile_cpu(args):
|
||||
cpp_class=args.cpp_class)
|
||||
|
||||
|
||||
def aot_compile_cpu_meta_graph_def(checkpoint_path,
|
||||
meta_graph_def,
|
||||
output_prefix,
|
||||
signature_def_key,
|
||||
cpp_class,
|
||||
target_triple,
|
||||
variables_to_feed=()):
|
||||
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
|
||||
|
||||
Use XLA AOT (`tfcompile`) to convert the given meta graph and
|
||||
signature into a header + object files. Also create an include makefile
|
||||
that helps identify the appropriate necessary include and library paths
|
||||
to incorporate these files into your C++ program.
|
||||
|
||||
The graph is always optimized with grappler, and optionally (by default)
|
||||
variables are frozen as constants, before compilation happens.
|
||||
|
||||
If the `freeze_graph` is `True`, all variables are embedded as constants
|
||||
into the graph and binary objects. If it is `False`, then the variable
|
||||
values become inputs and outputs of the compiled class and the C++
|
||||
caller must set these values manually.
|
||||
|
||||
Args:
|
||||
checkpoint_path: Python string. Path to checkpoints/variables.
|
||||
meta_graph_def: Instance of `MetaGraphDef`.
|
||||
output_prefix: Python string. Path prefix for outputs.
|
||||
signature_def_key: String, the signature_def to use in the SavedModel.
|
||||
cpp_class: String, Name of output C++ class.
|
||||
target_triple: String, LLVM target triple.
|
||||
variables_to_feed: A list of strings, the variables that will be fed by the
|
||||
user; these won't be frozen. If `None`, then we will extract all the
|
||||
variables in the graph and mark them as to-feed. The default behavior is
|
||||
an empty tuple: all variables must be frozen.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If tensorflow was not built with XLA.
|
||||
ImportError: If tensorflow was built with XLA but there was another
|
||||
issue importing the tfcompile python wrapper.
|
||||
ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
|
||||
missing or has empty outputs.
|
||||
"""
|
||||
if _pywrap_tfcompile_import_error:
|
||||
raise _pywrap_tfcompile_import_error
|
||||
|
||||
signature_def_map = meta_graph_def.signature_def
|
||||
if signature_def_key not in signature_def_map:
|
||||
raise ValueError(
|
||||
'Unable to find signature_def key \'{}\' in signature def map. '
|
||||
'Available keys: {}'.format(
|
||||
signature_def_key,
|
||||
list(signature_def_map.keys())))
|
||||
signature_def = signature_def_map[signature_def_key]
|
||||
if not signature_def.outputs:
|
||||
raise ValueError(
|
||||
'Signature key {} must have outputs, but saw none:\n{}'.format(
|
||||
signature_def_key, str(signature_def)))
|
||||
|
||||
temp_dir = test.get_temp_dir()
|
||||
file_io.recursive_create_dir(temp_dir)
|
||||
if logging.get_verbosity() >= logging.INFO:
|
||||
original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb')
|
||||
with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
|
||||
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
|
||||
|
||||
# This updates graph_def in place.
|
||||
_replace_input_placeholders_with_default_values(
|
||||
meta_graph_def.graph_def, signature_def)
|
||||
graph_def = _optimize_graph(meta_graph_def, signature_def)
|
||||
|
||||
all_variables = _get_variable_nodes_from_graph_def(graph_def)
|
||||
if variables_to_feed is None:
|
||||
variable_nodes_to_feed = list(all_variables)
|
||||
else:
|
||||
not_in_graph = (
|
||||
set(variables_to_feed).difference([x.name for x in all_variables]))
|
||||
if not_in_graph:
|
||||
raise ValueError(
|
||||
'Asked to feed variables that were not found in graph: {}. '
|
||||
'Variables contained in the graph: {}'.format(
|
||||
not_in_graph, set([x.name for x in all_variables])))
|
||||
all_variables_map = dict((x.name, x) for x in all_variables)
|
||||
variable_nodes_to_feed = [
|
||||
all_variables_map[name] for name in variables_to_feed
|
||||
]
|
||||
|
||||
if logging.get_verbosity() >= logging.INFO:
|
||||
prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb')
|
||||
with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
|
||||
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
|
||||
|
||||
# Load the Variables so that we can freeze the graph.
|
||||
with session.Session(graph=ops_lib.Graph()) as sess:
|
||||
restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
|
||||
restorer.restore(sess, checkpoint_path)
|
||||
graph_def.CopyFrom(
|
||||
graph_util.convert_variables_to_constants(
|
||||
sess,
|
||||
graph_def,
|
||||
output_node_names=[
|
||||
_parse_tensor_name(n.name)[0]
|
||||
for n in signature_def.outputs.values()
|
||||
],
|
||||
))
|
||||
|
||||
signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
|
||||
|
||||
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
|
||||
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
|
||||
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
|
||||
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
|
||||
graph_writer.write(graph_def.SerializeToString())
|
||||
config = _signature_to_tf2xla_config(
|
||||
signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
|
||||
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
|
||||
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
|
||||
config_writer.write(str(config))
|
||||
|
||||
output_dir = os.path.dirname(output_prefix)
|
||||
file_io.recursive_create_dir(output_dir)
|
||||
|
||||
entry_digest = hashlib.md5()
|
||||
entry_digest.update(str(config).encode())
|
||||
entry_digest.update(str(graph_def).encode())
|
||||
entry_digest = entry_digest.hexdigest()
|
||||
|
||||
logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
|
||||
|
||||
makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
|
||||
with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
|
||||
makefile_writer.write(_xla_makefile_string(output_prefix))
|
||||
|
||||
output_prefix = _shlex_quote(output_prefix)
|
||||
|
||||
_pywrap_tfcompile.Compile(
|
||||
graph=frozen_graph_def_location,
|
||||
config=config_pbtxt_location,
|
||||
cpp_class=cpp_class,
|
||||
target_triple=target_triple,
|
||||
entry_point='entry_{}'.format(entry_digest),
|
||||
out_function_object='{}.o'.format(output_prefix),
|
||||
out_header='{}.h'.format(output_prefix),
|
||||
out_metadata_object='{}_metadata.o'.format(output_prefix),
|
||||
gen_name_to_index=True,
|
||||
# ProgramShape isn't uniquefied by entry_point.
|
||||
gen_program_shape=False)
|
||||
|
||||
|
||||
def _optimize_graph(meta_graph_def, signature_def):
|
||||
"""Optimize `meta_graph_def` using grappler. Returns a `GraphDef`."""
|
||||
# We need to add a collection called 'train_op' so that grappler
|
||||
# knows what the outputs are.
|
||||
new_meta_graph_def = copy.deepcopy(meta_graph_def)
|
||||
fetch_collection = meta_graph_pb2.CollectionDef()
|
||||
for tensor_info in (
|
||||
list(signature_def.inputs.values()) +
|
||||
list(signature_def.outputs.values())):
|
||||
fetch_collection.node_list.value.append(tensor_info.name)
|
||||
|
||||
new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
|
||||
|
||||
config = config_pb2.ConfigProto()
|
||||
return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
|
||||
|
||||
|
||||
def _replace_input_placeholders_with_default_values(graph_def, signature_def):
|
||||
"""Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
|
||||
name_to_node_map = dict((n.name, n) for n in graph_def.node)
|
||||
processed_nodes = set([])
|
||||
for name, input_ in signature_def.inputs.items():
|
||||
tensor_name, _ = _parse_tensor_name(input_.name)
|
||||
if tensor_name in processed_nodes:
|
||||
continue
|
||||
processed_nodes.add(tensor_name)
|
||||
if tensor_name not in name_to_node_map:
|
||||
raise RuntimeError(
|
||||
'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
|
||||
'Graph nodes are: {}'.format(tensor_name,
|
||||
list(name_to_node_map.keys())))
|
||||
node = name_to_node_map[tensor_name]
|
||||
if node.op not in ('Placeholder', 'PlaceholderV2'):
|
||||
logging.info(
|
||||
'Tried to convert SavedModel input node \'{}\' from a placeholder, '
|
||||
'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
|
||||
node))
|
||||
continue
|
||||
shape = tensor_shape.TensorShape(input_.tensor_shape)
|
||||
if not shape.is_fully_defined():
|
||||
raise ValueError(
|
||||
'Expected fully defined input shape for signature_def \'{}\', '
|
||||
'tensor name: \'{}\'; but shape is: {}.'
|
||||
.format(name, tensor_name, shape))
|
||||
temp_graph = ops_lib.Graph()
|
||||
with temp_graph.as_default():
|
||||
const = array_ops.zeros(
|
||||
shape, dtype=input_.dtype, name=tensor_name)
|
||||
node.CopyFrom(const.op.node_def)
|
||||
# Sometimes zeros() also creates additional nodes
|
||||
for op in temp_graph.get_operations():
|
||||
if op.name == const.op.name: # We just inserted this one.
|
||||
continue
|
||||
graph_def.node.append(op.node_def)
|
||||
name_to_node_map[op.node_def.name] = op.node_def
|
||||
|
||||
|
||||
def add_show_subparser(subparsers):
|
||||
"""Add parser for `show`."""
|
||||
show_msg = (
|
||||
@ -1482,61 +1134,6 @@ def create_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
|
||||
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
|
||||
|
||||
Args:
|
||||
signature_def: Instance of `SignatureDef`.
|
||||
variable_nodes_to_feed: List NodeDefs corresponding to VarHandleOp,
|
||||
the list of variables to feed.
|
||||
|
||||
Returns:
|
||||
An instance of `tf2xla.Config` proto.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If TensorFlow was not compiled with XLA.
|
||||
"""
|
||||
from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top
|
||||
|
||||
config = tf2xla_pb2.Config()
|
||||
tensor_id = tf2xla_pb2.TensorId
|
||||
|
||||
for name, input_ in signature_def.inputs.items():
|
||||
name = name.replace('/', '_')
|
||||
name = 'feed_{}'.format(name)
|
||||
(node_name, output_index) = _parse_tensor_name(input_.name)
|
||||
output_index = int(output_index)
|
||||
config.feed.append(
|
||||
tf2xla_pb2.Feed(
|
||||
id=tensor_id(node_name=node_name, output_index=output_index),
|
||||
name=name,
|
||||
type=input_.dtype,
|
||||
shape=input_.tensor_shape))
|
||||
for name, output_ in signature_def.outputs.items():
|
||||
name = name.replace('/', '_')
|
||||
name = 'fetch_{}'.format(name)
|
||||
(node_name, output_index) = _parse_tensor_name(output_.name)
|
||||
output_index = int(output_index)
|
||||
config.fetch.append(
|
||||
tf2xla_pb2.Fetch(
|
||||
id=tensor_id(node_name=node_name, output_index=output_index),
|
||||
name=name,
|
||||
type=output_.dtype,
|
||||
shape=output_.tensor_shape))
|
||||
for node in variable_nodes_to_feed:
|
||||
name = node.name.replace('/', '_')
|
||||
name = 'param_{}'.format(name)
|
||||
config.variable.append(
|
||||
tf2xla_pb2.Variable(
|
||||
node_name=node.name,
|
||||
name=name,
|
||||
type=node.attr['dtype'].type,
|
||||
shape=node.attr['shape'].shape,
|
||||
readonly=True))
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def main():
|
||||
logging.set_verbosity(logging.INFO)
|
||||
parser = create_parser()
|
||||
|
@ -733,6 +733,7 @@ Defined Functions:
|
||||
|
||||
def __init__(self):
|
||||
self.var = variables.Variable(1.0, name='my_var')
|
||||
self.write_var = variables.Variable(1.0, name='write_var')
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=(2, 2), dtype=dtypes.float32),
|
||||
@ -752,20 +753,32 @@ Defined Functions:
|
||||
del y
|
||||
return {'res': x + self.var}
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
|
||||
tensor_spec.TensorSpec(shape=(), dtype=dtypes.float32),
|
||||
])
|
||||
def func_write(self, x, y):
|
||||
del y
|
||||
self.write_var.assign(x + self.var)
|
||||
return {'res': self.write_var}
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('VariablesToFeedNone', '', 'func2'),
|
||||
('VariablesToFeedAll', 'all', 'func2'),
|
||||
('VariablesToFeedMyVar', 'my_var', 'func2'),
|
||||
('VariablesToFeedNoneLargeConstant', '', 'func3'))
|
||||
('VariablesToFeedNoneLargeConstant', '', 'func3'),
|
||||
('WriteToWriteVar', 'all', 'func_write'),
|
||||
)
|
||||
def testAOTCompileCPUFreezesAndCompiles(self, variables_to_feed, func):
|
||||
if not test.is_built_with_xla():
|
||||
self.skipTest('Skipping test because XLA is not compiled in.')
|
||||
|
||||
saved_model_dir = os.path.join(test.get_temp_dir(), 'dummy_model')
|
||||
dummy_model = self.AOTCompileDummyModel()
|
||||
func = dummy_model.func2 if func == 'func2' else dummy_model.func3
|
||||
func = getattr(dummy_model, func)
|
||||
with self.cached_session():
|
||||
self.evaluate(dummy_model.var.initializer)
|
||||
self.evaluate(dummy_model.write_var.initializer)
|
||||
save.save(dummy_model, saved_model_dir, signatures={'func': func})
|
||||
|
||||
self.parser = saved_model_cli.create_parser()
|
||||
@ -793,7 +806,15 @@ Defined Functions:
|
||||
# arg_y got filtered out as it's not used by the output.
|
||||
self.assertNotIn('arg_feed_y_data', header_contents)
|
||||
if variables_to_feed:
|
||||
self.assertIn('var_param_my_var', header_contents)
|
||||
# Read-only-variables' setters preserve constness.
|
||||
self.assertIn('set_var_param_my_var_data(const float', header_contents)
|
||||
self.assertNotIn('set_var_param_my_var_data(float', header_contents)
|
||||
if func == dummy_model.func_write:
|
||||
# Writeable variables setters do not preserve constness.
|
||||
self.assertIn('set_var_param_write_var_data(float', header_contents)
|
||||
self.assertNotIn(
|
||||
'set_var_param_write_var_data(const float', header_contents)
|
||||
|
||||
makefile_contents = file_io.read_file_to_string(
|
||||
'{}_makefile.inc'.format(output_prefix))
|
||||
self.assertIn('-D_GLIBCXX_USE_CXX11_ABI=', makefile_contents)
|
||||
|
Loading…
Reference in New Issue
Block a user