Refactor PJRT.
- Add TransferToInfeed and TransferFromOutfeed to PjRtDevice's methods. PiperOrigin-RevId: 338105874 Change-Id: I629b2efa27394bc99b26371c3de779b1104eea4f
This commit is contained in:
parent
86ea5401c4
commit
c98ac56bc2
@ -851,6 +851,20 @@ void PjRtClient::MakeCrossHostReceiveBuffers(
|
||||
EnqueueCrossHostReceive(std::move(buffers), std::move(notifier));
|
||||
}
|
||||
|
||||
// Transfer the given literal to the infeed queue of the given local device.
|
||||
Status PjRtDevice::TransferToInfeed(const LiteralSlice& literal) const {
|
||||
// Only support infeed to local device.
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||
return local_device->client()->TransferToInfeedLocal(
|
||||
literal, local_device->device_ordinal());
|
||||
}
|
||||
|
||||
StatusOr<Literal> PjRtDevice::TransferFromOutfeed(const Shape& shape) const {
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||
return local_device->client()->TransferFromOutfeedLocal(
|
||||
shape, local_device->device_ordinal());
|
||||
}
|
||||
|
||||
PjRtBuffer::PjRtBuffer(Shape on_host_shape, Shape on_device_shape,
|
||||
std::shared_ptr<TrackedDeviceBuffer> device_buffer,
|
||||
PjRtClient* client, PjRtDevice* device)
|
||||
|
@ -94,6 +94,13 @@ class PjRtDevice {
|
||||
|
||||
PjRtClient* client() const { return client_; }
|
||||
|
||||
// Transfer the given literal to the infeed queue of the given localdevice.
|
||||
virtual Status TransferToInfeed(const LiteralSlice& literal) const;
|
||||
|
||||
// Transfer and return a value of the given shape from the outfeed of the
|
||||
// given device.
|
||||
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const;
|
||||
|
||||
private:
|
||||
friend class PjRtClient;
|
||||
|
||||
|
@ -479,10 +479,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
[](const PjRtDevice& device, const LiteralSlice& literal) {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device.GetLocalDeviceState());
|
||||
return local_device->client()->TransferToInfeedLocal(
|
||||
literal, local_device->device_ordinal());
|
||||
return device.TransferToInfeed(literal);
|
||||
})
|
||||
.def("transfer_from_outfeed",
|
||||
[](const PjRtDevice& device,
|
||||
@ -491,8 +488,6 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
std::shared_ptr<Literal> literal_shared;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device,
|
||||
device.GetLocalDeviceState());
|
||||
Shape shape_with_layout = shape;
|
||||
ShapeUtil::ForEachMutableSubshape(
|
||||
&shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
|
||||
@ -500,10 +495,8 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
LayoutUtil::SetToDefaultLayout(subshape);
|
||||
}
|
||||
});
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
Literal literal,
|
||||
local_device->client()->TransferFromOutfeedLocal(
|
||||
shape_with_layout, local_device->device_ordinal()));
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed(
|
||||
shape_with_layout));
|
||||
|
||||
literal_shared = std::make_shared<Literal>(std::move(literal));
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user