[XLA:Python] Add support for zero-copy NumPy -> XLA buffer transfers on CPU.
When creating a PyLocalBuffer from a NumPy array, we don't need to copy the array if it is in the right layout and alignment already. This saves us a buffer copy and some thread hops. Note that a copy still occurs if the input NumPy array was not in C-contiguous layout. A copy will also take place if the NumPy array was not 64-byte aligned. Remove support for tuple trees in Buffer.from_pyval. JAX doesn't use tuple trees, and this simplifies the code and the API. PiperOrigin-RevId: 292650189 Change-Id: If6f7b34507c0aebc6dcbee136049b66e66d426c3
This commit is contained in:
parent
419ebe51e0
commit
9fc5b891b8
@ -291,21 +291,46 @@ StatusOr<DeviceAssignment> PyLocalClient::GetDefaultDeviceAssignment(
|
|||||||
}
|
}
|
||||||
|
|
||||||
/* static */
|
/* static */
|
||||||
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromHostBuffer(
|
||||||
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
const void* data, const Shape& shape, bool force_copy,
|
||||||
std::shared_ptr<void> leaves_reference,
|
std::shared_ptr<void> buffer_reference,
|
||||||
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) {
|
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device) {
|
||||||
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
|
tensorflow::profiler::TraceMe traceme("PyLocalBuffer::FromLiterals");
|
||||||
VLOG(2) << "PyLocalBuffer::FromLiterals: shape: " << tuple_shape.ToString()
|
VLOG(2) << "PyLocalBuffer::FromLiterals: shape: " << shape.ToString()
|
||||||
<< " device: " << device->DebugString();
|
<< " device: " << device->DebugString();
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||||
device->GetLocalDeviceState());
|
device->GetLocalDeviceState());
|
||||||
|
|
||||||
|
// If we are on the host platform and the input buffer is sufficiently
|
||||||
|
// aligned, we can simply point to the NumPy array's data without any further
|
||||||
|
// copies. We require a 64-byte alignment because XLA may generate AVX512
|
||||||
|
// code which requires it. Unfortunately NumPy's allocator doesn't align
|
||||||
|
// quite as aggressively, so there's a high chance this test will fail.
|
||||||
|
static constexpr int kMinimumAlignment = 64;
|
||||||
|
if (!force_copy &&
|
||||||
|
((absl::bit_cast<std::uintptr_t>(data) & (kMinimumAlignment - 1)) == 0) &&
|
||||||
|
local_device->executor()->platform_kind() == se::PlatformKind::kHost) {
|
||||||
|
std::function<void()> on_delete_callback =
|
||||||
|
[buffer_reference{std::move(buffer_reference)}]() {
|
||||||
|
// Frees buffer_reference.
|
||||||
|
};
|
||||||
|
se::DeviceMemoryBase buffer(const_cast<void*>(data),
|
||||||
|
ShapeUtil::ByteSizeOf(shape));
|
||||||
|
auto device_buffer = std::make_shared<SharedDeviceBuffer>(
|
||||||
|
/*allocator=*/nullptr, local_device->device_ordinal(),
|
||||||
|
std::initializer_list<se::DeviceMemoryBase>{buffer},
|
||||||
|
/*children=*/std::vector<std::shared_ptr<SharedDeviceBuffer>>{},
|
||||||
|
/*definition_event=*/nullptr, std::move(on_delete_callback));
|
||||||
|
return absl::make_unique<PyLocalBuffer>(
|
||||||
|
shape, shape, std::move(device_buffer), std::move(client),
|
||||||
|
std::move(device));
|
||||||
|
}
|
||||||
|
|
||||||
TransferManager* transfer_manager =
|
TransferManager* transfer_manager =
|
||||||
client->client()->backend().transfer_manager();
|
client->client()->backend().transfer_manager();
|
||||||
se::DeviceMemoryAllocator* allocator = client->allocator();
|
se::DeviceMemoryAllocator* allocator = client->allocator();
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(Shape compact_shape,
|
||||||
Shape compact_shape,
|
transfer_manager->ChooseCompactLayoutForShape(shape));
|
||||||
transfer_manager->ChooseCompactLayoutForShape(tuple_shape));
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
ScopedShapedBuffer scoped_buffer,
|
ScopedShapedBuffer scoped_buffer,
|
||||||
transfer_manager->AllocateScopedShapedBuffer(
|
transfer_manager->AllocateScopedShapedBuffer(
|
||||||
@ -330,12 +355,9 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
|||||||
definition_event);
|
definition_event);
|
||||||
Shape on_device_shape = scoped_buffer.on_device_shape();
|
Shape on_device_shape = scoped_buffer.on_device_shape();
|
||||||
|
|
||||||
// TODO(makro): Use move capture once C++ 14 features are available.
|
|
||||||
auto leaves = std::make_shared<std::vector<BorrowingLiteral>>(
|
|
||||||
std::move(leaves_literals));
|
|
||||||
auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
|
auto transfer_h2d = [client, transfer_manager, local_device, device_buffer,
|
||||||
compact_shape, on_device_shape, leaves,
|
shape, compact_shape, on_device_shape, data,
|
||||||
leaves_reference]() {
|
buffer_reference{std::move(buffer_reference)}]() {
|
||||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
|
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way to
|
||||||
// report failures from a callback. However, the operations here are
|
// report failures from a callback. However, the operations here are
|
||||||
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
// unlikely to fail and not recoverable even if we were to fail: DMAs to
|
||||||
@ -344,39 +366,27 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
|||||||
compact_shape, on_device_shape, client->client()->platform());
|
compact_shape, on_device_shape, client->client()->platform());
|
||||||
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
|
TF_CHECK_OK(transfer_manager->WriteTupleIndexTablesAsync(
|
||||||
local_device->host_to_device_stream(), buffer));
|
local_device->host_to_device_stream(), buffer));
|
||||||
std::vector<std::shared_ptr<void>> staging_buffers;
|
std::shared_ptr<void> staging_buffer;
|
||||||
staging_buffers.reserve(leaves->size());
|
|
||||||
auto it = leaves->begin();
|
|
||||||
for (const ShapeUtil::IndexedShape& indexed_shape :
|
|
||||||
ShapeUtil::GetLeafShapes(compact_shape)) {
|
|
||||||
CHECK(it != leaves->end());
|
|
||||||
ShapedBuffer leaf(
|
|
||||||
indexed_shape.shape,
|
|
||||||
transfer_manager->HostShapeToDeviceShape(indexed_shape.shape),
|
|
||||||
client->client()->platform(), local_device->device_ordinal());
|
|
||||||
leaf.buffers().CopySubtreeFrom(buffer.buffers(), indexed_shape.index, {});
|
|
||||||
|
|
||||||
// If applicable on the backend, stage the transfer via host memory
|
// If applicable on the backend, stage the transfer via host memory
|
||||||
// allocated via the host_memory_allocator. On GPU, this is pinned memory.
|
// allocated via the host_memory_allocator. On GPU, this is pinned memory.
|
||||||
if (client->host_memory_allocator()) {
|
if (client->host_memory_allocator()) {
|
||||||
int64 size = it->size_bytes({});
|
int64 size = ShapeUtil::ByteSizeOf(shape);
|
||||||
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
void* ptr = client->host_memory_allocator()->AllocateRaw(
|
||||||
tensorflow::Allocator::kAllocatorAlignment, size);
|
tensorflow::Allocator::kAllocatorAlignment, size);
|
||||||
std::shared_ptr<void> staging_buffer(ptr, [client](void* ptr) {
|
staging_buffer = std::shared_ptr<void>(ptr, [client](void* ptr) {
|
||||||
client->host_memory_allocator()->DeallocateRaw(ptr);
|
client->host_memory_allocator()->DeallocateRaw(ptr);
|
||||||
});
|
});
|
||||||
std::memcpy(ptr, it->untyped_data({}), size);
|
std::memcpy(ptr, data, size);
|
||||||
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
|
BorrowingLiteral literal(static_cast<const char*>(staging_buffer.get()),
|
||||||
it->shape());
|
shape);
|
||||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||||
local_device->host_to_device_stream(), literal, leaf));
|
local_device->host_to_device_stream(), literal, buffer));
|
||||||
staging_buffers.push_back(std::move(staging_buffer));
|
|
||||||
} else {
|
} else {
|
||||||
|
BorrowingLiteral literal(static_cast<const char*>(data), shape);
|
||||||
// Otherwise, just transfer the literal.
|
// Otherwise, just transfer the literal.
|
||||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||||
local_device->host_to_device_stream(), *it, leaf));
|
local_device->host_to_device_stream(), literal, buffer));
|
||||||
}
|
|
||||||
++it;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
EventPool::Handle event =
|
EventPool::Handle event =
|
||||||
@ -397,7 +407,7 @@ StatusOr<std::unique_ptr<PyLocalBuffer>> PyLocalBuffer::FromLiterals(
|
|||||||
|
|
||||||
local_device->ThenRelease(
|
local_device->ThenRelease(
|
||||||
local_device->host_to_device_stream(),
|
local_device->host_to_device_stream(),
|
||||||
std::make_pair(leaves_reference, std::move(staging_buffers)));
|
std::make_pair(buffer_reference, std::move(staging_buffer)));
|
||||||
};
|
};
|
||||||
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
client->h2d_transfer_pool()->Schedule(transfer_h2d);
|
||||||
return absl::make_unique<PyLocalBuffer>(
|
return absl::make_unique<PyLocalBuffer>(
|
||||||
|
@ -202,9 +202,14 @@ class PyLocalClient {
|
|||||||
// Thread-safe.
|
// Thread-safe.
|
||||||
class PyLocalBuffer {
|
class PyLocalBuffer {
|
||||||
public:
|
public:
|
||||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromLiterals(
|
// If `force_copy` is true, forces a copy of the input buffer on CPU.
|
||||||
std::vector<BorrowingLiteral> leaves_literals, const Shape& tuple_shape,
|
// Otherwise the library is free to alias the output buffer with `data`.
|
||||||
std::shared_ptr<void> leaves_reference,
|
// `buffer_reference` is an optional shared pointer that should be kept alive
|
||||||
|
// by the runtime as long as the contents of `data` may still be accessed by
|
||||||
|
// the runtime (may be nullptr).
|
||||||
|
static StatusOr<std::unique_ptr<PyLocalBuffer>> FromHostBuffer(
|
||||||
|
const void* data, const Shape& shape, bool force_copy,
|
||||||
|
std::shared_ptr<void> buffer_reference,
|
||||||
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
|
std::shared_ptr<PyLocalClient> client, std::shared_ptr<Device> device);
|
||||||
|
|
||||||
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
static StatusOr<std::unique_ptr<PyLocalBuffer>> MakeTuple(
|
||||||
|
@ -15,6 +15,8 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
#include "tensorflow/compiler/xla/python/python_ref_manager.h"
|
||||||
|
|
||||||
|
#include "absl/container/inlined_vector.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
@ -37,16 +39,27 @@ PythonRefManager::ManagedPyObjects::~ManagedPyObjects() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<PythonRefManager::ManagedPyObjects>
|
||||||
|
PythonRefManager::ManageReference(py::object object) {
|
||||||
|
return std::make_shared<ManagedPyObjects>(this,
|
||||||
|
absl::Span<py::object>(&object, 1));
|
||||||
|
}
|
||||||
|
|
||||||
std::shared_ptr<PythonRefManager::ManagedPyObjects>
|
std::shared_ptr<PythonRefManager::ManagedPyObjects>
|
||||||
PythonRefManager::ManageReferences(absl::Span<py::object> objects) {
|
PythonRefManager::ManageReferences(absl::Span<py::object> objects) {
|
||||||
return std::make_shared<ManagedPyObjects>(this, objects);
|
return std::make_shared<ManagedPyObjects>(this, objects);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PythonRefManager::CollectGarbage() {
|
void PythonRefManager::CollectGarbage() {
|
||||||
// TODO(phawkins): ideally we would assert that the GIL is held, but there is
|
// TODO(phawkins): we should CHECK(PyGILState_Check());
|
||||||
// no API to do this across all Python versions.
|
std::deque<pybind11::object> garbage;
|
||||||
|
{
|
||||||
absl::MutexLock lock(&mu_);
|
absl::MutexLock lock(&mu_);
|
||||||
python_garbage_.clear();
|
garbage.swap(python_garbage_);
|
||||||
|
}
|
||||||
|
// We defer deleting garbage until the lock is released. It's possible that
|
||||||
|
// deleting garbage will lead to more Python garbage being added; if we held
|
||||||
|
// the lock we would deadlock because absl::Mutex is not reentrant.
|
||||||
}
|
}
|
||||||
|
|
||||||
PythonRefManager* GlobalPyRefManager() {
|
PythonRefManager* GlobalPyRefManager() {
|
||||||
|
@ -62,6 +62,7 @@ class PythonRefManager {
|
|||||||
// Creates a managed std::shared_ptr to an object. When the shared_ptr is
|
// Creates a managed std::shared_ptr to an object. When the shared_ptr is
|
||||||
// destroyed, the reference to 'object' will be added to python_garbage_,
|
// destroyed, the reference to 'object' will be added to python_garbage_,
|
||||||
// and collected next time CollectGarbage() is called.
|
// and collected next time CollectGarbage() is called.
|
||||||
|
std::shared_ptr<ManagedPyObjects> ManageReference(pybind11::object object);
|
||||||
std::shared_ptr<ManagedPyObjects> ManageReferences(
|
std::shared_ptr<ManagedPyObjects> ManageReferences(
|
||||||
absl::Span<pybind11::object> objects);
|
absl::Span<pybind11::object> objects);
|
||||||
|
|
||||||
|
@ -81,7 +81,7 @@ class TpuBackend(xla_client.Backend):
|
|||||||
def host_id(self):
|
def host_id(self):
|
||||||
return self.client.host_id()
|
return self.client.host_id()
|
||||||
|
|
||||||
def buffer_from_pyval(self, pyval, device=None):
|
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = self.client.local_devices()[0]
|
device = self.client.local_devices()[0]
|
||||||
return _tpu_client.PyTpuBuffer.from_python(pyval, self.client, device)
|
return _tpu_client.PyTpuBuffer.from_python(pyval, self.client, device)
|
||||||
|
@ -96,7 +96,7 @@ std::vector<int64> IntSequenceToVector(const pybind11::object& sequence);
|
|||||||
// xla::BorrowingLiteral. Converts a Python array-like object into a buffer
|
// xla::BorrowingLiteral. Converts a Python array-like object into a buffer
|
||||||
// pointer and shape.
|
// pointer and shape.
|
||||||
struct CastToArrayResult {
|
struct CastToArrayResult {
|
||||||
pybind11::array array; // Holds a reference to the array to keep it alive.
|
pybind11::object array; // Holds a reference to the array to keep it alive.
|
||||||
const char* buf_ptr;
|
const char* buf_ptr;
|
||||||
xla::Shape shape;
|
xla::Shape shape;
|
||||||
};
|
};
|
||||||
|
@ -655,8 +655,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
"from_python",
|
"from_python",
|
||||||
[](const pybind11::object& argument,
|
[](const pybind11::object& argument,
|
||||||
std::shared_ptr<PyLocalClient> client,
|
std::shared_ptr<PyLocalClient> client,
|
||||||
std::shared_ptr<Device> device)
|
std::shared_ptr<Device> device,
|
||||||
-> StatusOr<std::unique_ptr<PyLocalBuffer>> {
|
bool force_copy) -> StatusOr<std::unique_ptr<PyLocalBuffer>> {
|
||||||
CHECK(device != nullptr);
|
CHECK(device != nullptr);
|
||||||
auto iter = client->id_to_device().find(device->id());
|
auto iter = client->id_to_device().find(device->id());
|
||||||
if (iter->second != device) {
|
if (iter->second != device) {
|
||||||
@ -665,23 +665,24 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
device->DebugString(), client->platform_name());
|
device->DebugString(), client->platform_name());
|
||||||
}
|
}
|
||||||
GlobalPyRefManager()->CollectGarbage();
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
|
|
||||||
|
absl::optional<CastToArrayResult> c = CastToArray(argument);
|
||||||
|
if (!c) {
|
||||||
|
return InvalidArgument("from_python argument must be an array.");
|
||||||
|
}
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
|
TF_ASSIGN_OR_RETURN(PythonBufferTree tree,
|
||||||
GetPythonBufferTree(argument));
|
GetPythonBufferTree(argument));
|
||||||
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
|
||||||
GlobalPyRefManager()->ManageReferences(
|
GlobalPyRefManager()->ManageReference(std::move(c->array));
|
||||||
absl::MakeSpan(tree.arrays));
|
|
||||||
tree.arrays.clear();
|
|
||||||
|
|
||||||
std::vector<BorrowingLiteral> leaves;
|
|
||||||
leaves.insert(leaves.end(),
|
|
||||||
std::make_move_iterator(tree.leaves.begin()),
|
|
||||||
std::make_move_iterator(tree.leaves.end()));
|
|
||||||
|
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
return PyLocalBuffer::FromLiterals(
|
return PyLocalBuffer::FromHostBuffer(
|
||||||
std::move(leaves), tree.shape, std::move(py_buffer_ref),
|
c->buf_ptr, c->shape, force_copy, std::move(py_buffer_ref),
|
||||||
std::move(client), std::move(device));
|
std::move(client), std::move(device));
|
||||||
})
|
},
|
||||||
|
py::arg("argument"), py::arg("client"), py::arg("device"),
|
||||||
|
py::arg("force_copy") = false)
|
||||||
.def_static("make_tuple",
|
.def_static("make_tuple",
|
||||||
[](const std::vector<PyLocalBuffer*> buffers,
|
[](const std::vector<PyLocalBuffer*> buffers,
|
||||||
std::shared_ptr<PyLocalClient> client,
|
std::shared_ptr<PyLocalClient> client,
|
||||||
|
@ -71,7 +71,7 @@ class Backend(object, metaclass=abc.ABCMeta):
|
|||||||
"""Returns the integer ID of this host."""
|
"""Returns the integer ID of this host."""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def buffer_from_pyval(self, pyval, device=None):
|
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
|
||||||
"""Allocates a fresh buffer and populates it with `pyval`."""
|
"""Allocates a fresh buffer and populates it with `pyval`."""
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@ -129,10 +129,11 @@ class LocalBackend(Backend):
|
|||||||
def host_id(self):
|
def host_id(self):
|
||||||
return self.client.host_id()
|
return self.client.host_id()
|
||||||
|
|
||||||
def buffer_from_pyval(self, pyval, device=None):
|
def buffer_from_pyval(self, pyval, device=None, force_copy=False):
|
||||||
if device is None:
|
if device is None:
|
||||||
device = self.local_devices()[0]
|
device = self.local_devices()[0]
|
||||||
return _xla.PyLocalBuffer.from_python(pyval, self.client, device)
|
return _xla.PyLocalBuffer.from_python(pyval, self.client, device,
|
||||||
|
force_copy)
|
||||||
|
|
||||||
def make_tuple(self, c_buffers, device):
|
def make_tuple(self, c_buffers, device):
|
||||||
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device)
|
return _xla.PyLocalBuffer.make_tuple(c_buffers, self.client, device)
|
||||||
@ -391,10 +392,10 @@ class Buffer(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_pyval(pyval, device=None, backend=None):
|
def from_pyval(pyval, device=None, backend=None, force_copy=False):
|
||||||
"""Copies the `pyval` to a freshly allocated on-device buffer."""
|
"""Copies the `pyval` to a freshly allocated on-device buffer."""
|
||||||
backend = backend or get_local_backend()
|
backend = backend or get_local_backend()
|
||||||
return backend.buffer_from_pyval(pyval, device)
|
return backend.buffer_from_pyval(pyval, device, force_copy=force_copy)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def make_tuple(buffers, device, backend=None):
|
def make_tuple(buffers, device, backend=None):
|
||||||
|
@ -471,15 +471,16 @@ class BufferTest(ComputationTest):
|
|||||||
compiled_c.Execute([arg_buffer])
|
compiled_c.Execute([arg_buffer])
|
||||||
|
|
||||||
def testDestructureTupleEmpty(self):
|
def testDestructureTupleEmpty(self):
|
||||||
t = ()
|
device = xla_client.get_local_backend().devices()[0]
|
||||||
local_buffer = xla_client.Buffer.from_pyval(t)
|
local_buffer = xla_client.Buffer.make_tuple((), device=device)
|
||||||
pieces = local_buffer.destructure()
|
pieces = local_buffer.destructure()
|
||||||
self.assertFalse(local_buffer.is_deleted())
|
self.assertFalse(local_buffer.is_deleted())
|
||||||
self.assertEmpty(pieces)
|
self.assertEmpty(pieces)
|
||||||
|
|
||||||
def testDestructureTupleOneArrayElement(self):
|
def testDestructureTupleOneArrayElement(self):
|
||||||
t = (np.array([1, 2, 3, 4], dtype=np.int32),)
|
device = xla_client.get_local_backend().devices()[0]
|
||||||
local_buffer = xla_client.Buffer.from_pyval(t)
|
t = xla_client.Buffer.from_pyval(np.array([1, 2, 3, 4], dtype=np.int32))
|
||||||
|
local_buffer = xla_client.Buffer.make_tuple((t,), device)
|
||||||
pieces = local_buffer.destructure()
|
pieces = local_buffer.destructure()
|
||||||
self.assertFalse(local_buffer.is_deleted())
|
self.assertFalse(local_buffer.is_deleted())
|
||||||
self.assertLen(pieces, 1)
|
self.assertLen(pieces, 1)
|
||||||
@ -489,11 +490,13 @@ class BufferTest(ComputationTest):
|
|||||||
np.testing.assert_equal(want, got)
|
np.testing.assert_equal(want, got)
|
||||||
|
|
||||||
def testDestructureTupleTwoArrayElementDifferentType(self):
|
def testDestructureTupleTwoArrayElementDifferentType(self):
|
||||||
|
device = xla_client.get_local_backend().devices()[0]
|
||||||
t = (
|
t = (
|
||||||
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32),
|
xla_client.Buffer.from_pyval(
|
||||||
np.array([2, 3, 4, 5], dtype=np.int32),
|
np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)),
|
||||||
|
xla_client.Buffer.from_pyval(np.array([2, 3, 4, 5], dtype=np.int32)),
|
||||||
)
|
)
|
||||||
local_buffer = xla_client.Buffer.from_pyval(t)
|
local_buffer = xla_client.Buffer.make_tuple(t, device)
|
||||||
# Run the test twice to verify that the original tuple buffer remains valid
|
# Run the test twice to verify that the original tuple buffer remains valid
|
||||||
# even after destructuring.
|
# even after destructuring.
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
@ -509,8 +512,12 @@ class BufferTest(ComputationTest):
|
|||||||
np.testing.assert_equal(want, got)
|
np.testing.assert_equal(want, got)
|
||||||
|
|
||||||
def testDestructureTupleNested(self):
|
def testDestructureTupleNested(self):
|
||||||
t = ((NumpyArrayF32([1.0, 2.0]), NumpyArrayS32([3, 4])), NumpyArrayS32([5]))
|
device = xla_client.get_local_backend().devices()[0]
|
||||||
local_buffer = xla_client.Buffer.from_pyval(t)
|
t = xla_client.Buffer.make_tuple(
|
||||||
|
(xla_client.Buffer.from_pyval(NumpyArrayF32([1.0, 2.0])),
|
||||||
|
xla_client.Buffer.from_pyval(NumpyArrayS32([3, 4]))), device)
|
||||||
|
local_buffer = xla_client.Buffer.make_tuple(
|
||||||
|
(t, xla_client.Buffer.from_pyval(NumpyArrayS32([5]))), device)
|
||||||
pieces = local_buffer.destructure()
|
pieces = local_buffer.destructure()
|
||||||
self.assertFalse(local_buffer.is_deleted())
|
self.assertFalse(local_buffer.is_deleted())
|
||||||
self.assertLen(pieces, 2)
|
self.assertLen(pieces, 2)
|
||||||
@ -2119,12 +2126,20 @@ class BufferProtocolTest(parameterized.TestCase):
|
|||||||
} for dtype in standard_dtypes for shape in testcase_shapes)
|
} for dtype in standard_dtypes for shape in testcase_shapes)
|
||||||
def testRoundTrip(self, dtype, shape):
|
def testRoundTrip(self, dtype, shape):
|
||||||
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
x = np.array(np.random.rand(*shape) * 100, dtype=dtype)
|
||||||
|
x_ptr = x.__array_interface__["data"][0]
|
||||||
backend = xla_client.get_local_backend("cpu")
|
backend = xla_client.get_local_backend("cpu")
|
||||||
buffer = xla_client.Buffer.from_pyval(x, backend=backend)
|
buffer = xla_client.Buffer.from_pyval(x, backend=backend)
|
||||||
y = np.array(buffer, copy=False)
|
y = np.array(buffer, copy=False)
|
||||||
|
y_ptr = y.__array_interface__["data"][0]
|
||||||
np.testing.assert_array_equal(x, y)
|
np.testing.assert_array_equal(x, y)
|
||||||
self.assertEqual(y.__array_interface__["data"][0],
|
# If the input was sufficiently aligned, the input and output should alias.
|
||||||
buffer.unsafe_buffer_pointer())
|
self.assertTrue((x_ptr & 63) != 0 or x_ptr == y_ptr)
|
||||||
|
self.assertEqual(y_ptr, buffer.unsafe_buffer_pointer())
|
||||||
|
|
||||||
|
buffer2 = xla_client.Buffer.from_pyval(x, backend=backend, force_copy=True)
|
||||||
|
z = np.array(buffer2, copy=False)
|
||||||
|
self.assertNotEqual(x.__array_interface__["data"][0],
|
||||||
|
z.__array_interface__["data"][0])
|
||||||
|
|
||||||
def testDeleteWithActiveView(self):
|
def testDeleteWithActiveView(self):
|
||||||
x = np.random.randn(20, 10)
|
x = np.random.randn(20, 10)
|
||||||
|
Loading…
Reference in New Issue
Block a user