Refactor PJRT.
- Make static methods of PjRtBuffer and PjRtExecutable instance methods on PjRtClient to allow us to extract a set of interfaces out of PJRT. PiperOrigin-RevId: 338101552 Change-Id: I8c10295948ea73d7d4157760a1cd8991384a01dc
This commit is contained in:
parent
b737cff5fd
commit
f187f93d7b
@ -54,9 +54,9 @@ TEST(GpuMultiStream, Basics) {
|
||||
device_assignment(0, 0) = device->id();
|
||||
compile_options.executable_build_options.set_device_assignment(
|
||||
device_assignment);
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtExecutable> executable,
|
||||
PjRtExecutable::Compile(computation, client.get(),
|
||||
std::move(compile_options)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
client->Compile(computation, std::move(compile_options)));
|
||||
|
||||
int64 dummy_size = 1 << 20;
|
||||
std::vector<int32> dummy_inputs(dummy_size);
|
||||
@ -71,22 +71,22 @@ TEST(GpuMultiStream, Basics) {
|
||||
// must wait.
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto dummy_buffer,
|
||||
PjRtBuffer::FromHostBuffer(
|
||||
client->BufferFromHostBuffer(
|
||||
dummy_inputs.data(), dummy_shape,
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, client.get(), device));
|
||||
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, device));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto in_buffer0,
|
||||
PjRtBuffer::FromHostBuffer(
|
||||
client->BufferFromHostBuffer(
|
||||
inputs.data(), shape,
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, client.get(), device));
|
||||
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, device));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto in_buffer1,
|
||||
PjRtBuffer::FromHostBuffer(
|
||||
client->BufferFromHostBuffer(
|
||||
inputs.data(), shape,
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, client.get(), device));
|
||||
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, device));
|
||||
// The execution may be enqueued before the transfers complete, requiring
|
||||
// adequate device-side synchronization.
|
||||
ExecuteOptions options;
|
||||
|
@ -576,24 +576,21 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
||||
}
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
||||
PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
|
||||
VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString()
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
|
||||
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
|
||||
<< " device: " << device->DebugString();
|
||||
if (shape.IsTuple()) {
|
||||
return InvalidArgument("Use FromHostLiteral to transfer a tuple");
|
||||
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||
transfer_manager->ChooseCompactLayoutForShape(shape));
|
||||
|
||||
@ -628,10 +625,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
};
|
||||
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
|
||||
} else {
|
||||
void* staging_buffer = client->host_memory_allocator()->AllocateRaw(
|
||||
void* staging_buffer = host_memory_allocator()->AllocateRaw(
|
||||
cpu_function_runtime::kMinAlign, size);
|
||||
on_delete_callback = [staging_buffer, client]() {
|
||||
client->host_memory_allocator()->DeallocateRaw(staging_buffer);
|
||||
on_delete_callback = [staging_buffer, host_memory_allocator =
|
||||
host_memory_allocator()]() {
|
||||
host_memory_allocator->DeallocateRaw(staging_buffer);
|
||||
};
|
||||
buffer = se::DeviceMemoryBase(staging_buffer, size);
|
||||
std::memcpy(staging_buffer, data, size);
|
||||
@ -643,7 +641,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||
definition_events, std::move(on_delete_callback));
|
||||
return absl::make_unique<PjRtBuffer>(
|
||||
shape, shape, std::move(device_buffer), client, device);
|
||||
shape, shape, std::move(device_buffer), this, device);
|
||||
}
|
||||
}
|
||||
|
||||
@ -651,21 +649,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
std::unique_ptr<PjRtBuffer> py_buffer,
|
||||
AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||
local_device->host_to_device_stream(),
|
||||
/*is_uninitialized_create=*/false, client));
|
||||
/*is_uninitialized_create=*/false, this));
|
||||
|
||||
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
||||
PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
||||
CHECK(device_buffer.ok());
|
||||
|
||||
// If necessary, allocate a host-side buffer for staging host-to-device
|
||||
// transfers. On GPU this is a buffer in pinned memory.
|
||||
std::shared_ptr<void> staging_buffer;
|
||||
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
|
||||
client->should_stage_host_to_device_transfers()) {
|
||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||
should_stage_host_to_device_transfers()) {
|
||||
void* ptr = host_memory_allocator()->AllocateRaw(
|
||||
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
|
||||
client->host_memory_allocator()->DeallocateRaw(ptr);
|
||||
});
|
||||
staging_buffer = std::shared_ptr<void>(
|
||||
ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
|
||||
host_memory_allocator->DeallocateRaw(ptr);
|
||||
});
|
||||
}
|
||||
|
||||
// Copy the buffer into a staging buffer before returning control to the
|
||||
@ -684,14 +683,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
// usage holds have gone away.
|
||||
// TODO(misard) assess if it would be preferable to introduce a heuristic to
|
||||
// put the transfer into the calling thread for small literals.
|
||||
auto transfer_h2d = [client, transfer_manager, local_device, data, size,
|
||||
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
|
||||
data, size,
|
||||
movable_device_buffer{device_buffer.ToClosure()}, shape,
|
||||
py_buffer{py_buffer.get()}, compact_shape,
|
||||
on_device_shape{py_buffer->on_device_shape()},
|
||||
staging_buffer{std::move(staging_buffer)},
|
||||
buffer_reference{std::move(buffer_reference)},
|
||||
host_buffer_semantics]() {
|
||||
ScopedHold device_buffer(movable_device_buffer);
|
||||
PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
|
||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
||||
// to report failures from a callback. However, the operations here are
|
||||
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
||||
@ -699,7 +699,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
// allocation.
|
||||
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
||||
compact_shape, on_device_shape, client->client()->platform());
|
||||
compact_shape, on_device_shape, local_client->platform());
|
||||
// If applicable on the backend, stage the transfer via host memory
|
||||
// allocated via the host_memory_allocator. On GPU, this is pinned
|
||||
// memory.
|
||||
@ -736,41 +736,38 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
// already defers its work onto a stream (= thread on CPU).
|
||||
transfer_h2d();
|
||||
} else {
|
||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
}
|
||||
return py_buffer;
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CreateUninitialized(
|
||||
const Shape& shape, PjRtClient* client, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized");
|
||||
VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString()
|
||||
<< " device: " << device->DebugString();
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme(
|
||||
"PjRtClient::CreateUninitializedBuffer");
|
||||
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
|
||||
<< shape.ToString() << " device: " << device->DebugString();
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||
transfer_manager->ChooseCompactLayoutForShape(shape));
|
||||
|
||||
return AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||
/*copy_stream=*/nullptr,
|
||||
/*is_uninitialized_create=*/true, client);
|
||||
/*is_uninitialized_create=*/true, this);
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral");
|
||||
VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: "
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtDevice* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
|
||||
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
|
||||
<< literal.shape().ToString() << " device: " << device->DebugString();
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
TransferManager* transfer_manager = client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape compact_shape,
|
||||
transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
|
||||
@ -778,9 +775,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
std::unique_ptr<PjRtBuffer> py_buffer,
|
||||
AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||
local_device->host_to_device_stream(),
|
||||
/*is_uninitialized_create=*/false, client));
|
||||
/*is_uninitialized_create=*/false, this));
|
||||
|
||||
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
||||
PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
||||
CHECK(device_buffer.ok());
|
||||
|
||||
// The host to device transfer is performed on a thread pool, mostly because
|
||||
@ -789,11 +786,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
// usage holds have gone away.
|
||||
// TODO(misard) assess if it would be preferable to introduce a heuristic to
|
||||
// put the transfer into the calling thread for small literals.
|
||||
auto transfer_h2d = [client, transfer_manager, local_device,
|
||||
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
|
||||
movable_device_buffer{device_buffer.ToClosure()},
|
||||
literal, py_buffer{py_buffer.get()}, compact_shape,
|
||||
on_device_shape{py_buffer->on_device_shape()}]() {
|
||||
ScopedHold device_buffer(movable_device_buffer);
|
||||
PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
|
||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
||||
// to report failures from a callback. However, the operations here are
|
||||
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
||||
@ -802,7 +799,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
|
||||
se::Stream* h2d_stream = local_device->host_to_device_stream();
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
||||
compact_shape, on_device_shape, client->client()->platform());
|
||||
compact_shape, on_device_shape, local_client->platform());
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
h2d_stream, literal, buffer));
|
||||
|
||||
@ -817,12 +814,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
.IgnoreError(); // Can return error::Unimplemented
|
||||
QCHECK(h2d_stream->ok());
|
||||
};
|
||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
return py_buffer;
|
||||
}
|
||||
|
||||
/*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, PjRtClient* client, PjRtDevice* device,
|
||||
void PjRtClient::MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier) {
|
||||
if (shapes.empty()) {
|
||||
notifier(InvalidArgument(
|
||||
@ -843,7 +840,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or =
|
||||
AllocateDestinationBuffer(shape, device, local_device,
|
||||
/*copy_stream=*/nullptr,
|
||||
/*is_uninitialized_create=*/false, client);
|
||||
/*is_uninitialized_create=*/false, this);
|
||||
if (!buffer_or.ok()) {
|
||||
notifier(buffer_or.status());
|
||||
return;
|
||||
@ -851,7 +848,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
|
||||
buffers.push_back(buffer_or.ConsumeValueOrDie());
|
||||
}
|
||||
|
||||
client->EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
|
||||
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
|
||||
}
|
||||
|
||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
@ -1159,7 +1156,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
|
||||
|
||||
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
|
||||
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
|
||||
CopyToHostAsyncInternal(discard_cached_copy, layout));
|
||||
if (host_value == nullptr) {
|
||||
@ -1267,9 +1264,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
|
||||
// Copying across PjRtClients involves a copy through the host.
|
||||
if (dst_device->client() != client_) {
|
||||
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
|
||||
return FromHostBuffer(literal->untyped_data(), literal->shape(),
|
||||
HostBufferSemantics::kZeroCopy, nullptr,
|
||||
dst_device->client(), dst_device);
|
||||
return dst_device->client()->BufferFromHostBuffer(
|
||||
literal->untyped_data(), literal->shape(),
|
||||
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
|
||||
@ -2061,14 +2058,13 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
|
||||
} // namespace
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtExecutable::Compile(
|
||||
const XlaComputation& computation, PjRtClient* client,
|
||||
CompileOptions options) {
|
||||
tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile");
|
||||
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
|
||||
const XlaComputation& computation, CompileOptions options) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
|
||||
|
||||
ExecutableBuildOptions& build_options = options.executable_build_options;
|
||||
if (!build_options.device_allocator()) {
|
||||
build_options.set_device_allocator(client->allocator());
|
||||
build_options.set_device_allocator(allocator());
|
||||
}
|
||||
|
||||
int num_replicas;
|
||||
@ -2084,14 +2080,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
num_partitions = 1;
|
||||
} else {
|
||||
if (!build_options.has_device_assignment()) {
|
||||
VLOG(2) << "PjRtExecutable::Compile using default device_assignment.";
|
||||
VLOG(2) << "PjRtClient::Compile using default device_assignment.";
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
DeviceAssignment device_assignment,
|
||||
client->GetDefaultDeviceAssignment(build_options.num_replicas(),
|
||||
build_options.num_partitions()));
|
||||
GetDefaultDeviceAssignment(build_options.num_replicas(),
|
||||
build_options.num_partitions()));
|
||||
build_options.set_device_assignment(device_assignment);
|
||||
}
|
||||
VLOG(2) << "PjRtExecutable::Compile device_assignment:\n"
|
||||
VLOG(2) << "PjRtClient::Compile device_assignment:\n"
|
||||
<< build_options.device_assignment().ToString();
|
||||
num_replicas = build_options.device_assignment().replica_count();
|
||||
num_partitions = build_options.device_assignment().computation_count();
|
||||
@ -2118,7 +2114,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
|
||||
// Assign a default layout based on `sharded_shape` to any array subshapes in
|
||||
// `dst_shape` that are missing layouts.
|
||||
auto assign_layouts = [client](const Shape& sharded_shape, Shape* dst_shape) {
|
||||
auto assign_layouts = [local_client = client()](const Shape& sharded_shape,
|
||||
Shape* dst_shape) {
|
||||
return ShapeUtil::ForEachMutableSubshapeWithStatus(
|
||||
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
|
||||
if (subshape->IsArray() && !subshape->has_layout()) {
|
||||
@ -2126,8 +2123,7 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
const Shape& sharded_subshape =
|
||||
ShapeUtil::GetSubshape(sharded_shape, idx);
|
||||
LayoutUtil::SetToDefaultLayout(subshape);
|
||||
TF_ASSIGN_OR_RETURN(Shape layout, client->client()
|
||||
->backend()
|
||||
TF_ASSIGN_OR_RETURN(Shape layout, local_client->backend()
|
||||
.transfer_manager()
|
||||
->ChooseCompactLayoutForShape(
|
||||
sharded_subshape));
|
||||
@ -2162,8 +2158,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
for (int replica = 0; replica < num_replicas; ++replica) {
|
||||
for (int partition = 0; partition < num_partitions; ++partition) {
|
||||
int device_id = (*device_assignment)(replica, partition);
|
||||
PjRtDevice* device = LookupDevice(*client, device_id);
|
||||
if (device->host_id() != client->host_id()) {
|
||||
PjRtDevice* device = LookupDevice(*this, device_id);
|
||||
if (device->host_id() != host_id()) {
|
||||
VLOG(3) << "Non-local device: " << device_id;
|
||||
continue;
|
||||
}
|
||||
@ -2185,15 +2181,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<std::unique_ptr<LocalExecutable>> local_executables,
|
||||
client->client()->Compile(computation, argument_layout_pointers,
|
||||
build_options));
|
||||
client()->Compile(computation, argument_layout_pointers, build_options));
|
||||
|
||||
auto executable = absl::make_unique<PjRtExecutable>(
|
||||
std::move(local_executables), options.parameter_is_tupled_arguments,
|
||||
std::move(device_assignment), std::move(local_logical_device_ids),
|
||||
std::move(local_devices), client);
|
||||
std::move(local_devices), this);
|
||||
TF_RETURN_IF_ERROR(
|
||||
executable->SetUpDonation(client, options.parameter_is_tupled_arguments));
|
||||
executable->SetUpDonation(this, options.parameter_is_tupled_arguments));
|
||||
return executable;
|
||||
}
|
||||
|
||||
|
@ -120,6 +120,24 @@ struct PjRtCrossHostRecvBuffer {
|
||||
using PjRtCrossHostRecvNotifier =
|
||||
std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
|
||||
|
||||
struct CompileOptions {
|
||||
// The layouts of the arguments that the computation should expect.
|
||||
absl::optional<std::vector<Shape>> argument_layouts;
|
||||
|
||||
// If true, the supplied computation expects its arguments to be wrapped in a
|
||||
// tuple and passed as a single parameter.
|
||||
bool parameter_is_tupled_arguments = false;
|
||||
|
||||
// XLA's compilation time options.
|
||||
ExecutableBuildOptions executable_build_options;
|
||||
|
||||
// If true, the executable can be run on any device. May only be true if
|
||||
// !executable_build_options.has_device_assignment(), so only applies to
|
||||
// single-device executables. Beware: on GPUs, sometimes an executable
|
||||
// compiled for one device doesn't run on another.
|
||||
bool compile_portable_executable = false;
|
||||
};
|
||||
|
||||
class PjRtExecutable;
|
||||
|
||||
// Encapsulates the state of Python session with XLA.
|
||||
@ -198,6 +216,63 @@ class PjRtClient {
|
||||
// Returns a backend-specific HLO cost analysis visitor.
|
||||
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
|
||||
|
||||
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options);
|
||||
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
|
||||
const Shape& shape, PjRtDevice* device);
|
||||
|
||||
// Describes the semantics the caller to BufferFromHostBuffer expects from the
|
||||
// runtime, in a total order from most restrictive to least restrictive.
|
||||
enum class HostBufferSemantics {
|
||||
// The runtime may not hold references to `data` after the call to
|
||||
// `BufferFromHostBuffer` completes. The caller promises that `data` is
|
||||
// immutable and will not be freed only for the duration of the
|
||||
// BufferFromHostBuffer call. `buffer_reference` will be freed by the time
|
||||
// `BufferFromHostBuffer` returns.
|
||||
kImmutableOnlyDuringCall,
|
||||
|
||||
// The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
|
||||
// returns while the runtime completes a transfer to the device. The caller
|
||||
// promises not to mutate or free `data` until the transfer completes, at
|
||||
// which point the runtime will release `buffer_reference`. It is also
|
||||
// correct to wait on the host (directly or indirectly) for the buffer's
|
||||
// definition event to complete.
|
||||
kImmutableUntilTransferCompletes,
|
||||
|
||||
// The PjRtBuffer may alias `data` internally and the runtime may use the
|
||||
// `data` contents as long as the buffer is alive. The caller promises to
|
||||
// keep `data` alive and not to mutate its contents as long as the buffer is
|
||||
// alive; to notify the caller that the buffer may be freed, the runtime
|
||||
// will release its `buffer_reference` when the PjRtBuffer is freed. On
|
||||
// non-CPU platforms this acts identically to
|
||||
// kImmutableUntilTransferCompletes.
|
||||
kZeroCopy,
|
||||
};
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtDevice* device);
|
||||
|
||||
// Note that literal must remain in scope until the transfer has completed, so
|
||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||
// the return value before letting literal go out of scope.
|
||||
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtDevice* device);
|
||||
|
||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||
// shapes, with identical layouts, corresponding to the buffers that will be
|
||||
// sent. When resources for the transfer are available, notifier will be
|
||||
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
|
||||
// shape in `shapes`. Each struct contains a buffer that will contain the
|
||||
// received value, and an opaque string that should be transmitted to the
|
||||
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
||||
// buffers will become ready until *all* of the sends have completed.
|
||||
virtual void MakeCrossHostReceiveBuffers(
|
||||
absl::Span<const Shape> shapes, PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier);
|
||||
|
||||
protected:
|
||||
friend class PjRtBuffer;
|
||||
virtual void EnqueueCrossHostReceive(
|
||||
@ -385,6 +460,7 @@ class PjRtBuffer {
|
||||
|
||||
private:
|
||||
friend class PjRtBuffer;
|
||||
friend class PjRtClient;
|
||||
|
||||
// Helper struct that makes it possible to move a ScopedHold through a
|
||||
// closure.
|
||||
@ -423,62 +499,6 @@ class PjRtBuffer {
|
||||
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
|
||||
};
|
||||
|
||||
// Returns a buffer with uninitialized contents.
|
||||
static StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitialized(
|
||||
const Shape& shape, PjRtClient* client, PjRtDevice* device);
|
||||
|
||||
// Describes the semantics the caller to FromHostBuffer expects from the
|
||||
// runtime, in a total order from most restrictive to least restrictive.
|
||||
enum class HostBufferSemantics {
|
||||
// The runtime may not hold references to `data` after the call to
|
||||
// `FromHostBuffer` completes. The caller promises that `data` is immutable
|
||||
// and will not be freed only for the duration of the FromHostBuffer call.
|
||||
// `buffer_reference` will be freed by the time `FromHostBuffer` returns.
|
||||
kImmutableOnlyDuringCall,
|
||||
|
||||
// The runtime may hold onto `data` after the call to `FromHostBuffer`
|
||||
// returns while the runtime completes a transfer to the device. The caller
|
||||
// promises not to mutate or free `data` until the transfer completes, at
|
||||
// which point the runtime will release `buffer_reference`. It is also
|
||||
// correct to wait on the host (directly or indirectly) for the buffer's
|
||||
// definition event to complete.
|
||||
kImmutableUntilTransferCompletes,
|
||||
|
||||
// The PjRtBuffer may alias `data` internally and the runtime may use the
|
||||
// `data` contents as long as the buffer is alive.
|
||||
// The caller promises to keep `data` alive and not to mutate its contents
|
||||
// as long as the buffer is alive; to notify the caller that the buffer may
|
||||
// be freed, the runtime will release its `buffer_reference` when the
|
||||
// PjRtBuffer is freed. On non-CPU platforms this acts identically to
|
||||
// kImmutableUntilTransferCompletes.
|
||||
kZeroCopy,
|
||||
};
|
||||
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
||||
PjRtDevice* device);
|
||||
|
||||
// Note that literal must remain in scope until the transfer has completed, so
|
||||
// the caller should, for example, wait for BlockHostUntilReady() completes on
|
||||
// the return value before letting literal go out of scope.
|
||||
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostLiteral(
|
||||
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device);
|
||||
|
||||
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
|
||||
// cross host transfers using `client` on `device'. `shapes` must be the exact
|
||||
// shapes, with identical layouts, corresponding to the buffers that will be
|
||||
// sent. When resources for the transfer are available, notifier will be
|
||||
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
|
||||
// shape in `shapes`. Each struct contains a buffer that will contain the
|
||||
// received value, and an opaque string that should be transmitted to the
|
||||
// sending host and used in a call to CopyToRemoteDevice. None of the recv
|
||||
// buffers will become ready until *all* of the sends have completed.
|
||||
static void MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
|
||||
PjRtClient* client,
|
||||
PjRtDevice* device,
|
||||
PjRtCrossHostRecvNotifier&& notifier);
|
||||
|
||||
PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||
PjRtClient* client, PjRtDevice* device);
|
||||
@ -661,24 +681,6 @@ class PjRtBuffer {
|
||||
Semaphore donation_semaphore_;
|
||||
};
|
||||
|
||||
struct CompileOptions {
|
||||
// The layouts of the arguments that the computation should expect.
|
||||
absl::optional<std::vector<Shape>> argument_layouts;
|
||||
|
||||
// If true, the supplied computation expects its arguments to be wrapped in a
|
||||
// tuple and passed as a single parameter.
|
||||
bool parameter_is_tupled_arguments = false;
|
||||
|
||||
// XLA's compilation time options.
|
||||
ExecutableBuildOptions executable_build_options;
|
||||
|
||||
// If true, the executable can be run on any device. May only be true if
|
||||
// !executable_build_options.has_device_assignment(), so only applies to
|
||||
// single-device executables. Beware: on GPUs, sometimes an executable
|
||||
// compiled for one device doesn't run on another.
|
||||
bool compile_portable_executable = false;
|
||||
};
|
||||
|
||||
class ExecuteContext {
|
||||
public:
|
||||
virtual ~ExecuteContext() = default;
|
||||
@ -710,10 +712,6 @@ struct ExecuteOptions {
|
||||
// buffer will be donated when passed to the execution.
|
||||
class PjRtExecutable {
|
||||
public:
|
||||
static StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
|
||||
const XlaComputation& computation, PjRtClient* client,
|
||||
CompileOptions options);
|
||||
|
||||
PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
|
||||
bool parameter_is_tupled_arguments,
|
||||
std::shared_ptr<DeviceAssignment> device_assignment,
|
||||
@ -783,6 +781,7 @@ class PjRtExecutable {
|
||||
}
|
||||
|
||||
private:
|
||||
friend class PjRtClient;
|
||||
// Initializes information about which arguments to which executables must be
|
||||
// donated due to aliases that were specified by the computation.
|
||||
Status SetUpDonation(PjRtClient* client, bool tuple_inputs);
|
||||
|
@ -465,10 +465,10 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
|
||||
xla::PjRtDevice* device) {
|
||||
CppType data = py::cast<Pybind11Type>(scalar);
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
|
||||
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
|
||||
return ValueOrThrow(client->BufferFromHostBuffer(
|
||||
&data, shape,
|
||||
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
|
||||
client, device));
|
||||
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
|
||||
device));
|
||||
}
|
||||
|
||||
// Convert a scalar to the associated PjRtBuffer or raises an error if it is
|
||||
@ -502,17 +502,17 @@ StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer(
|
||||
if (jax_enable_x64) {
|
||||
xla::complex128 data(result.real, result.imag);
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
|
||||
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
|
||||
return ValueOrThrow(client->BufferFromHostBuffer(
|
||||
&data, shape,
|
||||
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall,
|
||||
nullptr, client, device));
|
||||
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
|
||||
nullptr, device));
|
||||
} else {
|
||||
xla::complex64 data(result.real, result.imag);
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
|
||||
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
|
||||
return ValueOrThrow(client->BufferFromHostBuffer(
|
||||
&data, shape,
|
||||
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall,
|
||||
nullptr, client, device));
|
||||
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
|
||||
nullptr, device));
|
||||
}
|
||||
}
|
||||
return InvalidArgument(
|
||||
@ -678,7 +678,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||
ValueOrThrow(pyclient.BufferFromPyval(
|
||||
numpy_array, data_device,
|
||||
/*force_copy=*/false, /*host_buffer_semantics=*/
|
||||
xla::PjRtBuffer::HostBufferSemantics::kZeroCopy));
|
||||
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
|
||||
arg_buffers.push_back(buffer->buffer());
|
||||
|
||||
ArgSignature sig;
|
||||
|
@ -409,10 +409,9 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
|
||||
compile_options.executable_build_options.set_device_assignment(
|
||||
device_assignment);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
PjRtExecutable::Compile(computation, devices_[device_idx]->client(),
|
||||
std::move(compile_options)));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
||||
devices_[device_idx]->client()->Compile(
|
||||
computation, std::move(compile_options)));
|
||||
ExecuteOptions execute_options;
|
||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
|
||||
executable->Execute({}, execute_options));
|
||||
|
@ -40,9 +40,8 @@ Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id,
|
||||
compile_options.executable_build_options.set_device_assignment(
|
||||
device_assignment);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtExecutable> executable,
|
||||
PjRtExecutable::Compile(computation, client, std::move(compile_options)));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
|
||||
client->Compile(computation, std::move(compile_options)));
|
||||
ExecuteOptions execute_options;
|
||||
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
|
||||
executable->Execute({}, execute_options));
|
||||
|
@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
||||
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
|
||||
PjRtClient::HostBufferSemantics host_buffer_semantics) {
|
||||
if (device == nullptr) {
|
||||
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
||||
device = pjrt_client_->local_devices().front();
|
||||
@ -114,10 +114,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
|
||||
std::unique_ptr<PjRtBuffer> buffer;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
buffer, PjRtBuffer::FromHostBuffer(
|
||||
c->buf_ptr, c->shape, host_buffer_semantics,
|
||||
std::move(py_buffer_ref), pjrt_client_.get(), device));
|
||||
TF_ASSIGN_OR_RETURN(buffer, pjrt_client_->BufferFromHostBuffer(
|
||||
c->buf_ptr, c->shape, host_buffer_semantics,
|
||||
std::move(py_buffer_ref), device));
|
||||
}
|
||||
auto traceback = Traceback::Get();
|
||||
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
||||
@ -131,8 +130,7 @@ StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(executable,
|
||||
PjRtExecutable::Compile(computation, pjrt_client_.get(),
|
||||
std::move(options)));
|
||||
pjrt_client_->Compile(computation, std::move(options)));
|
||||
TF_ASSIGN_OR_RETURN(fingerprint,
|
||||
pjrt_client_->ExecutableFingerprint(*executable));
|
||||
}
|
||||
|
@ -123,7 +123,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
|
||||
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
|
||||
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
|
||||
PjRtClient::HostBufferSemantics host_buffer_semantics);
|
||||
|
||||
StatusOr<std::shared_ptr<PyExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options);
|
||||
|
@ -535,12 +535,12 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
|
||||
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
|
||||
|
||||
py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics")
|
||||
py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics")
|
||||
.value("IMMUTABLE_ONLY_DURING_CALL",
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall)
|
||||
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall)
|
||||
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes)
|
||||
.value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy);
|
||||
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes)
|
||||
.value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy);
|
||||
|
||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||
@ -562,7 +562,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"),
|
||||
py::arg("device") = nullptr, py::arg("force_copy") = false,
|
||||
py::arg("host_buffer_semantics") =
|
||||
PjRtBuffer::HostBufferSemantics::kZeroCopy)
|
||||
PjRtClient::HostBufferSemantics::kZeroCopy)
|
||||
.def("compile", &PyClient::Compile, py::arg("computation"),
|
||||
py::arg("compile_options") = CompileOptions())
|
||||
.def("heap_profile", &PyClient::HeapProfile);
|
||||
|
Loading…
Reference in New Issue
Block a user