Allow PJRT infeed/outfeed methods to be non-const.
PiperOrigin-RevId: 355321939 Change-Id: Ifc081fd55121d328631ce4fce87af1f7a9452b2e
This commit is contained in:
parent
1df8d8fd63
commit
ddb089b72f
@ -91,10 +91,10 @@ class PjRtDevice {
|
||||
virtual std::string DebugString() const = 0;
|
||||
|
||||
// Transfer the given literal to the infeed queue.
|
||||
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
|
||||
virtual Status TransferToInfeed(const LiteralSlice& literal) = 0;
|
||||
|
||||
// Transfer and return a value of the given shape from the outfeed queue.
|
||||
virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) const = 0;
|
||||
virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) = 0;
|
||||
};
|
||||
|
||||
// Forward declaration.
|
||||
|
@ -961,8 +961,7 @@ PjRtStreamExecutorClient::CreateViewOfDeviceBuffer(
|
||||
}
|
||||
|
||||
// Transfer the given literal to the infeed queue of the given local device.
|
||||
Status PjRtStreamExecutorDevice::TransferToInfeed(
|
||||
const LiteralSlice& literal) const {
|
||||
Status PjRtStreamExecutorDevice::TransferToInfeed(const LiteralSlice& literal) {
|
||||
// Only support infeed to local device.
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||
return local_device->client()->TransferToInfeedLocal(
|
||||
@ -970,7 +969,7 @@ Status PjRtStreamExecutorDevice::TransferToInfeed(
|
||||
}
|
||||
|
||||
Status PjRtStreamExecutorDevice::TransferFromOutfeed(
|
||||
MutableBorrowingLiteral literal) const {
|
||||
MutableBorrowingLiteral literal) {
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||
return local_device->client()->TransferFromOutfeedLocal(
|
||||
local_device->device_ordinal(), literal);
|
||||
|
@ -105,9 +105,9 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
|
||||
|
||||
std::string DebugString() const override;
|
||||
|
||||
Status TransferToInfeed(const LiteralSlice& literal) const override;
|
||||
Status TransferToInfeed(const LiteralSlice& literal) override;
|
||||
|
||||
Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override;
|
||||
Status TransferFromOutfeed(MutableBorrowingLiteral literal) override;
|
||||
|
||||
private:
|
||||
const int id_;
|
||||
|
@ -188,8 +188,8 @@ class OutfeedReceiverImpl {
|
||||
Status SendShutdownOutfeedHeader(int device_idx);
|
||||
|
||||
// Receives a raw Literal from a device outfeed.
|
||||
StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(
|
||||
const PjRtDevice* device, const Shape& shape);
|
||||
StatusOr<std::unique_ptr<Literal>> ReceiveRawFromOutfeed(PjRtDevice* device,
|
||||
const Shape& shape);
|
||||
|
||||
// Enqueues received data in the callbaback queue.
|
||||
void EnqueueReceivedData(std::unique_ptr<OutfeedData> received)
|
||||
@ -340,7 +340,7 @@ void OutfeedReceiverImpl::EnqueueReceivedData(
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
|
||||
const PjRtDevice* device, const Shape& shape) {
|
||||
PjRtDevice* device, const Shape& shape) {
|
||||
auto literal = std::make_unique<Literal>(shape);
|
||||
TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
|
||||
return literal;
|
||||
|
@ -65,11 +65,11 @@ class TpuDevice : public PjRtDevice {
|
||||
|
||||
absl::string_view device_kind() const override { return device_kind_; }
|
||||
|
||||
Status TransferToInfeed(const LiteralSlice& literal) const override {
|
||||
Status TransferToInfeed(const LiteralSlice& literal) override {
|
||||
return Unimplemented("Infeed not yet implemented via this API");
|
||||
}
|
||||
|
||||
Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override {
|
||||
Status TransferFromOutfeed(MutableBorrowingLiteral literal) override {
|
||||
return Unimplemented("Outfeed not yet implemented via this API");
|
||||
}
|
||||
|
||||
|
@ -142,14 +142,13 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
[](const ClientAndPtr<PjRtDevice>& device) { return device.client; })
|
||||
.def("__str__", &PjRtDevice::DebugString)
|
||||
.def("transfer_to_infeed",
|
||||
[](const PjRtDevice& device, const LiteralSlice& literal) {
|
||||
[](PjRtDevice& device, const LiteralSlice& literal) {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
py::gil_scoped_release gil_release;
|
||||
return device.TransferToInfeed(literal);
|
||||
})
|
||||
.def("transfer_from_outfeed",
|
||||
[](const PjRtDevice& device,
|
||||
const Shape& shape) -> StatusOr<py::object> {
|
||||
[](PjRtDevice& device, const Shape& shape) -> StatusOr<py::object> {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
std::shared_ptr<Literal> literal;
|
||||
{
|
||||
|
Loading…
x
Reference in New Issue
Block a user