Create intra op threadpool. This is required for some CPU operations (e.g. FFT).

The number of threads is configurable via the newly-introduced --intra_op_thread_pool_size flag.

PiperOrigin-RevId: 248018440
This commit is contained in:
Skye Wanderman-Milne 2019-05-13 15:03:50 -07:00 committed by TensorFlower Gardener
parent 43343ce22a
commit a94a09ad78
2 changed files with 21 additions and 0 deletions

View File

@ -72,6 +72,7 @@ cc_library(
"//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal", "//tensorflow/core:framework_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//third_party/eigen3",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
alwayslink = True, alwayslink = True,

View File

@ -34,13 +34,17 @@ limitations under the License.
// Note: If you pass multiple modules, they will be compiled in parallel but run // Note: If you pass multiple modules, they will be compiled in parallel but run
// in series. // in series.
#define EIGEN_USE_THREADS
#include <stdio.h> #include <stdio.h>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "absl/types/span.h" #include "absl/types/span.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/global_data.h"
@ -61,6 +65,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
@ -73,6 +78,9 @@ namespace {
// Command-line opts to this tool. See main() for descriptions of these // Command-line opts to this tool. See main() for descriptions of these
// fields. // fields.
struct Options { struct Options {
Options()
: intra_op_thread_pool_size(tensorflow::port::NumSchedulableCPUs()) {}
string fake_infeed_shape; string fake_infeed_shape;
string fake_outfeed_shape; string fake_outfeed_shape;
@ -88,6 +96,8 @@ struct Options {
bool use_fake_data = false; bool use_fake_data = false;
bool print_result = true; bool print_result = true;
int num_runs = 1; int num_runs = 1;
int intra_op_thread_pool_size;
}; };
StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable( StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
@ -282,10 +292,16 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
if (xla_hlo_profile && is_final_result) { if (xla_hlo_profile && is_final_result) {
LOG(INFO) << "\n\n***** Final run below ******"; LOG(INFO) << "\n\n***** Final run below ******";
} }
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
opts.intra_op_thread_pool_size);
Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
pool.NumThreads());
ExecutionProfile profile; ExecutionProfile profile;
ExecutableRunOptions run_options; ExecutableRunOptions run_options;
run_options.set_execution_profile(&profile); run_options.set_execution_profile(&profile);
run_options.set_allocator(&allocator); run_options.set_allocator(&allocator);
run_options.set_intra_op_thread_pool(&thread_pool);
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result, TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
executable->Run(argument_ptrs, run_options)); executable->Run(argument_ptrs, run_options));
@ -439,6 +455,10 @@ int main(int argc, char** argv) {
tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed, tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed,
"Whether a fake outfeed shape should be derived " "Whether a fake outfeed shape should be derived "
"from the computation"), "from the computation"),
tensorflow::Flag("intra_op_thread_pool_size",
&opts.intra_op_thread_pool_size,
"How many threads to use in the intra-op thread pool. "
"Defaults to the number of CPUs."),
}; };
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);