From ed2134ae4327e1e09ed3808e748489405a845382 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 12 Jan 2021 03:28:32 -0800 Subject: [PATCH] [XLA] Adjust outfeed_receiver to only create threads for local devices. Otherwise, we get an error when using this with non-local devices, for example "Invalid argument: Device TPU_10(host=0,(1,1,0,0)) is not a local device." when using id_tap/host_callback in a multihost TPU setup - since we can only actually obtain the outfeeds for local devices. PiperOrigin-RevId: 351336978 Change-Id: I68f8b4345c9d01ef197dd400386a8bc04a59b0d2 --- tensorflow/compiler/xla/python/BUILD | 5 ++ .../compiler/xla/python/outfeed_receiver.cc | 2 +- .../xla/python/outfeed_receiver_test.cc | 65 +++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 055c8d973f1..b004e252c62 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -329,13 +329,18 @@ tf_cc_test( deps = [ ":outfeed_receiver", "//tensorflow/compiler/jit:xla_cpu_jit", + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/pjrt:cpu_device", "//tensorflow/compiler/xla/pjrt:pjrt_client", + "//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", ], ) diff --git a/tensorflow/compiler/xla/python/outfeed_receiver.cc b/tensorflow/compiler/xla/python/outfeed_receiver.cc index 92aae351085..69191604237 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver.cc @@ -231,7 +231,7 @@ OutfeedReceiverImpl::OutfeedReceiverImpl( callback_ = callback; max_callback_queue_size_bytes_ = max_callback_queue_size_bytes; for (const auto& client : clients) { - for (auto device : client->devices()) { + for (auto device : client->local_devices()) { devices_.push_back(device); } } diff --git a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc index 17d64b8aef4..8e60553d290 100644 --- a/tensorflow/compiler/xla/python/outfeed_receiver_test.cc +++ b/tensorflow/compiler/xla/python/outfeed_receiver_test.cc @@ -18,10 +18,13 @@ limitations under the License. #include #include "absl/synchronization/mutex.h" +#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/pjrt/cpu_device.h" #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" +#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/test.h" namespace xla { @@ -72,6 +75,35 @@ class Accumulator { std::vector received_ TF_GUARDED_BY(mutex_); }; +StatusOr> GetCpuClientWithNonLocalDevice() { + TF_ASSIGN_OR_RETURN(se::Platform * platform, + PlatformUtil::GetPlatform("Host")); + if (platform->VisibleDeviceCount() <= 0) { + return FailedPrecondition("CPU platform has no visible devices."); + } + LocalClientOptions options; + options.set_platform(platform); + TF_ASSIGN_OR_RETURN(LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(options)); + + se::StreamExecutorConfig config(0); + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + platform->GetExecutor(config)); + auto device_state = absl::make_unique( + executor, client, LocalDeviceState::kSynchronous, /*asynchronous=*/true, + /*allow_event_reuse=*/false); + + std::vector> devices; + devices.push_back(absl::make_unique(0, std::move(device_state))); + devices.push_back(absl::make_unique(1, nullptr)); + + return std::unique_ptr(std::make_unique( + kCpuName, client, std::move(devices), /*host_id=*/0, + /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr, + /*should_stage_host_to_device_transfers=*/false, + /*gpu_run_options=*/nullptr)); +} + TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) { TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, GetCpuClient(true)); @@ -253,6 +285,39 @@ TEST(OutfeedReceiverTest, InvalidConsumerIdError) { testing::HasSubstr("Consumer ID cannot be a reserved value")); } +TEST(OutfeedReceiverTest, NonLocalDevicesIgnored) { + TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr cpu_client, + GetCpuClientWithNonLocalDevice()); + std::vector clients{cpu_client.get()}; + + auto receiver = absl::make_unique(); + OutfeedReceiver::Callback callback = + [&receiver](PjRtDevice* device, uint32_t consumer_id, + std::shared_ptr data) { + receiver->Receive(consumer_id, data); + }; + auto outfeed_receiver = + std::make_shared(callback, clients, 128); + outfeed_receiver->Start(); + + XlaBuilder builder("execute_test_outfeed"); + constexpr int consumer_id0 = 5; + const Shape shape0 = ShapeUtil::MakeShape(U32, {16}); + XlaOp data = Iota(&builder, shape0, 0); + XlaOp send = outfeed_receiver + ->AddOutfeedToBuilder(&builder, CreateToken(&builder), + consumer_id0, {data}) + .ValueOrDie(); + EXPECT_TRUE(CompileAndExecute(&builder, send, 0, cpu_client.get()).ok()); + + // Shutdown the receiver, to force it to wait to deliver the callbacks. + outfeed_receiver = nullptr; + std::vector received = receiver->received(); + EXPECT_EQ(1, received.size()); + EXPECT_EQ(consumer_id0, received[0].consumer_id); + EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape()); +} + } // namespace } // namespace xla