From ddb089b72f7313f2ced3de84fde78f3184b76e7c Mon Sep 17 00:00:00 2001 From: Qiao Zhang Date: Tue, 2 Feb 2021 21:53:22 -0800 Subject: [PATCH] Allow PJRT infeed/outfeed methods to be non-const. PiperOrigin-RevId: 355321939 Change-Id: Ifc081fd55121d328631ce4fce87af1f7a9452b2e --- tensorflow/compiler/xla/pjrt/pjrt_client.h | 4 ++-- tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc | 5 ++--- tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h | 4 ++-- tensorflow/compiler/xla/python/outfeed_receiver.cc | 6 +++--- .../compiler/xla/python/tpu_driver/client/tpu_client.h | 4 ++-- tensorflow/compiler/xla/python/xla.cc | 5 ++--- 6 files changed, 13 insertions(+), 15 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index 75098e22e24..732ee82cfb0 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -91,10 +91,10 @@ class PjRtDevice { virtual std::string DebugString() const = 0; // Transfer the given literal to the infeed queue. - virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0; + virtual Status TransferToInfeed(const LiteralSlice& literal) = 0; // Transfer and return a value of the given shape from the outfeed queue. - virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) const = 0; + virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) = 0; }; // Forward declaration. diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index 58a85432777..e7ce0644878 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -961,8 +961,7 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer( } // Transfer the given literal to the infeed queue of the given local device. -Status PjRtStreamExecutorDevice::TransferToInfeed( - const LiteralSlice& literal) const { +Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) { // Only support infeed to local device. TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); return local_device->client()->TransferToInfeedLocal( @@ -970,7 +969,7 @@ Status PjRtStreamExecutorDevice::TransferToInfeed( } Status PjRtStreamExecutorDevice::TransferFromOutfeed( - MutableBorrowingLiteral literal) const { + MutableBorrowingLiteral literal) { TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); return local_device->client()->TransferFromOutfeedLocal( local_device->device_ordinal(), literal); diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h index ccc7057ebf3..34f819dc3c5 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h @@ -105,9 +105,9 @@ class PjRtStreamExecutorDevice : public PjRtDevice { std::string DebugString() const override; - Status TransferToInfeed(const LiteralSlice& literal) const override; + Status TransferToInfeed(const LiteralSlice& literal) override; - Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override; + Status TransferFromOutfeed(MutableBorrowingLiteral literal) override; private: const int id_; diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc index 37c5f357290..9396882262e 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -188,8 +188,8 @@ class OutfeedReceiverImpl { Status SendShutdownOutfeedHeader(int device_idx); // Receives a raw Literal from a device outfeed. - StatusOr> ReceiveRawFromOutfeed( - const PjRtDevice* device, const Shape& shape); + StatusOr> ReceiveRawFromOutfeed(PjRtDevice* device, + const Shape& shape); // Enqueues received data in the callbaback queue. void EnqueueReceivedData(std::unique_ptr received) @@ -340,7 +340,7 @@ void OutfeedReceiverImpl::EnqueueReceivedData( } StatusOr> OutfeedReceiverImpl::ReceiveRawFromOutfeed( - const PjRtDevice* device, const Shape& shape) { + PjRtDevice* device, const Shape& shape) { auto literal = std::make_unique(shape); TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get())); return literal; diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index decd4f852a6..6ae35903399 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -65,11 +65,11 @@ class TpuDevice : public PjRtDevice { absl::string_view device_kind() const override { return device_kind_; } - Status TransferToInfeed(const LiteralSlice& literal) const override { + Status TransferToInfeed(const LiteralSlice& literal) override { return Unimplemented("Infeed not yet implemented via this API"); } - Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override { + Status TransferFromOutfeed(MutableBorrowingLiteral literal) override { return Unimplemented("Outfeed not yet implemented via this API"); } diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index d06374df39a..0c04ce9ae61 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -142,14 +142,13 @@ PYBIND11_MODULE(xla_extension, m) { [](const ClientAndPtr& device) { return device.client; }) .def("__str__", &PjRtDevice::DebugString) .def("transfer_to_infeed", - [](const PjRtDevice& device, const LiteralSlice& literal) { + [](PjRtDevice& device, const LiteralSlice& literal) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; return device.TransferToInfeed(literal); }) .def("transfer_from_outfeed", - [](const PjRtDevice& device, - const Shape& shape) -> StatusOr { + [](PjRtDevice& device, const Shape& shape) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); std::shared_ptr literal; {