[XLA:Python] Add support for zero-copy NumPy -> XLA buffer transfers on CPU.
When creating a PyLocalBuffer from a NumPy array, we don't need to copy the array if it is in the right layout and alignment already. This saves us a buffer copy and some thread hops. Note that a copy still occurs if the input NumPy array was not in C-contiguous layout. A copy will also take place if the NumPy array was not 64-byte aligned. Remove support for tuple trees in Buffer.from_pyval. JAX doesn't use tuple trees, and this simplifies the code and the API. PiperOrigin-RevId: 292650189 Change-Id: If6f7b34507c0aebc6dcbee136049b66e66d426c3
This commit is contained in:
parent
419ebe51e0
commit
9fc5b891b8
@ -291,21 +291,46 @@ StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
||||
std::shared_ptr<void> leaves_reference,
|
||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
||||
const void* data, const Shape& shape, bool force_copy,
|
||||
std::shared_ptr<void> buffer_reference,
|
||||
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) {
|
||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
|
||||
VLOG(2) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString()
|
||||
VLOG(2) << "PyLocalBuffer::FromLiterals: shape: " << shape.ToString()
|
||||
<< " device: " << device->DebugString();
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device->GetLocalDeviceState());
|
||||
|
||||
// If we are on the host platform and the input buffer is sufficiently
|
||||
// aligned, we can simply point to the NumPy array's data without any further
|
||||
// copies. We require a 64-byte alignment because XLA may generate AVX512
|
||||
// code which requires it. Unfortunately NumPy's allocator doesn't align
|
||||
// quite as aggressively, so there's a high chance this test will fail.
|
||||
static constexpr int kMinimumAlignment = 64;
|
||||
if (!force_copy &&
|
||||
((absl::bit_cast<std::uintptr_t>(data) & (kMinimumAlignment - 1)) == 0) &&
|
||||
local_device->executor()->platform_kind() == se::PlatformKind::kHost) {
|
||||
std::function<void()> on_delete_callback =
|
||||
[buffer_reference{std::move(buffer_reference)}]() {
|
||||
// Frees buffer_reference.
|
||||
};
|
||||
se::DeviceMemoryBase buffer(const_cast<void*>(data),
|
||||
ShapeUtil::ByteSizeOf(shape));
|
||||
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
|
||||
/*allocator=*/nullptr, local_device->device_ordinal(),
|
||||
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
||||
/*definition_event=*/nullptr, std::move(on_delete_callback));
|
||||
return absl::make_unique<PyLocalBuffer>(
|
||||
shape, shape, std::move(device_buffer), std::move(client),
|
||||
std::move(device));
|
||||
}
|
||||
|
||||
TransferManager* transfer_manager =
|
||||
client->client()->backend().transfer_manager();
|
||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Shape compact_shape,
|
||||
transfer_manager->ChooseCompactLayoutForShape(tuple_shape));
|
||||
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||
transfer_manager->ChooseCompactLayoutForShape(shape));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
ScopedShapedBuffer scoped_buffer,
|
||||
transfer_manager->AllocateScopedShapedBuffer(
|
||||
@ -330,12 +355,9 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
definition_event);
|
||||
Shape on_device_shape = scoped_buffer.on_device_shape();
|
||||
|
||||
// TODO(makro): Use move capture once C++ 14 features are available.
|
||||
auto leaves = std::make_shared<std::vector<BorrowingLiteral>>(
|
||||
std::move(leaves_literals));
|
||||
auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
|
||||
compact_shape, on_device_shape, leaves,
|
||||
leaves_reference]() {
|
||||
shape, compact_shape, on_device_shape, data,
|
||||
buffer_reference{std::move(buffer_reference)}]() {
|
||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
|
||||
// report failures from a callback. However, the operations here are
|
||||
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
||||
@ -344,39 +366,27 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
compact_shape, on_device_shape, client->client()->platform());
|
||||
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
|
||||
local_device->host_to_device_stream(), buffer));
|
||||
std::vector<std::shared_ptr<void>> staging_buffers;
|
||||
staging_buffers.reserve(leaves->size());
|
||||
auto it = leaves->begin();
|
||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
||||
ShapeUtil::GetLeafShapes(compact_shape)) {
|
||||
CHECK(it != leaves->end());
|
||||
ShapedBuffer leaf(
|
||||
indexed_shape.shape,
|
||||
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
|
||||
client->client()->platform(), local_device->device_ordinal());
|
||||
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
|
||||
std::shared_ptr<void> staging_buffer;
|
||||
|
||||
// If applicable on the backend, stage the transfer via host memory
|
||||
// allocated via the host_memory_allocator. On GPU, this is pinned memory.
|
||||
if (client->host_memory_allocator()) {
|
||||
int64 size = it->size_bytes({});
|
||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||
std::shared_ptr<void> staging_buffer(ptr, [client](void* ptr) {
|
||||
client->host_memory_allocator()->DeallocateRaw(ptr);
|
||||
});
|
||||
std::memcpy(ptr, it->untyped_data({}), size);
|
||||
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
|
||||
it->shape());
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
local_device->host_to_device_stream(), literal, leaf));
|
||||
staging_buffers.push_back(std::move(staging_buffer));
|
||||
} else {
|
||||
// Otherwise, just transfer the literal.
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
local_device->host_to_device_stream(), *it, leaf));
|
||||
}
|
||||
++it;
|
||||
// If applicable on the backend, stage the transfer via host memory
|
||||
// allocated via the host_memory_allocator. On GPU, this is pinned memory.
|
||||
if (client->host_memory_allocator()) {
|
||||
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
|
||||
client->host_memory_allocator()->DeallocateRaw(ptr);
|
||||
});
|
||||
std::memcpy(ptr, data, size);
|
||||
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
|
||||
shape);
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
local_device->host_to_device_stream(), literal, buffer));
|
||||
} else {
|
||||
BorrowingLiteral literal(static_cast<const char*>(data), shape);
|
||||
// Otherwise, just transfer the literal.
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
local_device->host_to_device_stream(), literal, buffer));
|
||||
}
|
||||
|
||||
EventPool::Handle event =
|
||||
@ -397,7 +407,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
||||
|
||||
local_device->ThenRelease(
|
||||
local_device->host_to_device_stream(),
|
||||
std::make_pair(leaves_reference, std::move(staging_buffers)));
|
||||
std::make_pair(buffer_reference, std::move(staging_buffer)));
|
||||
};
|
||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||
return absl::make_unique<PyLocalBuffer>(
|
||||
|
@ -202,9 +202,14 @@ class PyLocalClient {
|
||||
// Thread-safe.
|
||||
class PyLocalBuffer {
|
||||
public:
|
||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals(
|
||||
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
||||
std::shared_ptr<void> leaves_reference,
|
||||
// If `force_copy` is true, forces a copy of the input buffer on CPU.
|
||||
// Otherwise the library is free to alias the output buffer with `data`.
|
||||
// `buffer_reference` is an optional shared pointer that should be kept alive
|
||||
// by the runtime as long as the contents of `data` may still be accessed by
|
||||
// the runtime (may be nullptr).
|
||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromHostBuffer(
|
||||
const void* data, const Shape& shape, bool force_copy,
|
||||
std::shared_ptr<void> buffer_reference,
|
||||
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
|
||||
|
||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace py = pybind11;
|
||||
@ -37,16 +39,27 @@ PythonRefManager::ManagedPyObjects::~ManagedPyObjects() {
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<PythonRefManager::ManagedPyObjects>
|
||||
PythonRefManager::ManageReference(py::object object) {
|
||||
return std::make_shared<ManagedPyObjects>(this,
|
||||
absl::Span<py::object>(&object, 1));
|
||||
}
|
||||
|
||||
std::shared_ptr<PythonRefManager::ManagedPyObjects>
|
||||
PythonRefManager::ManageReferences(absl::Span<py::object> objects) {
|
||||
return std::make_shared<ManagedPyObjects>(this, objects);
|
||||
}
|
||||
|
||||
void PythonRefManager::CollectGarbage() {
|
||||
// TODO(phawkins): ideally we would assert that the GIL is held, but there is
|
||||
// no API to do this across all Python versions.
|
||||
absl::MutexLock lock(&mu_);
|
||||
python_garbage_.clear();
|
||||
// TODO(phawkins): we should CHECK(PyGILState_Check());
|
||||
std::deque<pybind11::object> garbage;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
garbage.swap(python_garbage_);
|
||||
}
|
||||
// We defer deleting garbage until the lock is released. It's possible that
|
||||
// deleting garbage will lead to more Python garbage being added; if we held
|
||||
// the lock we would deadlock because absl::Mutex is not reentrant.
|
||||
}
|
||||
|
||||
PythonRefManager* GlobalPyRefManager() {
|
||||
|
@ -62,6 +62,7 @@ class PythonRefManager {
|
||||
// Creates a managed std::shared_ptr to an object. When the shared_ptr is
|
||||
// destroyed, the reference to 'object' will be added to python_garbage_,
|
||||
// and collected next time CollectGarbage() is called.
|
||||
std::shared_ptr<ManagedPyObjects> ManageReference(pybind11::object object);
|
||||
std::shared_ptr<ManagedPyObjects> ManageReferences(
|
||||
absl::Span<pybind11::object> objects);
|
||||
|
||||
|
@ -81,7 +81,7 @@ class TpuBackend(xla_client.Backend):
|
||||
def host_id(self):
|
||||
return self.client.host_id()
|
||||
|
||||
def buffer_from_pyval(self, pyval, device=None):
|
||||
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
|
||||
if device is None:
|
||||
device = self.client.local_devices()[0]
|
||||
return _tpu_client.PyTpuBuffer.from_python(pyval, self.client, device)
|
||||
|
@ -96,7 +96,7 @@ std::vector<int64> IntSequenceToVector(const pybind11::object& sequence);
|
||||
// xla::BorrowingLiteral. Converts a Python array-like object into a buffer
|
||||
// pointer and shape.
|
||||
struct CastToArrayResult {
|
||||
pybind11::array array; // Holds a reference to the array to keep it alive.
|
||||
pybind11::object array; // Holds a reference to the array to keep it alive.
|
||||
const char* buf_ptr;
|
||||
xla::Shape shape;
|
||||
};
|
||||
|
@ -655,8 +655,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
"from_python",
|
||||
[](const pybind11::object& argument,
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
std::shared_ptr<Device> device)
|
||||
-> StatusOr<std::unique_ptr<PyLocalBuffer>> {
|
||||
std::shared_ptr<Device> device,
|
||||
bool force_copy) -> StatusOr<std::unique_ptr<PyLocalBuffer>> {
|
||||
CHECK(device != nullptr);
|
||||
auto iter = client->id_to_device().find(device->id());
|
||||
if (iter->second != device) {
|
||||
@ -665,23 +665,24 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
device->DebugString(), client->platform_name());
|
||||
}
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
|
||||
absl::optional<CastToArrayResult> c = CastToArray(argument);
|
||||
if (!c) {
|
||||
return InvalidArgument("from_python argument must be an array.");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
|
||||
GetPythonBufferTree(argument));
|
||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||
GlobalPyRefManager()->ManageReferences(
|
||||
absl::MakeSpan(tree.arrays));
|
||||
tree.arrays.clear();
|
||||
|
||||
std::vector<BorrowingLiteral> leaves;
|
||||
leaves.insert(leaves.end(),
|
||||
std::make_move_iterator(tree.leaves.begin()),
|
||||
std::make_move_iterator(tree.leaves.end()));
|
||||
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
||||
|
||||
py::gil_scoped_release gil_release;
|
||||
return PyLocalBuffer::FromLiterals(
|
||||
std::move(leaves), tree.shape, std::move(py_buffer_ref),
|
||||
return PyLocalBuffer::FromHostBuffer(
|
||||
c->buf_ptr, c->shape, force_copy, std::move(py_buffer_ref),
|
||||
std::move(client), std::move(device));
|
||||
})
|
||||
},
|
||||
py::arg("argument"), py::arg("client"), py::arg("device"),
|
||||
py::arg("force_copy") = false)
|
||||
.def_static("make_tuple",
|
||||
[](const std::vector<PyLocalBuffer*> buffers,
|
||||
std::shared_ptr<PyLocalClient> client,
|
||||
|
@ -71,7 +71,7 @@ class Backend(object, metaclass=abc.ABCMeta):
|
||||
"""Returns the integer ID of this host."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def buffer_from_pyval(self, pyval, device=None):
|
||||
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
|
||||
"""Allocates a fresh buffer and populates it with `pyval`."""
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -129,10 +129,11 @@ class LocalBackend(Backend):
|
||||
def host_id(self):
|
||||
return self.client.host_id()
|
||||
|
||||
def buffer_from_pyval(self, pyval, device=None):
|
||||
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
|
||||
if device is None:
|
||||
device = self.local_devices()[0]
|
||||
return _xla.PyLocalBuffer.from_python(pyval, self.client, device)
|
||||
return _xla.PyLocalBuffer.from_python(pyval, self.client, device,
|
||||
force_copy)
|
||||
|
||||
def make_tuple(self, c_buffers, device):
|
||||
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device)
|
||||
@ -391,10 +392,10 @@ class Buffer(object):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_pyval(pyval, device=None, backend=None):
|
||||
def from_pyval(pyval, device=None, backend=None, force_copy=False):
|
||||
"""Copies the `pyval` to a freshly allocated on-device buffer."""
|
||||
backend = backend or get_local_backend()
|
||||
return backend.buffer_from_pyval(pyval, device)
|
||||
return backend.buffer_from_pyval(pyval, device, force_copy=force_copy)
|
||||
|
||||
@staticmethod
|
||||
def make_tuple(buffers, device, backend=None):
|
||||
|
@ -471,15 +471,16 @@ class BufferTest(ComputationTest):
|
||||
compiled_c.Execute([arg_buffer])
|
||||
|
||||
def testDestructureTupleEmpty(self):
|
||||
t = ()
|
||||
local_buffer = xla_client.Buffer.from_pyval(t)
|
||||
device = xla_client.get_local_backend().devices()[0]
|
||||
local_buffer = xla_client.Buffer.make_tuple((), device=device)
|
||||
pieces = local_buffer.destructure()
|
||||
self.assertFalse(local_buffer.is_deleted())
|
||||
self.assertEmpty(pieces)
|
||||
|
||||
def testDestructureTupleOneArrayElement(self):
|
||||
t = (np.array([1, 2, 3, 4], dtype=np.int32),)
|
||||
local_buffer = xla_client.Buffer.from_pyval(t)
|
||||
device = xla_client.get_local_backend().devices()[0]
|
||||
t = xla_client.Buffer.from_pyval(np.array([1, 2, 3, 4], dtype=np.int32))
|
||||
local_buffer = xla_client.Buffer.make_tuple((t,), device)
|
||||
pieces = local_buffer.destructure()
|
||||
self.assertFalse(local_buffer.is_deleted())
|
||||
self.assertLen(pieces, 1)
|
||||
@ -489,11 +490,13 @@ class BufferTest(ComputationTest):
|
||||
np.testing.assert_equal(want, got)
|
||||
|
||||
def testDestructureTupleTwoArrayElementDifferentType(self):
|
||||
device = xla_client.get_local_backend().devices()[0]
|
||||
t = (
|
||||
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
|
||||
np.array([2, 3, 4, 5], dtype=np.int32),
|
||||
xla_client.Buffer.from_pyval(
|
||||
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)),
|
||||
xla_client.Buffer.from_pyval(np.array([2, 3, 4, 5], dtype=np.int32)),
|
||||
)
|
||||
local_buffer = xla_client.Buffer.from_pyval(t)
|
||||
local_buffer = xla_client.Buffer.make_tuple(t, device)
|
||||
# Run the test twice to verify that the original tuple buffer remains valid
|
||||
# even after destructuring.
|
||||
for _ in range(2):
|
||||
@ -509,8 +512,12 @@ class BufferTest(ComputationTest):
|
||||
np.testing.assert_equal(want, got)
|
||||
|
||||
def testDestructureTupleNested(self):
|
||||
t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5]))
|
||||
local_buffer = xla_client.Buffer.from_pyval(t)
|
||||
device = xla_client.get_local_backend().devices()[0]
|
||||
t = xla_client.Buffer.make_tuple(
|
||||
(xla_client.Buffer.from_pyval(NumpyArrayF32([1.0, 2.0])),
|
||||
xla_client.Buffer.from_pyval(NumpyArrayS32([3, 4]))), device)
|
||||
local_buffer = xla_client.Buffer.make_tuple(
|
||||
(t, xla_client.Buffer.from_pyval(NumpyArrayS32([5]))), device)
|
||||
pieces = local_buffer.destructure()
|
||||
self.assertFalse(local_buffer.is_deleted())
|
||||
self.assertLen(pieces, 2)
|
||||
@ -2119,12 +2126,20 @@ class BufferProtocolTest(parameterized.TestCase):
|
||||
} for dtype in standard_dtypes for shape in testcase_shapes)
|
||||
def testRoundTrip(self, dtype, shape):
|
||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||
x_ptr = x.__array_interface__["data"][0]
|
||||
backend = xla_client.get_local_backend("cpu")
|
||||
buffer = xla_client.Buffer.from_pyval(x, backend=backend)
|
||||
y = np.array(buffer, copy=False)
|
||||
y_ptr = y.__array_interface__["data"][0]
|
||||
np.testing.assert_array_equal(x, y)
|
||||
self.assertEqual(y.__array_interface__["data"][0],
|
||||
buffer.unsafe_buffer_pointer())
|
||||
# If the input was sufficiently aligned, the input and output should alias.
|
||||
self.assertTrue((x_ptr & 63) != 0 or x_ptr == y_ptr)
|
||||
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
|
||||
|
||||
buffer2 = xla_client.Buffer.from_pyval(x, backend=backend, force_copy=True)
|
||||
z = np.array(buffer2, copy=False)
|
||||
self.assertNotEqual(x.__array_interface__["data"][0],
|
||||
z.__array_interface__["data"][0])
|
||||
|
||||
def testDeleteWithActiveView(self):
|
||||
x = np.random.randn(20, 10)
|
||||
|
Loading…
Reference in New Issue
Block a user