[PJRT] Parallelize GPU compilation.
Repurpose the h2d threadpool into a more general common thread pool, and size it to the default maximum parallelism. Provide the threadpool to XLA:GPU allowing for parallel compilation. Rename misnamed member of xla::ExecutableBuildOptions. PiperOrigin-RevId: 347597838 Change-Id: I481c188d05ca5750284a6711f6eeabd141c28d4c
This commit is contained in:
parent
70bc79a21a
commit
c70e89e48a
@ -119,7 +119,7 @@ class ExecutableBuildOptions {
|
||||
tensorflow::thread::ThreadPool* compile_thread_pool() const {
|
||||
return compile_thread_pool_;
|
||||
}
|
||||
ExecutableBuildOptions& set_run_backend_only(
|
||||
ExecutableBuildOptions& set_compile_thread_pool(
|
||||
tensorflow::thread::ThreadPool* compile_thread_pool) {
|
||||
compile_thread_pool_ = compile_thread_pool;
|
||||
return *this;
|
||||
|
||||
@ -65,6 +65,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdlib>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -98,6 +99,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/fingerprint.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
@ -185,6 +187,19 @@ class CpuAllocator : public tensorflow::Allocator {
|
||||
}
|
||||
};
|
||||
|
||||
static int DefaultThreadPoolSize() {
|
||||
// Google's CI system exposes an environment variable NPROC that describes
|
||||
// a CPU reservation for tests.
|
||||
// TODO(phawkins): expose a better thought-out set of knobs to control
|
||||
// parallelism.
|
||||
const char* nproc_str = std::getenv("NPROC");
|
||||
int nproc = 0;
|
||||
if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
|
||||
return std::max(0, nproc);
|
||||
}
|
||||
return tensorflow::port::MaxParallelism();
|
||||
}
|
||||
|
||||
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
|
||||
@ -202,8 +217,9 @@ PjRtStreamExecutorClient::PjRtStreamExecutorClient(
|
||||
should_stage_host_to_device_transfers_(
|
||||
should_stage_host_to_device_transfers),
|
||||
gpu_run_options_(std::move(gpu_run_options)),
|
||||
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
||||
client->device_count()) {
|
||||
thread_pool_(tensorflow::Env::Default(), "pjrt_thread_pool",
|
||||
std::max<int>(DefaultThreadPoolSize(),
|
||||
client->device_count())) {
|
||||
if (owned_allocator_ != nullptr) {
|
||||
allocator_ = owned_allocator_.get();
|
||||
} else {
|
||||
@ -745,11 +761,11 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
|
||||
};
|
||||
if (is_cpu_platform) {
|
||||
// Using the h2d_transfer_pool would be a double thread hop; the code
|
||||
// Using the thread_pool would be a double thread hop; the code
|
||||
// already defers its work onto a stream (= thread on CPU).
|
||||
transfer_h2d();
|
||||
} else {
|
||||
h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
thread_pool()->Schedule(transfer_h2d);
|
||||
}
|
||||
return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
|
||||
}
|
||||
@ -844,7 +860,7 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
|
||||
.IgnoreError(); // Can return error::Unimplemented
|
||||
QCHECK(h2d_stream->ok());
|
||||
};
|
||||
h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
thread_pool()->Schedule(transfer_h2d);
|
||||
return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
|
||||
}
|
||||
|
||||
@ -2094,6 +2110,9 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
|
||||
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
|
||||
|
||||
ExecutableBuildOptions& build_options = options.executable_build_options;
|
||||
if (!build_options.compile_thread_pool()) {
|
||||
build_options.set_compile_thread_pool(thread_pool());
|
||||
}
|
||||
if (!build_options.device_allocator()) {
|
||||
build_options.set_device_allocator(allocator());
|
||||
}
|
||||
|
||||
@ -223,9 +223,7 @@ class PjRtStreamExecutorClient : public PjRtClient {
|
||||
return gpu_run_options_.get();
|
||||
}
|
||||
|
||||
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
|
||||
return &h2d_transfer_pool_;
|
||||
}
|
||||
tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; }
|
||||
|
||||
protected:
|
||||
friend class PjRtStreamExecutorBuffer;
|
||||
@ -268,7 +266,7 @@ class PjRtStreamExecutorClient : public PjRtClient {
|
||||
|
||||
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
|
||||
|
||||
tensorflow::thread::ThreadPool h2d_transfer_pool_;
|
||||
tensorflow::thread::ThreadPool thread_pool_;
|
||||
};
|
||||
|
||||
// Converts a 2D set of Device objects indexed by [replica][partition] into an
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user