Create intra op threadpool. This is required for some CPU operations (e.g. FFT).

The number of threads is configurable via the newly-introduced --intra_op_thread_pool_size flag.

PiperOrigin-RevId: 248018440
This commit is contained in:
Skye Wanderman-Milne 2019-05-13 15:03:50 -07:00 committed by TensorFlower Gardener
parent 43343ce22a
commit a94a09ad78
2 changed files with 21 additions and 0 deletions

View File

@ -72,6 +72,7 @@ cc_library(
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//third_party/eigen3",
"@com_google_absl//absl/types:span",
],
alwayslink = True,

View File

@ -34,13 +34,17 @@ limitations under the License.
// Note: If you pass multiple modules, they will be compiled in parallel but run
// in series.
#define EIGEN_USE_THREADS
#include <stdio.h>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "absl/types/span.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/client/client.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/global_data.h"
@ -61,6 +65,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
@ -73,6 +78,9 @@ namespace {
// Command-line opts to this tool. See main() for descriptions of these
// fields.
struct Options {
Options()
: intra_op_thread_pool_size(tensorflow::port::NumSchedulableCPUs()) {}
string fake_infeed_shape;
string fake_outfeed_shape;
@ -88,6 +96,8 @@ struct Options {
bool use_fake_data = false;
bool print_result = true;
int num_runs = 1;
int intra_op_thread_pool_size;
};
StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
@ -282,10 +292,16 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
if (xla_hlo_profile && is_final_result) {
LOG(INFO) << "\n\n***** Final run below ******";
}
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
opts.intra_op_thread_pool_size);
Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
pool.NumThreads());
ExecutionProfile profile;
ExecutableRunOptions run_options;
run_options.set_execution_profile(&profile);
run_options.set_allocator(&allocator);
run_options.set_intra_op_thread_pool(&thread_pool);
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
executable->Run(argument_ptrs, run_options));
@ -439,6 +455,10 @@ int main(int argc, char** argv) {
tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed,
"Whether a fake outfeed shape should be derived "
"from the computation"),
tensorflow::Flag("intra_op_thread_pool_size",
&opts.intra_op_thread_pool_size,
"How many threads to use in the intra-op thread pool. "
"Defaults to the number of CPUs."),
};
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);