Refactor PJRT.

- Add TransferToInfeed and TransferFromOutfeed to PjRtDevice's methods.

PiperOrigin-RevId: 338105874
Change-Id: I629b2efa27394bc99b26371c3de779b1104eea4f
This commit is contained in:
Qiao Zhang 2020-10-20 12:01:17 -07:00 committed by TensorFlower Gardener
parent 86ea5401c4
commit c98ac56bc2
3 changed files with 24 additions and 10 deletions

View File

@ -851,6 +851,20 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
}
// Transfer the given literal to the infeed queue of the given local device.
Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const {
// Only support infeed to local device.
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferToInfeedLocal(
literal, local_device->device_ordinal());
}
StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferFromOutfeedLocal(
shape, local_device->device_ordinal());
}
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device)

View File

@ -94,6 +94,13 @@ class PjRtDevice {
PjRtClient* client() const { return client_; }
// Transfer the given literal to the infeed queue of the given localdevice.
virtual Status TransferToInfeed(const LiteralSlice& literal) const;
// Transfer and return a value of the given shape from the outfeed of the
// given device.
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
private:
friend class PjRtClient;

View File

@ -479,10 +479,7 @@ PYBIND11_MODULE(xla_extension, m) {
[](const PjRtDevice& device, const LiteralSlice& literal) {
GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device.GetLocalDeviceState());
return local_device->client()->TransferToInfeedLocal(
literal, local_device->device_ordinal());
return device.TransferToInfeed(literal);
})
.def("transfer_from_outfeed",
[](const PjRtDevice& device,
@ -491,8 +488,6 @@ PYBIND11_MODULE(xla_extension, m) {
std::shared_ptr<Literal> literal_shared;
{
py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device.GetLocalDeviceState());
Shape shape_with_layout = shape;
ShapeUtil::ForEachMutableSubshape(
&shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
@ -500,10 +495,8 @@ PYBIND11_MODULE(xla_extension, m) {
LayoutUtil::SetToDefaultLayout(subshape);
}
});
TF_ASSIGN_OR_RETURN(
Literal literal,
local_device->client()->TransferFromOutfeedLocal(
shape_with_layout, local_device->device_ordinal()));
TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed(
shape_with_layout));
literal_shared = std::make_shared<Literal>(std::move(literal));
}