From 84967b39fa98d27f5984648f9ec47a159206cfda Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 20 Oct 2020 12:29:23 -0700 Subject: [PATCH] [TF] [saved_model_cli] Add support for multithreaded cpu service. Off by default. This change allows the linkage of multithreaded XLA AOT CPU backend objects, such as multithreaded matmul, conv2d, etc. These are not enabled by default. New unit tests confirm that the objects are emitted and linked correctly, and the resulting computations are numerically correct. MKL service backend objects are not included. Other changes: * C++ Unit tests now use arg_feed_{x,y} instead of arg0/arg1, since the names are flaky (they may swap from the signature) * Add argument "multithreading=" to the bzl file and saved_model_cli. * Add unit tests using "nm" to ensure that the proper symbols are used when enabling or disabling multithreading (not sure if they are windows-friendly). * Use a simpler and more unique string for the entry_point string. PiperOrigin-RevId: 338112208 Change-Id: Id734e75e63e72db93a743f451ddb7eb6f489c1c7 --- tensorflow/compiler/tf2xla/BUILD | 9 +- tensorflow/compiler/xla/service/cpu/BUILD | 19 ++++- tensorflow/python/tools/BUILD | 37 +++++++++ tensorflow/python/tools/aot_compiled_test.cc | 83 ++++++++++++++----- .../tools/no_xla_multithread_symbols_test.sh | 27 ++++++ .../python/tools/saved_model_aot_compile.py | 32 ++++--- tensorflow/python/tools/saved_model_cli.py | 15 ++-- tensorflow/python/tools/skip_test.sh | 15 ++++ tensorflow/python/tools/tools.bzl | 7 ++ .../tools/xla_multithread_symbols_test.sh | 27 ++++++ 10 files changed, 220 insertions(+), 51 deletions(-) create mode 100755 tensorflow/python/tools/no_xla_multithread_symbols_test.sh create mode 100755 tensorflow/python/tools/skip_test.sh create mode 100755 tensorflow/python/tools/xla_multithread_symbols_test.sh diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5641339e7ef..588b4269fee 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -196,7 +196,7 @@ filegroup( srcs = [ "xla_compiled_cpu_function.h", "//tensorflow/compiler/xla:cpu_runtime_hdrs", - "//tensorflow/compiler/xla/service/cpu:single_threaded_runtime_hdrs", + "//tensorflow/compiler/xla/service/cpu:runtime_hdrs", "//tensorflow/core/kernels:xla_cpu_runtime_hdrs", "//tensorflow/core/platform:xla_cpu_runtime_srcs", ], @@ -208,7 +208,7 @@ filegroup( srcs = [ "xla_compiled_cpu_function.cc", "//tensorflow/compiler/xla:cpu_runtime_srcs", - "//tensorflow/compiler/xla/service/cpu:single_threaded_runtime_srcs", + "//tensorflow/compiler/xla/service/cpu:runtime_srcs", "//tensorflow/core/kernels:xla_cpu_runtime_srcs", "//tensorflow/core/platform:xla_cpu_runtime_srcs", ], @@ -249,6 +249,11 @@ cc_library( "//third_party/eigen3", "//tensorflow/core/framework:numeric_types", "//tensorflow/core/platform:bfloat16", + ] + [ + # Extra dependencies required for multithreaded runtime objects. + "//tensorflow/core/platform:blocking_counter", + "//tensorflow/core/platform:logging", + "//tensorflow/core/platform:mutex", ] + tf_additional_tensor_coding_deps(), alwayslink = 1, ) diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 0cc27e32749..c64cfda0b94 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -45,8 +45,9 @@ cc_library( ) filegroup( - name = "single_threaded_runtime_srcs", + name = "runtime_srcs", srcs = [ + # Single-threaded support. "runtime_fp16.cc", "runtime_key_value_sort.cc", "runtime_pow.cc", @@ -54,13 +55,20 @@ filegroup( "runtime_single_threaded_fft.cc", "runtime_single_threaded_matmul.cc", "runtime_topk.cc", + ] + [ + # Multi-threaded support. + "runtime_conv2d.cc", + "runtime_fft.cc", + "runtime_matmul.cc", + "runtime_fork_join.cc", ], visibility = [":friends"], ) filegroup( - name = "single_threaded_runtime_hdrs", + name = "runtime_hdrs", srcs = [ + # Single-threaded support. "runtime_conv2d_impl.h", "runtime_fft_impl.h", "runtime_fp16.h", @@ -70,6 +78,13 @@ filegroup( "runtime_single_threaded_fft.h", "runtime_single_threaded_matmul.h", "runtime_topk.h", + ] + [ + # Multi-threaded support. + "runtime_conv2d.h", + "runtime_fft.h", + "runtime_fork_join.h", + "runtime_lightweight_check.h", + "runtime_matmul.h", ], visibility = [":friends"], ) diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD index b2724ccd901..7b1f85dc0e9 100644 --- a/tensorflow/python/tools/BUILD +++ b/tensorflow/python/tools/BUILD @@ -417,6 +417,16 @@ saved_model_compile_aot( tags = ["no_rocm"], ) +saved_model_compile_aot( + name = "aot_compiled_x_matmul_y_large_multithreaded", + cpp_class = "XMatmulYLargeMultithreaded", + directory = "//tensorflow/python/tools:x_matmul_y_large", + filegroups = [":aot_saved_models"], + force_without_xla_support_flag = False, + multithreading = True, + tags = ["no_rocm"], +) + saved_model_compile_aot( name = "aot_compiled_x_matmul_y_small", cpp_class = "XMatmulYSmall", @@ -460,6 +470,32 @@ saved_model_compile_aot( variables_to_feed = "variable_x", ) +sh_test( + name = "large_matmul_no_multithread_test", + srcs = if_xla_available( + ["no_xla_multithread_symbols_test.sh"], + if_false = ["skip_test.sh"], + ), + args = if_xla_available(["$(location :aot_compiled_x_matmul_y_large.o)"]), + data = if_xla_available([":aot_compiled_x_matmul_y_large.o"]), +) + +sh_test( + name = "large_matmul_yes_multithread_test", + srcs = if_xla_available( + [ + "xla_multithread_symbols_test.sh", + ], + if_false = ["skip_test.sh"], + ), + args = if_xla_available( + ["$(location :aot_compiled_x_matmul_y_large_multithreaded.o)"], + ), + data = if_xla_available( + [":aot_compiled_x_matmul_y_large_multithreaded.o"], + ), +) + tf_cc_test( name = "aot_compiled_test", srcs = if_xla_available([ @@ -472,6 +508,7 @@ tf_cc_test( ":aot_compiled_vars_and_arithmetic", ":aot_compiled_vars_and_arithmetic_frozen", ":aot_compiled_x_matmul_y_large", + ":aot_compiled_x_matmul_y_large_multithreaded", ":aot_compiled_x_matmul_y_small", ":aot_compiled_x_plus_y", "//tensorflow/core:test", diff --git a/tensorflow/python/tools/aot_compiled_test.cc b/tensorflow/python/tools/aot_compiled_test.cc index e628a6a1c37..0c15e638841 100644 --- a/tensorflow/python/tools/aot_compiled_test.cc +++ b/tensorflow/python/tools/aot_compiled_test.cc @@ -13,12 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic.h" #include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic_frozen.h" #include "tensorflow/python/tools/aot_compiled_x_matmul_y_large.h" +#include "tensorflow/python/tools/aot_compiled_x_matmul_y_large_multithreaded.h" #include "tensorflow/python/tools/aot_compiled_x_matmul_y_small.h" #include "tensorflow/python/tools/aot_compiled_x_plus_y.h" @@ -36,24 +39,24 @@ TEST(AOTCompiledSavedModelTest, XPlusY) { TEST(AOTCompiledSavedModelTest, XMatmulYLarge) { XMatmulYLarge model; // Calculation is: output_0 = x @ y. - EXPECT_EQ(model.arg0_size(), sizeof(float) * 3000 * 5000); - EXPECT_EQ(model.arg1_size(), sizeof(float) * 5000 * 4000); - EXPECT_EQ(model.result0_size(), sizeof(float) * 3000 * 4000); + EXPECT_EQ(model.arg_feed_x_count(), 3000 * 5000); + EXPECT_EQ(model.arg_feed_y_count(), 5000 * 4000); + EXPECT_EQ(model.result0_count(), 3000 * 4000); - Eigen::Tensor arg0(3000, 5000); - Eigen::Tensor arg1(5000, 4000); - arg0.setRandom(); - arg1.setRandom(); + Eigen::Tensor arg_feed_x(3000, 5000); + Eigen::Tensor arg_feed_y(5000, 4000); + arg_feed_x.setRandom(); + arg_feed_y.setRandom(); // Set up dimensions for standard matmul. const Eigen::array, 1> product_dims = { Eigen::IndexPair(1, 0)}; // Ground truth matmul. const Eigen::Tensor expected_output0 = - arg0.contract(arg1, product_dims); + arg_feed_x.contract(arg_feed_y, product_dims); - model.set_arg_feed_x_data(arg0.data()); - model.set_arg_feed_y_data(arg1.data()); + model.set_arg_feed_x_data(arg_feed_x.data()); + model.set_arg_feed_y_data(arg_feed_y.data()); CHECK(model.Run()); EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0), /*abs_error=*/1e-6f); @@ -62,27 +65,61 @@ TEST(AOTCompiledSavedModelTest, XMatmulYLarge) { /*abs_error=*/1e-6f); } -TEST(AOTCompiledSavedModelTest, XMatmulYSmall) { - XMatmulYSmall model; - // Calculation is: output_0 = x @ y. - EXPECT_EQ(model.arg0_size(), sizeof(float) * 3 * 5); - EXPECT_EQ(model.arg1_size(), sizeof(float) * 5 * 4); - EXPECT_EQ(model.result0_size(), sizeof(float) * 3 * 4); +TEST(AOTCompiledSavedModelTest, XMatmulYLargeMultithreaded) { + XMatmulYLargeMultithreaded model; - Eigen::Tensor arg0(3, 5); - Eigen::Tensor arg1(5, 4); - arg0.setRandom(); - arg1.setRandom(); + Eigen::ThreadPool pool(2); + Eigen::ThreadPoolDevice device(&pool, pool.NumThreads()); + model.set_thread_pool(&device); + + // Calculation is: output_0 = x @ y. + EXPECT_EQ(model.arg_feed_x_count(), 3000 * 5000); + EXPECT_EQ(model.arg_feed_y_count(), 5000 * 4000); + EXPECT_EQ(model.result0_count(), 3000 * 4000); + + Eigen::Tensor arg_feed_x(3000, 5000); + Eigen::Tensor arg_feed_y(5000, 4000); + arg_feed_x.setRandom(); + arg_feed_y.setRandom(); // Set up dimensions for standard matmul. const Eigen::array, 1> product_dims = { Eigen::IndexPair(1, 0)}; // Ground truth matmul. const Eigen::Tensor expected_output0 = - arg0.contract(arg1, product_dims); + arg_feed_x.contract(arg_feed_y, product_dims); - model.set_arg_feed_x_data(arg0.data()); - model.set_arg_feed_y_data(arg1.data()); + model.set_arg_feed_x_data(arg_feed_x.data()); + model.set_arg_feed_y_data(arg_feed_y.data()); + CHECK(model.Run()); + EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0), + /*abs_error=*/1e-3f); + EXPECT_NEAR(model.result_fetch_output_0(2999, 3999), + expected_output0(2999, 3999), + /*abs_error=*/1e-3f); +} + +TEST(AOTCompiledSavedModelTest, XMatmulYSmall) { + XMatmulYSmall model; + // Calculation is: output_0 = x @ y. + EXPECT_EQ(model.arg_feed_x_count(), 3 * 5); + EXPECT_EQ(model.arg_feed_y_count(), 5 * 4); + EXPECT_EQ(model.result0_count(), 3 * 4); + + Eigen::Tensor arg_feed_x(3, 5); + Eigen::Tensor arg_feed_y(5, 4); + arg_feed_x.setRandom(); + arg_feed_y.setRandom(); + + // Set up dimensions for standard matmul. + const Eigen::array, 1> product_dims = { + Eigen::IndexPair(1, 0)}; + // Ground truth matmul. + const Eigen::Tensor expected_output0 = + arg_feed_x.contract(arg_feed_y, product_dims); + + model.set_arg_feed_x_data(arg_feed_x.data()); + model.set_arg_feed_y_data(arg_feed_y.data()); CHECK(model.Run()); EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0), /*abs_error=*/1e-6f); diff --git a/tensorflow/python/tools/no_xla_multithread_symbols_test.sh b/tensorflow/python/tools/no_xla_multithread_symbols_test.sh new file mode 100755 index 00000000000..468c283ad98 --- /dev/null +++ b/tensorflow/python/tools/no_xla_multithread_symbols_test.sh @@ -0,0 +1,27 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e + +SYMBOLS=$(nm "$@" | grep __xla_cpu_runtime) +if echo "${SYMBOLS}" | grep -q SingleThread; then + exit 0 +else + echo "" 1>&2 + echo "Did not see SingleThread runtime symbol in $@:" 1>&2 + echo "" 1>&2 + echo "${SYMBOLS}" 1>&2 + echo "" 1>&2 + exit 1 +fi diff --git a/tensorflow/python/tools/saved_model_aot_compile.py b/tensorflow/python/tools/saved_model_aot_compile.py index bf955ad825c..d1478e205d3 100644 --- a/tensorflow/python/tools/saved_model_aot_compile.py +++ b/tensorflow/python/tools/saved_model_aot_compile.py @@ -19,11 +19,10 @@ from __future__ import division from __future__ import print_function import collections - import copy -import hashlib import os import pipes +import re import shlex import six @@ -217,7 +216,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, target_triple, target_cpu, variables_to_feed=(), - enable_multithreading=False): + multithreading=False): """Compile a `MetaGraphDef` to header+object files in `output_prefix`. Use XLA AOT (`tfcompile`) to convert the given meta graph and @@ -245,8 +244,9 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, user; these won't be frozen. If `None`, then we will extract all the variables in the graph and mark them as to-feed. The default behavior is an empty tuple: all variables must be frozen. - enable_multithreading: Not implemented. Enable multithreading in the - compiled computation. + multithreading: Whether to enable multithreading in the compiled + computation. Note that if using this option, the resulting object files + may have external dependencies on multithreading libraries like nsync. Raises: RuntimeError: If tensorflow was not built with XLA. @@ -254,23 +254,20 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, issue importing the tfcompile python wrapper. ValueError: If `meta_graph_def.signature_def[signature_def_key]` is missing or has empty outputs. - NotImplementedError: If `enable_multithreading is True`. """ if _pywrap_tfcompile_import_error: - raise _pywrap_tfcompile_import_error + raise _pywrap_tfcompile_import_error # pylint: disable=raising-bad-type - if enable_multithreading: - raise NotImplementedError( - 'Multithreading is not currently supported because it requires ' - 'additional dependencies in the AOT runtime.') else: # TODO(ebrevdo): Pipe DebugOptions through tfcompile::Main and pywrap # so that we can set these directly instead of relying on env vars. xla_flags = os.environ.get('XLA_FLAGS') if not xla_flags: - xla_flags = '--xla_cpu_multi_thread_eigen=false' + xla_flags = '--xla_cpu_multi_thread_eigen={}'.format( + 'true' if multithreading else 'false') else: - xla_flags += ',--xla_cpu_multi_thread_eigen=false' + xla_flags += ',--xla_cpu_multi_thread_eigen={}'.format( + 'true' if multithreading else 'false') os.environ['XLA_FLAGS'] = xla_flags signature_def_map = meta_graph_def.signature_def @@ -352,10 +349,9 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, output_dir = os.path.dirname(output_prefix) file_io.recursive_create_dir(output_dir) - entry_digest = hashlib.md5() - entry_digest.update(str(config).encode()) - entry_digest.update(str(graph_def).encode()) - entry_digest = entry_digest.hexdigest() + entry_point = re.sub( + '[^0-9a-zA-Z]+', '_', + '__xla_' + output_prefix + '__' + cpp_class) logging.info('Generating XLA AOT artifacts in: {}'.format(output_dir)) @@ -371,7 +367,7 @@ def aot_compile_cpu_meta_graph_def(checkpoint_path, cpp_class=cpp_class, target_triple=target_triple, target_cpu=target_cpu, - entry_point='entry_{}'.format(entry_digest), + entry_point=entry_point, out_function_object='{}.o'.format(output_prefix), out_header='{}.h'.format(output_prefix), out_metadata_object='{}_metadata.o'.format(output_prefix), diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 0c8b8f5576b..124686dff13 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -821,6 +821,7 @@ def aot_compile_cpu(args): variables_to_feed = None # We will identify them after. else: variables_to_feed = args.variables_to_feed.split(',') + saved_model_aot_compile.aot_compile_cpu_meta_graph_def( checkpoint_path=checkpoint_path, meta_graph_def=saved_model_utils.get_meta_graph_def( @@ -831,7 +832,7 @@ def aot_compile_cpu(args): target_triple=args.target_triple, target_cpu=args.target_cpu, cpp_class=args.cpp_class, - enable_multithreading=args.enable_multithreading) + multithreading=args.multithreading.lower() not in ('f', 'false', '0')) def add_show_subparser(subparsers): @@ -1140,11 +1141,13 @@ def add_aot_compile_cpu_subparser(subparsers): '(this applies to all input arguments from the signature as ' 'well).')) parser_compile.add_argument( - '--enable_multithreading', - type=bool, - default='', - help=('*NOT CURRENTLY SUPPORTED* ' - 'Enable multithreading in the compiled computation.')) + '--multithreading', + type=str, + default='False', + help=('Enable multithreading in the compiled computation. ' + 'Note that if using this option, the resulting object files ' + 'may have external dependencies on multithreading libraries ' + 'like nsync.')) parser_compile.set_defaults(func=aot_compile_cpu) diff --git a/tensorflow/python/tools/skip_test.sh b/tensorflow/python/tools/skip_test.sh new file mode 100755 index 00000000000..5c9407175fe --- /dev/null +++ b/tensorflow/python/tools/skip_test.sh @@ -0,0 +1,15 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +exit 0 diff --git a/tensorflow/python/tools/tools.bzl b/tensorflow/python/tools/tools.bzl index 79f771bbcad..db886746006 100644 --- a/tensorflow/python/tools/tools.bzl +++ b/tensorflow/python/tools/tools.bzl @@ -21,6 +21,7 @@ def saved_model_compile_aot( variables_to_feed = "", target_triple = None, target_cpu = None, + multithreading = False, force_without_xla_support_flag = True, tags = None): """Compile a SavedModel directory accessible from a filegroup. @@ -93,6 +94,11 @@ def saved_model_compile_aot( target architecture's triple). Similar to clang's -target flag. target_cpu: The LLVM cpu name used for compilation. Similar to clang's -mcpu flag. + multithreading: Whether to compile multithreaded AOT code. + Note, this increases the set of dependencies for binaries using + the AOT library at both build and runtime. For example, + the resulting object files may have external dependencies on + multithreading libraries like nsync. force_without_xla_support_flag: Whether to compile even when `--define=with_xla_support=true` is not set. If `False`, and the define is not passed when building, then the created `cc_library` @@ -135,6 +141,7 @@ def saved_model_compile_aot( "--cpp_class {} ".format(cpp_class) + "--variables_to_feed {} ".format(variables_to_feed) + "--signature_def_key {} ".format(signature_def) + + "--multithreading {} ".format(multithreading) + "--target_triple " + target_triple + " " + ("--target_cpu " + target_cpu + " " if target_cpu else "") + "--tag_set {} ".format(tag_set) diff --git a/tensorflow/python/tools/xla_multithread_symbols_test.sh b/tensorflow/python/tools/xla_multithread_symbols_test.sh new file mode 100755 index 00000000000..9576c762112 --- /dev/null +++ b/tensorflow/python/tools/xla_multithread_symbols_test.sh @@ -0,0 +1,27 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e + +SYMBOLS=$(nm "$@" | grep __xla_cpu_runtime) +if echo "${SYMBOLS}" | grep -q SingleThread; then + echo "" 1>&2 + echo "Saw a SingleThread runtime symbol in $@:" 1>&2 + echo "" 1>&2 + echo "${SYMBOLS}" 1>&2 + echo "" 1>&2 + exit 1 +else + exit 0 +fi