[XLA] Change LocalClient::TransferFromOutfeedLocal to write into a caller-provided literal rather than allocating its own literal.
PiperOrigin-RevId: 353904746 Change-Id: Ia9868bee8c90c5e2a99fbb787eede16dbe75d2d1
This commit is contained in:
parent
8e985babf8
commit
f5d89fe581
@ -433,14 +433,12 @@ Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal,
|
||||
literal);
|
||||
}
|
||||
|
||||
StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
|
||||
int device_ordinal) {
|
||||
Status LocalClient::TransferFromOutfeedLocal(int device_ordinal,
|
||||
MutableBorrowingLiteral literal) {
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||
backend().stream_executor(device_ordinal));
|
||||
auto literal = Literal::CreateFromShape(shape);
|
||||
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
|
||||
executor, &literal));
|
||||
return std::move(literal);
|
||||
return backend().transfer_manager()->TransferLiteralFromOutfeed(executor,
|
||||
literal);
|
||||
}
|
||||
|
||||
StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {
|
||||
|
@ -172,13 +172,13 @@ class LocalClient : public Client {
|
||||
// Client::TransferToInfeed.
|
||||
Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal);
|
||||
|
||||
// Transfer and return a value of the given shape from the outfeed of the
|
||||
// given device.
|
||||
// Transfer and return a value from the outfeed of the given device. The
|
||||
// shape of the object to transfer is determined by `literal`'s shape.
|
||||
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
|
||||
// not inherit from Client and there is no possibility of confusion with
|
||||
// Client::TransferFromOutfeed.
|
||||
StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
|
||||
int device_ordinal);
|
||||
Status TransferFromOutfeedLocal(int device_ordinal,
|
||||
MutableBorrowingLiteral literal);
|
||||
|
||||
// Returns the device ordinal that corresponds to the given replica number.
|
||||
//
|
||||
|
@ -132,6 +132,7 @@ cc_library(
|
||||
hdrs = ["pjrt_client.h"],
|
||||
visibility = ["//tensorflow/compiler/xla:friends"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
@ -93,7 +94,7 @@ class PjRtDevice {
|
||||
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
|
||||
|
||||
// Transfer and return a value of the given shape from the outfeed queue.
|
||||
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
|
||||
virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) const = 0;
|
||||
};
|
||||
|
||||
// Forward declaration.
|
||||
|
@ -910,11 +910,11 @@ Status PjRtStreamExecutorDevice::TransferToInfeed(
|
||||
literal, local_device->device_ordinal());
|
||||
}
|
||||
|
||||
StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
|
||||
const Shape& shape) const {
|
||||
Status PjRtStreamExecutorDevice::TransferFromOutfeed(
|
||||
MutableBorrowingLiteral literal) const {
|
||||
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
|
||||
return local_device->client()->TransferFromOutfeedLocal(
|
||||
shape, local_device->device_ordinal());
|
||||
local_device->device_ordinal(), literal);
|
||||
}
|
||||
|
||||
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(
|
||||
|
@ -32,6 +32,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/layout.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
|
||||
@ -106,7 +107,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
|
||||
|
||||
Status TransferToInfeed(const LiteralSlice& literal) const override;
|
||||
|
||||
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
|
||||
Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override;
|
||||
|
||||
private:
|
||||
const int id_;
|
||||
|
@ -341,11 +341,9 @@ void OutfeedReceiverImpl::EnqueueReceivedData(
|
||||
|
||||
StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
|
||||
const PjRtDevice* device, const Shape& shape) {
|
||||
std::shared_ptr<Literal> literal_shared;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape));
|
||||
|
||||
return absl::make_unique<Literal>(std::move(literal));
|
||||
auto literal = std::make_unique<Literal>(shape);
|
||||
TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
|
||||
return literal;
|
||||
}
|
||||
|
||||
void OutfeedReceiverImpl::CallbackThreadLoop() {
|
||||
|
@ -69,7 +69,7 @@ class TpuDevice : public PjRtDevice {
|
||||
return Unimplemented("Infeed not yet implemented via this API");
|
||||
}
|
||||
|
||||
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override {
|
||||
Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override {
|
||||
return Unimplemented("Outfeed not yet implemented via this API");
|
||||
}
|
||||
|
||||
|
@ -146,7 +146,7 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
[](const PjRtDevice& device,
|
||||
const Shape& shape) -> StatusOr<py::object> {
|
||||
GlobalPyRefManager()->CollectGarbage();
|
||||
std::shared_ptr<Literal> literal_shared;
|
||||
std::shared_ptr<Literal> literal;
|
||||
{
|
||||
py::gil_scoped_release gil_release;
|
||||
Shape shape_with_layout = shape;
|
||||
@ -156,12 +156,10 @@ PYBIND11_MODULE(xla_extension, m) {
|
||||
LayoutUtil::SetToDefaultLayout(subshape);
|
||||
}
|
||||
});
|
||||
TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed(
|
||||
shape_with_layout));
|
||||
|
||||
literal_shared = std::make_shared<Literal>(std::move(literal));
|
||||
literal = std::make_shared<Literal>(shape_with_layout);
|
||||
TF_RETURN_IF_ERROR(device.TransferFromOutfeed(literal.get()));
|
||||
}
|
||||
return LiteralToPython(std::move(literal_shared));
|
||||
return LiteralToPython(std::move(literal));
|
||||
});
|
||||
|
||||
py::class_<CpuDevice, PjRtDevice, ClientAndPtr<CpuDevice>>(m, "CpuDevice")
|
||||
|
@ -937,9 +937,9 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
|
||||
LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
|
||||
local_client_->default_device_ordinal()));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal result,
|
||||
local_client_->TransferFromOutfeedLocal(
|
||||
shape, local_client_->default_device_ordinal()));
|
||||
Literal result(shape);
|
||||
ASSERT_IS_OK(local_client_->TransferFromOutfeedLocal(
|
||||
local_client_->default_device_ordinal(), &result));
|
||||
|
||||
LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
|
||||
}
|
||||
|
@ -92,9 +92,8 @@ void TestWithDeviceCount(const int device_count) {
|
||||
|
||||
for (int device_ordinal = 0; device_ordinal < device_count;
|
||||
device_ordinal++) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(Literal outfeed,
|
||||
client->TransferFromOutfeedLocal(
|
||||
ShapeUtil::MakeShape(S32, {}), device_ordinal));
|
||||
Literal outfeed(ShapeUtil::MakeShape(S32, {}));
|
||||
TF_ASSERT_OK(client->TransferFromOutfeedLocal(device_ordinal, &outfeed));
|
||||
EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1));
|
||||
}
|
||||
|
||||
|
@ -282,9 +282,9 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
|
||||
outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed",
|
||||
/*num_threads=*/1);
|
||||
auto consume_outfeed = [client, outfeed_shape] {
|
||||
Literal outfeed(*outfeed_shape);
|
||||
TF_CHECK_OK(
|
||||
client->TransferFromOutfeedLocal(*outfeed_shape, /*device_ordinal=*/0)
|
||||
.status());
|
||||
client->TransferFromOutfeedLocal(/*device_ordinal=*/0, &outfeed));
|
||||
VLOG(1) << "Received outfeed data of shape "
|
||||
<< ShapeUtil::HumanStringWithLayout(*outfeed_shape);
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user