[XLA] Change LocalClient::TransferFromOutfeedLocal to write into a caller-provided literal rather than allocating its own literal.

PiperOrigin-RevId: 353904746
Change-Id: Ia9868bee8c90c5e2a99fbb787eede16dbe75d2d1
This commit is contained in:
Peter Hawkins 2021-01-26 11:04:36 -08:00 committed by TensorFlower Gardener
parent 8e985babf8
commit f5d89fe581
12 changed files with 31 additions and 35 deletions

View File

@ -433,14 +433,12 @@ Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal,
literal); literal);
} }
StatusOr<Literal> LocalClient::TransferFromOutfeedLocal(const Shape& shape, Status LocalClient::TransferFromOutfeedLocal(int device_ordinal,
int device_ordinal) { MutableBorrowingLiteral literal) {
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
backend().stream_executor(device_ordinal)); backend().stream_executor(device_ordinal));
auto literal = Literal::CreateFromShape(shape); return backend().transfer_manager()->TransferLiteralFromOutfeed(executor,
TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( literal);
executor, &literal));
return std::move(literal);
} }
StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) { StatusOr<int> LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) {

View File

@ -172,13 +172,13 @@ class LocalClient : public Client {
// Client::TransferToInfeed. // Client::TransferToInfeed.
Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal); Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal);
// Transfer and return a value of the given shape from the outfeed of the // Transfer and return a value from the outfeed of the given device. The
// given device. // shape of the object to transfer is determined by `literal`'s shape.
// TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does
// not inherit from Client and there is no possibility of confusion with // not inherit from Client and there is no possibility of confusion with
// Client::TransferFromOutfeed. // Client::TransferFromOutfeed.
StatusOr<Literal> TransferFromOutfeedLocal(const Shape& shape, Status TransferFromOutfeedLocal(int device_ordinal,
int device_ordinal); MutableBorrowingLiteral literal);
// Returns the device ordinal that corresponds to the given replica number. // Returns the device ordinal that corresponds to the given replica number.
// //

View File

@ -132,6 +132,7 @@ cc_library(
hdrs = ["pjrt_client.h"], hdrs = ["pjrt_client.h"],
visibility = ["//tensorflow/compiler/xla:friends"], visibility = ["//tensorflow/compiler/xla:friends"],
deps = [ deps = [
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/shape.h"
@ -93,7 +94,7 @@ class PjRtDevice {
virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0; virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0;
// Transfer and return a value of the given shape from the outfeed queue. // Transfer and return a value of the given shape from the outfeed queue.
virtual StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const = 0; virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) const = 0;
}; };
// Forward declaration. // Forward declaration.

View File

@ -910,11 +910,11 @@ Status PjRtStreamExecutorDevice::TransferToInfeed(
literal, local_device->device_ordinal()); literal, local_device->device_ordinal());
} }
StatusOr<Literal> PjRtStreamExecutorDevice::TransferFromOutfeed( Status PjRtStreamExecutorDevice::TransferFromOutfeed(
const Shape& shape) const { MutableBorrowingLiteral literal) const {
TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState());
return local_device->client()->TransferFromOutfeedLocal( return local_device->client()->TransferFromOutfeedLocal(
shape, local_device->device_ordinal()); local_device->device_ordinal(), literal);
} }
StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice( StatusOr<PjRtDevice*> PjRtStreamExecutorClient::LookupAddressableDevice(

View File

@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/layout.h" #include "tensorflow/compiler/xla/layout.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
@ -106,7 +107,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice {
Status TransferToInfeed(const LiteralSlice& literal) const override; Status TransferToInfeed(const LiteralSlice& literal) const override;
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override; Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override;
private: private:
const int id_; const int id_;

View File

@ -341,11 +341,9 @@ void OutfeedReceiverImpl::EnqueueReceivedData(
StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed( StatusOr<std::unique_ptr<Literal>> OutfeedReceiverImpl::ReceiveRawFromOutfeed(
const PjRtDevice* device, const Shape& shape) { const PjRtDevice* device, const Shape& shape) {
std::shared_ptr<Literal> literal_shared; auto literal = std::make_unique<Literal>(shape);
TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get()));
TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape)); return literal;
return absl::make_unique<Literal>(std::move(literal));
} }
void OutfeedReceiverImpl::CallbackThreadLoop() { void OutfeedReceiverImpl::CallbackThreadLoop() {

View File

@ -69,7 +69,7 @@ class TpuDevice : public PjRtDevice {
return Unimplemented("Infeed not yet implemented via this API"); return Unimplemented("Infeed not yet implemented via this API");
} }
StatusOr<Literal> TransferFromOutfeed(const Shape& shape) const override { Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override {
return Unimplemented("Outfeed not yet implemented via this API"); return Unimplemented("Outfeed not yet implemented via this API");
} }

View File

@ -146,7 +146,7 @@ PYBIND11_MODULE(xla_extension, m) {
[](const PjRtDevice& device, [](const PjRtDevice& device,
const Shape& shape) -> StatusOr<py::object> { const Shape& shape) -> StatusOr<py::object> {
GlobalPyRefManager()->CollectGarbage(); GlobalPyRefManager()->CollectGarbage();
std::shared_ptr<Literal> literal_shared; std::shared_ptr<Literal> literal;
{ {
py::gil_scoped_release gil_release; py::gil_scoped_release gil_release;
Shape shape_with_layout = shape; Shape shape_with_layout = shape;
@ -156,12 +156,10 @@ PYBIND11_MODULE(xla_extension, m) {
LayoutUtil::SetToDefaultLayout(subshape); LayoutUtil::SetToDefaultLayout(subshape);
} }
}); });
TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed( literal = std::make_shared<Literal>(shape_with_layout);
shape_with_layout)); TF_RETURN_IF_ERROR(device.TransferFromOutfeed(literal.get()));
literal_shared = std::make_shared<Literal>(std::move(literal));
} }
return LiteralToPython(std::move(literal_shared)); return LiteralToPython(std::move(literal));
}); });
py::class_<CpuDevice, PjRtDevice, ClientAndPtr<CpuDevice>>(m, "CpuDevice") py::class_<CpuDevice, PjRtDevice, ClientAndPtr<CpuDevice>>(m, "CpuDevice")

View File

@ -937,9 +937,9 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) {
LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}), LiteralUtil::CreateR1<float>({-5.0, 123.0, 42.0}),
local_client_->default_device_ordinal())); local_client_->default_device_ordinal()));
TF_ASSERT_OK_AND_ASSIGN(Literal result, Literal result(shape);
local_client_->TransferFromOutfeedLocal( ASSERT_IS_OK(local_client_->TransferFromOutfeedLocal(
shape, local_client_->default_device_ordinal())); local_client_->default_device_ordinal(), &result));
LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result); LiteralTestUtil::ExpectR1Equal<float>({-4.0, 125.0, 45.0}, result);
} }

View File

@ -92,9 +92,8 @@ void TestWithDeviceCount(const int device_count) {
for (int device_ordinal = 0; device_ordinal < device_count; for (int device_ordinal = 0; device_ordinal < device_count;
device_ordinal++) { device_ordinal++) {
TF_ASSERT_OK_AND_ASSIGN(Literal outfeed, Literal outfeed(ShapeUtil::MakeShape(S32, {}));
client->TransferFromOutfeedLocal( TF_ASSERT_OK(client->TransferFromOutfeedLocal(device_ordinal, &outfeed));
ShapeUtil::MakeShape(S32, {}), device_ordinal));
EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1)); EXPECT_EQ(outfeed, LiteralUtil::CreateR0<int32>(device_ordinal * 100 + 1));
} }

View File

@ -282,9 +282,9 @@ StatusOr<Literal> ReplayComputation(const HloSnapshot& module,
outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed", outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed",
/*num_threads=*/1); /*num_threads=*/1);
auto consume_outfeed = [client, outfeed_shape] { auto consume_outfeed = [client, outfeed_shape] {
Literal outfeed(*outfeed_shape);
TF_CHECK_OK( TF_CHECK_OK(
client->TransferFromOutfeedLocal(*outfeed_shape, /*device_ordinal=*/0) client->TransferFromOutfeedLocal(/*device_ordinal=*/0, &outfeed));
.status());
VLOG(1) << "Received outfeed data of shape " VLOG(1) << "Received outfeed data of shape "
<< ShapeUtil::HumanStringWithLayout(*outfeed_shape); << ShapeUtil::HumanStringWithLayout(*outfeed_shape);
}; };