[PJRT] Fix potential misuse of PjRtBuffer::FromHostBuffer.

Add a new `PjRtBuffer::HostBufferSemantics` enum that describes the possible contracts between caller and runtime.

* Change `FromHostBuffer(..., force_copy, ...)` to `FromHostBuffer(..., host_buffer_semantics, ...)`.

We were seeing some data races between modifications to a NumPy array and JAX on CPU, due to unintended buffer aliasing. This change allows clients to control whether they want zero-copy behavior or not.

PiperOrigin-RevId: 316672280
Change-Id: Ibee296305005e0aa306a2c0aacf4b35a3d6c3ac1
This commit is contained in:
Peter Hawkins 2020-06-16 06:56:05 -07:00 committed by TensorFlower Gardener
parent d9532e6526
commit 572442eb16
11 changed files with 195 additions and 67 deletions

View File

@ -59,6 +59,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
return std::make_shared<PjRtClient>( return std::make_shared<PjRtClient>(
kCpuPlatformName, client, std::move(devices), /*host_id=*/0, kCpuPlatformName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr); /*gpu_run_options=*/nullptr);
} }

View File

@ -72,18 +72,21 @@ TEST(GpuMultiStream, Basics) {
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto dummy_buffer, auto dummy_buffer,
PjRtBuffer::FromHostBuffer( PjRtBuffer::FromHostBuffer(
dummy_inputs.data(), dummy_shape, /*force_copy=*/false, dummy_inputs.data(), dummy_shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device)); /*buffer_reference=*/nullptr, client.get(), device));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer0, auto in_buffer0,
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false, PjRtBuffer::FromHostBuffer(
/*buffer_reference=*/nullptr, client.get(), inputs.data(), shape,
device)); PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
TF_ASSERT_OK_AND_ASSIGN( TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer1, auto in_buffer1,
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false, PjRtBuffer::FromHostBuffer(
/*buffer_reference=*/nullptr, client.get(), inputs.data(), shape,
device)); PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
// The execution may be enqueued before the transfers complete, requiring // The execution may be enqueued before the transfers complete, requiring
// adequate device-side synchronization. // adequate device-side synchronization.
ExecuteOptions options; ExecuteOptions options;

View File

@ -53,6 +53,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
return std::make_shared<PjRtClient>( return std::make_shared<PjRtClient>(
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0, kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr); /*gpu_run_options=*/nullptr);
} }

View File

@ -316,6 +316,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
"gpu", xla_client, std::move(devices), "gpu", xla_client, std::move(devices),
/*node_id=*/node_id, std::move(allocator), /*node_id=*/node_id, std::move(allocator),
std::move(host_memory_allocator), std::move(host_memory_allocator),
/*should_stage_host_to_device_transfers=*/true,
/*gpu_run_options=*/std::move(gpu_run_options)); /*gpu_run_options=*/std::move(gpu_run_options));
return pyclient; return pyclient;
} }

View File

@ -95,6 +95,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
@ -154,18 +155,35 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
return xla_assignment; return xla_assignment;
} }
class CpuAllocator : public tensorflow::Allocator {
public:
CpuAllocator() = default;
string Name() override { return "cpu"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
return tensorflow::port::AlignedMalloc(num_bytes, alignment);
}
void DeallocateRaw(void* ptr) override {
return tensorflow::port::AlignedFree(ptr);
}
};
PjRtClient::PjRtClient( PjRtClient::PjRtClient(
std::string platform_name, LocalClient* client, std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<Device>> devices, int host_id, std::vector<std::unique_ptr<Device>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator, std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options) std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
: platform_name_(std::move(platform_name)), : platform_name_(std::move(platform_name)),
client_(client), client_(client),
host_memory_allocator_(std::move(host_memory_allocator)),
devices_(std::move(devices)), devices_(std::move(devices)),
host_id_(host_id), host_id_(host_id),
owned_allocator_(std::move(allocator)), owned_allocator_(std::move(allocator)),
host_memory_allocator_(std::move(host_memory_allocator)), should_stage_host_to_device_transfers_(
should_stage_host_to_device_transfers),
gpu_run_options_(std::move(gpu_run_options)), gpu_run_options_(std::move(gpu_run_options)),
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer", h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
client->device_count()) { client->device_count()) {
@ -175,6 +193,10 @@ PjRtClient::PjRtClient(
allocator_ = client_->backend().memory_allocator(); allocator_ = client_->backend().memory_allocator();
} }
if (!host_memory_allocator_) {
host_memory_allocator_ = std::make_unique<CpuAllocator>();
}
for (const std::unique_ptr<Device>& device : devices_) { for (const std::unique_ptr<Device>& device : devices_) {
CHECK(id_to_device_.insert({device->id(), device.get()}).second) CHECK(id_to_device_.insert({device->id(), device.get()}).second)
<< "Duplicate device id: " << device->id(); << "Duplicate device id: " << device->id();
@ -526,7 +548,8 @@ void PjRtBuffer::ScopedHold::AddToInput(
/* static */ /* static */
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer( StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
const void* data, const Shape& shape, bool force_copy, const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client, std::shared_ptr<void> buffer_reference, PjRtClient* client,
Device* device) { Device* device) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer"); tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
@ -537,34 +560,63 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
} }
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState()); device->GetLocalDeviceState());
int64 size = ShapeUtil::ByteSizeOf(shape);
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any further
// copies. At the time of writing we require a 16-byte alignment because XLA
// may generate code which requires it.
if (!force_copy &&
((absl::bit_cast<std::uintptr_t>(data) &
(cpu_function_runtime::kMinAlign - 1)) == 0) &&
local_device->executor()->platform()->id() == se::host::kHostPlatformId) {
std::function<void()> on_delete_callback =
[buffer_reference{std::move(buffer_reference)}]() {
// Frees buffer_reference.
};
se::DeviceMemoryBase buffer(const_cast<void*>(data),
ShapeUtil::ByteSizeOf(shape));
absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
/*allocator=*/nullptr, local_device->device_ordinal(),
std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
std::move(on_delete_callback));
return absl::make_unique<PjRtBuffer>(shape, shape, std::move(device_buffer),
client, device);
}
TransferManager* transfer_manager = TransferManager* transfer_manager =
client->client()->backend().transfer_manager(); client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape, TF_ASSIGN_OR_RETURN(Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(shape)); transfer_manager->ChooseCompactLayoutForShape(shape));
// The CPU platform is special because the "host" and the "device" are in the
// same memory space. If the input shape is in the correct layout and we don't
// want to defer the copy onto a thread, we can use the following fast
// path.
bool is_cpu_platform =
local_device->executor()->platform()->id() == se::host::kHostPlatformId;
if (is_cpu_platform) {
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any
// further copies. At the time of writing we require a 16-byte alignment
// because XLA may generate code which requires it.
bool can_use_zero_copy =
host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
((absl::bit_cast<std::uintptr_t>(data) &
(cpu_function_runtime::kMinAlign - 1)) == 0);
if (shape.layout() == compact_shape.layout() &&
(host_buffer_semantics ==
HostBufferSemantics::kImmutableOnlyDuringCall ||
can_use_zero_copy)) {
std::function<void()> on_delete_callback;
se::DeviceMemoryBase buffer;
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any
// further copies. At the time of writing we require a 16-byte alignment
// because XLA may generate code which requires it.
if (can_use_zero_copy) {
on_delete_callback = [buffer_reference{std::move(buffer_reference)}]() {
// Frees buffer_reference.
};
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
} else {
void* staging_buffer = client->host_memory_allocator()->AllocateRaw(
cpu_function_runtime::kMinAlign, size);
on_delete_callback = [staging_buffer, client]() {
client->host_memory_allocator()->DeallocateRaw(staging_buffer);
};
buffer = se::DeviceMemoryBase(staging_buffer, size);
std::memcpy(staging_buffer, data, size);
}
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
definition_events;
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
/*allocator=*/nullptr, local_device->device_ordinal(),
std::initializer_list<se::DeviceMemoryBase>{buffer},
definition_events, std::move(on_delete_callback));
return absl::make_unique<PjRtBuffer>(
shape, shape, std::move(device_buffer), client, device);
}
}
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> py_buffer, std::unique_ptr<PjRtBuffer> py_buffer,
AllocateDestinationBuffer(compact_shape, device, local_device, AllocateDestinationBuffer(compact_shape, device, local_device,
@ -573,17 +625,41 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold()); ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
CHECK(device_buffer.ok()); CHECK(device_buffer.ok());
// If necessary, allocate a host-side buffer for staging host-to-device
// transfers. On GPU this is a buffer in pinned memory.
std::shared_ptr<void> staging_buffer;
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
client->should_stage_host_to_device_transfers()) {
void* ptr = client->host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size);
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
client->host_memory_allocator()->DeallocateRaw(ptr);
});
}
// Copy the buffer into a staging buffer before returning control to the
// caller if the caller only guaranteed that the buffer is valid for the
// duration of the call. Otherwise, we stage (if necessary) on a separate
// thread.
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
std::memcpy(staging_buffer.get(), data, size);
buffer_reference.reset();
data = nullptr;
}
// The host to device transfer is performed on a thread pool, mostly because // The host to device transfer is performed on a thread pool, mostly because
// it includes linearization that may be slow. It is OK to capture the // it includes linearization that may be slow. It is OK to capture the
// py_buffer pointer because the py_buffer can't be deleted until all the // py_buffer pointer because the py_buffer can't be deleted until all the
// usage holds have gone away. // usage holds have gone away.
// TODO(misard) assess if it would be preferable to introduce a heuristic to // TODO(misard) assess if it would be preferable to introduce a heuristic to
// put the transfer into the calling thread for small literals. // put the transfer into the calling thread for small literals.
auto transfer_h2d = [client, transfer_manager, local_device, auto transfer_h2d = [client, transfer_manager, local_device, data, size,
movable_device_buffer{device_buffer.ToClosure()}, data, movable_device_buffer{device_buffer.ToClosure()}, shape,
shape, py_buffer{py_buffer.get()}, compact_shape, py_buffer{py_buffer.get()}, compact_shape,
on_device_shape{py_buffer->on_device_shape()}, on_device_shape{py_buffer->on_device_shape()},
buffer_reference{std::move(buffer_reference)}]() { staging_buffer{std::move(staging_buffer)},
buffer_reference{std::move(buffer_reference)},
host_buffer_semantics]() {
ScopedHold device_buffer(movable_device_buffer); ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way // This function uses TF_CHECK_OK and ValueOrDie() since we have no way
// to report failures from a callback. However, the operations here are // to report failures from a callback. However, the operations here are
@ -593,20 +669,16 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
ShapedBuffer buffer = device_buffer->AsShapedBuffer( ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, client->client()->platform()); compact_shape, on_device_shape, client->client()->platform());
std::shared_ptr<void> staging_buffer;
// If applicable on the backend, stage the transfer via host memory // If applicable on the backend, stage the transfer via host memory
// allocated via the host_memory_allocator. On GPU, this is pinned // allocated via the host_memory_allocator. On GPU, this is pinned
// memory. // memory.
if (client->host_memory_allocator()) { if (staging_buffer) {
int64 size = ShapeUtil::ByteSizeOf(shape); // If we didn't already copy the input buffer into the staging buffer,
void* ptr = client->host_memory_allocator()->AllocateRaw( // do so now.
tensorflow::Allocator::kAllocatorAlignment, size); if (host_buffer_semantics !=
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) { HostBufferSemantics::kImmutableOnlyDuringCall) {
client->host_memory_allocator()->DeallocateRaw(ptr); std::memcpy(staging_buffer.get(), data, size);
}); }
std::memcpy(ptr, data, size);
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()), BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
shape); shape);
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
@ -626,9 +698,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
local_device->ThenRelease( local_device->ThenRelease(
local_device->host_to_device_stream(), local_device->host_to_device_stream(),
std::make_pair(buffer_reference, std::move(staging_buffer))); std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
}; };
if (is_cpu_platform) {
// Using the h2d_transfer_pool would be a double thread hop; the code
// already defers its work onto a stream (= thread on CPU).
transfer_h2d();
} else {
client->h2d_transfer_pool()->Schedule(transfer_h2d); client->h2d_transfer_pool()->Schedule(transfer_h2d);
}
return py_buffer; return py_buffer;
} }

View File

@ -128,6 +128,7 @@ class PjRtClient {
std::vector<std::unique_ptr<Device>> devices, int host_id, std::vector<std::unique_ptr<Device>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator, std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator, std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options); std::unique_ptr<GpuExecutableRunOptions> gpu_run_options);
virtual ~PjRtClient() = default; virtual ~PjRtClient() = default;
@ -153,6 +154,9 @@ class PjRtClient {
tensorflow::Allocator* host_memory_allocator() const { tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get(); return host_memory_allocator_.get();
} }
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
GpuExecutableRunOptions* gpu_run_options() const { GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get(); return gpu_run_options_.get();
@ -190,6 +194,9 @@ class PjRtClient {
std::string platform_name_; std::string platform_name_;
LocalClient* client_; LocalClient* client_;
// Allocator to be used for staging memory transfers to devices.
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Includes all devices, including non-local devices on multi-host platforms. // Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::unique_ptr<Device>> devices_; std::vector<std::unique_ptr<Device>> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices. // Maps Device::id() to the corresponding Device. Includes all devices.
@ -201,10 +208,10 @@ class PjRtClient {
se::DeviceMemoryAllocator* allocator_; se::DeviceMemoryAllocator* allocator_;
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_; std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
// Allocator to be used for staging memory transfers to devices. Optional; // Should we always prefer to stage host-to-device transfers via memory
// only used on GPU where it is more efficient to copy buffers to and from the // allocated on host_memory_allocator_? True only on GPU, where we prefer to
// device via a staging area of pinned memory. // transfer via pinned memory.
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_; bool should_stage_host_to_device_transfers_;
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options_; std::unique_ptr<GpuExecutableRunOptions> gpu_run_options_;
@ -396,13 +403,35 @@ class PjRtBuffer {
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_; StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
}; };
// If `force_copy` is true, forces a copy of the input buffer on CPU. // Describes the semantics the caller to FromHostBuffer expects from the
// Otherwise the library is free to alias the output buffer with `data`. // runtime, in a total order from most restrictive to least restrictive.
// `buffer_reference` is an optional shared pointer that should be kept alive enum class HostBufferSemantics {
// by the runtime as long as the contents of `data` may still be accessed by // The runtime may not hold references to `data` after the call to
// the runtime (may be nullptr). // `FromHostBuffer` completes. The caller promises that `data` is immutable
// and will not be freed only for the duration of the FromHostBuffer call.
// `buffer_reference` will be freed by the time `FromHostBuffer` returns.
kImmutableOnlyDuringCall,
// The runtime may hold onto `data` after the call to `FromHostBuffer`
// returns while the runtime completes a transfer to the device. The caller
// promises not to mutate or free `data` until the transfer completes, at
// which point the runtime will release `buffer_reference`. It is also
// correct to wait on the host (directly or indirectly) for the buffer's
// definition event to complete.
kImmutableUntilTransferCompletes,
// The PjRtBuffer may alias `data` internally and the runtime may use the
// `data` contents as long as the buffer is alive.
// The caller promises to keep `data` alive and not to mutate its contents
// as long as the buffer is alive; to notify the caller that the buffer may
// be freed, the runtime will release its `buffer_reference` when the
// PjRtBuffer is freed. On non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes.
kZeroCopy,
};
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer( static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
const void* data, const Shape& shape, bool force_copy, const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client, std::shared_ptr<void> buffer_reference, PjRtClient* client,
Device* device); Device* device);

View File

@ -84,7 +84,8 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
} }
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal( StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy) { const pybind11::object& argument, Device* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
if (device == nullptr) { if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty()); TF_RET_CHECK(!pjrt_client_->local_devices().empty());
device = pjrt_client_->local_devices().front(); device = pjrt_client_->local_devices().front();
@ -111,9 +112,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
{ {
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
buffer, PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy, buffer, PjRtBuffer::FromHostBuffer(
std::move(py_buffer_ref), c->buf_ptr, c->shape, host_buffer_semantics,
pjrt_client_.get(), device)); std::move(py_buffer_ref), pjrt_client_.get(), device));
} }
auto traceback = Traceback::Get(); auto traceback = Traceback::Get();
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer), return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),

View File

@ -120,7 +120,8 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
} }
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal( StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy); const pybind11::object& argument, Device* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
StatusOr<std::unique_ptr<PyExecutable>> Compile( StatusOr<std::unique_ptr<PyExecutable>> Compile(
const XlaComputation& computation, CompileOptions options); const XlaComputation& computation, CompileOptions options);

View File

@ -509,6 +509,13 @@ PYBIND11_MODULE(xla_extension, m) {
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform) .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
.value("BFC", GpuAllocatorConfig::Kind::kBFC); .value("BFC", GpuAllocatorConfig::Kind::kBFC);
py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics")
.value("IMMUTABLE_ONLY_DURING_CALL",
PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall)
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes)
.value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy);
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client"); py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name) py_local_client.def_property_readonly("platform", &PyClient::platform_name)
.def("device_count", &PyClient::device_count) .def("device_count", &PyClient::device_count)
@ -527,7 +534,9 @@ PYBIND11_MODULE(xla_extension, m) {
.def("create_host_to_device_channel_handle", .def("create_host_to_device_channel_handle",
&PyClient::CreateHostToDeviceChannelHandle) &PyClient::CreateHostToDeviceChannelHandle)
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"), .def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
py::arg("device") = nullptr, py::arg("force_copy") = false) py::arg("device") = nullptr, py::arg("force_copy") = false,
py::arg("host_buffer_semantics") =
PjRtBuffer::HostBufferSemantics::kZeroCopy)
.def("compile", &PyClient::Compile, py::arg("computation"), .def("compile", &PyClient::Compile, py::arg("computation"),
py::arg("compile_options") = CompileOptions()) py::arg("compile_options") = CompileOptions())
.def("heap_profile", &PyClient::HeapProfile); .def("heap_profile", &PyClient::HeapProfile);

View File

@ -304,6 +304,7 @@ def computation_count():
Device = _xla.Device Device = _xla.Device
CompileOptions = _xla.CompileOptions CompileOptions = _xla.CompileOptions
HostBufferSemantics = _xla.HostBufferSemantics
# An Executable is a C++ class that duck types with the following API: # An Executable is a C++ class that duck types with the following API:
# class Executable(object): # class Executable(object):

View File

@ -1986,7 +1986,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
def testRoundTrip(self, dtype, shape): def testRoundTrip(self, dtype, shape):
x = np.array(np.random.rand(*shape) * 100, dtype=dtype) x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
x_ptr = x.__array_interface__["data"][0] x_ptr = x.__array_interface__["data"][0]
buffer = self.backend.buffer_from_pyval(x) buffer = self.backend.buffer_from_pyval(
x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY)
y = np.array(buffer, copy=False) y = np.array(buffer, copy=False)
y_ptr = y.__array_interface__["data"][0] y_ptr = y.__array_interface__["data"][0]
np.testing.assert_array_equal(x, y) np.testing.assert_array_equal(x, y)
@ -1995,7 +1996,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr) self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer()) self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
buffer2 = self.backend.buffer_from_pyval(x, force_copy=True) during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL
buffer2 = self.backend.buffer_from_pyval(
x, host_buffer_semantics=during_call)
z = np.array(buffer2, copy=False) z = np.array(buffer2, copy=False)
self.assertNotEqual(x.__array_interface__["data"][0], self.assertNotEqual(x.__array_interface__["data"][0],
z.__array_interface__["data"][0]) z.__array_interface__["data"][0])