[PJRT] Fix potential misuse of PjRtBuffer::FromHostBuffer
.
Add a new `PjRtBuffer::HostBufferSemantics` enum that describes the possible contracts between caller and runtime. * Change `FromHostBuffer(..., force_copy, ...)` to `FromHostBuffer(..., host_buffer_semantics, ...)`. We were seeing some data races between modifications to a NumPy array and JAX on CPU, due to unintended buffer aliasing. This change allows clients to control whether they want zero-copy behavior or not. PiperOrigin-RevId: 316672280 Change-Id: Ibee296305005e0aa306a2c0aacf4b35a3d6c3ac1
This commit is contained in:
parent
d9532e6526
commit
572442eb16
@ -59,6 +59,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
|
||||
return std::make_shared<PjRtClient>(
|
||||
kCpuPlatformName, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
}
|
||||
|
||||
|
@ -72,18 +72,21 @@ TEST(GpuMultiStream, Basics) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto dummy_buffer,
|
||||
PjRtBuffer::FromHostBuffer(
|
||||
dummy_inputs.data(), dummy_shape, /*force_copy=*/false,
|
||||
dummy_inputs.data(), dummy_shape,
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, client.get(), device));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto in_buffer0,
|
||||
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false,
|
||||
/*buffer_reference=*/nullptr, client.get(),
|
||||
device));
|
||||
PjRtBuffer::FromHostBuffer(
|
||||
inputs.data(), shape,
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, client.get(), device));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto in_buffer1,
|
||||
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false,
|
||||
/*buffer_reference=*/nullptr, client.get(),
|
||||
device));
|
||||
PjRtBuffer::FromHostBuffer(
|
||||
inputs.data(), shape,
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
|
||||
/*buffer_reference=*/nullptr, client.get(), device));
|
||||
// The execution may be enqueued before the transfers complete, requiring
|
||||
// adequate device-side synchronization.
|
||||
ExecuteOptions options;
|
||||
|
@ -53,6 +53,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
|
||||
return std::make_shared<PjRtClient>(
|
||||
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr);
|
||||
}
|
||||
|
||||
|
@ -316,6 +316,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
|
||||
"gpu", xla_client, std::move(devices),
|
||||
/*node_id=*/node_id, std::move(allocator),
|
||||
std::move(host_memory_allocator),
|
||||
/*should_stage_host_to_device_transfers=*/true,
|
||||
/*gpu_run_options=*/std::move(gpu_run_options));
|
||||
return pyclient;
|
||||
}
|
||||
|
@ -95,6 +95,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
@ -154,18 +155,35 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
|
||||
return xla_assignment;
|
||||
}
|
||||
|
||||
class CpuAllocator : public tensorflow::Allocator {
|
||||
public:
|
||||
CpuAllocator() = default;
|
||||
|
||||
string Name() override { return "cpu"; }
|
||||
|
||||
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
|
||||
return tensorflow::port::AlignedMalloc(num_bytes, alignment);
|
||||
}
|
||||
void DeallocateRaw(void* ptr) override {
|
||||
return tensorflow::port::AlignedFree(ptr);
|
||||
}
|
||||
};
|
||||
|
||||
PjRtClient::PjRtClient(
|
||||
std::string platform_name, LocalClient* client,
|
||||
std::vector<std::unique_ptr<Device>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
|
||||
: platform_name_(std::move(platform_name)),
|
||||
client_(client),
|
||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||
devices_(std::move(devices)),
|
||||
host_id_(host_id),
|
||||
owned_allocator_(std::move(allocator)),
|
||||
host_memory_allocator_(std::move(host_memory_allocator)),
|
||||
should_stage_host_to_device_transfers_(
|
||||
should_stage_host_to_device_transfers),
|
||||
gpu_run_options_(std::move(gpu_run_options)),
|
||||
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
|
||||
client->device_count()) {
|
||||
@ -175,6 +193,10 @@ PjRtClient::PjRtClient(
|
||||
allocator_ = client_->backend().memory_allocator();
|
||||
}
|
||||
|
||||
if (!host_memory_allocator_) {
|
||||
host_memory_allocator_ = std::make_unique<CpuAllocator>();
|
||||
}
|
||||
|
||||
for (const std::unique_ptr<Device>& device : devices_) {
|
||||
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
|
||||
<< "Duplicate device id: " << device->id();
|
||||
@ -526,7 +548,8 @@ void PjRtBuffer::ScopedHold::AddToInput(
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
const void* data, const Shape& shape, bool force_copy,
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
||||
Device* device) {
|
||||
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
|
||||
@ -537,34 +560,63 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
|
||||
// If we are on the host platform and the input buffer is sufficiently
|
||||
// aligned, we can simply point to the input array's data without any further
|
||||
// copies. At the time of writing we require a 16-byte alignment because XLA
|
||||
// may generate code which requires it.
|
||||
if (!force_copy &&
|
||||
((absl::bit_cast<std::uintptr_t>(data) &
|
||||
(cpu_function_runtime::kMinAlign - 1)) == 0) &&
|
||||
local_device->executor()->platform()->id() == se::host::kHostPlatformId) {
|
||||
std::function<void()> on_delete_callback =
|
||||
[buffer_reference{std::move(buffer_reference)}]() {
|
||||
// Frees buffer_reference.
|
||||
};
|
||||
se::DeviceMemoryBase buffer(const_cast<void*>(data),
|
||||
ShapeUtil::ByteSizeOf(shape));
|
||||
absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
|
||||
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
|
||||
/*allocator=*/nullptr, local_device->device_ordinal(),
|
||||
std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
|
||||
std::move(on_delete_callback));
|
||||
return absl::make_unique<PjRtBuffer>(shape, shape, std::move(device_buffer),
|
||||
client, device);
|
||||
}
|
||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||
transfer_manager->ChooseCompactLayoutForShape(shape));
|
||||
|
||||
// The CPU platform is special because the "host" and the "device" are in the
|
||||
// same memory space. If the input shape is in the correct layout and we don't
|
||||
// want to defer the copy onto a thread, we can use the following fast
|
||||
// path.
|
||||
bool is_cpu_platform =
|
||||
local_device->executor()->platform()->id() == se::host::kHostPlatformId;
|
||||
if (is_cpu_platform) {
|
||||
// If we are on the host platform and the input buffer is sufficiently
|
||||
// aligned, we can simply point to the input array's data without any
|
||||
// further copies. At the time of writing we require a 16-byte alignment
|
||||
// because XLA may generate code which requires it.
|
||||
bool can_use_zero_copy =
|
||||
host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
|
||||
((absl::bit_cast<std::uintptr_t>(data) &
|
||||
(cpu_function_runtime::kMinAlign - 1)) == 0);
|
||||
if (shape.layout() == compact_shape.layout() &&
|
||||
(host_buffer_semantics ==
|
||||
HostBufferSemantics::kImmutableOnlyDuringCall ||
|
||||
can_use_zero_copy)) {
|
||||
std::function<void()> on_delete_callback;
|
||||
se::DeviceMemoryBase buffer;
|
||||
// If we are on the host platform and the input buffer is sufficiently
|
||||
// aligned, we can simply point to the input array's data without any
|
||||
// further copies. At the time of writing we require a 16-byte alignment
|
||||
// because XLA may generate code which requires it.
|
||||
if (can_use_zero_copy) {
|
||||
on_delete_callback = [buffer_reference{std::move(buffer_reference)}]() {
|
||||
// Frees buffer_reference.
|
||||
};
|
||||
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
|
||||
} else {
|
||||
void* staging_buffer = client->host_memory_allocator()->AllocateRaw(
|
||||
cpu_function_runtime::kMinAlign, size);
|
||||
on_delete_callback = [staging_buffer, client]() {
|
||||
client->host_memory_allocator()->DeallocateRaw(staging_buffer);
|
||||
};
|
||||
buffer = se::DeviceMemoryBase(staging_buffer, size);
|
||||
std::memcpy(staging_buffer, data, size);
|
||||
}
|
||||
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
|
||||
definition_events;
|
||||
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
|
||||
/*allocator=*/nullptr, local_device->device_ordinal(),
|
||||
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||
definition_events, std::move(on_delete_callback));
|
||||
return absl::make_unique<PjRtBuffer>(
|
||||
shape, shape, std::move(device_buffer), client, device);
|
||||
}
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtBuffer> py_buffer,
|
||||
AllocateDestinationBuffer(compact_shape, device, local_device,
|
||||
@ -573,17 +625,41 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
|
||||
CHECK(device_buffer.ok());
|
||||
|
||||
// If necessary, allocate a host-side buffer for staging host-to-device
|
||||
// transfers. On GPU this is a buffer in pinned memory.
|
||||
std::shared_ptr<void> staging_buffer;
|
||||
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
|
||||
client->should_stage_host_to_device_transfers()) {
|
||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
|
||||
client->host_memory_allocator()->DeallocateRaw(ptr);
|
||||
});
|
||||
}
|
||||
|
||||
// Copy the buffer into a staging buffer before returning control to the
|
||||
// caller if the caller only guaranteed that the buffer is valid for the
|
||||
// duration of the call. Otherwise, we stage (if necessary) on a separate
|
||||
// thread.
|
||||
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
|
||||
std::memcpy(staging_buffer.get(), data, size);
|
||||
buffer_reference.reset();
|
||||
data = nullptr;
|
||||
}
|
||||
|
||||
// The host to device transfer is performed on a thread pool, mostly because
|
||||
// it includes linearization that may be slow. It is OK to capture the
|
||||
// py_buffer pointer because the py_buffer can't be deleted until all the
|
||||
// usage holds have gone away.
|
||||
// TODO(misard) assess if it would be preferable to introduce a heuristic to
|
||||
// put the transfer into the calling thread for small literals.
|
||||
auto transfer_h2d = [client, transfer_manager, local_device,
|
||||
movable_device_buffer{device_buffer.ToClosure()}, data,
|
||||
shape, py_buffer{py_buffer.get()}, compact_shape,
|
||||
auto transfer_h2d = [client, transfer_manager, local_device, data, size,
|
||||
movable_device_buffer{device_buffer.ToClosure()}, shape,
|
||||
py_buffer{py_buffer.get()}, compact_shape,
|
||||
on_device_shape{py_buffer->on_device_shape()},
|
||||
buffer_reference{std::move(buffer_reference)}]() {
|
||||
staging_buffer{std::move(staging_buffer)},
|
||||
buffer_reference{std::move(buffer_reference)},
|
||||
host_buffer_semantics]() {
|
||||
ScopedHold device_buffer(movable_device_buffer);
|
||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
||||
// to report failures from a callback. However, the operations here are
|
||||
@ -593,20 +669,16 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
|
||||
compact_shape, on_device_shape, client->client()->platform());
|
||||
|
||||
std::shared_ptr<void> staging_buffer;
|
||||
|
||||
// If applicable on the backend, stage the transfer via host memory
|
||||
// allocated via the host_memory_allocator. On GPU, this is pinned
|
||||
// memory.
|
||||
if (client->host_memory_allocator()) {
|
||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
|
||||
client->host_memory_allocator()->DeallocateRaw(ptr);
|
||||
});
|
||||
std::memcpy(ptr, data, size);
|
||||
if (staging_buffer) {
|
||||
// If we didn't already copy the input buffer into the staging buffer,
|
||||
// do so now.
|
||||
if (host_buffer_semantics !=
|
||||
HostBufferSemantics::kImmutableOnlyDuringCall) {
|
||||
std::memcpy(staging_buffer.get(), data, size);
|
||||
}
|
||||
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
|
||||
shape);
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
@ -626,9 +698,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
|
||||
|
||||
local_device->ThenRelease(
|
||||
local_device->host_to_device_stream(),
|
||||
std::make_pair(buffer_reference, std::move(staging_buffer)));
|
||||
std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
|
||||
};
|
||||
if (is_cpu_platform) {
|
||||
// Using the h2d_transfer_pool would be a double thread hop; the code
|
||||
// already defers its work onto a stream (= thread on CPU).
|
||||
transfer_h2d();
|
||||
} else {
|
||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
}
|
||||
return py_buffer;
|
||||
}
|
||||
|
||||
|
@ -128,6 +128,7 @@ class PjRtClient {
|
||||
std::vector<std::unique_ptr<Device>> devices, int host_id,
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
|
||||
bool should_stage_host_to_device_transfers,
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options);
|
||||
virtual ~PjRtClient() = default;
|
||||
|
||||
@ -153,6 +154,9 @@ class PjRtClient {
|
||||
tensorflow::Allocator* host_memory_allocator() const {
|
||||
return host_memory_allocator_.get();
|
||||
}
|
||||
bool should_stage_host_to_device_transfers() const {
|
||||
return should_stage_host_to_device_transfers_;
|
||||
}
|
||||
|
||||
GpuExecutableRunOptions* gpu_run_options() const {
|
||||
return gpu_run_options_.get();
|
||||
@ -190,6 +194,9 @@ class PjRtClient {
|
||||
std::string platform_name_;
|
||||
LocalClient* client_;
|
||||
|
||||
// Allocator to be used for staging memory transfers to devices.
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||
|
||||
// Includes all devices, including non-local devices on multi-host platforms.
|
||||
std::vector<std::unique_ptr<Device>> devices_;
|
||||
// Maps Device::id() to the corresponding Device. Includes all devices.
|
||||
@ -201,10 +208,10 @@ class PjRtClient {
|
||||
se::DeviceMemoryAllocator* allocator_;
|
||||
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
|
||||
|
||||
// Allocator to be used for staging memory transfers to devices. Optional;
|
||||
// only used on GPU where it is more efficient to copy buffers to and from the
|
||||
// device via a staging area of pinned memory.
|
||||
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
|
||||
// Should we always prefer to stage host-to-device transfers via memory
|
||||
// allocated on host_memory_allocator_? True only on GPU, where we prefer to
|
||||
// transfer via pinned memory.
|
||||
bool should_stage_host_to_device_transfers_;
|
||||
|
||||
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options_;
|
||||
|
||||
@ -396,13 +403,35 @@ class PjRtBuffer {
|
||||
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
|
||||
};
|
||||
|
||||
// If `force_copy` is true, forces a copy of the input buffer on CPU.
|
||||
// Otherwise the library is free to alias the output buffer with `data`.
|
||||
// `buffer_reference` is an optional shared pointer that should be kept alive
|
||||
// by the runtime as long as the contents of `data` may still be accessed by
|
||||
// the runtime (may be nullptr).
|
||||
// Describes the semantics the caller to FromHostBuffer expects from the
|
||||
// runtime, in a total order from most restrictive to least restrictive.
|
||||
enum class HostBufferSemantics {
|
||||
// The runtime may not hold references to `data` after the call to
|
||||
// `FromHostBuffer` completes. The caller promises that `data` is immutable
|
||||
// and will not be freed only for the duration of the FromHostBuffer call.
|
||||
// `buffer_reference` will be freed by the time `FromHostBuffer` returns.
|
||||
kImmutableOnlyDuringCall,
|
||||
|
||||
// The runtime may hold onto `data` after the call to `FromHostBuffer`
|
||||
// returns while the runtime completes a transfer to the device. The caller
|
||||
// promises not to mutate or free `data` until the transfer completes, at
|
||||
// which point the runtime will release `buffer_reference`. It is also
|
||||
// correct to wait on the host (directly or indirectly) for the buffer's
|
||||
// definition event to complete.
|
||||
kImmutableUntilTransferCompletes,
|
||||
|
||||
// The PjRtBuffer may alias `data` internally and the runtime may use the
|
||||
// `data` contents as long as the buffer is alive.
|
||||
// The caller promises to keep `data` alive and not to mutate its contents
|
||||
// as long as the buffer is alive; to notify the caller that the buffer may
|
||||
// be freed, the runtime will release its `buffer_reference` when the
|
||||
// PjRtBuffer is freed. On non-CPU platforms this acts identically to
|
||||
// kImmutableUntilTransferCompletes.
|
||||
kZeroCopy,
|
||||
};
|
||||
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
|
||||
const void* data, const Shape& shape, bool force_copy,
|
||||
const void* data, const Shape& shape,
|
||||
HostBufferSemantics host_buffer_semantics,
|
||||
std::shared_ptr<void> buffer_reference, PjRtClient* client,
|
||||
Device* device);
|
||||
|
||||
|
@ -84,7 +84,8 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
|
||||
const pybind11::object& argument, Device* device, bool force_copy) {
|
||||
const pybind11::object& argument, Device* device, bool force_copy,
|
||||
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
|
||||
if (device == nullptr) {
|
||||
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
|
||||
device = pjrt_client_->local_devices().front();
|
||||
@ -111,9 +112,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
buffer, PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
|
||||
std::move(py_buffer_ref),
|
||||
pjrt_client_.get(), device));
|
||||
buffer, PjRtBuffer::FromHostBuffer(
|
||||
c->buf_ptr, c->shape, host_buffer_semantics,
|
||||
std::move(py_buffer_ref), pjrt_client_.get(), device));
|
||||
}
|
||||
auto traceback = Traceback::Get();
|
||||
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
|
||||
|
@ -120,7 +120,8 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
|
||||
const pybind11::object& argument, Device* device, bool force_copy);
|
||||
const pybind11::object& argument, Device* device, bool force_copy,
|
||||
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
|
||||
|
||||
StatusOr<std::unique_ptr<PyExecutable>> Compile(
|
||||
const XlaComputation& computation, CompileOptions options);
|
||||
|
@ -509,6 +509,13 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
|
||||
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
|
||||
|
||||
py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics")
|
||||
.value("IMMUTABLE_ONLY_DURING_CALL",
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall)
|
||||
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
|
||||
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes)
|
||||
.value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy);
|
||||
|
||||
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
|
||||
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
|
||||
.def("device_count", &PyClient::device_count)
|
||||
@ -527,7 +534,9 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
.def("create_host_to_device_channel_handle",
|
||||
&PyClient::CreateHostToDeviceChannelHandle)
|
||||
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
|
||||
py::arg("device") = nullptr, py::arg("force_copy") = false)
|
||||
py::arg("device") = nullptr, py::arg("force_copy") = false,
|
||||
py::arg("host_buffer_semantics") =
|
||||
PjRtBuffer::HostBufferSemantics::kZeroCopy)
|
||||
.def("compile", &PyClient::Compile, py::arg("computation"),
|
||||
py::arg("compile_options") = CompileOptions())
|
||||
.def("heap_profile", &PyClient::HeapProfile);
|
||||
|
@ -304,6 +304,7 @@ def computation_count():
|
||||
Device = _xla.Device
|
||||
CompileOptions = _xla.CompileOptions
|
||||
|
||||
HostBufferSemantics = _xla.HostBufferSemantics
|
||||
|
||||
# An Executable is a C++ class that duck types with the following API:
|
||||
# class Executable(object):
|
||||
|
@ -1986,7 +1986,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
def testRoundTrip(self, dtype, shape):
|
||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||
x_ptr = x.__array_interface__["data"][0]
|
||||
buffer = self.backend.buffer_from_pyval(x)
|
||||
buffer = self.backend.buffer_from_pyval(
|
||||
x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY)
|
||||
y = np.array(buffer, copy=False)
|
||||
y_ptr = y.__array_interface__["data"][0]
|
||||
np.testing.assert_array_equal(x, y)
|
||||
@ -1995,7 +1996,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
|
||||
self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
|
||||
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
|
||||
|
||||
buffer2 = self.backend.buffer_from_pyval(x, force_copy=True)
|
||||
during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL
|
||||
buffer2 = self.backend.buffer_from_pyval(
|
||||
x, host_buffer_semantics=during_call)
|
||||
z = np.array(buffer2, copy=False)
|
||||
self.assertNotEqual(x.__array_interface__["data"][0],
|
||||
z.__array_interface__["data"][0])
|
||||
|
Loading…
Reference in New Issue
Block a user