Create intra op threadpool. This is required for some CPU operations (e.g. FFT).
The number of threads is configurable via the newly-introduced --intra_op_thread_pool_size flag. PiperOrigin-RevId: 248018440
This commit is contained in:
parent
43343ce22a
commit
a94a09ad78
@ -72,6 +72,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/tests:test_utils",
|
"//tensorflow/compiler/xla/tests:test_utils",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//third_party/eigen3",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
alwayslink = True,
|
alwayslink = True,
|
||||||
|
@ -34,13 +34,17 @@ limitations under the License.
|
|||||||
// Note: If you pass multiple modules, they will be compiled in parallel but run
|
// Note: If you pass multiple modules, they will be compiled in parallel but run
|
||||||
// in series.
|
// in series.
|
||||||
|
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/compiler/xla/client/client.h"
|
#include "tensorflow/compiler/xla/client/client.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/global_data.h"
|
#include "tensorflow/compiler/xla/client/global_data.h"
|
||||||
@ -61,6 +65,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/lib/core/threadpool.h"
|
#include "tensorflow/core/lib/core/threadpool.h"
|
||||||
|
#include "tensorflow/core/platform/cpu_info.h"
|
||||||
#include "tensorflow/core/platform/env.h"
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/init_main.h"
|
#include "tensorflow/core/platform/init_main.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
@ -73,6 +78,9 @@ namespace {
|
|||||||
// Command-line opts to this tool. See main() for descriptions of these
|
// Command-line opts to this tool. See main() for descriptions of these
|
||||||
// fields.
|
// fields.
|
||||||
struct Options {
|
struct Options {
|
||||||
|
Options()
|
||||||
|
: intra_op_thread_pool_size(tensorflow::port::NumSchedulableCPUs()) {}
|
||||||
|
|
||||||
string fake_infeed_shape;
|
string fake_infeed_shape;
|
||||||
string fake_outfeed_shape;
|
string fake_outfeed_shape;
|
||||||
|
|
||||||
@ -88,6 +96,8 @@ struct Options {
|
|||||||
bool use_fake_data = false;
|
bool use_fake_data = false;
|
||||||
bool print_result = true;
|
bool print_result = true;
|
||||||
int num_runs = 1;
|
int num_runs = 1;
|
||||||
|
|
||||||
|
int intra_op_thread_pool_size;
|
||||||
};
|
};
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
|
StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
|
||||||
@ -282,10 +292,16 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
|
|||||||
if (xla_hlo_profile && is_final_result) {
|
if (xla_hlo_profile && is_final_result) {
|
||||||
LOG(INFO) << "\n\n***** Final run below ******";
|
LOG(INFO) << "\n\n***** Final run below ******";
|
||||||
}
|
}
|
||||||
|
tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen",
|
||||||
|
opts.intra_op_thread_pool_size);
|
||||||
|
Eigen::ThreadPoolDevice thread_pool(pool.AsEigenThreadPool(),
|
||||||
|
pool.NumThreads());
|
||||||
|
|
||||||
ExecutionProfile profile;
|
ExecutionProfile profile;
|
||||||
ExecutableRunOptions run_options;
|
ExecutableRunOptions run_options;
|
||||||
run_options.set_execution_profile(&profile);
|
run_options.set_execution_profile(&profile);
|
||||||
run_options.set_allocator(&allocator);
|
run_options.set_allocator(&allocator);
|
||||||
|
run_options.set_intra_op_thread_pool(&thread_pool);
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
|
||||||
executable->Run(argument_ptrs, run_options));
|
executable->Run(argument_ptrs, run_options));
|
||||||
@ -439,6 +455,10 @@ int main(int argc, char** argv) {
|
|||||||
tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed,
|
tensorflow::Flag("generate_fake_outfeed", &opts.generate_fake_outfeed,
|
||||||
"Whether a fake outfeed shape should be derived "
|
"Whether a fake outfeed shape should be derived "
|
||||||
"from the computation"),
|
"from the computation"),
|
||||||
|
tensorflow::Flag("intra_op_thread_pool_size",
|
||||||
|
&opts.intra_op_thread_pool_size,
|
||||||
|
"How many threads to use in the intra-op thread pool. "
|
||||||
|
"Defaults to the number of CPUs."),
|
||||||
};
|
};
|
||||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||||
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||||
|
Loading…
Reference in New Issue
Block a user