Allow PJRT infeed/outfeed methods to be non-const.

PiperOrigin-RevId: 355321939
Change-Id: Ifc081fd55121d328631ce4fce87af1f7a9452b2e
This commit is contained in:
Qiao Zhang 2021-02-02 21:53:22 -08:00 committed by TensorFlower Gardener
parent 1df8d8fd63
commit ddb089b72f
6 changed files with 13 additions and 15 deletions

View File

@ -91,10 +91,10 @@ class PjRtDevice {
virtual std::string DebugString() const = 0;
// Transfer the given literal to the infeed queue.
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
virtual Status TransferToInfeed(const LiteralSlice& literal) = 0;
// Transfer and return a value of the given shape from the outfeed queue.
virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) const = 0;
virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) = 0;
};
// Forward declaration.

View File

@ -961,8 +961,7 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
}
// Transfer the given literal to the infeed queue of the given local device.
Status PjRtStreamExecutorDevice::TransferToInfeed(
const LiteralSlice& literal) const {
Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) {
// Only support infeed to local device.
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferToInfeedLocal(
@ -970,7 +969,7 @@ Status PjRtStreamExecutorDevice::TransferToInfeed(
}
Status PjRtStreamExecutorDevice::TransferFromOutfeed(
MutableBorrowingLiteral literal) const {
MutableBorrowingLiteral literal) {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferFromOutfeedLocal(
local_device->device_ordinal(), literal);

View File

@ -105,9 +105,9 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
std::string DebugString() const override;
Status TransferToInfeed(const LiteralSlice& literal) const override;
Status TransferToInfeed(const LiteralSlice& literal) override;
Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override;
Status TransferFromOutfeed(MutableBorrowingLiteral literal) override;
private:
const int id_;

View File

@ -188,8 +188,8 @@ class OutfeedReceiverImpl {
Status SendShutdownOutfeedHeader(int device_idx);
// Receives a raw Literal from a device outfeed.
StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(
const PjRtDevice* device, const Shape& shape);
StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(PjRtDevice* device,
const Shape& shape);
// Enqueues received data in the callbaback queue.
void EnqueueReceivedData(std::unique_ptr<OutfeedData> received)
@ -340,7 +340,7 @@ void OutfeedReceiverImpl::EnqueueReceivedData(
}
StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
const PjRtDevice* device, const Shape& shape) {
PjRtDevice* device, const Shape& shape) {
auto literal = std::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
return literal;

View File

@ -65,11 +65,11 @@ class TpuDevice : public PjRtDevice {
absl::string_view device_kind() const override { return device_kind_; }
Status TransferToInfeed(const LiteralSlice& literal) const override {
Status TransferToInfeed(const LiteralSlice& literal) override {
return Unimplemented("Infeed not yet implemented via this API");
}
Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override {
Status TransferFromOutfeed(MutableBorrowingLiteral literal) override {
return Unimplemented("Outfeed not yet implemented via this API");
}

View File

@ -142,14 +142,13 @@ PYBIND11_MODULE(xla_extension, m) {
[](const ClientAndPtr<PjRtDevice>& device) { return device.client; })
.def("__str__", &PjRtDevice::DebugString)
.def("transfer_to_infeed",
[](const PjRtDevice& device, const LiteralSlice& literal) {
[](PjRtDevice& device, const LiteralSlice& literal) {
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
return device.TransferToInfeed(literal);
})
.def("transfer_from_outfeed",
[](const PjRtDevice& device,
const Shape& shape) -> StatusOr<py::object> {
[](PjRtDevice& device, const Shape& shape) -> StatusOr<py::object> {
GlobalPyRefManager()->CollectGarbage();
std::shared_ptr<Literal> literal;
{