[PJRT] Parallelize GPU compilation.

Repurpose the h2d threadpool into a more general common thread pool, and size it to the default maximum parallelism. Provide the threadpool to XLA:GPU allowing for parallel compilation.

Rename misnamed member of xla::ExecutableBuildOptions.

PiperOrigin-RevId: 347597838
Change-Id: I481c188d05ca5750284a6711f6eeabd141c28d4c
This commit is contained in:
Peter Hawkins 2020-12-15 06:13:00 -08:00 committed by TensorFlower Gardener
parent 70bc79a21a
commit c70e89e48a
3 changed files with 27 additions and 10 deletions

View File

@ -119,7 +119,7 @@ class ExecutableBuildOptions {
tensorflow::thread::ThreadPool* compile_thread_pool() const {
return compile_thread_pool_;
}
ExecutableBuildOptions& set_run_backend_only(
ExecutableBuildOptions& set_compile_thread_pool(
tensorflow::thread::ThreadPool* compile_thread_pool) {
compile_thread_pool_ = compile_thread_pool;
return *this;

View File

@ -65,6 +65,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include <cstddef>
#include <cstdlib>
#include <memory>
#include <string>
#include <vector>
@ -98,6 +99,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/fingerprint.h"
#include "tensorflow/core/platform/mem.h"
@ -185,6 +187,19 @@ class CpuAllocator : public tensorflow::Allocator {
}
};
static int DefaultThreadPoolSize() {
// Google's CI system exposes an environment variable NPROC that describes
// a CPU reservation for tests.
// TODO(phawkins): expose a better thought-out set of knobs to control
// parallelism.
const char* nproc_str = std::getenv("NPROC");
int nproc = 0;
if (nproc_str && absl::SimpleAtoi(nproc_str, &nproc)) {
return std::max(0, nproc);
}
return tensorflow::port::MaxParallelism();
}
PjRtStreamExecutorClient::PjRtStreamExecutorClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices, int host_id,
@ -202,8 +217,9 @@ PjRtStreamExecutorClient::PjRtStreamExecutorClient(
should_stage_host_to_device_transfers_(
should_stage_host_to_device_transfers),
gpu_run_options_(std::move(gpu_run_options)),
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
client->device_count()) {
thread_pool_(tensorflow::Env::Default(), "pjrt_thread_pool",
std::max<int>(DefaultThreadPoolSize(),
client->device_count())) {
if (owned_allocator_ != nullptr) {
allocator_ = owned_allocator_.get();
} else {
@ -745,11 +761,11 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
};
if (is_cpu_platform) {
// Using the h2d_transfer_pool would be a double thread hop; the code
// Using the thread_pool would be a double thread hop; the code
// already defers its work onto a stream (= thread on CPU).
transfer_h2d();
} else {
h2d_transfer_pool()->Schedule(transfer_h2d);
thread_pool()->Schedule(transfer_h2d);
}
return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
}
@ -844,7 +860,7 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
.IgnoreError(); // Can return error::Unimplemented
QCHECK(h2d_stream->ok());
};
h2d_transfer_pool()->Schedule(transfer_h2d);
thread_pool()->Schedule(transfer_h2d);
return std::unique_ptr<PjRtBuffer>(std::move(py_buffer));
}
@ -2094,6 +2110,9 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
tensorflow::profiler::TraceMe traceme("PjRtStreamExecutorClient::Compile");
ExecutableBuildOptions& build_options = options.executable_build_options;
if (!build_options.compile_thread_pool()) {
build_options.set_compile_thread_pool(thread_pool());
}
if (!build_options.device_allocator()) {
build_options.set_device_allocator(allocator());
}

View File

@ -223,9 +223,7 @@ class PjRtStreamExecutorClient : public PjRtClient {
return gpu_run_options_.get();
}
tensorflow::thread::ThreadPool* h2d_transfer_pool() {
return &h2d_transfer_pool_;
}
tensorflow::thread::ThreadPool* thread_pool() { return &thread_pool_; }
protected:
friend class PjRtStreamExecutorBuffer;
@ -268,7 +266,7 @@ class PjRtStreamExecutorClient : public PjRtClient {
std::unique_ptr<gpu::GpuExecutableRunOptions> gpu_run_options_;
tensorflow::thread::ThreadPool h2d_transfer_pool_;
tensorflow::thread::ThreadPool thread_pool_;
};
// Converts a 2D set of Device objects indexed by [replica][partition] into an