From c98ac56bc2929d08578a08e9b1b8e16477d28b08 Mon Sep 17 00:00:00 2001 From: Qiao Zhang Date: Tue, 20 Oct 2020 12:01:17 -0700 Subject: [PATCH] Refactor PJRT. - Add TransferToInfeed and TransferFromOutfeed to PjRtDevice's methods. PiperOrigin-RevId: 338105874 Change-Id: I629b2efa27394bc99b26371c3de779b1104eea4f --- tensorflow/compiler/xla/pjrt/pjrt_client.cc | 14 ++++++++++++++ tensorflow/compiler/xla/pjrt/pjrt_client.h | 7 +++++++ tensorflow/compiler/xla/python/xla.cc | 13 +++---------- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_client.cc index 41afcb01511..8752b6260f6 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.cc @@ -851,6 +851,20 @@ void PjRtClient::MakeCrossHostReceiveBuffers( EnqueueCrossHostReceive(std::move(buffers), std::move(notifier)); } +// Transfer the given literal to the infeed queue of the given local device. +Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const { + // Only support infeed to local device. + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); + return local_device->client()->TransferToInfeedLocal( + literal, local_device->device_ordinal()); +} + +StatusOr PjRtDevice::TransferFromOutfeed(const Shape& shape) const { + TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); + return local_device->client()->TransferFromOutfeedLocal( + shape, local_device->device_ordinal()); +} + PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, std::shared_ptr device_buffer, PjRtClient* client, PjRtDevice* device) diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index c10470f7d60..3331bf890cc 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -94,6 +94,13 @@ class PjRtDevice { PjRtClient* client() const { return client_; } + // Transfer the given literal to the infeed queue of the given localdevice. + virtual Status TransferToInfeed(const LiteralSlice& literal) const; + + // Transfer and return a value of the given shape from the outfeed of the + // given device. + virtual StatusOr TransferFromOutfeed(const Shape& shape) const; + private: friend class PjRtClient; diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index b0948fab2b7..2101191be86 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -479,10 +479,7 @@ PYBIND11_MODULE(xla_extension, m) { [](const PjRtDevice& device, const LiteralSlice& literal) { GlobalPyRefManager()->CollectGarbage(); py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device.GetLocalDeviceState()); - return local_device->client()->TransferToInfeedLocal( - literal, local_device->device_ordinal()); + return device.TransferToInfeed(literal); }) .def("transfer_from_outfeed", [](const PjRtDevice& device, @@ -491,8 +488,6 @@ PYBIND11_MODULE(xla_extension, m) { std::shared_ptr literal_shared; { py::gil_scoped_release gil_release; - TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, - device.GetLocalDeviceState()); Shape shape_with_layout = shape; ShapeUtil::ForEachMutableSubshape( &shape_with_layout, [](Shape* subshape, const ShapeIndex&) { @@ -500,10 +495,8 @@ PYBIND11_MODULE(xla_extension, m) { LayoutUtil::SetToDefaultLayout(subshape); } }); - TF_ASSIGN_OR_RETURN( - Literal literal, - local_device->client()->TransferFromOutfeedLocal( - shape_with_layout, local_device->device_ordinal())); + TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed( + shape_with_layout)); literal_shared = std::make_shared(std::move(literal)); }