[PJRT] Remove on_host_shape from PjRtBuffer API.
The on-host shape can be derived from the on-device shape, so there's no need for PJRT to maintain both. In many places that were consuming the host shape, it should be legal to consume the device shape. In a handful of places where I was not sure, I used xla::ShapeUtil::DeviceShapeToHostShape to construct the host shape again. Note: device shapes and host shapes are identical on CPU and GPU. PiperOrigin-RevId: 354401607 Change-Id: I38064cfaf8c1be908448d2a6131d47fad03e2ddf
This commit is contained in:
parent
4ec61b6146
commit
b13ed5c441
@ -275,7 +275,6 @@ class PjRtBuffer {
|
||||
public:
|
||||
virtual ~PjRtBuffer() = default;
|
||||
|
||||
virtual const Shape& on_host_shape() const = 0;
|
||||
virtual const Shape& on_device_shape() const = 0;
|
||||
virtual PjRtDevice* device() const = 0;
|
||||
virtual PjRtClient* client() const = 0;
|
||||
@ -319,7 +318,8 @@ class PjRtBuffer {
|
||||
// Convenience synchronous overload that allocates a literal with a default
|
||||
// layout.
|
||||
StatusOr<std::shared_ptr<Literal>> ToLiteral() {
|
||||
auto literal = std::make_shared<Literal>(on_host_shape());
|
||||
auto literal = std::make_shared<Literal>(
|
||||
ShapeUtil::DeviceShapeToHostShape(on_device_shape()));
|
||||
TF_RETURN_IF_ERROR(ToLiteral(literal.get()));
|
||||
return literal;
|
||||
}
|
||||
|
@ -359,6 +359,7 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer,
|
||||
//
|
||||
// The caller may optionally provide a definition event to be recorded in
|
||||
// the buffer.
|
||||
// TODO(phawkins): replace on_host_shape here with on_device_shape.
|
||||
StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
|
||||
const Shape& on_host_shape, PjRtDevice* device,
|
||||
LocalDeviceState* local_device, se::Stream* copy_stream,
|
||||
@ -453,8 +454,7 @@ StatusOr<std::unique_ptr<PjRtStreamExecutorBuffer>> AllocateDestinationBuffer(
|
||||
definition_events);
|
||||
|
||||
auto py_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
|
||||
on_host_shape, on_device_shape, std::move(dst_device_buffer), client,
|
||||
device);
|
||||
on_device_shape, std::move(dst_device_buffer), client, device);
|
||||
|
||||
if (on_device_shape.IsTuple()) {
|
||||
// Add a usage hold for the tuple table write and immediately convert it to
|
||||
@ -670,7 +670,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
definition_events, std::move(on_delete_callback));
|
||||
return std::unique_ptr<PjRtBuffer>(
|
||||
std::make_unique<PjRtStreamExecutorBuffer>(
|
||||
shape, shape, std::move(device_buffer), this, device));
|
||||
shape, std::move(device_buffer), this, device));
|
||||
}
|
||||
}
|
||||
|
||||
@ -719,7 +719,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
|
||||
data, size,
|
||||
movable_device_buffer{device_buffer.ToClosure()}, shape,
|
||||
py_buffer{py_buffer.get()}, compact_shape,
|
||||
py_buffer{py_buffer.get()},
|
||||
on_device_shape{py_buffer->on_device_shape()},
|
||||
staging_buffer{std::move(staging_buffer)},
|
||||
on_done_with_host_buffer{
|
||||
@ -732,8 +732,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer(
|
||||
// memory that has already been allocated, and a possible Event
|
||||
// allocation.
|
||||
|
||||
ShapedBuffer buffer =
|
||||
device_buffer->AsShapedBuffer(compact_shape, on_device_shape);
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
|
||||
// If applicable on the backend, stage the transfer via host memory
|
||||
// allocated via the host_memory_allocator. On GPU, this is pinned
|
||||
// memory.
|
||||
@ -846,7 +845,7 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
|
||||
// put the transfer into the calling thread for small literals.
|
||||
auto transfer_h2d = [local_client = client(), transfer_manager, local_device,
|
||||
movable_device_buffer{device_buffer.ToClosure()},
|
||||
literal, py_buffer{py_buffer.get()}, compact_shape,
|
||||
literal, py_buffer{py_buffer.get()},
|
||||
on_device_shape{py_buffer->on_device_shape()}]() {
|
||||
PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer);
|
||||
// This function uses TF_CHECK_OK and ValueOrDie() since we have no way
|
||||
@ -856,8 +855,7 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal,
|
||||
// allocation.
|
||||
|
||||
se::Stream* h2d_stream = local_device->host_to_device_stream();
|
||||
ShapedBuffer buffer =
|
||||
device_buffer->AsShapedBuffer(compact_shape, on_device_shape);
|
||||
ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape);
|
||||
TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync(
|
||||
h2d_stream, literal, buffer));
|
||||
|
||||
@ -924,7 +922,7 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
|
||||
std::initializer_list<se::DeviceMemoryBase>{buffer}, definition_events,
|
||||
std::move(on_delete_callback));
|
||||
return std::unique_ptr<PjRtBuffer>(std::make_unique<PjRtStreamExecutorBuffer>(
|
||||
shape, shape, std::move(device_buffer), this, device));
|
||||
shape, std::move(device_buffer), this, device));
|
||||
}
|
||||
|
||||
// Transfer the given literal to the infeed queue of the given local device.
|
||||
@ -955,11 +953,9 @@ StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
|
||||
}
|
||||
|
||||
PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer(
|
||||
Shape on_host_shape, Shape on_device_shape,
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer, PjRtClient* client,
|
||||
PjRtDevice* device)
|
||||
Shape on_device_shape, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||
PjRtClient* client, PjRtDevice* device)
|
||||
: client_(tensorflow::down_cast<PjRtStreamExecutorClient*>(client)),
|
||||
on_host_shape_(std::move(on_host_shape)),
|
||||
on_device_shape_(std::move(on_device_shape)),
|
||||
device_(tensorflow::down_cast<PjRtStreamExecutorDevice*>(device)),
|
||||
device_buffer_(std::move(device_buffer)),
|
||||
@ -1195,8 +1191,7 @@ void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal,
|
||||
}
|
||||
|
||||
WaitForBufferDefinitionEventsOnStream(*device_buffer, stream);
|
||||
ShapedBuffer shaped_buffer =
|
||||
device_buffer->AsShapedBuffer(literal->shape(), on_device_shape_);
|
||||
ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_);
|
||||
StatusOr<EventPool::Handle> event_or =
|
||||
local_device->event_pool().AllocateEvent(stream->parent());
|
||||
if (!event_or.ok()) {
|
||||
@ -1233,7 +1228,7 @@ StatusOr<ShapedBuffer> PjRtStreamExecutorBuffer::AsShapedBuffer() const {
|
||||
return InvalidArgument(
|
||||
"Attempted to fetch value of invalid/deleted buffer.");
|
||||
}
|
||||
return device_buffer_->AsShapedBuffer(on_host_shape_, on_device_shape_);
|
||||
return device_buffer_->AsShapedBuffer(on_device_shape_);
|
||||
}
|
||||
|
||||
PjRtStreamExecutorBuffer::ScopedHold
|
||||
@ -1257,11 +1252,11 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper(
|
||||
PjRtDevice* dst_device, LocalDeviceState* dst_local_device,
|
||||
LocalDeviceState* transfer_local_device, se::Stream* transfer_stream,
|
||||
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
|
||||
AllocateDestinationBuffer(on_host_shape_, dst_device, dst_local_device,
|
||||
transfer_stream,
|
||||
/*is_uninitialized_create=*/false, client_));
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtStreamExecutorBuffer> py_buffer,
|
||||
AllocateDestinationBuffer(
|
||||
ShapeUtil::DeviceShapeToHostShape(on_device_shape_),
|
||||
dst_device, dst_local_device, transfer_stream,
|
||||
/*is_uninitialized_create=*/false, client_));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer());
|
||||
|
||||
@ -1269,8 +1264,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper(
|
||||
|
||||
ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold());
|
||||
CHECK(dst_device_buffer.ok());
|
||||
ShapedBuffer dst_buffer =
|
||||
dst_device_buffer->AsShapedBuffer(on_host_shape_, on_device_shape_);
|
||||
ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(on_device_shape_);
|
||||
|
||||
// Copy the leaf buffers.
|
||||
StatusOr<std::shared_ptr<BufferSequencingEvent>> copy_event_or =
|
||||
@ -1451,10 +1445,8 @@ StatusOr<TupleHandle> MakeTupleHelper(
|
||||
host_shapes.reserve(py_buffers.size());
|
||||
device_shapes.reserve(py_buffers.size());
|
||||
for (const PjRtBuffer* buffer : py_buffers) {
|
||||
host_shapes.push_back(buffer->on_host_shape());
|
||||
device_shapes.push_back(buffer->on_device_shape());
|
||||
}
|
||||
Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes);
|
||||
Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes);
|
||||
|
||||
se::DeviceMemoryAllocator* allocator =
|
||||
@ -1469,7 +1461,7 @@ StatusOr<TupleHandle> MakeTupleHelper(
|
||||
se::OwningDeviceMemory root_table_memory,
|
||||
allocator->Allocate(
|
||||
device_ordinal,
|
||||
transfer_manager->GetByteSizeRequirement(on_host_shape)));
|
||||
transfer_manager->GetByteSizeRequirement(on_device_shape)));
|
||||
|
||||
if (local_device->allocation_model() ==
|
||||
LocalDeviceState::kComputeSynchronized) {
|
||||
@ -1479,7 +1471,7 @@ StatusOr<TupleHandle> MakeTupleHelper(
|
||||
local_device->compute_stream()->parent(), root_table_memory.cref()));
|
||||
}
|
||||
|
||||
ExecutionInput execution_input(on_device_shape, on_host_shape);
|
||||
ExecutionInput execution_input(on_device_shape);
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
|
||||
execution_input.MutableBuffers()->begin();
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator iterator_end =
|
||||
@ -1521,8 +1513,7 @@ std::unique_ptr<PjRtBuffer> OutputBufferHelper(
|
||||
TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer,
|
||||
{definition_event});
|
||||
auto pjrt_buffer = absl::make_unique<PjRtStreamExecutorBuffer>(
|
||||
result_buffer->on_host_shape(), result_buffer->on_device_shape(),
|
||||
std::move(out_buffer), client, device);
|
||||
result_buffer->on_device_shape(), std::move(out_buffer), client, device);
|
||||
RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device,
|
||||
definition_event, local_device->compute_stream(),
|
||||
/*prefer_to_retain_reference=*/false);
|
||||
@ -1621,8 +1612,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents(
|
||||
PjRtBuffer* handle = argument_handles[i];
|
||||
|
||||
// Make an ExecutionInput from the device buffer.
|
||||
execution_inputs.emplace_back(handle->on_device_shape(),
|
||||
handle->on_host_shape());
|
||||
execution_inputs.emplace_back(handle->on_device_shape());
|
||||
ExecutionInput& execution_input = execution_inputs.back();
|
||||
ShapeTree<MaybeOwningDeviceMemory>::iterator input_iterator =
|
||||
execution_input.MutableBuffers()->begin();
|
||||
@ -1794,8 +1784,8 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers(
|
||||
PjRtDevice* device) const {
|
||||
std::vector<std::unique_ptr<PjRtBuffer>> outputs;
|
||||
LocalDeviceState* device_state = &(client_->device_state(device_ordinal));
|
||||
if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) {
|
||||
int tuple_count = result_buffer.on_host_shape().tuple_shapes_size();
|
||||
if (options.untuple_result && result_buffer.on_device_shape().IsTuple()) {
|
||||
int tuple_count = result_buffer.on_device_shape().tuple_shapes_size();
|
||||
outputs.reserve(tuple_count);
|
||||
// Take ownership of each of the output values, leaving only the root table
|
||||
// in result_buffer.
|
||||
|
@ -455,7 +455,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
||||
StatusOr<std::shared_ptr<TrackedDeviceBuffer>> buffer_or_;
|
||||
};
|
||||
|
||||
PjRtStreamExecutorBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
PjRtStreamExecutorBuffer(Shape on_device_shape,
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||
PjRtClient* client, PjRtDevice* device);
|
||||
~PjRtStreamExecutorBuffer() override;
|
||||
@ -465,14 +465,14 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
||||
PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete;
|
||||
PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete;
|
||||
|
||||
const Shape& on_host_shape() const override { return on_host_shape_; }
|
||||
const Shape& on_device_shape() const override { return on_device_shape_; }
|
||||
PjRtStreamExecutorDevice* device() const override { return device_; }
|
||||
PjRtPlatformId platform_id() const { return client_->platform_id(); }
|
||||
absl::string_view platform_name() const { return client_->platform_name(); }
|
||||
PjRtStreamExecutorClient* client() const override { return client_; }
|
||||
bool IsEmptyTuple() const {
|
||||
return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0;
|
||||
return on_device_shape_.IsTuple() &&
|
||||
on_device_shape_.tuple_shapes_size() == 0;
|
||||
}
|
||||
|
||||
int64 OnDeviceSizeInBytes() const override;
|
||||
@ -603,7 +603,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer {
|
||||
std::shared_ptr<TrackedDeviceBuffer> src_device_buffer);
|
||||
|
||||
PjRtStreamExecutorClient* const client_;
|
||||
const Shape on_host_shape_;
|
||||
const Shape on_device_shape_;
|
||||
PjRtStreamExecutorDevice* const device_;
|
||||
|
||||
|
@ -118,8 +118,8 @@ TrackedDeviceBuffer::FromScopedShapedBuffer(
|
||||
}
|
||||
|
||||
ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
|
||||
const Shape& on_host_shape, const Shape& on_device_shape) const {
|
||||
ShapedBuffer shaped_buffer(on_host_shape, on_device_shape, device_ordinal_);
|
||||
const Shape& on_device_shape) const {
|
||||
ShapedBuffer shaped_buffer(on_device_shape, device_ordinal_);
|
||||
ShapeTree<se::DeviceMemoryBase>::iterator iterator =
|
||||
shaped_buffer.buffers().begin();
|
||||
for (const se::DeviceMemoryBase& buf : device_memory_) {
|
||||
|
@ -137,11 +137,8 @@ class TrackedDeviceBuffer {
|
||||
absl::Span<const std::shared_ptr<BufferSequencingEvent>>
|
||||
definition_events);
|
||||
|
||||
// Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do
|
||||
// not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) ==
|
||||
// on_device_shape().
|
||||
ShapedBuffer AsShapedBuffer(const Shape& on_host_shape,
|
||||
const Shape& on_device_shape) const;
|
||||
// Builds a ShapedBuffer view onto the buffers of 'tree'.
|
||||
ShapedBuffer AsShapedBuffer(const Shape& on_device_shape) const;
|
||||
|
||||
// Adds the owned device buffers in order to 'iterator'. Used to add the
|
||||
// buffers to an ExecutionInput. We require but do not verify that 'iterator'
|
||||
|
@ -65,13 +65,10 @@ TEST(TrackedDeviceBufferTest, AsShapedBuffer) {
|
||||
a_buffer->device_memory()[0], b_buffer->device_memory()[0],
|
||||
c_buffer->device_memory()[0]};
|
||||
ShapedBuffer shaped_a = a_buffer->AsShapedBuffer(
|
||||
a_shape,
|
||||
client->backend().transfer_manager()->HostShapeToDeviceShape(a_shape));
|
||||
ShapedBuffer shaped_b = b_buffer->AsShapedBuffer(
|
||||
b_shape,
|
||||
client->backend().transfer_manager()->HostShapeToDeviceShape(b_shape));
|
||||
ShapedBuffer shaped_c = c_buffer->AsShapedBuffer(
|
||||
c_shape,
|
||||
client->backend().transfer_manager()->HostShapeToDeviceShape(c_shape));
|
||||
auto expected_it = expected_buffer_sequence.begin();
|
||||
for (auto it = shaped_a.buffers().begin(); it != shaped_a.buffers().end();
|
||||
|
@ -295,15 +295,15 @@ StatusOr<py::capsule> BufferToDLPackManagedTensor(py::handle py_buffer,
|
||||
pack->tensor.deleter = DLPackTensorDeleter;
|
||||
TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device()));
|
||||
dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id();
|
||||
dt.ndim = buffer->buffer()->on_host_shape().dimensions_size();
|
||||
dt.ndim = buffer->buffer()->on_device_shape().dimensions_size();
|
||||
TF_ASSIGN_OR_RETURN(dt.dtype,
|
||||
PrimitiveTypeToDLDataType(
|
||||
buffer->buffer()->on_host_shape().element_type()));
|
||||
buffer->buffer()->on_device_shape().element_type()));
|
||||
|
||||
pack->shape =
|
||||
std::vector<int64>(buffer->buffer()->on_host_shape().dimensions().begin(),
|
||||
buffer->buffer()->on_host_shape().dimensions().end());
|
||||
pack->strides = StridesForShape(buffer->buffer()->on_host_shape());
|
||||
pack->shape = std::vector<int64>(
|
||||
buffer->buffer()->on_device_shape().dimensions().begin(),
|
||||
buffer->buffer()->on_device_shape().dimensions().end());
|
||||
pack->strides = StridesForShape(buffer->buffer()->on_device_shape());
|
||||
dt.shape = reinterpret_cast<std::int64_t*>(pack->shape.data());
|
||||
dt.strides = reinterpret_cast<std::int64_t*>(pack->strides.data());
|
||||
dt.byte_offset = 0;
|
||||
|
@ -438,8 +438,8 @@ xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
|
||||
[](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
|
||||
xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(h);
|
||||
bool weak_type = py::cast<py::bool_>(h.attr("aval").attr("weak_type"));
|
||||
return ArgSignature(buffer->buffer()->on_host_shape().element_type(),
|
||||
buffer->buffer()->on_host_shape().dimensions(),
|
||||
return ArgSignature(buffer->buffer()->on_device_shape().element_type(),
|
||||
buffer->buffer()->on_device_shape().dimensions(),
|
||||
weak_type);
|
||||
};
|
||||
(*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] = buffer_handler;
|
||||
@ -1015,8 +1015,9 @@ xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
|
||||
keep_alive.emplace_back(std::move(on_device.owned_buffer));
|
||||
}
|
||||
|
||||
ArgSignature sig(buffer->on_host_shape().element_type(),
|
||||
buffer->on_host_shape().dimensions(), on_device.weak_type);
|
||||
ArgSignature sig(buffer->on_device_shape().element_type(),
|
||||
buffer->on_device_shape().dimensions(),
|
||||
on_device.weak_type);
|
||||
arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
|
||||
}
|
||||
return xla::Status::OK();
|
||||
|
@ -56,11 +56,11 @@ PyBuffer::~PyBuffer() {
|
||||
}
|
||||
|
||||
pybind11::tuple PyBuffer::python_shape() const {
|
||||
return IntSpanToTuple(buffer()->on_host_shape().dimensions());
|
||||
return IntSpanToTuple(buffer()->on_device_shape().dimensions());
|
||||
}
|
||||
|
||||
pybind11::dtype PyBuffer::python_dtype() const {
|
||||
PrimitiveType primitive = buffer()->on_host_shape().element_type();
|
||||
PrimitiveType primitive = buffer()->on_device_shape().element_type();
|
||||
return PrimitiveTypeToDtype(primitive).ValueOrDie();
|
||||
}
|
||||
|
||||
@ -91,7 +91,8 @@ Status PyBuffer::BlockHostUntilReady() {
|
||||
Status PyBuffer::CopyToHostAsync() {
|
||||
if (!buffer_->IsOnCpu() && !host_value_) {
|
||||
host_value_ = std::make_shared<HostValue>();
|
||||
host_value_->value = std::make_shared<Literal>(buffer_->on_host_shape());
|
||||
host_value_->value = std::make_shared<Literal>(
|
||||
ShapeUtil::DeviceShapeToHostShape(buffer_->on_device_shape()));
|
||||
buffer_->ToLiteral(host_value_->value.get(),
|
||||
[host_value{host_value_}](Status status) {
|
||||
host_value->status = std::move(status);
|
||||
@ -110,7 +111,7 @@ StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
|
||||
if (buffer_->IsOnCpu()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
py::dtype dtype,
|
||||
PrimitiveTypeToDtype(buffer_->on_host_shape().element_type()));
|
||||
PrimitiveTypeToDtype(buffer_->on_device_shape().element_type()));
|
||||
// Objects that must be kept alive while the array is alive.
|
||||
struct Hold {
|
||||
py::object buffer;
|
||||
@ -124,8 +125,8 @@ StatusOr<pybind11::object> PyBuffer::AsNumPyArray(py::handle this_obj) {
|
||||
void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer();
|
||||
py::capsule hold_capsule(hold.release(),
|
||||
[](void* h) { delete static_cast<Hold*>(h); });
|
||||
py::array array(dtype, buffer_->on_host_shape().dimensions(),
|
||||
ByteStridesForShape(buffer_->on_host_shape()), data,
|
||||
py::array array(dtype, buffer_->on_device_shape().dimensions(),
|
||||
ByteStridesForShape(buffer_->on_device_shape()), data,
|
||||
hold_capsule);
|
||||
array.attr("flags").attr("writeable") = Py_False;
|
||||
{
|
||||
@ -171,18 +172,18 @@ StatusOr<py::dict> PyBuffer::CudaArrayInterface() const {
|
||||
return InvalidArgument(
|
||||
"__cuda_array_interface__ is only defined for array buffers.");
|
||||
}
|
||||
if (buffer_->on_host_shape().element_type() == BF16) {
|
||||
if (buffer_->on_device_shape().element_type() == BF16) {
|
||||
return InvalidArgument(
|
||||
"__cuda_array_interface__ is not supported for bfloat16 buffers.");
|
||||
}
|
||||
TF_RET_CHECK(
|
||||
LayoutUtil::IsMonotonicWithDim0Major(buffer_->on_host_shape().layout()));
|
||||
TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(
|
||||
buffer_->on_device_shape().layout()));
|
||||
|
||||
py::dict result;
|
||||
result["shape"] = IntSpanToTuple(buffer_->on_host_shape().dimensions());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
py::str typestr,
|
||||
TypeDescriptorForPrimitiveType(buffer_->on_host_shape().element_type()));
|
||||
result["shape"] = IntSpanToTuple(buffer_->on_device_shape().dimensions());
|
||||
TF_ASSIGN_OR_RETURN(py::str typestr,
|
||||
TypeDescriptorForPrimitiveType(
|
||||
buffer_->on_device_shape().element_type()));
|
||||
result["typestr"] = std::move(typestr);
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtBuffer::ExternalReferenceHold>
|
||||
external_reference_hold,
|
||||
@ -235,7 +236,7 @@ int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) {
|
||||
// If we allowed exports of formatted BF16 buffers, consumers would get
|
||||
// confused about the type because there is no way to describe BF16 to
|
||||
// Python.
|
||||
if (buffer.on_host_shape().element_type() == BF16 &&
|
||||
if (buffer.on_device_shape().element_type() == BF16 &&
|
||||
((flags & PyBUF_FORMAT) == PyBUF_FORMAT)) {
|
||||
return InvalidArgument(
|
||||
"bfloat16 buffer format not supported by Python buffer protocol.");
|
||||
@ -249,7 +250,7 @@ int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) {
|
||||
if (buffer.IsDeleted()) {
|
||||
return InvalidArgument("Deleted buffer used in buffer protocol.");
|
||||
}
|
||||
const Shape& shape = buffer.on_host_shape();
|
||||
const Shape& shape = buffer.on_device_shape();
|
||||
if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS ||
|
||||
(flags & PyBUF_STRIDES) == PyBUF_ND) &&
|
||||
!LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) {
|
||||
|
@ -79,7 +79,7 @@ class PyBuffer : public DeviceArrayBase {
|
||||
Status BlockHostUntilReady();
|
||||
Status CopyToHostAsync();
|
||||
|
||||
const Shape& shape() { return buffer_->on_host_shape(); }
|
||||
const Shape& shape() { return buffer_->on_device_shape(); }
|
||||
|
||||
StatusOr<std::uintptr_t> UnsafeBufferPointer() const;
|
||||
|
||||
@ -93,10 +93,10 @@ class PyBuffer : public DeviceArrayBase {
|
||||
Traceback* traceback() { return traceback_.get(); }
|
||||
|
||||
// Returns the size (i.e. number of elements) of the (host) numpy array.
|
||||
int64 size() { return ShapeUtil::ElementsIn(buffer()->on_host_shape()); }
|
||||
int64 size() { return ShapeUtil::ElementsIn(buffer()->on_device_shape()); }
|
||||
|
||||
// Returns the number of dimensions of the (host) numpy array.
|
||||
int ndim() const { return buffer()->on_host_shape().dimensions_size(); }
|
||||
int ndim() const { return buffer()->on_device_shape().dimensions_size(); }
|
||||
|
||||
pybind11::tuple python_shape() const;
|
||||
pybind11::dtype python_dtype() const;
|
||||
|
@ -301,13 +301,13 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
"shape",
|
||||
[](const PyBuffer& pybuffer) -> pybind11::tuple {
|
||||
return IntSpanToTuple(
|
||||
pybuffer.buffer()->on_host_shape().dimensions());
|
||||
pybuffer.buffer()->on_device_shape().dimensions());
|
||||
})
|
||||
.def_property_readonly(
|
||||
"dtype",
|
||||
[](const PyBuffer& buffer) {
|
||||
PrimitiveType primitive =
|
||||
buffer.buffer()->on_host_shape().element_type();
|
||||
buffer.buffer()->on_device_shape().element_type();
|
||||
return PrimitiveTypeToDtype(primitive).ValueOrDie();
|
||||
})
|
||||
.def_property_readonly("size", &PyBuffer::size)
|
||||
|
Loading…
Reference in New Issue
Block a user