Refactor PJRT.

- Make static methods of PjRtBuffer and PjRtExecutable instance methods on PjRtClient to allow us to extract a set of interfaces out of PJRT.

PiperOrigin-RevId: 338101552
Change-Id: I8c10295948ea73d7d4157760a1cd8991384a01dc
This commit is contained in:
Qiao Zhang 2020-10-20 11:39:31 -07:00 committed by TensorFlower Gardener
parent b737cff5fd
commit f187f93d7b
9 changed files with 182 additions and 192 deletions

View File

@ -54,9 +54,9 @@ TEST(GpuMultiStream, Basics) {
device_assignment(0, 0) = device->id();
compile_options.executable_build_options.set_device_assignment(
device_assignment);
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<PjRtExecutable> executable,
PjRtExecutable::Compile(computation, client.get(),
std::move(compile_options)));
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<PjRtExecutable> executable,
client->Compile(computation, std::move(compile_options)));
int64 dummy_size = 1 << 20;
std::vector<int32> dummy_inputs(dummy_size);
@ -71,22 +71,22 @@ TEST(GpuMultiStream, Basics) {
// must wait.
TF_ASSERT_OK_AND_ASSIGN(
auto dummy_buffer,
PjRtBuffer::FromHostBuffer(
client->BufferFromHostBuffer(
dummy_inputs.data(), dummy_shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, device));
TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer0,
PjRtBuffer::FromHostBuffer(
client->BufferFromHostBuffer(
inputs.data(), shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, device));
TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer1,
PjRtBuffer::FromHostBuffer(
client->BufferFromHostBuffer(
inputs.data(), shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, device));
// The execution may be enqueued before the transfers complete, requiring
// adequate device-side synchronization.
ExecuteOptions options;

View File

@ -576,24 +576,21 @@ void PjRtBuffer::ScopedHold::AddToInput(
}
}
/* static */
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client,
PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
VLOG(2) << "PjRtBuffer::FromHostBuffer: shape: " << shape.ToString()
std::shared_ptr<void> buffer_reference, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostBuffer");
VLOG(2) << "PjRtClient::BufferFromHostBuffer: shape: " << shape.ToString()
<< " device: " << device->DebugString();
if (shape.IsTuple()) {
return InvalidArgument("Use FromHostLiteral to transfer a tuple");
return InvalidArgument("Use BufferFromHostLiteral to transfer a tuple");
}
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
int64 size = ShapeUtil::ByteSizeOf(shape);
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
TransferManager* transfer_manager = client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(shape));
@ -628,10 +625,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
};
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
} else {
void* staging_buffer = client->host_memory_allocator()->AllocateRaw(
void* staging_buffer = host_memory_allocator()->AllocateRaw(
cpu_function_runtime::kMinAlign, size);
on_delete_callback = [staging_buffer, client]() {
client->host_memory_allocator()->DeallocateRaw(staging_buffer);
on_delete_callback = [staging_buffer, host_memory_allocator =
host_memory_allocator()]() {
host_memory_allocator->DeallocateRaw(staging_buffer);
};
buffer = se::DeviceMemoryBase(staging_buffer, size);
std::memcpy(staging_buffer, data, size);
@ -643,7 +641,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
std::initializer_list<se::DeviceMemoryBase>{buffer},
definition_events, std::move(on_delete_callback));
return absl::make_unique<PjRtBuffer>(
shape, shape, std::move(device_buffer), client, device);
shape, shape, std::move(device_buffer), this, device);
}
}
@ -651,21 +649,22 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
std::unique_ptr<PjRtBuffer> py_buffer,
AllocateDestinationBuffer(compact_shape, device, local_device,
local_device->host_to_device_stream(),
/*is_uninitialized_create=*/false, client));
/*is_uninitialized_create=*/false, this));
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
CHECK(device_buffer.ok());
// If necessary, allocate a host-side buffer for staging host-to-device
// transfers. On GPU this is a buffer in pinned memory.
std::shared_ptr<void> staging_buffer;
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
client->should_stage_host_to_device_transfers()) {
void* ptr = client->host_memory_allocator()->AllocateRaw(
should_stage_host_to_device_transfers()) {
void* ptr = host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size);
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
client->host_memory_allocator()->DeallocateRaw(ptr);
});
staging_buffer = std::shared_ptr<void>(
ptr, [host_memory_allocator = host_memory_allocator()](void* ptr) {
host_memory_allocator->DeallocateRaw(ptr);
});
}
// Copy the buffer into a staging buffer before returning control to the
@ -684,14 +683,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
// usage holds have gone away.
// TODO(misard) assess if it would be preferable to introduce a heuristic to
// put the transfer into the calling thread for small literals.
auto transfer_h2d = [client, transfer_manager, local_device, data, size,
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
data, size,
movable_device_buffer{device_buffer.ToClosure()}, shape,
py_buffer{py_buffer.get()}, compact_shape,
on_device_shape{py_buffer->on_device_shape()},
staging_buffer{std::move(staging_buffer)},
buffer_reference{std::move(buffer_reference)},
host_buffer_semantics]() {
ScopedHold device_buffer(movable_device_buffer);
PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
// to report failures from a callback. However, the operations here are
// unlikely to fail and not recoverable even if we were to fail: DMAs to
@ -699,7 +699,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
// allocation.
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, client->client()->platform());
compact_shape, on_device_shape, local_client->platform());
// If applicable on the backend, stage the transfer via host memory
// allocated via the host_memory_allocator. On GPU, this is pinned
// memory.
@ -736,41 +736,38 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
// already defers its work onto a stream (= thread on CPU).
transfer_h2d();
} else {
client->h2d_transfer_pool()->Schedule(transfer_h2d);
h2d_transfer_pool()->Schedule(transfer_h2d);
}
return py_buffer;
}
/* static */
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CreateUninitialized(
const Shape& shape, PjRtClient* client, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::CreateUninitialized");
VLOG(2) << "PjRtBuffer::CreateUninitialized: shape: " << shape.ToString()
<< " device: " << device->DebugString();
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme(
"PjRtClient::CreateUninitializedBuffer");
VLOG(2) << "PjRtClient::CreateUninitializedBuffer: shape: "
<< shape.ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
TransferManager* transfer_manager = client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(shape));
return AllocateDestinationBuffer(compact_shape, device, local_device,
/*copy_stream=*/nullptr,
/*is_uninitialized_create=*/true, client);
/*is_uninitialized_create=*/true, this);
}
/* static */
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostLiteral");
VLOG(2) << "PjRtBuffer::FromHostLiteral: shape: "
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtClient::BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device) {
tensorflow::profiler::TraceMe traceme("PjRtClient::BufferFromHostLiteral");
VLOG(2) << "PjRtClient::BufferFromHostLiteral: shape: "
<< literal.shape().ToString() << " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
TransferManager* transfer_manager = client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(
Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(literal.shape()));
@ -778,9 +775,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
std::unique_ptr<PjRtBuffer> py_buffer,
AllocateDestinationBuffer(compact_shape, device, local_device,
local_device->host_to_device_stream(),
/*is_uninitialized_create=*/false, client));
/*is_uninitialized_create=*/false, this));
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
PjRtBuffer::ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
CHECK(device_buffer.ok());
// The host to device transfer is performed on a thread pool, mostly because
@ -789,11 +786,11 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
// usage holds have gone away.
// TODO(misard) assess if it would be preferable to introduce a heuristic to
// put the transfer into the calling thread for small literals.
auto transfer_h2d = [client, transfer_manager, local_device,
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
movable_device_buffer{device_buffer.ToClosure()},
literal, py_buffer{py_buffer.get()}, compact_shape,
on_device_shape{py_buffer->on_device_shape()}]() {
ScopedHold device_buffer(movable_device_buffer);
PjRtBuffer::ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
// to report failures from a callback. However, the operations here are
// unlikely to fail and not recoverable even if we were to fail: DMAs to
@ -802,7 +799,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
se::Stream* h2d_stream = local_device->host_to_device_stream();
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, client->client()->platform());
compact_shape, on_device_shape, local_client->platform());
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
h2d_stream, literal, buffer));
@ -817,12 +814,12 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
.IgnoreError(); // Can return error::Unimplemented
QCHECK(h2d_stream->ok());
};
client->h2d_transfer_pool()->Schedule(transfer_h2d);
h2d_transfer_pool()->Schedule(transfer_h2d);
return py_buffer;
}
/*static*/ void PjRtBuffer::MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtClient* client, PjRtDevice* device,
void PjRtClient::MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier) {
if (shapes.empty()) {
notifier(InvalidArgument(
@ -843,7 +840,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
StatusOr<std::unique_ptr<PjRtBuffer>> buffer_or =
AllocateDestinationBuffer(shape, device, local_device,
/*copy_stream=*/nullptr,
/*is_uninitialized_create=*/false, client);
/*is_uninitialized_create=*/false, this);
if (!buffer_or.ok()) {
notifier(buffer_or.status());
return;
@ -851,7 +848,7 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostLiteral(
buffers.push_back(buffer_or.ConsumeValueOrDie());
}
client->EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
}
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
@ -1159,7 +1156,7 @@ PjRtBuffer::CopyToHostAsyncInternal(bool discard_cached_copy,
StatusOr<std::shared_ptr<Literal>> PjRtBuffer::ToLiteral(
const bool discard_cached_copy, absl::optional<xla::Layout> layout) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::ToLiteral");
tensorflow::profiler::TraceMe traceme("PjRtClient::ToLiteral");
TF_ASSIGN_OR_RETURN(std::shared_ptr<HostValue> host_value,
CopyToHostAsyncInternal(discard_cached_copy, layout));
if (host_value == nullptr) {
@ -1267,9 +1264,9 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::CopyToDevice(
// Copying across PjRtClients involves a copy through the host.
if (dst_device->client() != client_) {
TF_ASSIGN_OR_RETURN(std::shared_ptr<Literal> literal, ToLiteral());
return FromHostBuffer(literal->untyped_data(), literal->shape(),
HostBufferSemantics::kZeroCopy, nullptr,
dst_device->client(), dst_device);
return dst_device->client()->BufferFromHostBuffer(
literal->untyped_data(), literal->shape(),
PjRtClient::HostBufferSemantics::kZeroCopy, nullptr, dst_device);
}
TF_ASSIGN_OR_RETURN(LocalDeviceState * dst_local_device,
@ -2061,14 +2058,13 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
} // namespace
/*static*/ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtExecutable::Compile(
const XlaComputation& computation, PjRtClient* client,
CompileOptions options) {
tensorflow::profiler::TraceMe traceme("LocalExecutable::Compile");
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtClient::Compile(
const XlaComputation& computation, CompileOptions options) {
tensorflow::profiler::TraceMe traceme("PjRtClient::Compile");
ExecutableBuildOptions& build_options = options.executable_build_options;
if (!build_options.device_allocator()) {
build_options.set_device_allocator(client->allocator());
build_options.set_device_allocator(allocator());
}
int num_replicas;
@ -2084,14 +2080,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
num_partitions = 1;
} else {
if (!build_options.has_device_assignment()) {
VLOG(2) << "PjRtExecutable::Compile using default device_assignment.";
VLOG(2) << "PjRtClient::Compile using default device_assignment.";
TF_ASSIGN_OR_RETURN(
DeviceAssignment device_assignment,
client->GetDefaultDeviceAssignment(build_options.num_replicas(),
build_options.num_partitions()));
GetDefaultDeviceAssignment(build_options.num_replicas(),
build_options.num_partitions()));
build_options.set_device_assignment(device_assignment);
}
VLOG(2) << "PjRtExecutable::Compile device_assignment:\n"
VLOG(2) << "PjRtClient::Compile device_assignment:\n"
<< build_options.device_assignment().ToString();
num_replicas = build_options.device_assignment().replica_count();
num_partitions = build_options.device_assignment().computation_count();
@ -2118,7 +2114,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
// Assign a default layout based on `sharded_shape` to any array subshapes in
// `dst_shape` that are missing layouts.
auto assign_layouts = [client](const Shape& sharded_shape, Shape* dst_shape) {
auto assign_layouts = [local_client = client()](const Shape& sharded_shape,
Shape* dst_shape) {
return ShapeUtil::ForEachMutableSubshapeWithStatus(
dst_shape, [&](Shape* subshape, const ShapeIndex& idx) {
if (subshape->IsArray() && !subshape->has_layout()) {
@ -2126,8 +2123,7 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
const Shape& sharded_subshape =
ShapeUtil::GetSubshape(sharded_shape, idx);
LayoutUtil::SetToDefaultLayout(subshape);
TF_ASSIGN_OR_RETURN(Shape layout, client->client()
->backend()
TF_ASSIGN_OR_RETURN(Shape layout, local_client->backend()
.transfer_manager()
->ChooseCompactLayoutForShape(
sharded_subshape));
@ -2162,8 +2158,8 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
for (int replica = 0; replica < num_replicas; ++replica) {
for (int partition = 0; partition < num_partitions; ++partition) {
int device_id = (*device_assignment)(replica, partition);
PjRtDevice* device = LookupDevice(*client, device_id);
if (device->host_id() != client->host_id()) {
PjRtDevice* device = LookupDevice(*this, device_id);
if (device->host_id() != host_id()) {
VLOG(3) << "Non-local device: " << device_id;
continue;
}
@ -2185,15 +2181,14 @@ StatusOr<std::pair<std::vector<Shape>, Shape>> GetShardedProgramShapes(
TF_ASSIGN_OR_RETURN(
std::vector<std::unique_ptr<LocalExecutable>> local_executables,
client->client()->Compile(computation, argument_layout_pointers,
build_options));
client()->Compile(computation, argument_layout_pointers, build_options));
auto executable = absl::make_unique<PjRtExecutable>(
std::move(local_executables), options.parameter_is_tupled_arguments,
std::move(device_assignment), std::move(local_logical_device_ids),
std::move(local_devices), client);
std::move(local_devices), this);
TF_RETURN_IF_ERROR(
executable->SetUpDonation(client, options.parameter_is_tupled_arguments));
executable->SetUpDonation(this, options.parameter_is_tupled_arguments));
return executable;
}

View File

@ -120,6 +120,24 @@ struct PjRtCrossHostRecvBuffer {
using PjRtCrossHostRecvNotifier =
std::function<void(StatusOr<std::vector<PjRtCrossHostRecvBuffer>>&&)>;
struct CompileOptions {
// The layouts of the arguments that the computation should expect.
absl::optional<std::vector<Shape>> argument_layouts;
// If true, the supplied computation expects its arguments to be wrapped in a
// tuple and passed as a single parameter.
bool parameter_is_tupled_arguments = false;
// XLA's compilation time options.
ExecutableBuildOptions executable_build_options;
// If true, the executable can be run on any device. May only be true if
// !executable_build_options.has_device_assignment(), so only applies to
// single-device executables. Beware: on GPUs, sometimes an executable
// compiled for one device doesn't run on another.
bool compile_portable_executable = false;
};
class PjRtExecutable;
// Encapsulates the state of Python session with XLA.
@ -198,6 +216,63 @@ class PjRtClient {
// Returns a backend-specific HLO cost analysis visitor.
virtual std::unique_ptr<HloCostAnalysis> GetHloCostAnalysis();
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);
virtual StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitializedBuffer(
const Shape& shape, PjRtDevice* device);
// Describes the semantics the caller to BufferFromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
enum class HostBufferSemantics {
// The runtime may not hold references to `data` after the call to
// `BufferFromHostBuffer` completes. The caller promises that `data` is
// immutable and will not be freed only for the duration of the
// BufferFromHostBuffer call. `buffer_reference` will be freed by the time
// `BufferFromHostBuffer` returns.
kImmutableOnlyDuringCall,
// The runtime may hold onto `data` after the call to `BufferFromHostBuffer`
// returns while the runtime completes a transfer to the device. The caller
// promises not to mutate or free `data` until the transfer completes, at
// which point the runtime will release `buffer_reference`. It is also
// correct to wait on the host (directly or indirectly) for the buffer's
// definition event to complete.
kImmutableUntilTransferCompletes,
// The PjRtBuffer may alias `data` internally and the runtime may use the
// `data` contents as long as the buffer is alive. The caller promises to
// keep `data` alive and not to mutate its contents as long as the buffer is
// alive; to notify the caller that the buffer may be freed, the runtime
// will release its `buffer_reference` when the PjRtBuffer is freed. On
// non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes.
kZeroCopy,
};
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtDevice* device);
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
virtual StatusOr<std::unique_ptr<PjRtBuffer>> BufferFromHostLiteral(
const LiteralSlice& literal, PjRtDevice* device);
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
virtual void MakeCrossHostReceiveBuffers(
absl::Span<const Shape> shapes, PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier);
protected:
friend class PjRtBuffer;
virtual void EnqueueCrossHostReceive(
@ -385,6 +460,7 @@ class PjRtBuffer {
private:
friend class PjRtBuffer;
friend class PjRtClient;
// Helper struct that makes it possible to move a ScopedHold through a
// closure.
@ -423,62 +499,6 @@ class PjRtBuffer {
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
};
// Returns a buffer with uninitialized contents.
static StatusOr<std::unique_ptr<PjRtBuffer>> CreateUninitialized(
const Shape& shape, PjRtClient* client, PjRtDevice* device);
// Describes the semantics the caller to FromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
enum class HostBufferSemantics {
// The runtime may not hold references to `data` after the call to
// `FromHostBuffer` completes. The caller promises that `data` is immutable
// and will not be freed only for the duration of the FromHostBuffer call.
// `buffer_reference` will be freed by the time `FromHostBuffer` returns.
kImmutableOnlyDuringCall,
// The runtime may hold onto `data` after the call to `FromHostBuffer`
// returns while the runtime completes a transfer to the device. The caller
// promises not to mutate or free `data` until the transfer completes, at
// which point the runtime will release `buffer_reference`. It is also
// correct to wait on the host (directly or indirectly) for the buffer's
// definition event to complete.
kImmutableUntilTransferCompletes,
// The PjRtBuffer may alias `data` internally and the runtime may use the
// `data` contents as long as the buffer is alive.
// The caller promises to keep `data` alive and not to mutate its contents
// as long as the buffer is alive; to notify the caller that the buffer may
// be freed, the runtime will release its `buffer_reference` when the
// PjRtBuffer is freed. On non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes.
kZeroCopy,
};
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client,
PjRtDevice* device);
// Note that literal must remain in scope until the transfer has completed, so
// the caller should, for example, wait for BlockHostUntilReady() completes on
// the return value before letting literal go out of scope.
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostLiteral(
const LiteralSlice& literal, PjRtClient* client, PjRtDevice* device);
// Asynchronously makes a vector of PjRtBuffers that can be used to receive
// cross host transfers using `client` on `device'. `shapes` must be the exact
// shapes, with identical layouts, corresponding to the buffers that will be
// sent. When resources for the transfer are available, notifier will be
// called with a vector of PjRtCrossHostRecvBuffer structs, one for each
// shape in `shapes`. Each struct contains a buffer that will contain the
// received value, and an opaque string that should be transmitted to the
// sending host and used in a call to CopyToRemoteDevice. None of the recv
// buffers will become ready until *all* of the sends have completed.
static void MakeCrossHostReceiveBuffers(absl::Span<const Shape> shapes,
PjRtClient* client,
PjRtDevice* device,
PjRtCrossHostRecvNotifier&& notifier);
PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device);
@ -661,24 +681,6 @@ class PjRtBuffer {
Semaphore donation_semaphore_;
};
struct CompileOptions {
// The layouts of the arguments that the computation should expect.
absl::optional<std::vector<Shape>> argument_layouts;
// If true, the supplied computation expects its arguments to be wrapped in a
// tuple and passed as a single parameter.
bool parameter_is_tupled_arguments = false;
// XLA's compilation time options.
ExecutableBuildOptions executable_build_options;
// If true, the executable can be run on any device. May only be true if
// !executable_build_options.has_device_assignment(), so only applies to
// single-device executables. Beware: on GPUs, sometimes an executable
// compiled for one device doesn't run on another.
bool compile_portable_executable = false;
};
class ExecuteContext {
public:
virtual ~ExecuteContext() = default;
@ -710,10 +712,6 @@ struct ExecuteOptions {
// buffer will be donated when passed to the execution.
class PjRtExecutable {
public:
static StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
const XlaComputation& computation, PjRtClient* client,
CompileOptions options);
PjRtExecutable(std::vector<std::unique_ptr<LocalExecutable>> executables,
bool parameter_is_tupled_arguments,
std::shared_ptr<DeviceAssignment> device_assignment,
@ -783,6 +781,7 @@ class PjRtExecutable {
}
private:
friend class PjRtClient;
// Initializes information about which arguments to which executables must be
// donated due to aliases that were specified by the computation.
Status SetUpDonation(PjRtClient* client, bool tuple_inputs);

View File

@ -465,10 +465,10 @@ std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
xla::PjRtDevice* device) {
CppType data = py::cast<Pybind11Type>(scalar);
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
return ValueOrThrow(client->BufferFromHostBuffer(
&data, shape,
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
client, device));
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
device));
}
// Convert a scalar to the associated PjRtBuffer or raises an error if it is
@ -502,17 +502,17 @@ StatusOr<std::unique_ptr<xla::PjRtBuffer>> ScalarToBuffer(
if (jax_enable_x64) {
xla::complex128 data(result.real, result.imag);
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
return ValueOrThrow(client->BufferFromHostBuffer(
&data, shape,
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall,
nullptr, client, device));
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
nullptr, device));
} else {
xla::complex64 data(result.real, result.imag);
xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
return ValueOrThrow(xla::PjRtBuffer::FromHostBuffer(
return ValueOrThrow(client->BufferFromHostBuffer(
&data, shape,
xla::PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall,
nullptr, client, device));
xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
nullptr, device));
}
}
return InvalidArgument(
@ -678,7 +678,7 @@ Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
ValueOrThrow(pyclient.BufferFromPyval(
numpy_array, data_device,
/*force_copy=*/false, /*host_buffer_semantics=*/
xla::PjRtBuffer::HostBufferSemantics::kZeroCopy));
xla::PjRtClient::HostBufferSemantics::kZeroCopy));
arg_buffers.push_back(buffer->buffer());
ArgSignature sig;

View File

@ -409,10 +409,9 @@ Status OutfeedReceiverImpl::SendShutdownOutfeedHeader(int device_idx) {
compile_options.executable_build_options.set_device_assignment(
device_assignment);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtExecutable> executable,
PjRtExecutable::Compile(computation, devices_[device_idx]->client(),
std::move(compile_options)));
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
devices_[device_idx]->client()->Compile(
computation, std::move(compile_options)));
ExecuteOptions execute_options;
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
executable->Execute({}, execute_options));

View File

@ -40,9 +40,8 @@ Status CompileAndExecute(XlaBuilder* builder, XlaOp root, int device_id,
compile_options.executable_build_options.set_device_assignment(
device_assignment);
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtExecutable> executable,
PjRtExecutable::Compile(computation, client, std::move(compile_options)));
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtExecutable> executable,
client->Compile(computation, std::move(compile_options)));
ExecuteOptions execute_options;
TF_ASSIGN_OR_RETURN(std::vector<std::unique_ptr<PjRtBuffer>> output_buffers,
executable->Execute({}, execute_options));

View File

@ -89,7 +89,7 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
PjRtClient::HostBufferSemantics host_buffer_semantics) {
if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
device = pjrt_client_->local_devices().front();
@ -114,10 +114,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
std::unique_ptr<PjRtBuffer> buffer;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
buffer, PjRtBuffer::FromHostBuffer(
c->buf_ptr, c->shape, host_buffer_semantics,
std::move(py_buffer_ref), pjrt_client_.get(), device));
TF_ASSIGN_OR_RETURN(buffer, pjrt_client_->BufferFromHostBuffer(
c->buf_ptr, c->shape, host_buffer_semantics,
std::move(py_buffer_ref), device));
}
auto traceback = Traceback::Get();
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
@ -131,8 +130,7 @@ StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(executable,
PjRtExecutable::Compile(computation, pjrt_client_.get(),
std::move(options)));
pjrt_client_->Compile(computation, std::move(options)));
TF_ASSIGN_OR_RETURN(fingerprint,
pjrt_client_->ExecutableFingerprint(*executable));
}

View File

@ -123,7 +123,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyval(
const pybind11::object& argument, PjRtDevice* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
PjRtClient::HostBufferSemantics host_buffer_semantics);
StatusOr<std::shared_ptr<PyExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);

View File

@ -535,12 +535,12 @@ PYBIND11_MODULE(xla_extension, m) {
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics")
py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics")
.value("IMMUTABLE_ONLY_DURING_CALL",
PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall)
PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall)
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes)
.value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy);
PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes)
.value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy);
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
@ -562,7 +562,7 @@ PYBIND11_MODULE(xla_extension, m) {
.def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"),
py::arg("device") = nullptr, py::arg("force_copy") = false,
py::arg("host_buffer_semantics") =
PjRtBuffer::HostBufferSemantics::kZeroCopy)
PjRtClient::HostBufferSemantics::kZeroCopy)
.def("compile", &PyClient::Compile, py::arg("computation"),
py::arg("compile_options") = CompileOptions())
.def("heap_profile", &PyClient::HeapProfile);