[PJRT] Fix potential misuse of PjRtBuffer::FromHostBuffer.

Add a new `PjRtBuffer::HostBufferSemantics` enum that describes the possible contracts between caller and runtime.

* Change `FromHostBuffer(..., force_copy, ...)` to `FromHostBuffer(..., host_buffer_semantics, ...)`.

We were seeing some data races between modifications to a NumPy array and JAX on CPU, due to unintended buffer aliasing. This change allows clients to control whether they want zero-copy behavior or not.

PiperOrigin-RevId: 316672280
Change-Id: Ibee296305005e0aa306a2c0aacf4b35a3d6c3ac1
This commit is contained in:
Peter Hawkins 2020-06-16 06:56:05 -07:00 committed by TensorFlower Gardener
parent d9532e6526
commit 572442eb16
11 changed files with 195 additions and 67 deletions

View File

@ -59,6 +59,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetCpuClient(bool asynchronous) {
return std::make_shared<PjRtClient>(
kCpuPlatformName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);
}

View File

@ -72,18 +72,21 @@ TEST(GpuMultiStream, Basics) {
TF_ASSERT_OK_AND_ASSIGN(
auto dummy_buffer,
PjRtBuffer::FromHostBuffer(
dummy_inputs.data(), dummy_shape, /*force_copy=*/false,
dummy_inputs.data(), dummy_shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer0,
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false,
/*buffer_reference=*/nullptr, client.get(),
device));
PjRtBuffer::FromHostBuffer(
inputs.data(), shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
TF_ASSERT_OK_AND_ASSIGN(
auto in_buffer1,
PjRtBuffer::FromHostBuffer(inputs.data(), shape, /*force_copy=*/false,
/*buffer_reference=*/nullptr, client.get(),
device));
PjRtBuffer::FromHostBuffer(
inputs.data(), shape,
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes,
/*buffer_reference=*/nullptr, client.get(), device));
// The execution may be enqueued before the transfers complete, requiring
// adequate device-side synchronization.
ExecuteOptions options;

View File

@ -53,6 +53,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetInterpreterClient() {
return std::make_shared<PjRtClient>(
kInterpreterPlatformName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);
}

View File

@ -316,6 +316,7 @@ StatusOr<std::shared_ptr<PjRtClient>> GetNvidiaGpuClient(
"gpu", xla_client, std::move(devices),
/*node_id=*/node_id, std::move(allocator),
std::move(host_memory_allocator),
/*should_stage_host_to_device_transfers=*/true,
/*gpu_run_options=*/std::move(gpu_run_options));
return pyclient;
}

View File

@ -95,6 +95,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
@ -154,18 +155,35 @@ StatusOr<DeviceAssignment> DevicesToDeviceAssignment(
return xla_assignment;
}
class CpuAllocator : public tensorflow::Allocator {
public:
CpuAllocator() = default;
string Name() override { return "cpu"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override {
return tensorflow::port::AlignedMalloc(num_bytes, alignment);
}
void DeallocateRaw(void* ptr) override {
return tensorflow::port::AlignedFree(ptr);
}
};
PjRtClient::PjRtClient(
std::string platform_name, LocalClient* client,
std::vector<std::unique_ptr<Device>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options)
: platform_name_(std::move(platform_name)),
client_(client),
host_memory_allocator_(std::move(host_memory_allocator)),
devices_(std::move(devices)),
host_id_(host_id),
owned_allocator_(std::move(allocator)),
host_memory_allocator_(std::move(host_memory_allocator)),
should_stage_host_to_device_transfers_(
should_stage_host_to_device_transfers),
gpu_run_options_(std::move(gpu_run_options)),
h2d_transfer_pool_(tensorflow::Env::Default(), "py_xla_h2d_transfer",
client->device_count()) {
@ -175,6 +193,10 @@ PjRtClient::PjRtClient(
allocator_ = client_->backend().memory_allocator();
}
if (!host_memory_allocator_) {
host_memory_allocator_ = std::make_unique<CpuAllocator>();
}
for (const std::unique_ptr<Device>& device : devices_) {
CHECK(id_to_device_.insert({device->id(), device.get()}).second)
<< "Duplicate device id: " << device->id();
@ -526,7 +548,8 @@ void PjRtBuffer::ScopedHold::AddToInput(
/* static */
StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
const void* data, const Shape& shape, bool force_copy,
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client,
Device* device) {
tensorflow::profiler::TraceMe traceme("PjRtBuffer::FromHostBuffer");
@ -537,34 +560,63 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
}
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any further
// copies. At the time of writing we require a 16-byte alignment because XLA
// may generate code which requires it.
if (!force_copy &&
((absl::bit_cast<std::uintptr_t>(data) &
(cpu_function_runtime::kMinAlign - 1)) == 0) &&
local_device->executor()->platform()->id() == se::host::kHostPlatformId) {
std::function<void()> on_delete_callback =
[buffer_reference{std::move(buffer_reference)}]() {
// Frees buffer_reference.
};
se::DeviceMemoryBase buffer(const_cast<void*>(data),
ShapeUtil::ByteSizeOf(shape));
absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events;
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
/*allocator=*/nullptr, local_device->device_ordinal(),
std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
std::move(on_delete_callback));
return absl::make_unique<PjRtBuffer>(shape, shape, std::move(device_buffer),
client, device);
}
int64 size = ShapeUtil::ByteSizeOf(shape);
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
TF_ASSIGN_OR_RETURN(Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(shape));
// The CPU platform is special because the "host" and the "device" are in the
// same memory space. If the input shape is in the correct layout and we don't
// want to defer the copy onto a thread, we can use the following fast
// path.
bool is_cpu_platform =
local_device->executor()->platform()->id() == se::host::kHostPlatformId;
if (is_cpu_platform) {
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any
// further copies. At the time of writing we require a 16-byte alignment
// because XLA may generate code which requires it.
bool can_use_zero_copy =
host_buffer_semantics == HostBufferSemantics::kZeroCopy &&
((absl::bit_cast<std::uintptr_t>(data) &
(cpu_function_runtime::kMinAlign - 1)) == 0);
if (shape.layout() == compact_shape.layout() &&
(host_buffer_semantics ==
HostBufferSemantics::kImmutableOnlyDuringCall ||
can_use_zero_copy)) {
std::function<void()> on_delete_callback;
se::DeviceMemoryBase buffer;
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the input array's data without any
// further copies. At the time of writing we require a 16-byte alignment
// because XLA may generate code which requires it.
if (can_use_zero_copy) {
on_delete_callback = [buffer_reference{std::move(buffer_reference)}]() {
// Frees buffer_reference.
};
buffer = se::DeviceMemoryBase(const_cast<void*>(data), size);
} else {
void* staging_buffer = client->host_memory_allocator()->AllocateRaw(
cpu_function_runtime::kMinAlign, size);
on_delete_callback = [staging_buffer, client]() {
client->host_memory_allocator()->DeallocateRaw(staging_buffer);
};
buffer = se::DeviceMemoryBase(staging_buffer, size);
std::memcpy(staging_buffer, data, size);
}
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
definition_events;
auto device_buffer = std::make_shared<TrackedDeviceBuffer>(
/*allocator=*/nullptr, local_device->device_ordinal(),
std::initializer_list<se::DeviceMemoryBase>{buffer},
definition_events, std::move(on_delete_callback));
return absl::make_unique<PjRtBuffer>(
shape, shape, std::move(device_buffer), client, device);
}
}
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtBuffer> py_buffer,
AllocateDestinationBuffer(compact_shape, device, local_device,
@ -573,17 +625,41 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
ScopedHold device_buffer(py_buffer->GetBufferWithUsageHold());
CHECK(device_buffer.ok());
// If necessary, allocate a host-side buffer for staging host-to-device
// transfers. On GPU this is a buffer in pinned memory.
std::shared_ptr<void> staging_buffer;
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall ||
client->should_stage_host_to_device_transfers()) {
void* ptr = client->host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size);
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
client->host_memory_allocator()->DeallocateRaw(ptr);
});
}
// Copy the buffer into a staging buffer before returning control to the
// caller if the caller only guaranteed that the buffer is valid for the
// duration of the call. Otherwise, we stage (if necessary) on a separate
// thread.
if (host_buffer_semantics == HostBufferSemantics::kImmutableOnlyDuringCall) {
std::memcpy(staging_buffer.get(), data, size);
buffer_reference.reset();
data = nullptr;
}
// The host to device transfer is performed on a thread pool, mostly because
// it includes linearization that may be slow. It is OK to capture the
// py_buffer pointer because the py_buffer can't be deleted until all the
// usage holds have gone away.
// TODO(misard) assess if it would be preferable to introduce a heuristic to
// put the transfer into the calling thread for small literals.
auto transfer_h2d = [client, transfer_manager, local_device,
movable_device_buffer{device_buffer.ToClosure()}, data,
shape, py_buffer{py_buffer.get()}, compact_shape,
auto transfer_h2d = [client, transfer_manager, local_device, data, size,
movable_device_buffer{device_buffer.ToClosure()}, shape,
py_buffer{py_buffer.get()}, compact_shape,
on_device_shape{py_buffer->on_device_shape()},
buffer_reference{std::move(buffer_reference)}]() {
staging_buffer{std::move(staging_buffer)},
buffer_reference{std::move(buffer_reference)},
host_buffer_semantics]() {
ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
// to report failures from a callback. However, the operations here are
@ -593,20 +669,16 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
ShapedBuffer buffer = device_buffer->AsShapedBuffer(
compact_shape, on_device_shape, client->client()->platform());
std::shared_ptr<void> staging_buffer;
// If applicable on the backend, stage the transfer via host memory
// allocated via the host_memory_allocator. On GPU, this is pinned
// memory.
if (client->host_memory_allocator()) {
int64 size = ShapeUtil::ByteSizeOf(shape);
void* ptr = client->host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size);
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
client->host_memory_allocator()->DeallocateRaw(ptr);
});
std::memcpy(ptr, data, size);
if (staging_buffer) {
// If we didn't already copy the input buffer into the staging buffer,
// do so now.
if (host_buffer_semantics !=
HostBufferSemantics::kImmutableOnlyDuringCall) {
std::memcpy(staging_buffer.get(), data, size);
}
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
shape);
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
@ -626,9 +698,15 @@ StatusOr<std::unique_ptr<PjRtBuffer>> PjRtBuffer::FromHostBuffer(
local_device->ThenRelease(
local_device->host_to_device_stream(),
std::make_pair(buffer_reference, std::move(staging_buffer)));
std::make_pair(std::move(buffer_reference), std::move(staging_buffer)));
};
if (is_cpu_platform) {
// Using the h2d_transfer_pool would be a double thread hop; the code
// already defers its work onto a stream (= thread on CPU).
transfer_h2d();
} else {
client->h2d_transfer_pool()->Schedule(transfer_h2d);
}
return py_buffer;
}

View File

@ -128,6 +128,7 @@ class PjRtClient {
std::vector<std::unique_ptr<Device>> devices, int host_id,
std::unique_ptr<se::DeviceMemoryAllocator> allocator,
std::unique_ptr<tensorflow::Allocator> host_memory_allocator,
bool should_stage_host_to_device_transfers,
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options);
virtual ~PjRtClient() = default;
@ -153,6 +154,9 @@ class PjRtClient {
tensorflow::Allocator* host_memory_allocator() const {
return host_memory_allocator_.get();
}
bool should_stage_host_to_device_transfers() const {
return should_stage_host_to_device_transfers_;
}
GpuExecutableRunOptions* gpu_run_options() const {
return gpu_run_options_.get();
@ -190,6 +194,9 @@ class PjRtClient {
std::string platform_name_;
LocalClient* client_;
// Allocator to be used for staging memory transfers to devices.
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Includes all devices, including non-local devices on multi-host platforms.
std::vector<std::unique_ptr<Device>> devices_;
// Maps Device::id() to the corresponding Device. Includes all devices.
@ -201,10 +208,10 @@ class PjRtClient {
se::DeviceMemoryAllocator* allocator_;
std::unique_ptr<se::DeviceMemoryAllocator> owned_allocator_;
// Allocator to be used for staging memory transfers to devices. Optional;
// only used on GPU where it is more efficient to copy buffers to and from the
// device via a staging area of pinned memory.
std::unique_ptr<tensorflow::Allocator> host_memory_allocator_;
// Should we always prefer to stage host-to-device transfers via memory
// allocated on host_memory_allocator_? True only on GPU, where we prefer to
// transfer via pinned memory.
bool should_stage_host_to_device_transfers_;
std::unique_ptr<GpuExecutableRunOptions> gpu_run_options_;
@ -396,13 +403,35 @@ class PjRtBuffer {
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
};
// If `force_copy` is true, forces a copy of the input buffer on CPU.
// Otherwise the library is free to alias the output buffer with `data`.
// `buffer_reference` is an optional shared pointer that should be kept alive
// by the runtime as long as the contents of `data` may still be accessed by
// the runtime (may be nullptr).
// Describes the semantics the caller to FromHostBuffer expects from the
// runtime, in a total order from most restrictive to least restrictive.
enum class HostBufferSemantics {
// The runtime may not hold references to `data` after the call to
// `FromHostBuffer` completes. The caller promises that `data` is immutable
// and will not be freed only for the duration of the FromHostBuffer call.
// `buffer_reference` will be freed by the time `FromHostBuffer` returns.
kImmutableOnlyDuringCall,
// The runtime may hold onto `data` after the call to `FromHostBuffer`
// returns while the runtime completes a transfer to the device. The caller
// promises not to mutate or free `data` until the transfer completes, at
// which point the runtime will release `buffer_reference`. It is also
// correct to wait on the host (directly or indirectly) for the buffer's
// definition event to complete.
kImmutableUntilTransferCompletes,
// The PjRtBuffer may alias `data` internally and the runtime may use the
// `data` contents as long as the buffer is alive.
// The caller promises to keep `data` alive and not to mutate its contents
// as long as the buffer is alive; to notify the caller that the buffer may
// be freed, the runtime will release its `buffer_reference` when the
// PjRtBuffer is freed. On non-CPU platforms this acts identically to
// kImmutableUntilTransferCompletes.
kZeroCopy,
};
static StatusOr<std::unique_ptr<PjRtBuffer>> FromHostBuffer(
const void* data, const Shape& shape, bool force_copy,
const void* data, const Shape& shape,
HostBufferSemantics host_buffer_semantics,
std::shared_ptr<void> buffer_reference, PjRtClient* client,
Device* device);

View File

@ -84,7 +84,8 @@ PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
}
StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy) {
const pybind11::object& argument, Device* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics) {
if (device == nullptr) {
TF_RET_CHECK(!pjrt_client_->local_devices().empty());
device = pjrt_client_->local_devices().front();
@ -111,9 +112,9 @@ StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyal(
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(
buffer, PjRtBuffer::FromHostBuffer(c->buf_ptr, c->shape, force_copy,
std::move(py_buffer_ref),
pjrt_client_.get(), device));
buffer, PjRtBuffer::FromHostBuffer(
c->buf_ptr, c->shape, host_buffer_semantics,
std::move(py_buffer_ref), pjrt_client_.get(), device));
}
auto traceback = Traceback::Get();
return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),

View File

@ -120,7 +120,8 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
}
StatusOr<std::unique_ptr<PyBuffer>> BufferFromPyal(
const pybind11::object& argument, Device* device, bool force_copy);
const pybind11::object& argument, Device* device, bool force_copy,
PjRtBuffer::HostBufferSemantics host_buffer_semantics);
StatusOr<std::unique_ptr<PyExecutable>> Compile(
const XlaComputation& computation, CompileOptions options);

View File

@ -509,6 +509,13 @@ PYBIND11_MODULE(xla_extension, m) {
.value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
.value("BFC", GpuAllocatorConfig::Kind::kBFC);
py::enum_<PjRtBuffer::HostBufferSemantics>(m, "HostBufferSemantics")
.value("IMMUTABLE_ONLY_DURING_CALL",
PjRtBuffer::HostBufferSemantics::kImmutableOnlyDuringCall)
.value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
PjRtBuffer::HostBufferSemantics::kImmutableUntilTransferCompletes)
.value("ZERO_COPY", PjRtBuffer::HostBufferSemantics::kZeroCopy);
py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
py_local_client.def_property_readonly("platform", &PyClient::platform_name)
.def("device_count", &PyClient::device_count)
@ -527,7 +534,9 @@ PYBIND11_MODULE(xla_extension, m) {
.def("create_host_to_device_channel_handle",
&PyClient::CreateHostToDeviceChannelHandle)
.def("buffer_from_pyval", &PyClient::BufferFromPyal, py::arg("argument"),
py::arg("device") = nullptr, py::arg("force_copy") = false)
py::arg("device") = nullptr, py::arg("force_copy") = false,
py::arg("host_buffer_semantics") =
PjRtBuffer::HostBufferSemantics::kZeroCopy)
.def("compile", &PyClient::Compile, py::arg("computation"),
py::arg("compile_options") = CompileOptions())
.def("heap_profile", &PyClient::HeapProfile);

View File

@ -304,6 +304,7 @@ def computation_count():
Device = _xla.Device
CompileOptions = _xla.CompileOptions
HostBufferSemantics = _xla.HostBufferSemantics
# An Executable is a C++ class that duck types with the following API:
# class Executable(object):

View File

@ -1986,7 +1986,8 @@ def TestFactory(xla_backend, cloud_tpu=False):
def testRoundTrip(self, dtype, shape):
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
x_ptr = x.__array_interface__["data"][0]
buffer = self.backend.buffer_from_pyval(x)
buffer = self.backend.buffer_from_pyval(
x, host_buffer_semantics=xla_client.HostBufferSemantics.ZERO_COPY)
y = np.array(buffer, copy=False)
y_ptr = y.__array_interface__["data"][0]
np.testing.assert_array_equal(x, y)
@ -1995,7 +1996,9 @@ def TestFactory(xla_backend, cloud_tpu=False):
self.assertTrue((x_ptr & 15) != 0 or x_ptr == y_ptr)
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
buffer2 = self.backend.buffer_from_pyval(x, force_copy=True)
during_call = xla_client.HostBufferSemantics.IMMUTABLE_ONLY_DURING_CALL
buffer2 = self.backend.buffer_from_pyval(
x, host_buffer_semantics=during_call)
z = np.array(buffer2, copy=False)
self.assertNotEqual(x.__array_interface__["data"][0],
z.__array_interface__["data"][0])