[PJRT] Parallelize GPU compilation.
Repurpose the h2d threadpool into a more general common thread pool, and size it to the default maximum parallelism. Provide the threadpool to XLA:GPU allowing for parallel compilation. Rename misnamed member of xla::ExecutableBuildOptions. PiperOrigin-RevId: 346882817 Change-Id: I0b83598466bb1a13cf89b38ea094c2f4a7bc974b
This commit is contained in:
parent
43946a6e32
commit
b938722556
@ -119,7 +119,7 @@ class ExecutableBuildOptions {
|
||||
tensorflow::thread::ThreadPool* compile_thread_pool() const {
|
||||
return compile_thread_pool_;
|
||||
}
|
||||
ExecutableBuildOptions& set_compile_thread_pool(
|
||||
ExecutableBuildOptions& set_run_backend_only(
|
||||
tensorflow::thread::ThreadPool* compile_thread_pool) {
|
||||
compile_thread_pool_ = compile_thread_pool;
|
||||
return *this;
|
||||
|
@ -97,7 +97,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
@ -202,9 +201,8 @@ PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||
should_stage_host_to_device_transfers_(
|
||||
should_stage_host_to_device_transfers),
|
||||
gpu_run_options_(std::move(gpu_run_options)),
|
||||
thread_pool_(tensorflow::Env::Default(), "pjrt_thread_pool",
|
||||
std::max<int>(tensorflow::port::MaxParallelism(),
|
||||
client->device_count())) {
|
||||
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
||||
client->device_count()) {
|
||||
if (owned_allocator_ != nullptr) {
|
||||
allocator_ = owned_allocator_.get();
|
||||
} else {
|
||||
@ -746,11 +744,11 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
|
||||
};
|
||||
if (is_cpu_platform) {
|
||||
// Using the thread_pool would be a double thread hop; the code
|
||||
// Using the h2d_transfer_pool would be a double thread hop; the code
|
||||
// already defers its work onto a stream (= thread on CPU).
|
||||
transfer_h2d();
|
||||
} else {
|
||||
thread_pool()->Schedule(transfer_h2d);
|
||||
h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
}
|
||||
return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
|
||||
}
|
||||
@ -845,7 +843,7 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
|
||||
.IgnoreError(); // Can return error::Unimplemented
|
||||
QCHECK(h2d_stream->ok());
|
||||
};
|
||||
thread_pool()->Schedule(transfer_h2d);
|
||||
h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
|
||||
}
|
||||
|
||||
@ -2236,9 +2234,6 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
|
||||
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
|
||||
|
||||
ExecutableBuildOptions& build_options = options.executable_build_options;
|
||||
if (!build_options.compile_thread_pool()) {
|
||||
build_options.set_compile_thread_pool(thread_pool());
|
||||
}
|
||||
if (!build_options.device_allocator()) {
|
||||
build_options.set_device_allocator(allocator());
|
||||
}
|
||||
|
@ -223,7 +223,9 @@ class PjRtStreamExecutorClient : public PjRtClient {
|
||||
return gpu_run_options_.get();
|
||||
}
|
||||
|
||||
tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; }
|
||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||
return &h2d_transfer_pool_;
|
||||
}
|
||||
|
||||
protected:
|
||||
friend class PjRtStreamExecutorBuffer;
|
||||
@ -266,7 +268,7 @@ class PjRtStreamExecutorClient : public PjRtClient {
|
||||
|
||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
|
||||
|
||||
tensorflow::thread::ThreadPool thread_pool_;
|
||||
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
||||
};
|
||||
|
||||
// Converts a 2D set of Device objects indexed by [replica][partition] into an
|
||||
|
Loading…
x
Reference in New Issue
Block a user