From b13ed5c4417059a51fd193ca02d0ed32e3a7ba62 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 28 Jan 2021 14:58:51 -0800 Subject: [PATCH] [PJRT] Remove on_host_shape from PjRtBuffer API. The on-host shape can be derived from the on-device shape, so there's no need for PJRT to maintain both. In many places that were consuming the host shape, it should be legal to consume the device shape. In a handful of places where I was not sure, I used xla::ShapeUtil::DeviceShapeToHostShape to construct the host shape again. Note: device shapes and host shapes are identical on CPU and GPU. PiperOrigin-RevId: 354401607 Change-Id: I38064cfaf8c1be908448d2a6131d47fad03e2ddf --- tensorflow/compiler/xla/pjrt/pjrt_client.h | 4 +- .../xla/pjrt/pjrt_stream_executor_client.cc | 58 ++++++++----------- .../xla/pjrt/pjrt_stream_executor_client.h | 7 +-- .../xla/pjrt/tracked_device_buffer.cc | 4 +- .../compiler/xla/pjrt/tracked_device_buffer.h | 7 +-- .../xla/pjrt/tracked_device_buffer_test.cc | 3 - tensorflow/compiler/xla/python/dlpack.cc | 12 ++-- tensorflow/compiler/xla/python/jax_jit.cc | 9 +-- tensorflow/compiler/xla/python/py_buffer.cc | 31 +++++----- tensorflow/compiler/xla/python/py_buffer.h | 6 +- tensorflow/compiler/xla/python/xla.cc | 4 +- 11 files changed, 65 insertions(+), 80 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 3d124a0e852..74597c1aaa6 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -275,7 +275,6 @@ class PjRtBuffer { public: virtual ~PjRtBuffer() = default; - virtual const Shape& on_host_shape() const = 0; virtual const Shape& on_device_shape() const = 0; virtual PjRtDevice* device() const = 0; virtual PjRtClient* client() const = 0; @@ -319,7 +318,8 @@ class PjRtBuffer { // Convenience synchronous overload that allocates a literal with a default // layout. StatusOr> ToLiteral() { - auto literal = std::make_shared(on_host_shape()); + auto literal = std::make_shared( + ShapeUtil::DeviceShapeToHostShape(on_device_shape())); TF_RETURN_IF_ERROR(ToLiteral(literal.get())); return literal; } diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index 30506a11a4b..54b9b7c3314 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -359,6 +359,7 @@ void RecordUsage(PjRtStreamExecutorBuffer::ScopedHold device_buffer, // // The caller may optionally provide a definition event to be recorded in // the buffer. +// TODO(phawkins): replace on_host_shape here with on_device_shape. StatusOr> AllocateDestinationBuffer( const Shape& on_host_shape, PjRtDevice* device, LocalDeviceState* local_device, se::Stream* copy_stream, @@ -453,8 +454,7 @@ StatusOr> AllocateDestinationBuffer( definition_events); auto py_buffer = absl::make_unique( - on_host_shape, on_device_shape, std::move(dst_device_buffer), client, - device); + on_device_shape, std::move(dst_device_buffer), client, device); if (on_device_shape.IsTuple()) { // Add a usage hold for the tuple table write and immediately convert it to @@ -670,7 +670,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( definition_events, std::move(on_delete_callback)); return std::unique_ptr( std::make_unique( - shape, shape, std::move(device_buffer), this, device)); + shape, std::move(device_buffer), this, device)); } } @@ -719,7 +719,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( auto transfer_h2d = [local_client = client(), transfer_manager, local_device, data, size, movable_device_buffer{device_buffer.ToClosure()}, shape, - py_buffer{py_buffer.get()}, compact_shape, + py_buffer{py_buffer.get()}, on_device_shape{py_buffer->on_device_shape()}, staging_buffer{std::move(staging_buffer)}, on_done_with_host_buffer{ @@ -732,8 +732,7 @@ PjRtStreamExecutorClient::BufferFromHostBuffer( // memory that has already been allocated, and a possible Event // allocation. - ShapedBuffer buffer = - device_buffer->AsShapedBuffer(compact_shape, on_device_shape); + ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape); // If applicable on the backend, stage the transfer via host memory // allocated via the host_memory_allocator. On GPU, this is pinned // memory. @@ -846,7 +845,7 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal, // put the transfer into the calling thread for small literals. auto transfer_h2d = [local_client = client(), transfer_manager, local_device, movable_device_buffer{device_buffer.ToClosure()}, - literal, py_buffer{py_buffer.get()}, compact_shape, + literal, py_buffer{py_buffer.get()}, on_device_shape{py_buffer->on_device_shape()}]() { PjRtStreamExecutorBuffer::ScopedHold device_buffer(movable_device_buffer); // This function uses TF_CHECK_OK and ValueOrDie() since we have no way @@ -856,8 +855,7 @@ PjRtStreamExecutorClient::BufferFromHostLiteral(const LiteralSlice& literal, // allocation. se::Stream* h2d_stream = local_device->host_to_device_stream(); - ShapedBuffer buffer = - device_buffer->AsShapedBuffer(compact_shape, on_device_shape); + ShapedBuffer buffer = device_buffer->AsShapedBuffer(on_device_shape); TF_CHECK_OK(transfer_manager->TransferLiteralToDeviceAsync( h2d_stream, literal, buffer)); @@ -924,7 +922,7 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer( std::initializer_list{buffer}, definition_events, std::move(on_delete_callback)); return std::unique_ptr(std::make_unique( - shape, shape, std::move(device_buffer), this, device)); + shape, std::move(device_buffer), this, device)); } // Transfer the given literal to the infeed queue of the given local device. @@ -955,11 +953,9 @@ StatusOr PjRtStreamExecutorClient::LookupAddressableDevice( } PjRtStreamExecutorBuffer::PjRtStreamExecutorBuffer( - Shape on_host_shape, Shape on_device_shape, - std::shared_ptr device_buffer, PjRtClient* client, - PjRtDevice* device) + Shape on_device_shape, std::shared_ptr device_buffer, + PjRtClient* client, PjRtDevice* device) : client_(tensorflow::down_cast(client)), - on_host_shape_(std::move(on_host_shape)), on_device_shape_(std::move(on_device_shape)), device_(tensorflow::down_cast(device)), device_buffer_(std::move(device_buffer)), @@ -1195,8 +1191,7 @@ void PjRtStreamExecutorBuffer::ToLiteral(MutableLiteralBase* literal, } WaitForBufferDefinitionEventsOnStream(*device_buffer, stream); - ShapedBuffer shaped_buffer = - device_buffer->AsShapedBuffer(literal->shape(), on_device_shape_); + ShapedBuffer shaped_buffer = device_buffer->AsShapedBuffer(on_device_shape_); StatusOr event_or = local_device->event_pool().AllocateEvent(stream->parent()); if (!event_or.ok()) { @@ -1233,7 +1228,7 @@ StatusOr PjRtStreamExecutorBuffer::AsShapedBuffer() const { return InvalidArgument( "Attempted to fetch value of invalid/deleted buffer."); } - return device_buffer_->AsShapedBuffer(on_host_shape_, on_device_shape_); + return device_buffer_->AsShapedBuffer(on_device_shape_); } PjRtStreamExecutorBuffer::ScopedHold @@ -1257,11 +1252,11 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( PjRtDevice* dst_device, LocalDeviceState* dst_local_device, LocalDeviceState* transfer_local_device, se::Stream* transfer_stream, std::shared_ptr src_device_buffer) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr py_buffer, - AllocateDestinationBuffer(on_host_shape_, dst_device, dst_local_device, - transfer_stream, - /*is_uninitialized_create=*/false, client_)); + TF_ASSIGN_OR_RETURN(std::unique_ptr py_buffer, + AllocateDestinationBuffer( + ShapeUtil::DeviceShapeToHostShape(on_device_shape_), + dst_device, dst_local_device, transfer_stream, + /*is_uninitialized_create=*/false, client_)); TF_ASSIGN_OR_RETURN(ShapedBuffer src_buffer, AsShapedBuffer()); @@ -1269,8 +1264,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper( ScopedHold dst_device_buffer(py_buffer->GetBufferWithUsageHold()); CHECK(dst_device_buffer.ok()); - ShapedBuffer dst_buffer = - dst_device_buffer->AsShapedBuffer(on_host_shape_, on_device_shape_); + ShapedBuffer dst_buffer = dst_device_buffer->AsShapedBuffer(on_device_shape_); // Copy the leaf buffers. StatusOr> copy_event_or = @@ -1451,10 +1445,8 @@ StatusOr MakeTupleHelper( host_shapes.reserve(py_buffers.size()); device_shapes.reserve(py_buffers.size()); for (const PjRtBuffer* buffer : py_buffers) { - host_shapes.push_back(buffer->on_host_shape()); device_shapes.push_back(buffer->on_device_shape()); } - Shape on_host_shape = ShapeUtil::MakeTupleShape(host_shapes); Shape on_device_shape = ShapeUtil::MakeTupleShape(device_shapes); se::DeviceMemoryAllocator* allocator = @@ -1469,7 +1461,7 @@ StatusOr MakeTupleHelper( se::OwningDeviceMemory root_table_memory, allocator->Allocate( device_ordinal, - transfer_manager->GetByteSizeRequirement(on_host_shape))); + transfer_manager->GetByteSizeRequirement(on_device_shape))); if (local_device->allocation_model() == LocalDeviceState::kComputeSynchronized) { @@ -1479,7 +1471,7 @@ StatusOr MakeTupleHelper( local_device->compute_stream()->parent(), root_table_memory.cref())); } - ExecutionInput execution_input(on_device_shape, on_host_shape); + ExecutionInput execution_input(on_device_shape); ShapeTree::iterator input_iterator = execution_input.MutableBuffers()->begin(); ShapeTree::iterator iterator_end = @@ -1521,8 +1513,7 @@ std::unique_ptr OutputBufferHelper( TrackedDeviceBuffer::FromScopedShapedBuffer(result_buffer, {definition_event}); auto pjrt_buffer = absl::make_unique( - result_buffer->on_host_shape(), result_buffer->on_device_shape(), - std::move(out_buffer), client, device); + result_buffer->on_device_shape(), std::move(out_buffer), client, device); RecordUsage(pjrt_buffer->GetBufferWithUsageHold(), local_device, local_device, definition_event, local_device->compute_stream(), /*prefer_to_retain_reference=*/false); @@ -1621,8 +1612,7 @@ PjRtStreamExecutorExecutable::MakeExecutionInputsAndWaitForEvents( PjRtBuffer* handle = argument_handles[i]; // Make an ExecutionInput from the device buffer. - execution_inputs.emplace_back(handle->on_device_shape(), - handle->on_host_shape()); + execution_inputs.emplace_back(handle->on_device_shape()); ExecutionInput& execution_input = execution_inputs.back(); ShapeTree::iterator input_iterator = execution_input.MutableBuffers()->begin(); @@ -1794,8 +1784,8 @@ PjRtStreamExecutorExecutable::MakeOutputBuffers( PjRtDevice* device) const { std::vector> outputs; LocalDeviceState* device_state = &(client_->device_state(device_ordinal)); - if (options.untuple_result && result_buffer.on_host_shape().IsTuple()) { - int tuple_count = result_buffer.on_host_shape().tuple_shapes_size(); + if (options.untuple_result && result_buffer.on_device_shape().IsTuple()) { + int tuple_count = result_buffer.on_device_shape().tuple_shapes_size(); outputs.reserve(tuple_count); // Take ownership of each of the output values, leaving only the root table // in result_buffer. diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h index 924066028eb..ff3b8802116 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h @@ -455,7 +455,7 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { StatusOr> buffer_or_; }; - PjRtStreamExecutorBuffer(Shape on_host_shape, Shape on_device_shape, + PjRtStreamExecutorBuffer(Shape on_device_shape, std::shared_ptr device_buffer, PjRtClient* client, PjRtDevice* device); ~PjRtStreamExecutorBuffer() override; @@ -465,14 +465,14 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { PjRtStreamExecutorBuffer& operator=(const PjRtStreamExecutorBuffer&) = delete; PjRtStreamExecutorBuffer& operator=(PjRtStreamExecutorBuffer&&) = delete; - const Shape& on_host_shape() const override { return on_host_shape_; } const Shape& on_device_shape() const override { return on_device_shape_; } PjRtStreamExecutorDevice* device() const override { return device_; } PjRtPlatformId platform_id() const { return client_->platform_id(); } absl::string_view platform_name() const { return client_->platform_name(); } PjRtStreamExecutorClient* client() const override { return client_; } bool IsEmptyTuple() const { - return on_host_shape_.IsTuple() && on_host_shape_.tuple_shapes_size() == 0; + return on_device_shape_.IsTuple() && + on_device_shape_.tuple_shapes_size() == 0; } int64 OnDeviceSizeInBytes() const override; @@ -603,7 +603,6 @@ class PjRtStreamExecutorBuffer : public PjRtBuffer { std::shared_ptr src_device_buffer); PjRtStreamExecutorClient* const client_; - const Shape on_host_shape_; const Shape on_device_shape_; PjRtStreamExecutorDevice* const device_; diff --git a/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc index cd7a37a117a..c6d0ed9a081 100644 --- a/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.cc @@ -118,8 +118,8 @@ TrackedDeviceBuffer::FromScopedShapedBuffer( } ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer( - const Shape& on_host_shape, const Shape& on_device_shape) const { - ShapedBuffer shaped_buffer(on_host_shape, on_device_shape, device_ordinal_); + const Shape& on_device_shape) const { + ShapedBuffer shaped_buffer(on_device_shape, device_ordinal_); ShapeTree::iterator iterator = shaped_buffer.buffers().begin(); for (const se::DeviceMemoryBase& buf : device_memory_) { diff --git a/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h index 1476dc2039e..ad61ed3a3d0 100644 --- a/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer.h @@ -137,11 +137,8 @@ class TrackedDeviceBuffer { absl::Span> definition_events); - // Builds a ShapedBuffer view onto the buffers of 'tree'. We require but do - // not verify that TransferManager::HostShapeToDeviceShape(on_host_shape) == - // on_device_shape(). - ShapedBuffer AsShapedBuffer(const Shape& on_host_shape, - const Shape& on_device_shape) const; + // Builds a ShapedBuffer view onto the buffers of 'tree'. + ShapedBuffer AsShapedBuffer(const Shape& on_device_shape) const; // Adds the owned device buffers in order to 'iterator'. Used to add the // buffers to an ExecutionInput. We require but do not verify that 'iterator' diff --git a/tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc b/tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc index ffeb7c002a0..1bc2ba40a83 100644 --- a/tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc +++ b/tensorflow/compiler/xla/pjrt/tracked_device_buffer_test.cc @@ -65,13 +65,10 @@ TEST(TrackedDeviceBufferTest, AsShapedBuffer) { a_buffer->device_memory()[0], b_buffer->device_memory()[0], c_buffer->device_memory()[0]}; ShapedBuffer shaped_a = a_buffer->AsShapedBuffer( - a_shape, client->backend().transfer_manager()->HostShapeToDeviceShape(a_shape)); ShapedBuffer shaped_b = b_buffer->AsShapedBuffer( - b_shape, client->backend().transfer_manager()->HostShapeToDeviceShape(b_shape)); ShapedBuffer shaped_c = c_buffer->AsShapedBuffer( - c_shape, client->backend().transfer_manager()->HostShapeToDeviceShape(c_shape)); auto expected_it = expected_buffer_sequence.begin(); for (auto it = shaped_a.buffers().begin(); it != shaped_a.buffers().end(); diff --git a/tensorflow/compiler/xla/python/dlpack.cc b/tensorflow/compiler/xla/python/dlpack.cc index 3edc94a769f..d92e9cac3b4 100644 --- a/tensorflow/compiler/xla/python/dlpack.cc +++ b/tensorflow/compiler/xla/python/dlpack.cc @@ -295,15 +295,15 @@ StatusOr BufferToDLPackManagedTensor(py::handle py_buffer, pack->tensor.deleter = DLPackTensorDeleter; TF_ASSIGN_OR_RETURN(dt.ctx, DLContextForDevice(*buffer->buffer()->device())); dt.ctx.device_id = buffer->buffer()->device()->local_hardware_id(); - dt.ndim = buffer->buffer()->on_host_shape().dimensions_size(); + dt.ndim = buffer->buffer()->on_device_shape().dimensions_size(); TF_ASSIGN_OR_RETURN(dt.dtype, PrimitiveTypeToDLDataType( - buffer->buffer()->on_host_shape().element_type())); + buffer->buffer()->on_device_shape().element_type())); - pack->shape = - std::vector(buffer->buffer()->on_host_shape().dimensions().begin(), - buffer->buffer()->on_host_shape().dimensions().end()); - pack->strides = StridesForShape(buffer->buffer()->on_host_shape()); + pack->shape = std::vector( + buffer->buffer()->on_device_shape().dimensions().begin(), + buffer->buffer()->on_device_shape().dimensions().end()); + pack->strides = StridesForShape(buffer->buffer()->on_device_shape()); dt.shape = reinterpret_cast(pack->shape.data()); dt.strides = reinterpret_cast(pack->strides.data()); dt.byte_offset = 0; diff --git a/tensorflow/compiler/xla/python/jax_jit.cc b/tensorflow/compiler/xla/python/jax_jit.cc index bf08f9242e1..232716b3c74 100644 --- a/tensorflow/compiler/xla/python/jax_jit.cc +++ b/tensorflow/compiler/xla/python/jax_jit.cc @@ -438,8 +438,8 @@ xla::StatusOr ArgSignatureOfValue(pybind11::handle arg, [](py::handle h, bool jax_enable_x64) -> xla::StatusOr { xla::PyBuffer* buffer = py::cast(h); bool weak_type = py::cast(h.attr("aval").attr("weak_type")); - return ArgSignature(buffer->buffer()->on_host_shape().element_type(), - buffer->buffer()->on_host_shape().dimensions(), + return ArgSignature(buffer->buffer()->on_device_shape().element_type(), + buffer->buffer()->on_device_shape().dimensions(), weak_type); }; (*p)[py::type::handle_of().ptr()] = buffer_handler; @@ -1015,8 +1015,9 @@ xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient, keep_alive.emplace_back(std::move(on_device.owned_buffer)); } - ArgSignature sig(buffer->on_host_shape().element_type(), - buffer->on_host_shape().dimensions(), on_device.weak_type); + ArgSignature sig(buffer->on_device_shape().element_type(), + buffer->on_device_shape().dimensions(), + on_device.weak_type); arguments.signature.dynamic_args_signatures.push_back(std::move(sig)); } return xla::Status::OK(); diff --git a/tensorflow/compiler/xla/python/py_buffer.cc b/tensorflow/compiler/xla/python/py_buffer.cc index 127521d523e..4f2ef95d2ef 100644 --- a/tensorflow/compiler/xla/python/py_buffer.cc +++ b/tensorflow/compiler/xla/python/py_buffer.cc @@ -56,11 +56,11 @@ PyBuffer::~PyBuffer() { } pybind11::tuple PyBuffer::python_shape() const { - return IntSpanToTuple(buffer()->on_host_shape().dimensions()); + return IntSpanToTuple(buffer()->on_device_shape().dimensions()); } pybind11::dtype PyBuffer::python_dtype() const { - PrimitiveType primitive = buffer()->on_host_shape().element_type(); + PrimitiveType primitive = buffer()->on_device_shape().element_type(); return PrimitiveTypeToDtype(primitive).ValueOrDie(); } @@ -91,7 +91,8 @@ Status PyBuffer::BlockHostUntilReady() { Status PyBuffer::CopyToHostAsync() { if (!buffer_->IsOnCpu() && !host_value_) { host_value_ = std::make_shared(); - host_value_->value = std::make_shared(buffer_->on_host_shape()); + host_value_->value = std::make_shared( + ShapeUtil::DeviceShapeToHostShape(buffer_->on_device_shape())); buffer_->ToLiteral(host_value_->value.get(), [host_value{host_value_}](Status status) { host_value->status = std::move(status); @@ -110,7 +111,7 @@ StatusOr PyBuffer::AsNumPyArray(py::handle this_obj) { if (buffer_->IsOnCpu()) { TF_ASSIGN_OR_RETURN( py::dtype dtype, - PrimitiveTypeToDtype(buffer_->on_host_shape().element_type())); + PrimitiveTypeToDtype(buffer_->on_device_shape().element_type())); // Objects that must be kept alive while the array is alive. struct Hold { py::object buffer; @@ -124,8 +125,8 @@ StatusOr PyBuffer::AsNumPyArray(py::handle this_obj) { void* data = hold->external_reference_hold->OpaqueDeviceMemoryDataPointer(); py::capsule hold_capsule(hold.release(), [](void* h) { delete static_cast(h); }); - py::array array(dtype, buffer_->on_host_shape().dimensions(), - ByteStridesForShape(buffer_->on_host_shape()), data, + py::array array(dtype, buffer_->on_device_shape().dimensions(), + ByteStridesForShape(buffer_->on_device_shape()), data, hold_capsule); array.attr("flags").attr("writeable") = Py_False; { @@ -171,18 +172,18 @@ StatusOr PyBuffer::CudaArrayInterface() const { return InvalidArgument( "__cuda_array_interface__ is only defined for array buffers."); } - if (buffer_->on_host_shape().element_type() == BF16) { + if (buffer_->on_device_shape().element_type() == BF16) { return InvalidArgument( "__cuda_array_interface__ is not supported for bfloat16 buffers."); } - TF_RET_CHECK( - LayoutUtil::IsMonotonicWithDim0Major(buffer_->on_host_shape().layout())); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major( + buffer_->on_device_shape().layout())); py::dict result; - result["shape"] = IntSpanToTuple(buffer_->on_host_shape().dimensions()); - TF_ASSIGN_OR_RETURN( - py::str typestr, - TypeDescriptorForPrimitiveType(buffer_->on_host_shape().element_type())); + result["shape"] = IntSpanToTuple(buffer_->on_device_shape().dimensions()); + TF_ASSIGN_OR_RETURN(py::str typestr, + TypeDescriptorForPrimitiveType( + buffer_->on_device_shape().element_type())); result["typestr"] = std::move(typestr); TF_ASSIGN_OR_RETURN(std::unique_ptr external_reference_hold, @@ -235,7 +236,7 @@ int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { // If we allowed exports of formatted BF16 buffers, consumers would get // confused about the type because there is no way to describe BF16 to // Python. - if (buffer.on_host_shape().element_type() == BF16 && + if (buffer.on_device_shape().element_type() == BF16 && ((flags & PyBUF_FORMAT) == PyBUF_FORMAT)) { return InvalidArgument( "bfloat16 buffer format not supported by Python buffer protocol."); @@ -249,7 +250,7 @@ int PjRtBufferGetBuffer(PyObject* exporter, Py_buffer* view, int flags) { if (buffer.IsDeleted()) { return InvalidArgument("Deleted buffer used in buffer protocol."); } - const Shape& shape = buffer.on_host_shape(); + const Shape& shape = buffer.on_device_shape(); if (((flags & PyBUF_C_CONTIGUOUS) == PyBUF_C_CONTIGUOUS || (flags & PyBUF_STRIDES) == PyBUF_ND) && !LayoutUtil::IsMonotonicWithDim0Major(shape.layout())) { diff --git a/tensorflow/compiler/xla/python/py_buffer.h b/tensorflow/compiler/xla/python/py_buffer.h index 13ac891d30b..f412c4abbec 100644 --- a/tensorflow/compiler/xla/python/py_buffer.h +++ b/tensorflow/compiler/xla/python/py_buffer.h @@ -79,7 +79,7 @@ class PyBuffer : public DeviceArrayBase { Status BlockHostUntilReady(); Status CopyToHostAsync(); - const Shape& shape() { return buffer_->on_host_shape(); } + const Shape& shape() { return buffer_->on_device_shape(); } StatusOr UnsafeBufferPointer() const; @@ -93,10 +93,10 @@ class PyBuffer : public DeviceArrayBase { Traceback* traceback() { return traceback_.get(); } // Returns the size (i.e. number of elements) of the (host) numpy array. - int64 size() { return ShapeUtil::ElementsIn(buffer()->on_host_shape()); } + int64 size() { return ShapeUtil::ElementsIn(buffer()->on_device_shape()); } // Returns the number of dimensions of the (host) numpy array. - int ndim() const { return buffer()->on_host_shape().dimensions_size(); } + int ndim() const { return buffer()->on_device_shape().dimensions_size(); } pybind11::tuple python_shape() const; pybind11::dtype python_dtype() const; diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index 530132f3e25..e8c4b26489f 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -301,13 +301,13 @@ PYBIND11_MODULE(xla_extension, m) { "shape", [](const PyBuffer& pybuffer) -> pybind11::tuple { return IntSpanToTuple( - pybuffer.buffer()->on_host_shape().dimensions()); + pybuffer.buffer()->on_device_shape().dimensions()); }) .def_property_readonly( "dtype", [](const PyBuffer& buffer) { PrimitiveType primitive = - buffer.buffer()->on_host_shape().element_type(); + buffer.buffer()->on_device_shape().element_type(); return PrimitiveTypeToDtype(primitive).ValueOrDie(); }) .def_property_readonly("size", &PyBuffer::size)