[PJRT] Remove on_host_shape from PjRtBuffer API.

The on-host shape can be derived from the on-device shape, so there's no need for PJRT to maintain both.

In many places that were consuming the host shape, it should be legal to consume the device shape. In a handful of places where I was not sure, I used xla::ShapeUtil::DeviceShapeToHostShape to construct the host shape again.

Note: device shapes and host shapes are identical on CPU and GPU.
PiperOrigin-RevId: 354401607
Change-Id: I38064cfaf8c1be908448d2a6131d47fad03e2ddf
This commit is contained in:
Peter Hawkins 2021-01-28 14:58:51 -08:00 committed by TensorFlower Gardener
parent 4ec61b6146
commit b13ed5c441
11 changed files with 65 additions and 80 deletions

View File

@ -275,7 +275,6 @@ class PjRtBuffer {
public:
virtual ~PjRtBuffer() = default;
virtual const Shape& on_host_shape() const = 0;
virtual const Shape& on_device_shape() const = 0;
virtual PjRtDevice* device() const = 0;
virtual PjRtClient* client() const = 0;
@ -319,7 +318,8 @@ class PjRtBuffer {
// Convenience synchronous overload that allocates a literal with a default
// layout.
StatusOr<std::shared_ptr<Literal>> ToLiteral() {
auto literal = std::make_shared<Literal>(on_host_shape());
auto literal = std::make_shared<Literal>(
ShapeUtil::DeviceShapeToHostShape(on_device_shape()));
TF_RETURN_IF_ERROR(ToLiteral(literal.get()));
return literal;
}

View File

@ -359,6 +359,7 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
//
// The caller may optionally provide a definition event to be recorded in
// the buffer.
// TODO(phawkins): replace on_host_shape here with on_device_shape.
StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
const Shape& on_host_shape, PjRtDevice* device,
LocalDeviceState* local_device, se::Stream* copy_stream,
@ -453,8 +454,7 @@ StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
definition_events);
auto py_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
on_host_shape, on_device_shape, std::move(dst_device_buffer), client,
device);
on_device_shape, std::move(dst_device_buffer), client, device);
if (on_device_shape.IsTuple()) {
// Add a usage hold for the tuple table write and immediately convert it to
@ -670,7 +670,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
definition_events, std::move(on_delete_callback));
return std::unique_ptr<PjRtBuffer>(
std::make_unique<PjRtStreamExecutorBuffer>(
shape, shape, std::move(device_buffer), this, device));
shape, std::move(device_buffer), this, device));
}
}
@ -719,7 +719,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
data, size,
movable_device_buffer{device_buffer.ToClosure()}, shape,
py_buffer{py_buffer.get()}, compact_shape,
py_buffer{py_buffer.get()},
on_device_shape{py_buffer->on_device_shape()},
staging_buffer{std::move(staging_buffer)},
on_done_with_host_buffer{
@ -732,8 +732,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
// memory that has already been allocated, and a possible Event
// allocation.
ShapedBuffer buffer =
device_buffer->AsShapedBuffer(compact_shape, on_device_shape);
ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
// If applicable on the backend, stage the transfer via host memory
// allocated via the host_memory_allocator. On GPU, this is pinned
// memory.
@ -846,7 +845,7 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
// put the transfer into the calling thread for small literals.
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
movable_device_buffer{device_buffer.ToClosure()},
literal, py_buffer{py_buffer.get()}, compact_shape,
literal, py_buffer{py_buffer.get()},
on_device_shape{py_buffer->on_device_shape()}]() {
PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
@ -856,8 +855,7 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
// allocation.
se::Stream* h2d_stream = local_device->host_to_device_stream();
ShapedBuffer buffer =
device_buffer->AsShapedBuffer(compact_shape, on_device_shape);
ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
h2d_stream, literal, buffer));
@ -924,7 +922,7 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
std::move(on_delete_callback));
return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
shape, shape, std::move(device_buffer), this, device));
shape, std::move(device_buffer), this, device));
}
// Transfer the given literal to the infeed queue of the given local device.
@ -955,11 +953,9 @@ StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
}
PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer(
Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer, PjRtClient* client,
PjRtDevice* device)
Shape on_device_shape, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device)
: client_(tensorflow::down_cast<PjRtStreamExecutorClient*>(client)),
on_host_shape_(std::move(on_host_shape)),
on_device_shape_(std::move(on_device_shape)),
device_(tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)),
device_buffer_(std::move(device_buffer)),
@ -1195,8 +1191,7 @@ void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal,
}
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
ShapedBuffer shaped_buffer =
device_buffer->AsShapedBuffer(literal->shape(), on_device_shape_);
ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
StatusOr<EventPool::Handle> event_or =
local_device->event_pool().AllocateEvent(stream->parent());
if (!event_or.ok()) {
@ -1233,7 +1228,7 @@ StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
return InvalidArgument(
"Attempted to fetch value of invalid/deleted buffer.");
}
return device_buffer_->AsShapedBuffer(on_host_shape_, on_device_shape_);
return device_buffer_->AsShapedBuffer(on_device_shape_);
}
PjRtStreamExecutorBuffer::ScopedHold
@ -1257,11 +1252,11 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper(
PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
AllocateDestinationBuffer(on_host_shape_, dst_device, dst_local_device,
transfer_stream,
/*is_uninitialized_create=*/false, client_));
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
AllocateDestinationBuffer(
ShapeUtil::DeviceShapeToHostShape(on_device_shape_),
dst_device, dst_local_device, transfer_stream,
/*is_uninitialized_create=*/false, client_));
TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
@ -1269,8 +1264,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper(
ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold());
CHECK(dst_device_buffer.ok());
ShapedBuffer dst_buffer =
dst_device_buffer->AsShapedBuffer(on_host_shape_, on_device_shape_);
ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(on_device_shape_);
// Copy the leaf buffers.
StatusOr<std::shared_ptr<BufferSequencingEvent>> copy_event_or =
@ -1451,10 +1445,8 @@ StatusOr<TupleHandle> MakeTupleHelper(
host_shapes.reserve(py_buffers.size());
device_shapes.reserve(py_buffers.size());
for (const PjRtBuffer* buffer : py_buffers) {
host_shapes.push_back(buffer->on_host_shape());
device_shapes.push_back(buffer->on_device_shape());
}
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
se::DeviceMemoryAllocator* allocator =
@ -1469,7 +1461,7 @@ StatusOr<TupleHandle> MakeTupleHelper(
se::OwningDeviceMemory root_table_memory,
allocator->Allocate(
device_ordinal,
transfer_manager->GetByteSizeRequirement(on_host_shape)));
transfer_manager->GetByteSizeRequirement(on_device_shape)));
if (local_device->allocation_model() ==
LocalDeviceState::kComputeSynchronized) {
@ -1479,7 +1471,7 @@ StatusOr<TupleHandle> MakeTupleHelper(
local_device->compute_stream()->parent(), root_table_memory.cref()));
}
ExecutionInput execution_input(on_device_shape, on_host_shape);
ExecutionInput execution_input(on_device_shape);
ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
execution_input.MutableBuffers()->begin();
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
@ -1521,8 +1513,7 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
{definition_event});
auto pjrt_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
result_buffer->on_host_shape(), result_buffer->on_device_shape(),
std::move(out_buffer), client, device);
result_buffer->on_device_shape(), std::move(out_buffer), client, device);
RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
definition_event, local_device->compute_stream(),
/*prefer_to_retain_reference=*/false);
@ -1621,8 +1612,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
PjRtBuffer* handle = argument_handles[i];
// Make an ExecutionInput from the device buffer.
execution_inputs.emplace_back(handle->on_device_shape(),
handle->on_host_shape());
execution_inputs.emplace_back(handle->on_device_shape());
ExecutionInput& execution_input = execution_inputs.back();
ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
execution_input.MutableBuffers()->begin();
@ -1794,8 +1784,8 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers(
PjRtDevice* device) const {
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
if (options.untuple_result && result_buffer.on_device_shape().IsTuple()) {
int tuple_count = result_buffer.on_device_shape().tuple_shapes_size();
outputs.reserve(tuple_count);
// Take ownership of each of the output values, leaving only the root table
// in result_buffer.

View File

@ -455,7 +455,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
};
PjRtStreamExecutorBuffer(Shape on_host_shape, Shape on_device_shape,
PjRtStreamExecutorBuffer(Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device);
~PjRtStreamExecutorBuffer() override;
@ -465,14 +465,14 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
const Shape& on_host_shape() const override { return on_host_shape_; }
const Shape& on_device_shape() const override { return on_device_shape_; }
PjRtStreamExecutorDevice* device() const override { return device_; }
PjRtPlatformId platform_id() const { return client_->platform_id(); }
absl::string_view platform_name() const { return client_->platform_name(); }
PjRtStreamExecutorClient* client() const override { return client_; }
bool IsEmptyTuple() const {
return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
return on_device_shape_.IsTuple() &&
on_device_shape_.tuple_shapes_size() == 0;
}
int64 OnDeviceSizeInBytes() const override;
@ -603,7 +603,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
PjRtStreamExecutorClient* const client_;
const Shape on_host_shape_;
const Shape on_device_shape_;
PjRtStreamExecutorDevice* const device_;

View File

@ -118,8 +118,8 @@ TrackedDeviceBuffer::FromScopedShapedBuffer(
}
ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
const Shape& on_host_shape, const Shape& on_device_shape) const {
ShapedBuffer shaped_buffer(on_host_shape, on_device_shape, device_ordinal_);
const Shape& on_device_shape) const {
ShapedBuffer shaped_buffer(on_device_shape, device_ordinal_);
ShapeTree<se::DeviceMemoryBase>::iterator iterator =
shaped_buffer.buffers().begin();
for (const se::DeviceMemoryBase& buf : device_memory_) {

View File

@ -137,11 +137,8 @@ class TrackedDeviceBuffer {
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
definition_events);
// Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do
// not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) ==
// on_device_shape().
ShapedBuffer AsShapedBuffer(const Shape& on_host_shape,
const Shape& on_device_shape) const;
// Builds a ShapedBuffer view onto the buffers of 'tree'.
ShapedBuffer AsShapedBuffer(const Shape& on_device_shape) const;
// Adds the owned device buffers in order to 'iterator'. Used to add the
// buffers to an ExecutionInput. We require but do not verify that 'iterator'

View File

@ -65,13 +65,10 @@ TEST(TrackedDeviceBufferTest, AsShapedBuffer) {
a_buffer->device_memory()[0], b_buffer->device_memory()[0],
c_buffer->device_memory()[0]};
ShapedBuffer shaped_a = a_buffer->AsShapedBuffer(
a_shape,
client->backend().transfer_manager()->HostShapeToDeviceShape(a_shape));
ShapedBuffer shaped_b = b_buffer->AsShapedBuffer(
b_shape,
client->backend().transfer_manager()->HostShapeToDeviceShape(b_shape));
ShapedBuffer shaped_c = c_buffer->AsShapedBuffer(
c_shape,
client->backend().transfer_manager()->HostShapeToDeviceShape(c_shape));
auto expected_it = expected_buffer_sequence.begin();
for (auto it = shaped_a.buffers().begin(); it != shaped_a.buffers().end();

View File

@ -295,15 +295,15 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
pack->tensor.deleter = DLPackTensorDeleter;
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
dt.ndim = buffer->buffer()->on_device_shape().dimensions_size();
TF_ASSIGN_OR_RETURN(dt.dtype,
PrimitiveTypeToDLDataType(
buffer->buffer()->on_host_shape().element_type()));
buffer->buffer()->on_device_shape().element_type()));
pack->shape =
std::vector<int64>(buffer->buffer()->on_host_shape().dimensions().begin(),
buffer->buffer()->on_host_shape().dimensions().end());
pack->strides = StridesForShape(buffer->buffer()->on_host_shape());
pack->shape = std::vector<int64>(
buffer->buffer()->on_device_shape().dimensions().begin(),
buffer->buffer()->on_device_shape().dimensions().end());
pack->strides = StridesForShape(buffer->buffer()->on_device_shape());
dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
dt.byte_offset = 0;

View File

@ -438,8 +438,8 @@ xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(h);
bool weak_type = py::cast<py::bool_>(h.attr("aval").attr("weak_type"));
return ArgSignature(buffer->buffer()->on_host_shape().element_type(),
buffer->buffer()->on_host_shape().dimensions(),
return ArgSignature(buffer->buffer()->on_device_shape().element_type(),
buffer->buffer()->on_device_shape().dimensions(),
weak_type);
};
(*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] = buffer_handler;
@ -1015,8 +1015,9 @@ xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
keep_alive.emplace_back(std::move(on_device.owned_buffer));
}
ArgSignature sig(buffer->on_host_shape().element_type(),
buffer->on_host_shape().dimensions(), on_device.weak_type);
ArgSignature sig(buffer->on_device_shape().element_type(),
buffer->on_device_shape().dimensions(),
on_device.weak_type);
arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
}
return xla::Status::OK();

View File

@ -56,11 +56,11 @@ PyBuffer::~PyBuffer() {
}
pybind11::tuple PyBuffer::python_shape() const {
return IntSpanToTuple(buffer()->on_host_shape().dimensions());
return IntSpanToTuple(buffer()->on_device_shape().dimensions());
}
pybind11::dtype PyBuffer::python_dtype() const {
PrimitiveType primitive = buffer()->on_host_shape().element_type();
PrimitiveType primitive = buffer()->on_device_shape().element_type();
return PrimitiveTypeToDtype(primitive).ValueOrDie();
}
@ -91,7 +91,8 @@ Status PyBuffer::BlockHostUntilReady() {
Status PyBuffer::CopyToHostAsync() {
if (!buffer_->IsOnCpu() && !host_value_) {
host_value_ = std::make_shared<HostValue>();
host_value_->value = std::make_shared<Literal>(buffer_->on_host_shape());
host_value_->value = std::make_shared<Literal>(
ShapeUtil::DeviceShapeToHostShape(buffer_->on_device_shape()));
buffer_->ToLiteral(host_value_->value.get(),
[host_value{host_value_}](Status status) {
host_value->status = std::move(status);
@ -110,7 +111,7 @@ StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
if (buffer_->IsOnCpu()) {
TF_ASSIGN_OR_RETURN(
py::dtype dtype,
PrimitiveTypeToDtype(buffer_->on_host_shape().element_type()));
PrimitiveTypeToDtype(buffer_->on_device_shape().element_type()));
// Objects that must be kept alive while the array is alive.
struct Hold {
py::object buffer;
@ -124,8 +125,8 @@ StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer();
py::capsule hold_capsule(hold.release(),
[](void* h) { delete static_cast<Hold*>(h); });
py::array array(dtype, buffer_->on_host_shape().dimensions(),
ByteStridesForShape(buffer_->on_host_shape()), data,
py::array array(dtype, buffer_->on_device_shape().dimensions(),
ByteStridesForShape(buffer_->on_device_shape()), data,
hold_capsule);
array.attr("flags").attr("writeable") = Py_False;
{
@ -171,18 +172,18 @@ StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
return InvalidArgument(
"__cuda_array_interface__ is only defined for array buffers.");
}
if (buffer_->on_host_shape().element_type() == BF16) {
if (buffer_->on_device_shape().element_type() == BF16) {
return InvalidArgument(
"__cuda_array_interface__ is not supported for bfloat16 buffers.");
}
TF_RET_CHECK(
LayoutUtil::IsMonotonicWithDim0Major(buffer_->on_host_shape().layout()));
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
buffer_->on_device_shape().layout()));
py::dict result;
result["shape"] = IntSpanToTuple(buffer_->on_host_shape().dimensions());
TF_ASSIGN_OR_RETURN(
py::str typestr,
TypeDescriptorForPrimitiveType(buffer_->on_host_shape().element_type()));
result["shape"] = IntSpanToTuple(buffer_->on_device_shape().dimensions());
TF_ASSIGN_OR_RETURN(py::str typestr,
TypeDescriptorForPrimitiveType(
buffer_->on_device_shape().element_type()));
result["typestr"] = std::move(typestr);
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer::ExternalReferenceHold>
external_reference_hold,
@ -235,7 +236,7 @@ int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) {
// If we allowed exports of formatted BF16 buffers, consumers would get
// confused about the type because there is no way to describe BF16 to
// Python.
if (buffer.on_host_shape().element_type() == BF16 &&
if (buffer.on_device_shape().element_type() == BF16 &&
((flags & PyBUF_FORMAT) == PyBUF_FORMAT)) {
return InvalidArgument(
"bfloat16 buffer format not supported by Python buffer protocol.");
@ -249,7 +250,7 @@ int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) {
if (buffer.IsDeleted()) {
return InvalidArgument("Deleted buffer used in buffer protocol.");
}
const Shape& shape = buffer.on_host_shape();
const Shape& shape = buffer.on_device_shape();
if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS ||
(flags & PyBUF_STRIDES) == PyBUF_ND) &&
!LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {

View File

@ -79,7 +79,7 @@ class PyBuffer : public DeviceArrayBase {
Status BlockHostUntilReady();
Status CopyToHostAsync();
const Shape& shape() { return buffer_->on_host_shape(); }
const Shape& shape() { return buffer_->on_device_shape(); }
StatusOr<std::uintptr_t> UnsafeBufferPointer() const;
@ -93,10 +93,10 @@ class PyBuffer : public DeviceArrayBase {
Traceback* traceback() { return traceback_.get(); }
// Returns the size (i.e. number of elements) of the (host) numpy array.
int64 size() { return ShapeUtil::ElementsIn(buffer()->on_host_shape()); }
int64 size() { return ShapeUtil::ElementsIn(buffer()->on_device_shape()); }
// Returns the number of dimensions of the (host) numpy array.
int ndim() const { return buffer()->on_host_shape().dimensions_size(); }
int ndim() const { return buffer()->on_device_shape().dimensions_size(); }
pybind11::tuple python_shape() const;
pybind11::dtype python_dtype() const;

View File

@ -301,13 +301,13 @@ PYBIND11_MODULE(xla_extension, m) {
"shape",
[](const PyBuffer& pybuffer) -> pybind11::tuple {
return IntSpanToTuple(
pybuffer.buffer()->on_host_shape().dimensions());
pybuffer.buffer()->on_device_shape().dimensions());
})
.def_property_readonly(
"dtype",
[](const PyBuffer& buffer) {
PrimitiveType primitive =
buffer.buffer()->on_host_shape().element_type();
buffer.buffer()->on_device_shape().element_type();
return PrimitiveTypeToDtype(primitive).ValueOrDie();
})
.def_property_readonly("size", &PyBuffer::size)