From f5d89fe581e38b8bd44bd96df36af412f03634e1 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 26 Jan 2021 11:04:36 -0800 Subject: [PATCH] [XLA] Change LocalClient::TransferFromOutfeedLocal to write into a caller-provided literal rather than allocating its own literal. PiperOrigin-RevId: 353904746 Change-Id: Ia9868bee8c90c5e2a99fbb787eede16dbe75d2d1 --- tensorflow/compiler/xla/client/local_client.cc | 10 ++++------ tensorflow/compiler/xla/client/local_client.h | 8 ++++---- tensorflow/compiler/xla/pjrt/BUILD | 1 + tensorflow/compiler/xla/pjrt/pjrt_client.h | 3 ++- .../compiler/xla/pjrt/pjrt_stream_executor_client.cc | 6 +++--- .../compiler/xla/pjrt/pjrt_stream_executor_client.h | 3 ++- tensorflow/compiler/xla/python/outfeed_receiver.cc | 8 +++----- .../compiler/xla/python/tpu_driver/client/tpu_client.h | 2 +- tensorflow/compiler/xla/python/xla.cc | 10 ++++------ .../compiler/xla/tests/local_client_execute_test.cc | 6 +++--- .../xla/tests/multiple_devices_on_host_test.cc | 5 ++--- tensorflow/compiler/xla/tools/replay_computation.cc | 4 ++-- 12 files changed, 31 insertions(+), 35 deletions(-) diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index 1b2a19ba2aa..8c0e8426a03 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -433,14 +433,12 @@ Status LocalClient::TransferToInfeedLocal(const LiteralSlice& literal, literal); } -StatusOr LocalClient::TransferFromOutfeedLocal(const Shape& shape, - int device_ordinal) { +Status LocalClient::TransferFromOutfeedLocal(int device_ordinal, + MutableBorrowingLiteral literal) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); - auto literal = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromOutfeed( - executor, &literal)); - return std::move(literal); + return backend().transfer_manager()->TransferLiteralFromOutfeed(executor, + literal); } StatusOr LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) { diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index bb072a0fe2c..12a779618bb 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -172,13 +172,13 @@ class LocalClient : public Client { // Client::TransferToInfeed. Status TransferToInfeedLocal(const LiteralSlice& literal, int device_ordinal); - // Transfer and return a value of the given shape from the outfeed of the - // given device. + // Transfer and return a value from the outfeed of the given device. The + // shape of the object to transfer is determined by `literal`'s shape. // TODO(b/69670845): Remove the 'Local' from the name when LocalClient does // not inherit from Client and there is no possibility of confusion with // Client::TransferFromOutfeed. - StatusOr TransferFromOutfeedLocal(const Shape& shape, - int device_ordinal); + Status TransferFromOutfeedLocal(int device_ordinal, + MutableBorrowingLiteral literal); // Returns the device ordinal that corresponds to the given replica number. // diff --git a/tensorflow/compiler/xla/pjrt/BUILD b/tensorflow/compiler/xla/pjrt/BUILD index fcf840d2fdb..170c3bd622f 100644 --- a/tensorflow/compiler/xla/pjrt/BUILD +++ b/tensorflow/compiler/xla/pjrt/BUILD @@ -132,6 +132,7 @@ cc_library( hdrs = ["pjrt_client.h"], visibility = ["//tensorflow/compiler/xla:friends"], deps = [ + "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/pjrt/pjrt_client.h b/tensorflow/compiler/xla/pjrt/pjrt_client.h index d37b947a0dc..d0afcf356ff 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_client.h @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/shape.h" @@ -93,7 +94,7 @@ class PjRtDevice { virtual Status TransferToInfeed(const LiteralSlice& literal) const = 0; // Transfer and return a value of the given shape from the outfeed queue. - virtual StatusOr TransferFromOutfeed(const Shape& shape) const = 0; + virtual Status TransferFromOutfeed(MutableBorrowingLiteral literal) const = 0; }; // Forward declaration. diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc index 9578fb13f0c..cfe2303aade 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc @@ -910,11 +910,11 @@ Status PjRtStreamExecutorDevice::TransferToInfeed( literal, local_device->device_ordinal()); } -StatusOr PjRtStreamExecutorDevice::TransferFromOutfeed( - const Shape& shape) const { +Status PjRtStreamExecutorDevice::TransferFromOutfeed( + MutableBorrowingLiteral literal) const { TF_ASSIGN_OR_RETURN(LocalDeviceState * local_device, GetLocalDeviceState()); return local_device->client()->TransferFromOutfeedLocal( - shape, local_device->device_ordinal()); + local_device->device_ordinal(), literal); } StatusOr PjRtStreamExecutorClient::LookupAddressableDevice( diff --git a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h index ea2f19af257..d87f77f402a 100644 --- a/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h +++ b/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/layout.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/pjrt/local_device_state.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h" @@ -106,7 +107,7 @@ class PjRtStreamExecutorDevice : public PjRtDevice { Status TransferToInfeed(const LiteralSlice& literal) const override; - StatusOr TransferFromOutfeed(const Shape& shape) const override; + Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override; private: const int id_; diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc index 608b686f26e..37c5f357290 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -341,11 +341,9 @@ void OutfeedReceiverImpl::EnqueueReceivedData( StatusOr> OutfeedReceiverImpl::ReceiveRawFromOutfeed( const PjRtDevice* device, const Shape& shape) { - std::shared_ptr literal_shared; - - TF_ASSIGN_OR_RETURN(Literal literal, device->TransferFromOutfeed(shape)); - - return absl::make_unique(std::move(literal)); + auto literal = std::make_unique(shape); + TF_RETURN_IF_ERROR(device->TransferFromOutfeed(literal.get())); + return literal; } void OutfeedReceiverImpl::CallbackThreadLoop() { diff --git a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h index 33513746b12..73ac6dc40f4 100644 --- a/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h +++ b/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client.h @@ -69,7 +69,7 @@ class TpuDevice : public PjRtDevice { return Unimplemented("Infeed not yet implemented via this API"); } - StatusOr TransferFromOutfeed(const Shape& shape) const override { + Status TransferFromOutfeed(MutableBorrowingLiteral literal) const override { return Unimplemented("Outfeed not yet implemented via this API"); } diff --git a/tensorflow/compiler/xla/python/xla.cc b/tensorflow/compiler/xla/python/xla.cc index de343bc1ef3..e409e68e94e 100644 --- a/tensorflow/compiler/xla/python/xla.cc +++ b/tensorflow/compiler/xla/python/xla.cc @@ -146,7 +146,7 @@ PYBIND11_MODULE(xla_extension, m) { [](const PjRtDevice& device, const Shape& shape) -> StatusOr { GlobalPyRefManager()->CollectGarbage(); - std::shared_ptr literal_shared; + std::shared_ptr literal; { py::gil_scoped_release gil_release; Shape shape_with_layout = shape; @@ -156,12 +156,10 @@ PYBIND11_MODULE(xla_extension, m) { LayoutUtil::SetToDefaultLayout(subshape); } }); - TF_ASSIGN_OR_RETURN(Literal literal, device.TransferFromOutfeed( - shape_with_layout)); - - literal_shared = std::make_shared(std::move(literal)); + literal = std::make_shared(shape_with_layout); + TF_RETURN_IF_ERROR(device.TransferFromOutfeed(literal.get())); } - return LiteralToPython(std::move(literal_shared)); + return LiteralToPython(std::move(literal)); }); py::class_>(m, "CpuDevice") diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index fd83492a1ae..d915b6ed2a5 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -937,9 +937,9 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_INTERPRETER(InfeedOutfeedTest)) { LiteralUtil::CreateR1({-5.0, 123.0, 42.0}), local_client_->default_device_ordinal())); - TF_ASSERT_OK_AND_ASSIGN(Literal result, - local_client_->TransferFromOutfeedLocal( - shape, local_client_->default_device_ordinal())); + Literal result(shape); + ASSERT_IS_OK(local_client_->TransferFromOutfeedLocal( + local_client_->default_device_ordinal(), &result)); LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, result); } diff --git a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc index 2231fc6feab..d54139f384d 100644 --- a/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc +++ b/tensorflow/compiler/xla/tests/multiple_devices_on_host_test.cc @@ -92,9 +92,8 @@ void TestWithDeviceCount(const int device_count) { for (int device_ordinal = 0; device_ordinal < device_count; device_ordinal++) { - TF_ASSERT_OK_AND_ASSIGN(Literal outfeed, - client->TransferFromOutfeedLocal( - ShapeUtil::MakeShape(S32, {}), device_ordinal)); + Literal outfeed(ShapeUtil::MakeShape(S32, {})); + TF_ASSERT_OK(client->TransferFromOutfeedLocal(device_ordinal, &outfeed)); EXPECT_EQ(outfeed, LiteralUtil::CreateR0(device_ordinal * 100 + 1)); } diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 311a4a38e8b..0bf5d0e937b 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -282,9 +282,9 @@ StatusOr ReplayComputation(const HloSnapshot& module, outfeed_thread_pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); auto consume_outfeed = [client, outfeed_shape] { + Literal outfeed(*outfeed_shape); TF_CHECK_OK( - client->TransferFromOutfeedLocal(*outfeed_shape, /*device_ordinal=*/0) - .status()); + client->TransferFromOutfeedLocal(/*device_ordinal=*/0, &outfeed)); VLOG(1) << "Received outfeed data of shape " << ShapeUtil::HumanStringWithLayout(*outfeed_shape); };