[PJRT] Fix potential misuse of PjRtBuffer::FromHostBuffer
.
Add a new `PjRtBuffer::HostBufferSemantics` enum that describes the possible contracts between caller and runtime. * Change `FromHostBuffer(..., force_copy, ...)` to `FromHostBuffer(..., host_buffer_semantics, ...)`. We were seeing some data races between modifications to a NumPy array and JAX on CPU, due to unintended buffer aliasing. This change allows clients to control whether they want zero-copy behavior or not. PiperOrigin-RevId: 316672280 Change-Id: Ibee296305005e0aa306a2c0aacf4b35a3d6c3ac1
This commit is contained in:
parent
d9532e6526
commit
572442eb16
@ -59,6 +59,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
|||||||
return std::make_shared<PjRtClient>(
|
return std::make_shared<PjRtClient>(
|
||||||
kCpuPlatformName, client, std::move(devices), /*host_id=*/0,
|
kCpuPlatformName, client, std::move(devices), /*host_id=*/0,
|
||||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||||
|
/*should_stage_host_to_device_transfers=*/false,
|
||||||
/*gpu_run_options=*/nullptr);
|
/*gpu_run_options=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,18 +72,21 @@ TEST(GpuMultiStream, Basics) {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto dummy_buffer,
|
auto dummy_buffer,
|
||||||
PjRtBuffer::FromHostBuffer(
|
PjRtBuffer::FromHostBuffer(
|
||||||
dummy_inputs.data(), dummy_shape, /*force_copy=*/false,
|
dummy_inputs.data(), dummy_shape,
|
||||||
|
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||||
/*buffer_reference=*/nullptr, client.get(), device));
|
/*buffer_reference=*/nullptr, client.get(), device));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto in_buffer0,
|
auto in_buffer0,
|
||||||
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false,
|
PjRtBuffer::FromHostBuffer(
|
||||||
/*buffer_reference=*/nullptr, client.get(),
|
inputs.data(), shape,
|
||||||
device));
|
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||||
|
/*buffer_reference=*/nullptr, client.get(), device));
|
||||||
TF_ASSERT_OK_AND_ASSIGN(
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
auto in_buffer1,
|
auto in_buffer1,
|
||||||
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false,
|
PjRtBuffer::FromHostBuffer(
|
||||||
/*buffer_reference=*/nullptr, client.get(),
|
inputs.data(), shape,
|
||||||
device));
|
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||||
|
/*buffer_reference=*/nullptr, client.get(), device));
|
||||||
// The execution may be enqueued before the transfers complete, requiring
|
// The execution may be enqueued before the transfers complete, requiring
|
||||||
// adequate device-side synchronization.
|
// adequate device-side synchronization.
|
||||||
ExecuteOptions options;
|
ExecuteOptions options;
|
||||||
|
@ -53,6 +53,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
|
|||||||
return std::make_shared<PjRtClient>(
|
return std::make_shared<PjRtClient>(
|
||||||
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
|
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
|
||||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||||
|
/*should_stage_host_to_device_transfers=*/false,
|
||||||
/*gpu_run_options=*/nullptr);
|
/*gpu_run_options=*/nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -316,6 +316,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
|
|||||||
"gpu", xla_client, std::move(devices),
|
"gpu", xla_client, std::move(devices),
|
||||||
/*node_id=*/node_id, std::move(allocator),
|
/*node_id=*/node_id, std::move(allocator),
|
||||||
std::move(host_memory_allocator),
|
std::move(host_memory_allocator),
|
||||||
|
/*should_stage_host_to_device_transfers=*/true,
|
||||||
/*gpu_run_options=*/std::move(gpu_run_options));
|
/*gpu_run_options=*/std::move(gpu_run_options));
|
||||||
return pyclient;
|
return pyclient;
|
||||||
}
|
}
|
||||||
|
@ -95,6 +95,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/mem.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
@ -154,18 +155,35 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
|||||||
return xla_assignment;
|
return xla_assignment;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class CpuAllocator : public tensorflow::Allocator {
|
||||||
|
public:
|
||||||
|
CpuAllocator() = default;
|
||||||
|
|
||||||
|
string Name() override { return "cpu"; }
|
||||||
|
|
||||||
|
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
|
||||||
|
return tensorflow::port::AlignedMalloc(num_bytes, alignment);
|
||||||
|
}
|
||||||
|
void DeallocateRaw(void* ptr) override {
|
||||||
|
return tensorflow::port::AlignedFree(ptr);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
PjRtClient::PjRtClient(
|
PjRtClient::PjRtClient(
|
||||||
std::string platform_name, LocalClient* client,
|
std::string platform_name, LocalClient* client,
|
||||||
std::vector<std::unique_ptr<Device>> devices, int host_id,
|
std::vector<std::unique_ptr<Device>> devices, int host_id,
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||||
|
bool should_stage_host_to_device_transfers,
|
||||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
|
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
|
||||||
: platform_name_(std::move(platform_name)),
|
: platform_name_(std::move(platform_name)),
|
||||||
client_(client),
|
client_(client),
|
||||||
|
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||||
devices_(std::move(devices)),
|
devices_(std::move(devices)),
|
||||||
host_id_(host_id),
|
host_id_(host_id),
|
||||||
owned_allocator_(std::move(allocator)),
|
owned_allocator_(std::move(allocator)),
|
||||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
should_stage_host_to_device_transfers_(
|
||||||
|
should_stage_host_to_device_transfers),
|
||||||
gpu_run_options_(std::move(gpu_run_options)),
|
gpu_run_options_(std::move(gpu_run_options)),
|
||||||
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
||||||
client->device_count()) {
|
client->device_count()) {
|
||||||
@ -175,6 +193,10 @@ PjRtClient::PjRtClient(
|
|||||||
allocator_ = client_->backend().memory_allocator();
|
allocator_ = client_->backend().memory_allocator();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!host_memory_allocator_) {
|
||||||
|
host_memory_allocator_ = std::make_unique<CpuAllocator>();
|
||||||
|
}
|
||||||
|
|
||||||
for (const std::unique_ptr<Device>& device : devices_) {
|
for (const std::unique_ptr<Device>& device : devices_) {
|
||||||
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
||||||
<< "Duplicate device id: " << device->id();
|
<< "Duplicate device id: " << device->id();
|
||||||
@ -526,7 +548,8 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
|||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||||
const void* data, const Shape& shape, bool force_copy,
|
const void* data, const Shape& shape,
|
||||||
|
HostBufferSemantics host_buffer_semantics,
|
||||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
||||||
Device* device) {
|
Device* device) {
|
||||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
|
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
|
||||||
@ -537,34 +560,63 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
device->GetLocalDeviceState());
|
||||||
|
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||||
// If we are on the host platform and the input buffer is sufficiently
|
|
||||||
// aligned, we can simply point to the input array's data without any further
|
|
||||||
// copies. At the time of writing we require a 16-byte alignment because XLA
|
|
||||||
// may generate code which requires it.
|
|
||||||
if (!force_copy &&
|
|
||||||
((absl::bit_cast<std::uintptr_t>(data) &
|
|
||||||
(cpu_function_runtime::kMinAlign - 1)) == 0) &&
|
|
||||||
local_device->executor()->platform()->id() == se::host::kHostPlatformId) {
|
|
||||||
std::function<void()> on_delete_callback =
|
|
||||||
[buffer_reference{std::move(buffer_reference)}]() {
|
|
||||||
// Frees buffer_reference.
|
|
||||||
};
|
|
||||||
se::DeviceMemoryBase buffer(const_cast<void*>(data),
|
|
||||||
ShapeUtil::ByteSizeOf(shape));
|
|
||||||
absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
|
|
||||||
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
|
|
||||||
/*allocator=*/nullptr, local_device->device_ordinal(),
|
|
||||||
std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
|
|
||||||
std::move(on_delete_callback));
|
|
||||||
return absl::make_unique<PjRtBuffer>(shape, shape, std::move(device_buffer),
|
|
||||||
client, device);
|
|
||||||
}
|
|
||||||
|
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager =
|
||||||
client->client()->backend().transfer_manager();
|
client->client()->backend().transfer_manager();
|
||||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||||
transfer_manager->ChooseCompactLayoutForShape(shape));
|
transfer_manager->ChooseCompactLayoutForShape(shape));
|
||||||
|
|
||||||
|
// The CPU platform is special because the "host" and the "device" are in the
|
||||||
|
// same memory space. If the input shape is in the correct layout and we don't
|
||||||
|
// want to defer the copy onto a thread, we can use the following fast
|
||||||
|
// path.
|
||||||
|
bool is_cpu_platform =
|
||||||
|
local_device->executor()->platform()->id() == se::host::kHostPlatformId;
|
||||||
|
if (is_cpu_platform) {
|
||||||
|
// If we are on the host platform and the input buffer is sufficiently
|
||||||
|
// aligned, we can simply point to the input array's data without any
|
||||||
|
// further copies. At the time of writing we require a 16-byte alignment
|
||||||
|
// because XLA may generate code which requires it.
|
||||||
|
bool can_use_zero_copy =
|
||||||
|
host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
|
||||||
|
((absl::bit_cast<std::uintptr_t>(data) &
|
||||||
|
(cpu_function_runtime::kMinAlign - 1)) == 0);
|
||||||
|
if (shape.layout() == compact_shape.layout() &&
|
||||||
|
(host_buffer_semantics ==
|
||||||
|
HostBufferSemantics::kImmutableOnlyDuringCall ||
|
||||||
|
can_use_zero_copy)) {
|
||||||
|
std::function<void()> on_delete_callback;
|
||||||
|
se::DeviceMemoryBase buffer;
|
||||||
|
// If we are on the host platform and the input buffer is sufficiently
|
||||||
|
// aligned, we can simply point to the input array's data without any
|
||||||
|
// further copies. At the time of writing we require a 16-byte alignment
|
||||||
|
// because XLA may generate code which requires it.
|
||||||
|
if (can_use_zero_copy) {
|
||||||
|
on_delete_callback = [buffer_reference{std::move(buffer_reference)}]() {
|
||||||
|
// Frees buffer_reference.
|
||||||
|
};
|
||||||
|
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
|
||||||
|
} else {
|
||||||
|
void* staging_buffer = client->host_memory_allocator()->AllocateRaw(
|
||||||
|
cpu_function_runtime::kMinAlign, size);
|
||||||
|
on_delete_callback = [staging_buffer, client]() {
|
||||||
|
client->host_memory_allocator()->DeallocateRaw(staging_buffer);
|
||||||
|
};
|
||||||
|
buffer = se::DeviceMemoryBase(staging_buffer, size);
|
||||||
|
std::memcpy(staging_buffer, data, size);
|
||||||
|
}
|
||||||
|
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
|
||||||
|
definition_events;
|
||||||
|
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
|
||||||
|
/*allocator=*/nullptr, local_device->device_ordinal(),
|
||||||
|
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||||
|
definition_events, std::move(on_delete_callback));
|
||||||
|
return absl::make_unique<PjRtBuffer>(
|
||||||
|
shape, shape, std::move(device_buffer), client, device);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<PjRtBuffer> py_buffer,
|
std::unique_ptr<PjRtBuffer> py_buffer,
|
||||||
AllocateDestinationBuffer(compact_shape, device, local_device,
|
AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||||
@ -573,17 +625,41 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|||||||
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
||||||
CHECK(device_buffer.ok());
|
CHECK(device_buffer.ok());
|
||||||
|
|
||||||
|
// If necessary, allocate a host-side buffer for staging host-to-device
|
||||||
|
// transfers. On GPU this is a buffer in pinned memory.
|
||||||
|
std::shared_ptr<void> staging_buffer;
|
||||||
|
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
|
||||||
|
client->should_stage_host_to_device_transfers()) {
|
||||||
|
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||||
|
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||||
|
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
|
||||||
|
client->host_memory_allocator()->DeallocateRaw(ptr);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy the buffer into a staging buffer before returning control to the
|
||||||
|
// caller if the caller only guaranteed that the buffer is valid for the
|
||||||
|
// duration of the call. Otherwise, we stage (if necessary) on a separate
|
||||||
|
// thread.
|
||||||
|
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
|
||||||
|
std::memcpy(staging_buffer.get(), data, size);
|
||||||
|
buffer_reference.reset();
|
||||||
|
data = nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
// The host to device transfer is performed on a thread pool, mostly because
|
// The host to device transfer is performed on a thread pool, mostly because
|
||||||
// it includes linearization that may be slow. It is OK to capture the
|
// it includes linearization that may be slow. It is OK to capture the
|
||||||
// py_buffer pointer because the py_buffer can't be deleted until all the
|
// py_buffer pointer because the py_buffer can't be deleted until all the
|
||||||
// usage holds have gone away.
|
// usage holds have gone away.
|
||||||
// TODO(misard) assess if it would be preferable to introduce a heuristic to
|
// TODO(misard) assess if it would be preferable to introduce a heuristic to
|
||||||
// put the transfer into the calling thread for small literals.
|
// put the transfer into the calling thread for small literals.
|
||||||
auto transfer_h2d = [client, transfer_manager, local_device,
|
auto transfer_h2d = [client, transfer_manager, local_device, data, size,
|
||||||
movable_device_buffer{device_buffer.ToClosure()}, data,
|
movable_device_buffer{device_buffer.ToClosure()}, shape,
|
||||||
shape, py_buffer{py_buffer.get()}, compact_shape,
|
py_buffer{py_buffer.get()}, compact_shape,
|
||||||
on_device_shape{py_buffer->on_device_shape()},
|
on_device_shape{py_buffer->on_device_shape()},
|
||||||
buffer_reference{std::move(buffer_reference)}]() {
|
staging_buffer{std::move(staging_buffer)},
|
||||||
|
buffer_reference{std::move(buffer_reference)},
|
||||||
|
host_buffer_semantics]() {
|
||||||
ScopedHold device_buffer(movable_device_buffer);
|
ScopedHold device_buffer(movable_device_buffer);
|
||||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
||||||
// to report failures from a callback. However, the operations here are
|
// to report failures from a callback. However, the operations here are
|
||||||
@ -593,20 +669,16 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|||||||
|
|
||||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
||||||
compact_shape, on_device_shape, client->client()->platform());
|
compact_shape, on_device_shape, client->client()->platform());
|
||||||
|
|
||||||
std::shared_ptr<void> staging_buffer;
|
|
||||||
|
|
||||||
// If applicable on the backend, stage the transfer via host memory
|
// If applicable on the backend, stage the transfer via host memory
|
||||||
// allocated via the host_memory_allocator. On GPU, this is pinned
|
// allocated via the host_memory_allocator. On GPU, this is pinned
|
||||||
// memory.
|
// memory.
|
||||||
if (client->host_memory_allocator()) {
|
if (staging_buffer) {
|
||||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
// If we didn't already copy the input buffer into the staging buffer,
|
||||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
// do so now.
|
||||||
tensorflow::Allocator::kAllocatorAlignment, size);
|
if (host_buffer_semantics !=
|
||||||
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
|
HostBufferSemantics::kImmutableOnlyDuringCall) {
|
||||||
client->host_memory_allocator()->DeallocateRaw(ptr);
|
std::memcpy(staging_buffer.get(), data, size);
|
||||||
});
|
}
|
||||||
std::memcpy(ptr, data, size);
|
|
||||||
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
|
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
|
||||||
shape);
|
shape);
|
||||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||||
@ -626,9 +698,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
|||||||
|
|
||||||
local_device->ThenRelease(
|
local_device->ThenRelease(
|
||||||
local_device->host_to_device_stream(),
|
local_device->host_to_device_stream(),
|
||||||
std::make_pair(buffer_reference, std::move(staging_buffer)));
|
std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
|
||||||
};
|
};
|
||||||
|
if (is_cpu_platform) {
|
||||||
|
// Using the h2d_transfer_pool would be a double thread hop; the code
|
||||||
|
// already defers its work onto a stream (= thread on CPU).
|
||||||
|
transfer_h2d();
|
||||||
|
} else {
|
||||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||||
|
}
|
||||||
return py_buffer;
|
return py_buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -128,6 +128,7 @@ class PjRtClient {
|
|||||||
std::vector<std::unique_ptr<Device>> devices, int host_id,
|
std::vector<std::unique_ptr<Device>> devices, int host_id,
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||||
|
bool should_stage_host_to_device_transfers,
|
||||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options);
|
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options);
|
||||||
virtual ~PjRtClient() = default;
|
virtual ~PjRtClient() = default;
|
||||||
|
|
||||||
@ -153,6 +154,9 @@ class PjRtClient {
|
|||||||
tensorflow::Allocator* host_memory_allocator() const {
|
tensorflow::Allocator* host_memory_allocator() const {
|
||||||
return host_memory_allocator_.get();
|
return host_memory_allocator_.get();
|
||||||
}
|
}
|
||||||
|
bool should_stage_host_to_device_transfers() const {
|
||||||
|
return should_stage_host_to_device_transfers_;
|
||||||
|
}
|
||||||
|
|
||||||
GpuExecutableRunOptions* gpu_run_options() const {
|
GpuExecutableRunOptions* gpu_run_options() const {
|
||||||
return gpu_run_options_.get();
|
return gpu_run_options_.get();
|
||||||
@ -190,6 +194,9 @@ class PjRtClient {
|
|||||||
std::string platform_name_;
|
std::string platform_name_;
|
||||||
LocalClient* client_;
|
LocalClient* client_;
|
||||||
|
|
||||||
|
// Allocator to be used for staging memory transfers to devices.
|
||||||
|
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||||
|
|
||||||
// Includes all devices, including non-local devices on multi-host platforms.
|
// Includes all devices, including non-local devices on multi-host platforms.
|
||||||
std::vector<std::unique_ptr<Device>> devices_;
|
std::vector<std::unique_ptr<Device>> devices_;
|
||||||
// Maps Device::id() to the corresponding Device. Includes all devices.
|
// Maps Device::id() to the corresponding Device. Includes all devices.
|
||||||
@ -201,10 +208,10 @@ class PjRtClient {
|
|||||||
se::DeviceMemoryAllocator* allocator_;
|
se::DeviceMemoryAllocator* allocator_;
|
||||||
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
||||||
|
|
||||||
// Allocator to be used for staging memory transfers to devices. Optional;
|
// Should we always prefer to stage host-to-device transfers via memory
|
||||||
// only used on GPU where it is more efficient to copy buffers to and from the
|
// allocated on host_memory_allocator_? True only on GPU, where we prefer to
|
||||||
// device via a staging area of pinned memory.
|
// transfer via pinned memory.
|
||||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
bool should_stage_host_to_device_transfers_;
|
||||||
|
|
||||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options_;
|
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options_;
|
||||||
|
|
||||||
@ -396,13 +403,35 @@ class PjRtBuffer {
|
|||||||
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
|
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// If `force_copy` is true, forces a copy of the input buffer on CPU.
|
// Describes the semantics the caller to FromHostBuffer expects from the
|
||||||
// Otherwise the library is free to alias the output buffer with `data`.
|
// runtime, in a total order from most restrictive to least restrictive.
|
||||||
// `buffer_reference` is an optional shared pointer that should be kept alive
|
enum class HostBufferSemantics {
|
||||||
// by the runtime as long as the contents of `data` may still be accessed by
|
// The runtime may not hold references to `data` after the call to
|
||||||
// the runtime (may be nullptr).
|
// `FromHostBuffer` completes. The caller promises that `data` is immutable
|
||||||
|
// and will not be freed only for the duration of the FromHostBuffer call.
|
||||||
|
// `buffer_reference` will be freed by the time `FromHostBuffer` returns.
|
||||||
|
kImmutableOnlyDuringCall,
|
||||||
|
|
||||||
|
// The runtime may hold onto `data` after the call to `FromHostBuffer`
|
||||||
|
// returns while the runtime completes a transfer to the device. The caller
|
||||||
|
// promises not to mutate or free `data` until the transfer completes, at
|
||||||
|
// which point the runtime will release `buffer_reference`. It is also
|
||||||
|
// correct to wait on the host (directly or indirectly) for the buffer's
|
||||||
|
// definition event to complete.
|
||||||
|
kImmutableUntilTransferCompletes,
|
||||||
|
|
||||||
|
// The PjRtBuffer may alias `data` internally and the runtime may use the
|
||||||
|
// `data` contents as long as the buffer is alive.
|
||||||
|
// The caller promises to keep `data` alive and not to mutate its contents
|
||||||
|
// as long as the buffer is alive; to notify the caller that the buffer may
|
||||||
|
// be freed, the runtime will release its `buffer_reference` when the
|
||||||
|
// PjRtBuffer is freed. On non-CPU platforms this acts identically to
|
||||||
|
// kImmutableUntilTransferCompletes.
|
||||||
|
kZeroCopy,
|
||||||
|
};
|
||||||
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
|
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
|
||||||
const void* data, const Shape& shape, bool force_copy,
|
const void* data, const Shape& shape,
|
||||||
|
HostBufferSemantics host_buffer_semantics,
|
||||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
||||||
Device* device);
|
Device* device);
|
||||||
|
|
||||||
|
@ -84,7 +84,8 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
|
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
|
||||||
const pybind11::object& argument, Device* device, bool force_copy) {
|
const pybind11::object& argument, Device* device, bool force_copy,
|
||||||
|
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
|
||||||
if (device == nullptr) {
|
if (device == nullptr) {
|
||||||
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
||||||
device = pjrt_client_->local_devices().front();
|
device = pjrt_client_->local_devices().front();
|
||||||
@ -111,9 +112,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
|
|||||||
{
|
{
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
buffer, PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
|
buffer, PjRtBuffer::FromHostBuffer(
|
||||||
std::move(py_buffer_ref),
|
c->buf_ptr, c->shape, host_buffer_semantics,
|
||||||
pjrt_client_.get(), device));
|
std::move(py_buffer_ref), pjrt_client_.get(), device));
|
||||||
}
|
}
|
||||||
auto traceback = Traceback::Get();
|
auto traceback = Traceback::Get();
|
||||||
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
||||||
|
@ -120,7 +120,8 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
|
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
|
||||||
const pybind11::object& argument, Device* device, bool force_copy);
|
const pybind11::object& argument, Device* device, bool force_copy,
|
||||||
|
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<PyExecutable>> Compile(
|
StatusOr<std::unique_ptr<PyExecutable>> Compile(
|
||||||
const XlaComputation& computation, CompileOptions options);
|
const XlaComputation& computation, CompileOptions options);
|
||||||
|
@ -509,6 +509,13 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
|
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
|
||||||
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
|
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
|
||||||
|
|
||||||
|
py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics")
|
||||||
|
.value("IMMUTABLE_ONLY_DURING_CALL",
|
||||||
|
PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall)
|
||||||
|
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
|
||||||
|
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes)
|
||||||
|
.value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy);
|
||||||
|
|
||||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||||
.def("device_count", &PyClient::device_count)
|
.def("device_count", &PyClient::device_count)
|
||||||
@ -527,7 +534,9 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
.def("create_host_to_device_channel_handle",
|
.def("create_host_to_device_channel_handle",
|
||||||
&PyClient::CreateHostToDeviceChannelHandle)
|
&PyClient::CreateHostToDeviceChannelHandle)
|
||||||
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
|
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
|
||||||
py::arg("device") = nullptr, py::arg("force_copy") = false)
|
py::arg("device") = nullptr, py::arg("force_copy") = false,
|
||||||
|
py::arg("host_buffer_semantics") =
|
||||||
|
PjRtBuffer::HostBufferSemantics::kZeroCopy)
|
||||||
.def("compile", &PyClient::Compile, py::arg("computation"),
|
.def("compile", &PyClient::Compile, py::arg("computation"),
|
||||||
py::arg("compile_options") = CompileOptions())
|
py::arg("compile_options") = CompileOptions())
|
||||||
.def("heap_profile", &PyClient::HeapProfile);
|
.def("heap_profile", &PyClient::HeapProfile);
|
||||||
|
@ -304,6 +304,7 @@ def computation_count():
|
|||||||
Device = _xla.Device
|
Device = _xla.Device
|
||||||
CompileOptions = _xla.CompileOptions
|
CompileOptions = _xla.CompileOptions
|
||||||
|
|
||||||
|
HostBufferSemantics = _xla.HostBufferSemantics
|
||||||
|
|
||||||
# An Executable is a C++ class that duck types with the following API:
|
# An Executable is a C++ class that duck types with the following API:
|
||||||
# class Executable(object):
|
# class Executable(object):
|
||||||
|
@ -1986,7 +1986,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
def testRoundTrip(self, dtype, shape):
|
def testRoundTrip(self, dtype, shape):
|
||||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||||
x_ptr = x.__array_interface__["data"][0]
|
x_ptr = x.__array_interface__["data"][0]
|
||||||
buffer = self.backend.buffer_from_pyval(x)
|
buffer = self.backend.buffer_from_pyval(
|
||||||
|
x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY)
|
||||||
y = np.array(buffer, copy=False)
|
y = np.array(buffer, copy=False)
|
||||||
y_ptr = y.__array_interface__["data"][0]
|
y_ptr = y.__array_interface__["data"][0]
|
||||||
np.testing.assert_array_equal(x, y)
|
np.testing.assert_array_equal(x, y)
|
||||||
@ -1995,7 +1996,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
|||||||
self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
|
self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
|
||||||
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
|
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
|
||||||
|
|
||||||
buffer2 = self.backend.buffer_from_pyval(x, force_copy=True)
|
during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL
|
||||||
|
buffer2 = self.backend.buffer_from_pyval(
|
||||||
|
x, host_buffer_semantics=during_call)
|
||||||
z = np.array(buffer2, copy=False)
|
z = np.array(buffer2, copy=False)
|
||||||
self.assertNotEqual(x.__array_interface__["data"][0],
|
self.assertNotEqual(x.__array_interface__["data"][0],
|
||||||
z.__array_interface__["data"][0])
|
z.__array_interface__["data"][0])
|
||||||
|
Loading…
Reference in New Issue
Block a user