[TF] [saved_model_cli] Add support for multithreaded cpu service. Off by default.

This change allows the linkage of multithreaded XLA AOT CPU backend objects,
such as multithreaded matmul, conv2d, etc.  These are not enabled by default.

New unit tests confirm that the objects are emitted and linked correctly,
and the resulting computations are numerically correct.

MKL service backend objects are not included.

Other changes:
* C++ Unit tests now use arg_feed_{x,y} instead of arg0/arg1, since the names
  are flaky (they may swap from the signature)
* Add argument "multithreading=" to the bzl file and saved_model_cli.
* Add unit tests using "nm" to ensure that the proper symbols are used when
  enabling or disabling multithreading (not sure if they are windows-friendly).
* Use a simpler and more unique string for the entry_point string.

PiperOrigin-RevId: 338112208
Change-Id: Id734e75e63e72db93a743f451ddb7eb6f489c1c7
This commit is contained in:
Eugene Brevdo 2020-10-20 12:29:23 -07:00 committed by TensorFlower Gardener
parent ebab4d6209
commit 84967b39fa
10 changed files with 220 additions and 51 deletions

View File

@ -196,7 +196,7 @@ filegroup(
srcs = [
"xla_compiled_cpu_function.h",
"//tensorflow/compiler/xla:cpu_runtime_hdrs",
"//tensorflow/compiler/xla/service/cpu:single_threaded_runtime_hdrs",
"//tensorflow/compiler/xla/service/cpu:runtime_hdrs",
"//tensorflow/core/kernels:xla_cpu_runtime_hdrs",
"//tensorflow/core/platform:xla_cpu_runtime_srcs",
],
@ -208,7 +208,7 @@ filegroup(
srcs = [
"xla_compiled_cpu_function.cc",
"//tensorflow/compiler/xla:cpu_runtime_srcs",
"//tensorflow/compiler/xla/service/cpu:single_threaded_runtime_srcs",
"//tensorflow/compiler/xla/service/cpu:runtime_srcs",
"//tensorflow/core/kernels:xla_cpu_runtime_srcs",
"//tensorflow/core/platform:xla_cpu_runtime_srcs",
],
@ -249,6 +249,11 @@ cc_library(
"//third_party/eigen3",
"//tensorflow/core/framework:numeric_types",
"//tensorflow/core/platform:bfloat16",
] + [
# Extra dependencies required for multithreaded runtime objects.
"//tensorflow/core/platform:blocking_counter",
"//tensorflow/core/platform:logging",
"//tensorflow/core/platform:mutex",
] + tf_additional_tensor_coding_deps(),
alwayslink = 1,
)

View File

@ -45,8 +45,9 @@ cc_library(
)
filegroup(
name = "single_threaded_runtime_srcs",
name = "runtime_srcs",
srcs = [
# Single-threaded support.
"runtime_fp16.cc",
"runtime_key_value_sort.cc",
"runtime_pow.cc",
@ -54,13 +55,20 @@ filegroup(
"runtime_single_threaded_fft.cc",
"runtime_single_threaded_matmul.cc",
"runtime_topk.cc",
] + [
# Multi-threaded support.
"runtime_conv2d.cc",
"runtime_fft.cc",
"runtime_matmul.cc",
"runtime_fork_join.cc",
],
visibility = [":friends"],
)
filegroup(
name = "single_threaded_runtime_hdrs",
name = "runtime_hdrs",
srcs = [
# Single-threaded support.
"runtime_conv2d_impl.h",
"runtime_fft_impl.h",
"runtime_fp16.h",
@ -70,6 +78,13 @@ filegroup(
"runtime_single_threaded_fft.h",
"runtime_single_threaded_matmul.h",
"runtime_topk.h",
] + [
# Multi-threaded support.
"runtime_conv2d.h",
"runtime_fft.h",
"runtime_fork_join.h",
"runtime_lightweight_check.h",
"runtime_matmul.h",
],
visibility = [":friends"],
)

View File

@ -417,6 +417,16 @@ saved_model_compile_aot(
tags = ["no_rocm"],
)
saved_model_compile_aot(
name = "aot_compiled_x_matmul_y_large_multithreaded",
cpp_class = "XMatmulYLargeMultithreaded",
directory = "//tensorflow/python/tools:x_matmul_y_large",
filegroups = [":aot_saved_models"],
force_without_xla_support_flag = False,
multithreading = True,
tags = ["no_rocm"],
)
saved_model_compile_aot(
name = "aot_compiled_x_matmul_y_small",
cpp_class = "XMatmulYSmall",
@ -460,6 +470,32 @@ saved_model_compile_aot(
variables_to_feed = "variable_x",
)
sh_test(
name = "large_matmul_no_multithread_test",
srcs = if_xla_available(
["no_xla_multithread_symbols_test.sh"],
if_false = ["skip_test.sh"],
),
args = if_xla_available(["$(location :aot_compiled_x_matmul_y_large.o)"]),
data = if_xla_available([":aot_compiled_x_matmul_y_large.o"]),
)
sh_test(
name = "large_matmul_yes_multithread_test",
srcs = if_xla_available(
[
"xla_multithread_symbols_test.sh",
],
if_false = ["skip_test.sh"],
),
args = if_xla_available(
["$(location :aot_compiled_x_matmul_y_large_multithreaded.o)"],
),
data = if_xla_available(
[":aot_compiled_x_matmul_y_large_multithreaded.o"],
),
)
tf_cc_test(
name = "aot_compiled_test",
srcs = if_xla_available([
@ -472,6 +508,7 @@ tf_cc_test(
":aot_compiled_vars_and_arithmetic",
":aot_compiled_vars_and_arithmetic_frozen",
":aot_compiled_x_matmul_y_large",
":aot_compiled_x_matmul_y_large_multithreaded",
":aot_compiled_x_matmul_y_small",
":aot_compiled_x_plus_y",
"//tensorflow/core:test",

View File

@ -13,12 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic.h"
#include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic_frozen.h"
#include "tensorflow/python/tools/aot_compiled_x_matmul_y_large.h"
#include "tensorflow/python/tools/aot_compiled_x_matmul_y_large_multithreaded.h"
#include "tensorflow/python/tools/aot_compiled_x_matmul_y_small.h"
#include "tensorflow/python/tools/aot_compiled_x_plus_y.h"
@ -36,24 +39,24 @@ TEST(AOTCompiledSavedModelTest, XPlusY) {
TEST(AOTCompiledSavedModelTest, XMatmulYLarge) {
XMatmulYLarge model;
// Calculation is: output_0 = x @ y.
EXPECT_EQ(model.arg0_size(), sizeof(float) * 3000 * 5000);
EXPECT_EQ(model.arg1_size(), sizeof(float) * 5000 * 4000);
EXPECT_EQ(model.result0_size(), sizeof(float) * 3000 * 4000);
EXPECT_EQ(model.arg_feed_x_count(), 3000 * 5000);
EXPECT_EQ(model.arg_feed_y_count(), 5000 * 4000);
EXPECT_EQ(model.result0_count(), 3000 * 4000);
Eigen::Tensor<float, 2, Eigen::RowMajor> arg0(3000, 5000);
Eigen::Tensor<float, 2, Eigen::RowMajor> arg1(5000, 4000);
arg0.setRandom();
arg1.setRandom();
Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_x(3000, 5000);
Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_y(5000, 4000);
arg_feed_x.setRandom();
arg_feed_y.setRandom();
// Set up dimensions for standard matmul.
const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
Eigen::IndexPair<int>(1, 0)};
// Ground truth matmul.
const Eigen::Tensor<float, 2, Eigen::RowMajor> expected_output0 =
arg0.contract(arg1, product_dims);
arg_feed_x.contract(arg_feed_y, product_dims);
model.set_arg_feed_x_data(arg0.data());
model.set_arg_feed_y_data(arg1.data());
model.set_arg_feed_x_data(arg_feed_x.data());
model.set_arg_feed_y_data(arg_feed_y.data());
CHECK(model.Run());
EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0),
/*abs_error=*/1e-6f);
@ -62,27 +65,61 @@ TEST(AOTCompiledSavedModelTest, XMatmulYLarge) {
/*abs_error=*/1e-6f);
}
TEST(AOTCompiledSavedModelTest, XMatmulYSmall) {
XMatmulYSmall model;
// Calculation is: output_0 = x @ y.
EXPECT_EQ(model.arg0_size(), sizeof(float) * 3 * 5);
EXPECT_EQ(model.arg1_size(), sizeof(float) * 5 * 4);
EXPECT_EQ(model.result0_size(), sizeof(float) * 3 * 4);
TEST(AOTCompiledSavedModelTest, XMatmulYLargeMultithreaded) {
XMatmulYLargeMultithreaded model;
Eigen::Tensor<float, 2, Eigen::RowMajor> arg0(3, 5);
Eigen::Tensor<float, 2, Eigen::RowMajor> arg1(5, 4);
arg0.setRandom();
arg1.setRandom();
Eigen::ThreadPool pool(2);
Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
model.set_thread_pool(&device);
// Calculation is: output_0 = x @ y.
EXPECT_EQ(model.arg_feed_x_count(), 3000 * 5000);
EXPECT_EQ(model.arg_feed_y_count(), 5000 * 4000);
EXPECT_EQ(model.result0_count(), 3000 * 4000);
Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_x(3000, 5000);
Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_y(5000, 4000);
arg_feed_x.setRandom();
arg_feed_y.setRandom();
// Set up dimensions for standard matmul.
const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
Eigen::IndexPair<int>(1, 0)};
// Ground truth matmul.
const Eigen::Tensor<float, 2, Eigen::RowMajor> expected_output0 =
arg0.contract(arg1, product_dims);
arg_feed_x.contract(arg_feed_y, product_dims);
model.set_arg_feed_x_data(arg0.data());
model.set_arg_feed_y_data(arg1.data());
model.set_arg_feed_x_data(arg_feed_x.data());
model.set_arg_feed_y_data(arg_feed_y.data());
CHECK(model.Run());
EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0),
/*abs_error=*/1e-3f);
EXPECT_NEAR(model.result_fetch_output_0(2999, 3999),
expected_output0(2999, 3999),
/*abs_error=*/1e-3f);
}
TEST(AOTCompiledSavedModelTest, XMatmulYSmall) {
XMatmulYSmall model;
// Calculation is: output_0 = x @ y.
EXPECT_EQ(model.arg_feed_x_count(), 3 * 5);
EXPECT_EQ(model.arg_feed_y_count(), 5 * 4);
EXPECT_EQ(model.result0_count(), 3 * 4);
Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_x(3, 5);
Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_y(5, 4);
arg_feed_x.setRandom();
arg_feed_y.setRandom();
// Set up dimensions for standard matmul.
const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
Eigen::IndexPair<int>(1, 0)};
// Ground truth matmul.
const Eigen::Tensor<float, 2, Eigen::RowMajor> expected_output0 =
arg_feed_x.contract(arg_feed_y, product_dims);
model.set_arg_feed_x_data(arg_feed_x.data());
model.set_arg_feed_y_data(arg_feed_y.data());
CHECK(model.Run());
EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0),
/*abs_error=*/1e-6f);

View File

@ -0,0 +1,27 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -e
SYMBOLS=$(nm "$@" | grep __xla_cpu_runtime)
if echo "${SYMBOLS}" | grep -q SingleThread; then
exit 0
else
echo "" 1>&2
echo "Did not see SingleThread runtime symbol in $@:" 1>&2
echo "" 1>&2
echo "${SYMBOLS}" 1>&2
echo "" 1>&2
exit 1
fi

View File

@ -19,11 +19,10 @@ from __future__ import division
from __future__ import print_function
import collections
import copy
import hashlib
import os
import pipes
import re
import shlex
import six
@ -217,7 +216,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
target_triple,
target_cpu,
variables_to_feed=(),
enable_multithreading=False):
multithreading=False):
"""Compile a `MetaGraphDef` to header+object files in `output_prefix`.
Use XLA AOT (`tfcompile`) to convert the given meta graph and
@ -245,8 +244,9 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
user; these won't be frozen. If `None`, then we will extract all the
variables in the graph and mark them as to-feed. The default behavior is
an empty tuple: all variables must be frozen.
enable_multithreading: Not implemented. Enable multithreading in the
compiled computation.
multithreading: Whether to enable multithreading in the compiled
computation. Note that if using this option, the resulting object files
may have external dependencies on multithreading libraries like nsync.
Raises:
RuntimeError: If tensorflow was not built with XLA.
@ -254,23 +254,20 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
issue importing the tfcompile python wrapper.
ValueError: If `meta_graph_def.signature_def[signature_def_key]` is
missing or has empty outputs.
NotImplementedError: If `enable_multithreading is True`.
"""
if _pywrap_tfcompile_import_error:
raise _pywrap_tfcompile_import_error
raise _pywrap_tfcompile_import_error # pylint: disable=raising-bad-type
if enable_multithreading:
raise NotImplementedError(
'Multithreading is not currently supported because it requires '
'additional dependencies in the AOT runtime.')
else:
# TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap
# so that we can set these directly instead of relying on env vars.
xla_flags = os.environ.get('XLA_FLAGS')
if not xla_flags:
xla_flags = '--xla_cpu_multi_thread_eigen=false'
xla_flags = '--xla_cpu_multi_thread_eigen={}'.format(
'true' if multithreading else 'false')
else:
xla_flags += ',--xla_cpu_multi_thread_eigen=false'
xla_flags += ',--xla_cpu_multi_thread_eigen={}'.format(
'true' if multithreading else 'false')
os.environ['XLA_FLAGS'] = xla_flags
signature_def_map = meta_graph_def.signature_def
@ -352,10 +349,9 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
output_dir = os.path.dirname(output_prefix)
file_io.recursive_create_dir(output_dir)
entry_digest = hashlib.md5()
entry_digest.update(str(config).encode())
entry_digest.update(str(graph_def).encode())
entry_digest = entry_digest.hexdigest()
entry_point = re.sub(
'[^0-9a-zA-Z]+', '_',
'__xla_' + output_prefix + '__' + cpp_class)
logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir))
@ -371,7 +367,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path,
cpp_class=cpp_class,
target_triple=target_triple,
target_cpu=target_cpu,
entry_point='entry_{}'.format(entry_digest),
entry_point=entry_point,
out_function_object='{}.o'.format(output_prefix),
out_header='{}.h'.format(output_prefix),
out_metadata_object='{}_metadata.o'.format(output_prefix),

View File

@ -821,6 +821,7 @@ def aot_compile_cpu(args):
variables_to_feed = None # We will identify them after.
else:
variables_to_feed = args.variables_to_feed.split(',')
saved_model_aot_compile.aot_compile_cpu_meta_graph_def(
checkpoint_path=checkpoint_path,
meta_graph_def=saved_model_utils.get_meta_graph_def(
@ -831,7 +832,7 @@ def aot_compile_cpu(args):
target_triple=args.target_triple,
target_cpu=args.target_cpu,
cpp_class=args.cpp_class,
enable_multithreading=args.enable_multithreading)
multithreading=args.multithreading.lower() not in ('f', 'false', '0'))
def add_show_subparser(subparsers):
@ -1140,11 +1141,13 @@ def add_aot_compile_cpu_subparser(subparsers):
'(this applies to all input arguments from the signature as '
'well).'))
parser_compile.add_argument(
'--enable_multithreading',
type=bool,
default='',
help=('*NOT CURRENTLY SUPPORTED* '
'Enable multithreading in the compiled computation.'))
'--multithreading',
type=str,
default='False',
help=('Enable multithreading in the compiled computation. '
'Note that if using this option, the resulting object files '
'may have external dependencies on multithreading libraries '
'like nsync.'))
parser_compile.set_defaults(func=aot_compile_cpu)

View File

@ -0,0 +1,15 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
exit 0

View File

@ -21,6 +21,7 @@ def saved_model_compile_aot(
variables_to_feed = "",
target_triple = None,
target_cpu = None,
multithreading = False,
force_without_xla_support_flag = True,
tags = None):
"""Compile a SavedModel directory accessible from a filegroup.
@ -93,6 +94,11 @@ def saved_model_compile_aot(
target architecture's triple). Similar to clang's -target flag.
target_cpu: The LLVM cpu name used for compilation. Similar to clang's
-mcpu flag.
multithreading: Whether to compile multithreaded AOT code.
Note, this increases the set of dependencies for binaries using
the AOT library at both build and runtime. For example,
the resulting object files may have external dependencies on
multithreading libraries like nsync.
force_without_xla_support_flag: Whether to compile even when
`--define=with_xla_support=true` is not set. If `False`, and the
define is not passed when building, then the created `cc_library`
@ -135,6 +141,7 @@ def saved_model_compile_aot(
"--cpp_class {} ".format(cpp_class) +
"--variables_to_feed {} ".format(variables_to_feed) +
"--signature_def_key {} ".format(signature_def) +
"--multithreading {} ".format(multithreading) +
"--target_triple " + target_triple + " " +
("--target_cpu " + target_cpu + " " if target_cpu else "") +
"--tag_set {} ".format(tag_set)

View File

@ -0,0 +1,27 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -e
SYMBOLS=$(nm "$@" | grep __xla_cpu_runtime)
if echo "${SYMBOLS}" | grep -q SingleThread; then
echo "" 1>&2
echo "Saw a SingleThread runtime symbol in $@:" 1>&2
echo "" 1>&2
echo "${SYMBOLS}" 1>&2
echo "" 1>&2
exit 1
else
exit 0
fi