Refactor PJRT.
- Add TransferToInfeed and TransferFromOutfeed to PjRtDevice's methods. PiperOrigin-RevId: 338105874 Change-Id: I629b2efa27394bc99b26371c3de779b1104eea4f
This commit is contained in:
parent
86ea5401c4
commit
c98ac56bc2
tensorflow/compiler/xla
@ -851,6 +851,20 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
|||||||
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
|
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Transfer the given literal to the infeed queue of the given local device.
|
||||||
|
Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const {
|
||||||
|
// Only support infeed to local device.
|
||||||
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||||
|
return local_device->client()->TransferToInfeedLocal(
|
||||||
|
literal, local_device->device_ordinal());
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
|
||||||
|
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||||
|
return local_device->client()->TransferFromOutfeedLocal(
|
||||||
|
shape, local_device->device_ordinal());
|
||||||
|
}
|
||||||
|
|
||||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||||
PjRtClient* client, PjRtDevice* device)
|
PjRtClient* client, PjRtDevice* device)
|
||||||
|
@ -94,6 +94,13 @@ class PjRtDevice {
|
|||||||
|
|
||||||
PjRtClient* client() const { return client_; }
|
PjRtClient* client() const { return client_; }
|
||||||
|
|
||||||
|
// Transfer the given literal to the infeed queue of the given localdevice.
|
||||||
|
virtual Status TransferToInfeed(const LiteralSlice& literal) const;
|
||||||
|
|
||||||
|
// Transfer and return a value of the given shape from the outfeed of the
|
||||||
|
// given device.
|
||||||
|
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class PjRtClient;
|
friend class PjRtClient;
|
||||||
|
|
||||||
|
@ -479,10 +479,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
[](const PjRtDevice& device, const LiteralSlice& literal) {
|
[](const PjRtDevice& device, const LiteralSlice& literal) {
|
||||||
GlobalPyRefManager()->CollectGarbage();
|
GlobalPyRefManager()->CollectGarbage();
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
return device.TransferToInfeed(literal);
|
||||||
device.GetLocalDeviceState());
|
|
||||||
return local_device->client()->TransferToInfeedLocal(
|
|
||||||
literal, local_device->device_ordinal());
|
|
||||||
})
|
})
|
||||||
.def("transfer_from_outfeed",
|
.def("transfer_from_outfeed",
|
||||||
[](const PjRtDevice& device,
|
[](const PjRtDevice& device,
|
||||||
@ -491,8 +488,6 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
std::shared_ptr<Literal> literal_shared;
|
std::shared_ptr<Literal> literal_shared;
|
||||||
{
|
{
|
||||||
py::gil_scoped_release gil_release;
|
py::gil_scoped_release gil_release;
|
||||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
|
||||||
device.GetLocalDeviceState());
|
|
||||||
Shape shape_with_layout = shape;
|
Shape shape_with_layout = shape;
|
||||||
ShapeUtil::ForEachMutableSubshape(
|
ShapeUtil::ForEachMutableSubshape(
|
||||||
&shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
|
&shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
|
||||||
@ -500,10 +495,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
|||||||
LayoutUtil::SetToDefaultLayout(subshape);
|
LayoutUtil::SetToDefaultLayout(subshape);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed(
|
||||||
Literal literal,
|
shape_with_layout));
|
||||||
local_device->client()->TransferFromOutfeedLocal(
|
|
||||||
shape_with_layout, local_device->device_ordinal()));
|
|
||||||
|
|
||||||
literal_shared = std::make_shared<Literal>(std::move(literal));
|
literal_shared = std::make_shared<Literal>(std::move(literal));
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user