Create intra op threadpool. This is required for some CPU operations (e.g. FFT).
The number of threads is configurable via the newly-introduced --intra_op_thread_pool_size flag. PiperOrigin-RevId: 248018440
This commit is contained in:
parent
43343ce22a
commit
a94a09ad78
@ -72,6 +72,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//third_party/eigen3",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
alwayslink = True,
|
||||
|
@ -34,13 +34,17 @@ limitations under the License.
|
||||
// Note: If you pass multiple modules, they will be compiled in parallel but run
|
||||
// in series.
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/client/client.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||
@ -61,6 +65,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -73,6 +78,9 @@ namespace {
|
||||
// Command-line opts to this tool. See main() for descriptions of these
|
||||
// fields.
|
||||
struct Options {
|
||||
Options()
|
||||
: intra_op_thread_pool_size(tensorflow::port::NumSchedulableCPUs()) {}
|
||||
|
||||
string fake_infeed_shape;
|
||||
string fake_outfeed_shape;
|
||||
|
||||
@ -88,6 +96,8 @@ struct Options {
|
||||
bool use_fake_data = false;
|
||||
bool print_result = true;
|
||||
int num_runs = 1;
|
||||
|
||||
int intra_op_thread_pool_size;
|
||||
};
|
||||
|
||||
StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
|
||||
@ -282,10 +292,16 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
|
||||
if (xla_hlo_profile && is_final_result) {
|
||||
LOG(INFO) << "\n\n***** Final run below ******";
|
||||
}
|
||||
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
||||
opts.intra_op_thread_pool_size);
|
||||
Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
|
||||
pool.NumThreads());
|
||||
|
||||
ExecutionProfile profile;
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_execution_profile(&profile);
|
||||
run_options.set_allocator(&allocator);
|
||||
run_options.set_intra_op_thread_pool(&thread_pool);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
||||
executable->Run(argument_ptrs, run_options));
|
||||
@ -439,6 +455,10 @@ int main(int argc, char** argv) {
|
||||
tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed,
|
||||
"Whether a fake outfeed shape should be derived "
|
||||
"from the computation"),
|
||||
tensorflow::Flag("intra_op_thread_pool_size",
|
||||
&opts.intra_op_thread_pool_size,
|
||||
"How many threads to use in the intra-op thread pool. "
|
||||
"Defaults to the number of CPUs."),
|
||||
};
|
||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
|
Loading…
Reference in New Issue
Block a user