diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index e9244ecf9f1..4edd13c79c7 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -72,6 +72,7 @@ cc_library( "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", + "//third_party/eigen3", "@com_google_absl//absl/types:span", ], alwayslink = True, diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 3d443beeecb..852f1dfa9b0 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -34,13 +34,17 @@ limitations under the License. // Note: If you pass multiple modules, they will be compiled in parallel but run // in series. +#define EIGEN_USE_THREADS + #include + #include #include #include #include #include "absl/types/span.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/global_data.h" @@ -61,6 +65,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/logging.h" @@ -73,6 +78,9 @@ namespace { // Command-line opts to this tool. See main() for descriptions of these // fields. struct Options { + Options() + : intra_op_thread_pool_size(tensorflow::port::NumSchedulableCPUs()) {} + string fake_infeed_shape; string fake_outfeed_shape; @@ -88,6 +96,8 @@ struct Options { bool use_fake_data = false; bool print_result = true; int num_runs = 1; + + int intra_op_thread_pool_size; }; StatusOr> CompileExecutable( @@ -282,10 +292,16 @@ StatusOr ReplayComputation(const HloSnapshot& module, if (xla_hlo_profile && is_final_result) { LOG(INFO) << "\n\n***** Final run below ******"; } + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", + opts.intra_op_thread_pool_size); + Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(), + pool.NumThreads()); + ExecutionProfile profile; ExecutableRunOptions run_options; run_options.set_execution_profile(&profile); run_options.set_allocator(&allocator); + run_options.set_intra_op_thread_pool(&thread_pool); TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, executable->Run(argument_ptrs, run_options)); @@ -439,6 +455,10 @@ int main(int argc, char** argv) { tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed, "Whether a fake outfeed shape should be derived " "from the computation"), + tensorflow::Flag("intra_op_thread_pool_size", + &opts.intra_op_thread_pool_size, + "How many threads to use in the intra-op thread pool. " + "Defaults to the number of CPUs."), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);