[XLA:Python] Add support for zero-copy NumPy -> XLA buffer transfers on CPU.

When creating a PyLocalBuffer from a NumPy array, we don't need to copy the array if it is in the right layout and alignment already. This saves us a buffer copy and some thread hops.

Note that a copy still occurs if the input NumPy array was not in C-contiguous layout. A copy will also take place if the NumPy array was not 64-byte aligned.

Remove support for tuple trees in Buffer.from_pyval. JAX doesn't use tuple trees, and this simplifies the code and the API.

PiperOrigin-RevId: 292650189
Change-Id: If6f7b34507c0aebc6dcbee136049b66e66d426c3
This commit is contained in:
Peter Hawkins 2020-01-31 17:54:34 -08:00 committed by TensorFlower Gardener
parent 419ebe51e0
commit 9fc5b891b8
9 changed files with 129 additions and 83 deletions

View File

@ -291,21 +291,46 @@ StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
}
/* static */
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
std::shared_ptr<void> leaves_reference,
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
const void* data, const Shape& shape, bool force_copy,
std::shared_ptr<void> buffer_reference,
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) {
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
VLOG(2) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString()
VLOG(2) << "PyLocalBuffer::FromLiterals: shape: " << shape.ToString()
<< " device: " << device->DebugString();
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device->GetLocalDeviceState());
// If we are on the host platform and the input buffer is sufficiently
// aligned, we can simply point to the NumPy array's data without any further
// copies. We require a 64-byte alignment because XLA may generate AVX512
// code which requires it. Unfortunately NumPy's allocator doesn't align
// quite as aggressively, so there's a high chance this test will fail.
static constexpr int kMinimumAlignment = 64;
if (!force_copy &&
((absl::bit_cast<std::uintptr_t>(data) & (kMinimumAlignment - 1)) == 0) &&
local_device->executor()->platform_kind() == se::PlatformKind::kHost) {
std::function<void()> on_delete_callback =
[buffer_reference{std::move(buffer_reference)}]() {
// Frees buffer_reference.
};
se::DeviceMemoryBase buffer(const_cast<void*>(data),
ShapeUtil::ByteSizeOf(shape));
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
/*allocator=*/nullptr, local_device->device_ordinal(),
std::initializer_list<se::DeviceMemoryBase>{buffer},
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
/*definition_event=*/nullptr, std::move(on_delete_callback));
return absl::make_unique<PyLocalBuffer>(
shape, shape, std::move(device_buffer), std::move(client),
std::move(device));
}
TransferManager* transfer_manager =
client->client()->backend().transfer_manager();
se::DeviceMemoryAllocator* allocator = client->allocator();
TF_ASSIGN_OR_RETURN(
Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(tuple_shape));
TF_ASSIGN_OR_RETURN(Shape compact_shape,
transfer_manager->ChooseCompactLayoutForShape(shape));
TF_ASSIGN_OR_RETURN(
ScopedShapedBuffer scoped_buffer,
transfer_manager->AllocateScopedShapedBuffer(
@ -330,12 +355,9 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
definition_event);
Shape on_device_shape = scoped_buffer.on_device_shape();
// TODO(makro): Use move capture once C++ 14 features are available.
auto leaves = std::make_shared<std::vector<BorrowingLiteral>>(
std::move(leaves_literals));
auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
compact_shape, on_device_shape, leaves,
leaves_reference]() {
shape, compact_shape, on_device_shape, data,
buffer_reference{std::move(buffer_reference)}]() {
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
// report failures from a callback. However, the operations here are
// unlikely to fail and not recoverable even if we were to fail: DMAs to
@ -344,39 +366,27 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
compact_shape, on_device_shape, client->client()->platform());
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
local_device->host_to_device_stream(), buffer));
std::vector<std::shared_ptr<void>> staging_buffers;
staging_buffers.reserve(leaves->size());
auto it = leaves->begin();
for (const ShapeUtil::IndexedShape& indexed_shape :
ShapeUtil::GetLeafShapes(compact_shape)) {
CHECK(it != leaves->end());
ShapedBuffer leaf(
indexed_shape.shape,
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
client->client()->platform(), local_device->device_ordinal());
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
std::shared_ptr<void> staging_buffer;
// If applicable on the backend, stage the transfer via host memory
// allocated via the host_memory_allocator. On GPU, this is pinned memory.
if (client->host_memory_allocator()) {
int64 size = it->size_bytes({});
void* ptr = client->host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size);
std::shared_ptr<void> staging_buffer(ptr, [client](void* ptr) {
client->host_memory_allocator()->DeallocateRaw(ptr);
});
std::memcpy(ptr, it->untyped_data({}), size);
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
it->shape());
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
local_device->host_to_device_stream(), literal, leaf));
staging_buffers.push_back(std::move(staging_buffer));
} else {
// Otherwise, just transfer the literal.
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
local_device->host_to_device_stream(), *it, leaf));
}
++it;
// If applicable on the backend, stage the transfer via host memory
// allocated via the host_memory_allocator. On GPU, this is pinned memory.
if (client->host_memory_allocator()) {
int64 size = ShapeUtil::ByteSizeOf(shape);
void* ptr = client->host_memory_allocator()->AllocateRaw(
tensorflow::Allocator::kAllocatorAlignment, size);
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
client->host_memory_allocator()->DeallocateRaw(ptr);
});
std::memcpy(ptr, data, size);
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
shape);
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
local_device->host_to_device_stream(), literal, buffer));
} else {
BorrowingLiteral literal(static_cast<const char*>(data), shape);
// Otherwise, just transfer the literal.
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
local_device->host_to_device_stream(), literal, buffer));
}
EventPool::Handle event =
@ -397,7 +407,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
local_device->ThenRelease(
local_device->host_to_device_stream(),
std::make_pair(leaves_reference, std::move(staging_buffers)));
std::make_pair(buffer_reference, std::move(staging_buffer)));
};
client->h2d_transfer_pool()->Schedule(transfer_h2d);
return absl::make_unique<PyLocalBuffer>(

View File

@ -202,9 +202,14 @@ class PyLocalClient {
// Thread-safe.
class PyLocalBuffer {
public:
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals(
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
std::shared_ptr<void> leaves_reference,
// If `force_copy` is true, forces a copy of the input buffer on CPU.
// Otherwise the library is free to alias the output buffer with `data`.
// `buffer_reference` is an optional shared pointer that should be kept alive
// by the runtime as long as the contents of `data` may still be accessed by
// the runtime (may be nullptr).
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromHostBuffer(
const void* data, const Shape& shape, bool force_copy,
std::shared_ptr<void> buffer_reference,
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(

View File

@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
#include "absl/container/inlined_vector.h"
namespace xla {
namespace py = pybind11;
@ -37,16 +39,27 @@ PythonRefManager::ManagedPyObjects::~ManagedPyObjects() {
}
}
std::shared_ptr<PythonRefManager::ManagedPyObjects>
PythonRefManager::ManageReference(py::object object) {
return std::make_shared<ManagedPyObjects>(this,
absl::Span<py::object>(&object, 1));
}
std::shared_ptr<PythonRefManager::ManagedPyObjects>
PythonRefManager::ManageReferences(absl::Span<py::object> objects) {
return std::make_shared<ManagedPyObjects>(this, objects);
}
void PythonRefManager::CollectGarbage() {
// TODO(phawkins): ideally we would assert that the GIL is held, but there is
// no API to do this across all Python versions.
absl::MutexLock lock(&mu_);
python_garbage_.clear();
// TODO(phawkins): we should CHECK(PyGILState_Check());
std::deque<pybind11::object> garbage;
{
absl::MutexLock lock(&mu_);
garbage.swap(python_garbage_);
}
// We defer deleting garbage until the lock is released. It's possible that
// deleting garbage will lead to more Python garbage being added; if we held
// the lock we would deadlock because absl::Mutex is not reentrant.
}
PythonRefManager* GlobalPyRefManager() {

View File

@ -62,6 +62,7 @@ class PythonRefManager {
// Creates a managed std::shared_ptr to an object. When the shared_ptr is
// destroyed, the reference to 'object' will be added to python_garbage_,
// and collected next time CollectGarbage() is called.
std::shared_ptr<ManagedPyObjects> ManageReference(pybind11::object object);
std::shared_ptr<ManagedPyObjects> ManageReferences(
absl::Span<pybind11::object> objects);

View File

@ -81,7 +81,7 @@ class TpuBackend(xla_client.Backend):
def host_id(self):
return self.client.host_id()
def buffer_from_pyval(self, pyval, device=None):
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
if device is None:
device = self.client.local_devices()[0]
return _tpu_client.PyTpuBuffer.from_python(pyval, self.client, device)

View File

@ -96,7 +96,7 @@ std::vector<int64> IntSequenceToVector(const pybind11::object& sequence);
// xla::BorrowingLiteral. Converts a Python array-like object into a buffer
// pointer and shape.
struct CastToArrayResult {
pybind11::array array; // Holds a reference to the array to keep it alive.
pybind11::object array; // Holds a reference to the array to keep it alive.
const char* buf_ptr;
xla::Shape shape;
};

View File

@ -655,8 +655,8 @@ PYBIND11_MODULE(xla_extension, m) {
"from_python",
[](const pybind11::object& argument,
std::shared_ptr<PyLocalClient> client,
std::shared_ptr<Device> device)
-> StatusOr<std::unique_ptr<PyLocalBuffer>> {
std::shared_ptr<Device> device,
bool force_copy) -> StatusOr<std::unique_ptr<PyLocalBuffer>> {
CHECK(device != nullptr);
auto iter = client->id_to_device().find(device->id());
if (iter->second != device) {
@ -665,23 +665,24 @@ PYBIND11_MODULE(xla_extension, m) {
device->DebugString(), client->platform_name());
}
GlobalPyRefManager()->CollectGarbage();
absl::optional<CastToArrayResult> c = CastToArray(argument);
if (!c) {
return InvalidArgument("from_python argument must be an array.");
}
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
GetPythonBufferTree(argument));
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
GlobalPyRefManager()->ManageReferences(
absl::MakeSpan(tree.arrays));
tree.arrays.clear();
std::vector<BorrowingLiteral> leaves;
leaves.insert(leaves.end(),
std::make_move_iterator(tree.leaves.begin()),
std::make_move_iterator(tree.leaves.end()));
GlobalPyRefManager()->ManageReference(std::move(c->array));
py::gil_scoped_release gil_release;
return PyLocalBuffer::FromLiterals(
std::move(leaves), tree.shape, std::move(py_buffer_ref),
return PyLocalBuffer::FromHostBuffer(
c->buf_ptr, c->shape, force_copy, std::move(py_buffer_ref),
std::move(client), std::move(device));
})
},
py::arg("argument"), py::arg("client"), py::arg("device"),
py::arg("force_copy") = false)
.def_static("make_tuple",
[](const std::vector<PyLocalBuffer*> buffers,
std::shared_ptr<PyLocalClient> client,

View File

@ -71,7 +71,7 @@ class Backend(object, metaclass=abc.ABCMeta):
"""Returns the integer ID of this host."""
@abc.abstractmethod
def buffer_from_pyval(self, pyval, device=None):
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
"""Allocates a fresh buffer and populates it with `pyval`."""
@abc.abstractmethod
@ -129,10 +129,11 @@ class LocalBackend(Backend):
def host_id(self):
return self.client.host_id()
def buffer_from_pyval(self, pyval, device=None):
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
if device is None:
device = self.local_devices()[0]
return _xla.PyLocalBuffer.from_python(pyval, self.client, device)
return _xla.PyLocalBuffer.from_python(pyval, self.client, device,
force_copy)
def make_tuple(self, c_buffers, device):
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device)
@ -391,10 +392,10 @@ class Buffer(object):
"""
@staticmethod
def from_pyval(pyval, device=None, backend=None):
def from_pyval(pyval, device=None, backend=None, force_copy=False):
"""Copies the `pyval` to a freshly allocated on-device buffer."""
backend = backend or get_local_backend()
return backend.buffer_from_pyval(pyval, device)
return backend.buffer_from_pyval(pyval, device, force_copy=force_copy)
@staticmethod
def make_tuple(buffers, device, backend=None):

View File

@ -471,15 +471,16 @@ class BufferTest(ComputationTest):
compiled_c.Execute([arg_buffer])
def testDestructureTupleEmpty(self):
t = ()
local_buffer = xla_client.Buffer.from_pyval(t)
device = xla_client.get_local_backend().devices()[0]
local_buffer = xla_client.Buffer.make_tuple((), device=device)
pieces = local_buffer.destructure()
self.assertFalse(local_buffer.is_deleted())
self.assertEmpty(pieces)
def testDestructureTupleOneArrayElement(self):
t = (np.array([1, 2, 3, 4], dtype=np.int32),)
local_buffer = xla_client.Buffer.from_pyval(t)
device = xla_client.get_local_backend().devices()[0]
t = xla_client.Buffer.from_pyval(np.array([1, 2, 3, 4], dtype=np.int32))
local_buffer = xla_client.Buffer.make_tuple((t,), device)
pieces = local_buffer.destructure()
self.assertFalse(local_buffer.is_deleted())
self.assertLen(pieces, 1)
@ -489,11 +490,13 @@ class BufferTest(ComputationTest):
np.testing.assert_equal(want, got)
def testDestructureTupleTwoArrayElementDifferentType(self):
device = xla_client.get_local_backend().devices()[0]
t = (
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
np.array([2, 3, 4, 5], dtype=np.int32),
xla_client.Buffer.from_pyval(
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)),
xla_client.Buffer.from_pyval(np.array([2, 3, 4, 5], dtype=np.int32)),
)
local_buffer = xla_client.Buffer.from_pyval(t)
local_buffer = xla_client.Buffer.make_tuple(t, device)
# Run the test twice to verify that the original tuple buffer remains valid
# even after destructuring.
for _ in range(2):
@ -509,8 +512,12 @@ class BufferTest(ComputationTest):
np.testing.assert_equal(want, got)
def testDestructureTupleNested(self):
t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5]))
local_buffer = xla_client.Buffer.from_pyval(t)
device = xla_client.get_local_backend().devices()[0]
t = xla_client.Buffer.make_tuple(
(xla_client.Buffer.from_pyval(NumpyArrayF32([1.0, 2.0])),
xla_client.Buffer.from_pyval(NumpyArrayS32([3, 4]))), device)
local_buffer = xla_client.Buffer.make_tuple(
(t, xla_client.Buffer.from_pyval(NumpyArrayS32([5]))), device)
pieces = local_buffer.destructure()
self.assertFalse(local_buffer.is_deleted())
self.assertLen(pieces, 2)
@ -2119,12 +2126,20 @@ class BufferProtocolTest(parameterized.TestCase):
} for dtype in standard_dtypes for shape in testcase_shapes)
def testRoundTrip(self, dtype, shape):
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
x_ptr = x.__array_interface__["data"][0]
backend = xla_client.get_local_backend("cpu")
buffer = xla_client.Buffer.from_pyval(x, backend=backend)
y = np.array(buffer, copy=False)
y_ptr = y.__array_interface__["data"][0]
np.testing.assert_array_equal(x, y)
self.assertEqual(y.__array_interface__["data"][0],
buffer.unsafe_buffer_pointer())
# If the input was sufficiently aligned, the input and output should alias.
self.assertTrue((x_ptr & 63) != 0 or x_ptr == y_ptr)
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
buffer2 = xla_client.Buffer.from_pyval(x, backend=backend, force_copy=True)
z = np.array(buffer2, copy=False)
self.assertNotEqual(x.__array_interface__["data"][0],
z.__array_interface__["data"][0])
def testDeleteWithActiveView(self):
x = np.random.randn(20, 10)