This change allows the linkage of multithreaded XLA AOT CPU backend objects, such as multithreaded matmul, conv2d, etc. These are not enabled by default. New unit tests confirm that the objects are emitted and linked correctly, and the resulting computations are numerically correct. MKL service backend objects are not included. Other changes: * C++ Unit tests now use arg_feed_{x,y} instead of arg0/arg1, since the names are flaky (they may swap from the signature) * Add argument "multithreading=" to the bzl file and saved_model_cli. * Add unit tests using "nm" to ensure that the proper symbols are used when enabling or disabling multithreading (not sure if they are windows-friendly). * Use a simpler and more unique string for the entry_point string. PiperOrigin-RevId: 338112208 Change-Id: Id734e75e63e72db93a743f451ddb7eb6f489c1c7
492 lines
18 KiB
Python
492 lines
18 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Helper utilities for AOT compilation."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import collections
|
|
import copy
|
|
import os
|
|
import pipes
|
|
import re
|
|
import shlex
|
|
|
|
import six
|
|
|
|
from tensorflow.core.protobuf import config_pb2
|
|
from tensorflow.core.protobuf import meta_graph_pb2
|
|
from tensorflow.python.client import session
|
|
from tensorflow.python.framework import graph_util
|
|
from tensorflow.python.framework import ops as ops_lib
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import versions
|
|
from tensorflow.python.grappler import tf_optimizer
|
|
from tensorflow.python.lib.io import file_io
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.platform import sysconfig as sysconfig_lib
|
|
from tensorflow.python.platform import test
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
from tensorflow.python.training import saver as saver_lib
|
|
|
|
try:
|
|
from tensorflow.python import _pywrap_tfcompile # pylint: disable=g-import-not-at-top
|
|
except ImportError as e:
|
|
_pywrap_tfcompile_import_error = ImportError(
|
|
'Unable to import _pywrap_tfcompile; you must build TensorFlow '
|
|
'with XLA. You may need to build tensorflow with flag '
|
|
'--define=with_xla_support=true. Original error: {}'.format(str(e)))
|
|
else:
|
|
_pywrap_tfcompile_import_error = None
|
|
|
|
|
|
_READ_ONLY_VARIABLE_OPS = (
|
|
'ReadVariableOp',
|
|
'IsVariableInitializedOp',
|
|
'ResourceGather',
|
|
'ResourceGatherNd',
|
|
'VariableShape',
|
|
)
|
|
|
|
_PASS_THROUGH_VARIABLE_OPS = ('Identity', 'IdentityN')
|
|
|
|
|
|
def _shlex_quote(s):
|
|
if six.PY2:
|
|
return pipes.quote(s)
|
|
else:
|
|
return shlex.quote(s)
|
|
|
|
|
|
def _sysconfig_module():
|
|
"""Load tf.sysconfig if available and working (i.e., inside a pip package)."""
|
|
try:
|
|
_ = sysconfig_lib.get_include()
|
|
except ImportError:
|
|
return None
|
|
return sysconfig_lib
|
|
|
|
|
|
def _parse_tensor_name(name):
|
|
"""Convert a tensor name like 'tensor:0' into a tuple ('tensor', 0)."""
|
|
if ':' in name and not name.endswith(':'):
|
|
node_name = name[:name.rfind(':')]
|
|
output_slot = int(name[name.rfind(':') + 1:])
|
|
return node_name, output_slot
|
|
else:
|
|
return name, None
|
|
|
|
|
|
_XLA_MAKEFILE_TEMPLATE = """
|
|
INC = -I{tensorflow_includes}
|
|
LIB = -L{compiled_dir}
|
|
CXXFLAGS = {cxx_flags}
|
|
"""
|
|
|
|
|
|
def _xla_makefile_string(output_prefix):
|
|
"""Returns a Makefile string with variables for using XLA binary object files.
|
|
|
|
Attempts to identify the right include header paths when run from either
|
|
an installed TensorFlow pip package, or from bazel run.
|
|
|
|
Args:
|
|
output_prefix: A string containing the output prefix for the XLA AOT
|
|
compiled header + object files.
|
|
|
|
Returns:
|
|
A string containing a filled out `_XLA_MAKEFILE_TEMPLATE`.
|
|
"""
|
|
sysconfig = _sysconfig_module()
|
|
output_dir, _ = os.path.split(output_prefix)
|
|
if sysconfig:
|
|
tensorflow_includes = _shlex_quote(sysconfig.get_include())
|
|
else:
|
|
# Try hard to find the real source directory if this is a local bazel run.
|
|
if os.path.islink(__file__):
|
|
this_file = __file__
|
|
while os.path.islink(this_file):
|
|
this_file = os.readlink(this_file)
|
|
base = os.path.realpath(
|
|
os.path.join(os.path.dirname(this_file), *([os.path.pardir] * 3)))
|
|
else:
|
|
try:
|
|
base = test.test_src_dir_path('')
|
|
except KeyError: # Can't find TEST_SRCDIR in environment path.
|
|
base = os.path.realpath(
|
|
os.path.join(os.path.dirname(__file__), *([os.path.pardir] * 3)))
|
|
expected_header = os.path.join(
|
|
base, 'tensorflow', 'compiler', 'tf2xla', 'xla_compiled_cpu_function.h')
|
|
if not os.path.exists(expected_header):
|
|
logging.error(
|
|
'Could not find includes path. Missing file: {}'
|
|
.format(expected_header))
|
|
tensorflow_includes = base
|
|
|
|
return _XLA_MAKEFILE_TEMPLATE.format(
|
|
tensorflow_includes=tensorflow_includes,
|
|
compiled_dir=_shlex_quote(output_dir),
|
|
cxx_flags='-D_GLIBCXX_USE_CXX11_ABI={}'.format(
|
|
versions.CXX11_ABI_FLAG))
|
|
|
|
|
|
def _get_variable_nodes_from_graph_def(graph_def):
|
|
"""Get the list of Variable nodes from `graph_def`.
|
|
|
|
Args:
|
|
graph_def: An instance of `GraphDef`. This GraphDef *must*
|
|
have already been optimized by Grappler. In particular, function
|
|
inlining must have already happened.
|
|
|
|
Returns:
|
|
A dict mapping string names of variables to tuples `(node_def, modified)`,
|
|
where `node_def` is the `NodeDef` corresponding to variable, and `modified`
|
|
is a python bool describing whether the variable is modified during runtime.
|
|
"""
|
|
variables = [n for n in graph_def.node if n.op == 'VarHandleOp']
|
|
variable_name_map = dict((n.name, n) for n in variables)
|
|
child_map = collections.defaultdict(lambda: [])
|
|
for n in graph_def.node:
|
|
for inp in n.input:
|
|
if not inp.startswith('^'):
|
|
child_map[inp].append(n)
|
|
variables = {}
|
|
for (v_name, v_node) in variable_name_map.items():
|
|
queue = list(child_map[v_name])
|
|
processed = set([])
|
|
while queue:
|
|
n_current = queue.pop()
|
|
if n_current.name in processed:
|
|
continue
|
|
processed.add(n_current.name)
|
|
if n_current.op in _PASS_THROUGH_VARIABLE_OPS:
|
|
children = child_map.get(n_current.name, [])
|
|
queue.extend(children)
|
|
elif n_current.op not in _READ_ONLY_VARIABLE_OPS:
|
|
variables[v_name] = (v_node, True)
|
|
queue = []
|
|
if v_name not in variables:
|
|
variables[v_name] = (v_node, False)
|
|
|
|
return variables
|
|
|
|
|
|
def _prune_removed_feed_nodes(signature_def, graph_def):
|
|
"""Identify the inputs in the signature no longer in graph_def, prune them.
|
|
|
|
Args:
|
|
signature_def: A `SignatureDef` instance.
|
|
graph_def: A `GraphDef` instance.
|
|
|
|
Returns:
|
|
A new pruned `SignatureDef`.
|
|
"""
|
|
node_names = set([n.name for n in graph_def.node])
|
|
new_signature_def = meta_graph_pb2.SignatureDef()
|
|
new_signature_def.CopyFrom(signature_def)
|
|
for (k, v) in signature_def.inputs.items():
|
|
tensor_name, _ = _parse_tensor_name(v.name)
|
|
if tensor_name not in node_names:
|
|
logging.warn(
|
|
'Signature input key \'{}\', tensor name \'{}\', has been pruned '
|
|
'while freezing the graph. Removing it from the compiled signatures.'
|
|
.format(k, tensor_name))
|
|
del new_signature_def.inputs[k]
|
|
return new_signature_def
|
|
|
|
|
|
def aot_compile_cpu_meta_graph_def(checkpoint_path,
|
|
meta_graph_def,
|
|
output_prefix,
|
|
signature_def_key,
|
|
cpp_class,
|
|
target_triple,
|
|
target_cpu,
|
|
variables_to_feed=(),
|
|
multithreading=False):
|
|
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
|
|
|
|
Use XLA AOT (`tfcompile`) to convert the given meta graph and
|
|
signature into a header + object files. Also create an include makefile
|
|
that helps identify the appropriate necessary include and library paths
|
|
to incorporate these files into your C++ program.
|
|
|
|
The graph is always optimized with grappler, and optionally (by default)
|
|
variables are frozen as constants, before compilation happens.
|
|
|
|
If the `freeze_graph` is `True`, all variables are embedded as constants
|
|
into the graph and binary objects. If it is `False`, then the variable
|
|
values become inputs and outputs of the compiled class and the C++
|
|
caller must set these values manually.
|
|
|
|
Args:
|
|
checkpoint_path: Python string. Path to checkpoints/variables.
|
|
meta_graph_def: Instance of `MetaGraphDef`.
|
|
output_prefix: Python string. Path prefix for outputs.
|
|
signature_def_key: String, the signature_def to use in the SavedModel.
|
|
cpp_class: String, Name of output C++ class.
|
|
target_triple: String, LLVM target triple.
|
|
target_cpu: String, LLVM target cpu name.
|
|
variables_to_feed: A list of strings, the variables that will be fed by the
|
|
user; these won't be frozen. If `None`, then we will extract all the
|
|
variables in the graph and mark them as to-feed. The default behavior is
|
|
an empty tuple: all variables must be frozen.
|
|
multithreading: Whether to enable multithreading in the compiled
|
|
computation. Note that if using this option, the resulting object files
|
|
may have external dependencies on multithreading libraries like nsync.
|
|
|
|
Raises:
|
|
RuntimeError: If tensorflow was not built with XLA.
|
|
ImportError: If tensorflow was built with XLA but there was another
|
|
issue importing the tfcompile python wrapper.
|
|
ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
|
|
missing or has empty outputs.
|
|
"""
|
|
if _pywrap_tfcompile_import_error:
|
|
raise _pywrap_tfcompile_import_error # pylint: disable=raising-bad-type
|
|
|
|
else:
|
|
# TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap
|
|
# so that we can set these directly instead of relying on env vars.
|
|
xla_flags = os.environ.get('XLA_FLAGS')
|
|
if not xla_flags:
|
|
xla_flags = '--xla_cpu_multi_thread_eigen={}'.format(
|
|
'true' if multithreading else 'false')
|
|
else:
|
|
xla_flags += ',--xla_cpu_multi_thread_eigen={}'.format(
|
|
'true' if multithreading else 'false')
|
|
os.environ['XLA_FLAGS'] = xla_flags
|
|
|
|
signature_def_map = meta_graph_def.signature_def
|
|
if signature_def_key not in signature_def_map:
|
|
raise ValueError(
|
|
'Unable to find signature_def key \'{}\' in signature def map. '
|
|
'Available keys: {}'.format(
|
|
signature_def_key,
|
|
list(signature_def_map.keys())))
|
|
signature_def = signature_def_map[signature_def_key]
|
|
if not signature_def.outputs:
|
|
raise ValueError(
|
|
'Signature key {} must have outputs, but saw none:\n{}'.format(
|
|
signature_def_key, str(signature_def)))
|
|
|
|
temp_dir = test.get_temp_dir()
|
|
file_io.recursive_create_dir(temp_dir)
|
|
if logging.get_verbosity() >= logging.INFO:
|
|
original_graph_def_location = os.path.join(temp_dir, 'original_graph.pb')
|
|
with file_io.FileIO(original_graph_def_location, 'wb') as graph_writer:
|
|
graph_writer.write(meta_graph_def.graph_def.SerializeToString())
|
|
|
|
# This updates graph_def in place.
|
|
_replace_input_placeholders_with_default_values(
|
|
meta_graph_def.graph_def, signature_def)
|
|
|
|
graph_def = _optimize_graph(meta_graph_def, signature_def)
|
|
|
|
all_variables = _get_variable_nodes_from_graph_def(graph_def)
|
|
if variables_to_feed is None:
|
|
variable_nodes_to_feed = list(all_variables.values())
|
|
else:
|
|
not_in_graph = set(variables_to_feed).difference(list(all_variables))
|
|
if not_in_graph:
|
|
raise ValueError(
|
|
'Asked to feed variables that were not found in graph: {}. '
|
|
'Variables contained in the graph: {}'.format(
|
|
not_in_graph, list(all_variables)))
|
|
variable_nodes_to_feed = [
|
|
all_variables[name] for name in variables_to_feed
|
|
]
|
|
|
|
if logging.get_verbosity() >= logging.INFO:
|
|
prefrozen_graph_def_location = os.path.join(temp_dir, 'prefrozen_graph.pb')
|
|
with file_io.FileIO(prefrozen_graph_def_location, 'wb') as graph_writer:
|
|
graph_writer.write(graph_def.SerializeToString())
|
|
|
|
# Load the Variables so that we can freeze the graph.
|
|
with session.Session(graph=ops_lib.Graph()) as sess:
|
|
restorer = saver_lib.import_meta_graph(meta_graph_def, clear_devices=True)
|
|
if restorer is not None:
|
|
restorer.restore(sess, checkpoint_path)
|
|
graph_def.CopyFrom(
|
|
graph_util.convert_variables_to_constants(
|
|
sess,
|
|
graph_def,
|
|
output_node_names=[
|
|
_parse_tensor_name(n.name)[0]
|
|
for n in signature_def.outputs.values()
|
|
],
|
|
variable_names_blacklist=[
|
|
n.name for n, _ in variable_nodes_to_feed
|
|
],
|
|
))
|
|
|
|
signature_def = _prune_removed_feed_nodes(signature_def, graph_def)
|
|
|
|
frozen_graph_def_location = os.path.join(temp_dir, 'frozen_graph.pb')
|
|
config_pbtxt_location = os.path.join(temp_dir, 'config.pbtxt')
|
|
logging.info('Writing graph def to: {}'.format(frozen_graph_def_location))
|
|
with file_io.FileIO(frozen_graph_def_location, 'wb') as graph_writer:
|
|
graph_writer.write(graph_def.SerializeToString())
|
|
config = _signature_to_tf2xla_config(
|
|
signature_def, variable_nodes_to_feed=variable_nodes_to_feed)
|
|
logging.info('Writing config_pbtxt to: {}'.format(config_pbtxt_location))
|
|
with file_io.FileIO(config_pbtxt_location, mode='w') as config_writer:
|
|
config_writer.write(str(config))
|
|
|
|
output_dir = os.path.dirname(output_prefix)
|
|
file_io.recursive_create_dir(output_dir)
|
|
|
|
entry_point = re.sub(
|
|
'[^0-9a-zA-Z]+', '_',
|
|
'__xla_' + output_prefix + '__' + cpp_class)
|
|
|
|
logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
|
|
|
|
makefile_inc_location = '{}_makefile.inc'.format(output_prefix)
|
|
with file_io.FileIO(makefile_inc_location, mode='w') as makefile_writer:
|
|
makefile_writer.write(_xla_makefile_string(output_prefix))
|
|
|
|
output_prefix = _shlex_quote(output_prefix)
|
|
|
|
_pywrap_tfcompile.Compile(
|
|
graph=frozen_graph_def_location,
|
|
config=config_pbtxt_location,
|
|
cpp_class=cpp_class,
|
|
target_triple=target_triple,
|
|
target_cpu=target_cpu,
|
|
entry_point=entry_point,
|
|
out_function_object='{}.o'.format(output_prefix),
|
|
out_header='{}.h'.format(output_prefix),
|
|
out_metadata_object='{}_metadata.o'.format(output_prefix),
|
|
gen_name_to_index=True,
|
|
# ProgramShape isn't uniquefied by entry_point.
|
|
gen_program_shape=False)
|
|
|
|
|
|
def _optimize_graph(meta_graph_def, signature_def):
|
|
"""Optimize `meta_graph_def` using grappler. Returns a `GraphDef`."""
|
|
# We need to add a collection called 'train_op' so that grappler
|
|
# knows what the outputs are.
|
|
new_meta_graph_def = copy.deepcopy(meta_graph_def)
|
|
fetch_collection = meta_graph_pb2.CollectionDef()
|
|
for tensor_info in (
|
|
list(signature_def.inputs.values()) +
|
|
list(signature_def.outputs.values())):
|
|
fetch_collection.node_list.value.append(tensor_info.name)
|
|
|
|
new_meta_graph_def.collection_def['train_op'].CopyFrom(fetch_collection)
|
|
|
|
config = config_pb2.ConfigProto()
|
|
rewrite_options = config.graph_options.rewrite_options
|
|
rewrite_options.min_graph_nodes = -1 # do not skip small graphs
|
|
return tf_optimizer.OptimizeGraph(config, new_meta_graph_def)
|
|
|
|
|
|
def _replace_input_placeholders_with_default_values(graph_def, signature_def):
|
|
"""Replace graphdef's `tf.placeholder` input ops with all-zero constants."""
|
|
name_to_node_map = dict((n.name, n) for n in graph_def.node)
|
|
processed_nodes = set([])
|
|
for name, input_ in signature_def.inputs.items():
|
|
tensor_name, _ = _parse_tensor_name(input_.name)
|
|
if tensor_name in processed_nodes:
|
|
continue
|
|
processed_nodes.add(tensor_name)
|
|
if tensor_name not in name_to_node_map:
|
|
raise RuntimeError(
|
|
'Unable to find input signature tensor \'{}\' in optimized GraphDef. '
|
|
'Graph nodes are: {}'.format(tensor_name,
|
|
list(name_to_node_map.keys())))
|
|
node = name_to_node_map[tensor_name]
|
|
if node.op not in ('Placeholder', 'PlaceholderV2'):
|
|
logging.info(
|
|
'Tried to convert SavedModel input node \'{}\' from a placeholder, '
|
|
'but it doesn\'t look like a placeholder: {}'.format(tensor_name,
|
|
node))
|
|
continue
|
|
shape = tensor_shape.TensorShape(input_.tensor_shape)
|
|
if not shape.is_fully_defined():
|
|
raise ValueError(
|
|
'Expected fully defined input shape for signature_def \'{}\', '
|
|
'tensor name: \'{}\'; but shape is: {}.'
|
|
.format(name, tensor_name, shape))
|
|
temp_graph = ops_lib.Graph()
|
|
with temp_graph.as_default():
|
|
const = array_ops.zeros(
|
|
shape, dtype=input_.dtype, name=tensor_name)
|
|
node.CopyFrom(const.op.node_def)
|
|
# Sometimes zeros() also creates additional nodes
|
|
for op in temp_graph.get_operations():
|
|
if op.name == const.op.name: # We just inserted this one.
|
|
continue
|
|
graph_def.node.append(op.node_def)
|
|
name_to_node_map[op.node_def.name] = op.node_def
|
|
|
|
|
|
def _signature_to_tf2xla_config(signature_def, variable_nodes_to_feed):
|
|
"""Convert `signature_def` to tf2xla config. Returns a `tf2xla.Config` proto.
|
|
|
|
Args:
|
|
signature_def: Instance of `SignatureDef`.
|
|
variable_nodes_to_feed: List of tuples of form `(node_def, modified)`
|
|
corresponding to VarHandleOp, and a boolean `modified` that describes
|
|
whether the variable was modified during execution.
|
|
|
|
Returns:
|
|
An instance of `tf2xla.Config` proto.
|
|
|
|
Raises:
|
|
RuntimeError: If TensorFlow was not compiled with XLA.
|
|
"""
|
|
from tensorflow.compiler.tf2xla import tf2xla_pb2 # pylint: disable=g-import-not-at-top
|
|
|
|
config = tf2xla_pb2.Config()
|
|
tensor_id = tf2xla_pb2.TensorId
|
|
|
|
for name, input_ in signature_def.inputs.items():
|
|
name = name.replace('/', '_')
|
|
name = 'feed_{}'.format(name)
|
|
(node_name, output_index) = _parse_tensor_name(input_.name)
|
|
output_index = int(output_index)
|
|
config.feed.append(
|
|
tf2xla_pb2.Feed(
|
|
id=tensor_id(node_name=node_name, output_index=output_index),
|
|
name=name,
|
|
type=input_.dtype,
|
|
shape=input_.tensor_shape))
|
|
for name, output_ in signature_def.outputs.items():
|
|
name = name.replace('/', '_')
|
|
name = 'fetch_{}'.format(name)
|
|
(node_name, output_index) = _parse_tensor_name(output_.name)
|
|
output_index = int(output_index)
|
|
config.fetch.append(
|
|
tf2xla_pb2.Fetch(
|
|
id=tensor_id(node_name=node_name, output_index=output_index),
|
|
name=name,
|
|
type=output_.dtype,
|
|
shape=output_.tensor_shape))
|
|
for (node, modified) in variable_nodes_to_feed:
|
|
name = node.name.replace('/', '_')
|
|
name = 'param_{}'.format(name)
|
|
config.variable.append(
|
|
tf2xla_pb2.Variable(
|
|
node_name=node.name,
|
|
name=name,
|
|
type=node.attr['dtype'].type,
|
|
shape=node.attr['shape'].shape,
|
|
readonly=not modified))
|
|
|
|
return config
|