[XLA] Change LocalClient::TransferFromOutfeedLocal to write into a caller-provided literal rather than allocating its own literal.

PiperOrigin-RevId: 353904746
Change-Id: Ia9868bee8c90c5e2a99fbb787eede16dbe75d2d1
This commit is contained in:
Peter Hawkins 2021-01-26 11:04:36 -08:00 committed by TensorFlower Gardener
parent 8e985babf8
commit f5d89fe581
12 changed files with 31 additions and 35 deletions

View File

@ -433,14 +433,12 @@ Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal,
literal);
}
StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape,
int device_ordinal) {
Status LocalClient::TransferFromOutfeedLocal(int device_ordinal,
MutableBorrowingLiteral literal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal));
auto literal = Literal::CreateFromShape(shape);
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed(
executor, &literal));
return std::move(literal);
return backend().transfer_manager()->TransferLiteralFromOutfeed(executor,
literal);
}
StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {

View File

@ -172,13 +172,13 @@ class LocalClient : public Client {
// Client::TransferToInfeed.
Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal);
// Transfer and return a value of the given shape from the outfeed of the
// given device.
// Transfer and return a value from the outfeed of the given device. The
// shape of the object to transfer is determined by `literal`'s shape.
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with
// Client::TransferFromOutfeed.
StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape,
int device_ordinal);
Status TransferFromOutfeedLocal(int device_ordinal,
MutableBorrowingLiteral literal);
// Returns the device ordinal that corresponds to the given replica number.
//

View File

@ -132,6 +132,7 @@ cc_library(
hdrs = ["pjrt_client.h"],
visibility = ["//tensorflow/compiler/xla:friends"],
deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor",

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape.h"
@ -93,7 +94,7 @@ class PjRtDevice {
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
// Transfer and return a value of the given shape from the outfeed queue.
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0;
virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) const = 0;
};
// Forward declaration.

View File

@ -910,11 +910,11 @@ Status PjRtStreamExecutorDevice::TransferToInfeed(
literal, local_device->device_ordinal());
}
StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed(
const Shape& shape) const {
Status PjRtStreamExecutorDevice::TransferFromOutfeed(
MutableBorrowingLiteral literal) const {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferFromOutfeedLocal(
shape, local_device->device_ordinal());
local_device->device_ordinal(), literal);
}
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
@ -106,7 +107,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
Status TransferToInfeed(const LiteralSlice& literal) const override;
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override;
Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override;
private:
const int id_;

View File

@ -341,11 +341,9 @@ void OutfeedReceiverImpl::EnqueueReceivedData(
StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
const PjRtDevice* device, const Shape& shape) {
std::shared_ptr<Literal> literal_shared;
TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape));
return absl::make_unique<Literal>(std::move(literal));
auto literal = std::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
return literal;
}
void OutfeedReceiverImpl::CallbackThreadLoop() {

View File

@ -69,7 +69,7 @@ class TpuDevice : public PjRtDevice {
return Unimplemented("Infeed not yet implemented via this API");
}
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override {
Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override {
return Unimplemented("Outfeed not yet implemented via this API");
}

View File

@ -146,7 +146,7 @@ PYBIND11_MODULE(xla_extension, m) {
[](const PjRtDevice& device,
const Shape& shape) -> StatusOr<py::object> {
GlobalPyRefManager()->CollectGarbage();
std::shared_ptr<Literal> literal_shared;
std::shared_ptr<Literal> literal;
{
py::gil_scoped_release gil_release;
Shape shape_with_layout = shape;
@ -156,12 +156,10 @@ PYBIND11_MODULE(xla_extension, m) {
LayoutUtil::SetToDefaultLayout(subshape);
}
});
TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed(
shape_with_layout));
literal_shared = std::make_shared<Literal>(std::move(literal));
literal = std::make_shared<Literal>(shape_with_layout);
TF_RETURN_IF_ERROR(device.TransferFromOutfeed(literal.get()));
}
return LiteralToPython(std::move(literal_shared));
return LiteralToPython(std::move(literal));
});
py::class_<CpuDevice, PjRtDevice, ClientAndPtr<CpuDevice>>(m, "CpuDevice")

View File

@ -937,9 +937,9 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal()));
TF_ASSERT_OK_AND_ASSIGN(Literal result,
local_client_->TransferFromOutfeedLocal(
shape, local_client_->default_device_ordinal()));
Literal result(shape);
ASSERT_IS_OK(local_client_->TransferFromOutfeedLocal(
local_client_->default_device_ordinal(), &result));
LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
}

View File

@ -92,9 +92,8 @@ void TestWithDeviceCount(const int device_count) {
for (int device_ordinal = 0; device_ordinal < device_count;
device_ordinal++) {
TF_ASSERT_OK_AND_ASSIGN(Literal outfeed,
client->TransferFromOutfeedLocal(
ShapeUtil::MakeShape(S32, {}), device_ordinal));
Literal outfeed(ShapeUtil::MakeShape(S32, {}));
TF_ASSERT_OK(client->TransferFromOutfeedLocal(device_ordinal, &outfeed));
EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1));
}

View File

@ -282,9 +282,9 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed",
/*num_threads=*/1);
auto consume_outfeed = [client, outfeed_shape] {
Literal outfeed(*outfeed_shape);
TF_CHECK_OK(
client->TransferFromOutfeedLocal(*outfeed_shape, /*device_ordinal=*/0)
.status());
client->TransferFromOutfeedLocal(/*device_ordinal=*/0, &outfeed));
VLOG(1) << "Received outfeed data of shape "
<< ShapeUtil::HumanStringWithLayout(*outfeed_shape);
};