Refactor PJRT.

- Add TransferToInfeed and TransferFromOutfeed to PjRtDevice's methods.

PiperOrigin-RevId: 338105874
Change-Id: I629b2efa27394bc99b26371c3de779b1104eea4f
This commit is contained in:
Qiao Zhang 2020-10-20 12:01:17 -07:00 committed by TensorFlower Gardener
parent 86ea5401c4
commit c98ac56bc2
3 changed files with 24 additions and 10 deletions
tensorflow/compiler/xla

View File

@ -851,6 +851,20 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier)); EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
} }
// Transfer the given literal to the infeed queue of the given local device.
Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const {
// Only support infeed to local device.
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferToInfeedLocal(
literal, local_device->device_ordinal());
}
StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferFromOutfeedLocal(
shape, local_device->device_ordinal());
}
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape, PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
std::shared_ptr<TrackedDeviceBuffer> device_buffer, std::shared_ptr<TrackedDeviceBuffer> device_buffer,
PjRtClient* client, PjRtDevice* device) PjRtClient* client, PjRtDevice* device)

View File

@ -94,6 +94,13 @@ class PjRtDevice {
PjRtClient* client() const { return client_; } PjRtClient* client() const { return client_; }
// Transfer the given literal to the infeed queue of the given localdevice.
virtual Status TransferToInfeed(const LiteralSlice& literal) const;
// Transfer and return a value of the given shape from the outfeed of the
// given device.
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
private: private:
friend class PjRtClient; friend class PjRtClient;

View File

@ -479,10 +479,7 @@ PYBIND11_MODULE(xla_extension, m) {
[](const PjRtDevice& device, const LiteralSlice& literal) { [](const PjRtDevice& device, const LiteralSlice& literal) {
GlobalPyRefManager()->CollectGarbage(); GlobalPyRefManager()->CollectGarbage();
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, return device.TransferToInfeed(literal);
device.GetLocalDeviceState());
return local_device->client()->TransferToInfeedLocal(
literal, local_device->device_ordinal());
}) })
.def("transfer_from_outfeed", .def("transfer_from_outfeed",
[](const PjRtDevice& device, [](const PjRtDevice& device,
@ -491,8 +488,6 @@ PYBIND11_MODULE(xla_extension, m) {
std::shared_ptr<Literal> literal_shared; std::shared_ptr<Literal> literal_shared;
{ {
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
device.GetLocalDeviceState());
Shape shape_with_layout = shape; Shape shape_with_layout = shape;
ShapeUtil::ForEachMutableSubshape( ShapeUtil::ForEachMutableSubshape(
&shape_with_layout, [](Shape* subshape, const ShapeIndex&) { &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
@ -500,10 +495,8 @@ PYBIND11_MODULE(xla_extension, m) {
LayoutUtil::SetToDefaultLayout(subshape); LayoutUtil::SetToDefaultLayout(subshape);
} }
}); });
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed(
Literal literal, shape_with_layout));
local_device->client()->TransferFromOutfeedLocal(
shape_with_layout, local_device->device_ordinal()));
literal_shared = std::make_shared<Literal>(std::move(literal)); literal_shared = std::make_shared<Literal>(std::move(literal));
} }