[XLA] Adjust outfeed_receiver to only create threads for local devices.

Otherwise, we get an error when using this with non-local devices, for example
"Invalid argument: Device TPU_10(host=0,(1,1,0,0)) is not a local device." when
using id_tap/host_callback in a multihost TPU setup - since we can only actually
obtain the outfeeds for local devices.

PiperOrigin-RevId: 351336978
Change-Id: I68f8b4345c9d01ef197dd400386a8bc04a59b0d2
This commit is contained in:
A. Unique TensorFlower 2021-01-12 03:28:32 -08:00 committed by TensorFlower Gardener
parent 9f14841393
commit ed2134ae43
3 changed files with 71 additions and 1 deletions

View File

@ -329,13 +329,18 @@ tf_cc_test(
deps = [
":outfeed_receiver",
"//tensorflow/compiler/jit:xla_cpu_jit",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:executable_build_options",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/pjrt:cpu_device",
"//tensorflow/compiler/xla/pjrt:pjrt_client",
"//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/synchronization",
],
)

View File

@ -231,7 +231,7 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
callback_ = callback;
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
for (const auto& client : clients) {
for (auto device : client->devices()) {
for (auto device : client->local_devices()) {
devices_.push_back(device);
}
}

View File

@ -18,10 +18,13 @@ limitations under the License.
#include <memory>
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/executable_build_options.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
#include "tensorflow/compiler/xla/test.h"
namespace xla {
@ -72,6 +75,35 @@ class Accumulator {
std::vector<Data> received_ TF_GUARDED_BY(mutex_);
};
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClientWithNonLocalDevice() {
TF_ASSIGN_OR_RETURN(se::Platform * platform,
PlatformUtil::GetPlatform("Host"));
if (platform->VisibleDeviceCount() <= 0) {
return FailedPrecondition("CPU platform has no visible devices.");
}
LocalClientOptions options;
options.set_platform(platform);
TF_ASSIGN_OR_RETURN(LocalClient * client,
ClientLibrary::GetOrCreateLocalClient(options));
se::StreamExecutorConfig config(0);
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
platform->GetExecutor(config));
auto device_state = absl::make_unique<LocalDeviceState>(
executor, client, LocalDeviceState::kSynchronous, /*asynchronous=*/true,
/*allow_event_reuse=*/false);
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
devices.push_back(absl::make_unique<CpuDevice>(0, std::move(device_state)));
devices.push_back(absl::make_unique<CpuDevice>(1, nullptr));
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
kCpuName, client, std::move(devices), /*host_id=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr));
}
TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClient(true));
@ -253,6 +285,39 @@ TEST(OutfeedReceiverTest, InvalidConsumerIdError) {
testing::HasSubstr("Consumer ID cannot be a reserved value"));
}
TEST(OutfeedReceiverTest, NonLocalDevicesIgnored) {
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
GetCpuClientWithNonLocalDevice());
std::vector<PjRtClient*> clients{cpu_client.get()};
auto receiver = absl::make_unique<Accumulator>();
OutfeedReceiver::Callback callback =
[&receiver](PjRtDevice* device, uint32_t consumer_id,
std::shared_ptr<Literal> data) {
receiver->Receive(consumer_id, data);
};
auto outfeed_receiver =
std::make_shared<OutfeedReceiver>(callback, clients, 128);
outfeed_receiver->Start();
XlaBuilder builder("execute_test_outfeed");
constexpr int consumer_id0 = 5;
const Shape shape0 = ShapeUtil::MakeShape(U32, {16});
XlaOp data = Iota(&builder, shape0, 0);
XlaOp send = outfeed_receiver
->AddOutfeedToBuilder(&builder, CreateToken(&builder),
consumer_id0, {data})
.ValueOrDie();
EXPECT_TRUE(CompileAndExecute(&builder, send, 0, cpu_client.get()).ok());
// Shutdown the receiver, to force it to wait to deliver the callbacks.
outfeed_receiver = nullptr;
std::vector<Accumulator::Data> received = receiver->received();
EXPECT_EQ(1, received.size());
EXPECT_EQ(consumer_id0, received[0].consumer_id);
EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape());
}
} // namespace
} // namespace xla