[XLA] Adjust outfeed_receiver to only create threads for local devices.
Otherwise, we get an error when using this with non-local devices, for example "Invalid argument: Device TPU_10(host=0,(1,1,0,0)) is not a local device." when using id_tap/host_callback in a multihost TPU setup - since we can only actually obtain the outfeeds for local devices. PiperOrigin-RevId: 351336978 Change-Id: I68f8b4345c9d01ef197dd400386a8bc04a59b0d2
This commit is contained in:
parent
9f14841393
commit
ed2134ae43
@ -329,13 +329,18 @@ tf_cc_test(
|
||||
deps = [
|
||||
":outfeed_receiver",
|
||||
"//tensorflow/compiler/jit:xla_cpu_jit",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/pjrt:cpu_device",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||
"//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/synchronization",
|
||||
],
|
||||
)
|
||||
|
@ -231,7 +231,7 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
|
||||
callback_ = callback;
|
||||
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
|
||||
for (const auto& client : clients) {
|
||||
for (auto device : client->devices()) {
|
||||
for (auto device : client->local_devices()) {
|
||||
devices_.push_back(device);
|
||||
}
|
||||
}
|
||||
|
@ -18,10 +18,13 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "absl/synchronization/mutex.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
|
||||
namespace xla {
|
||||
@ -72,6 +75,35 @@ class Accumulator {
|
||||
std::vector<Data> received_ TF_GUARDED_BY(mutex_);
|
||||
};
|
||||
|
||||
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClientWithNonLocalDevice() {
|
||||
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||
PlatformUtil::GetPlatform("Host"));
|
||||
if (platform->VisibleDeviceCount() <= 0) {
|
||||
return FailedPrecondition("CPU platform has no visible devices.");
|
||||
}
|
||||
LocalClientOptions options;
|
||||
options.set_platform(platform);
|
||||
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||
ClientLibrary::GetOrCreateLocalClient(options));
|
||||
|
||||
se::StreamExecutorConfig config(0);
|
||||
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||
platform->GetExecutor(config));
|
||||
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||
executor, client, LocalDeviceState::kSynchronous, /*asynchronous=*/true,
|
||||
/*allow_event_reuse=*/false);
|
||||
|
||||
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||
devices.push_back(absl::make_unique<CpuDevice>(0, std::move(device_state)));
|
||||
devices.push_back(absl::make_unique<CpuDevice>(1, nullptr));
|
||||
|
||||
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||
kCpuName, client, std::move(devices), /*host_id=*/0,
|
||||
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||
/*should_stage_host_to_device_transfers=*/false,
|
||||
/*gpu_run_options=*/nullptr));
|
||||
}
|
||||
|
||||
TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
|
||||
GetCpuClient(true));
|
||||
@ -253,6 +285,39 @@ TEST(OutfeedReceiverTest, InvalidConsumerIdError) {
|
||||
testing::HasSubstr("Consumer ID cannot be a reserved value"));
|
||||
}
|
||||
|
||||
TEST(OutfeedReceiverTest, NonLocalDevicesIgnored) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
|
||||
GetCpuClientWithNonLocalDevice());
|
||||
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||
|
||||
auto receiver = absl::make_unique<Accumulator>();
|
||||
OutfeedReceiver::Callback callback =
|
||||
[&receiver](PjRtDevice* device, uint32_t consumer_id,
|
||||
std::shared_ptr<Literal> data) {
|
||||
receiver->Receive(consumer_id, data);
|
||||
};
|
||||
auto outfeed_receiver =
|
||||
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||
outfeed_receiver->Start();
|
||||
|
||||
XlaBuilder builder("execute_test_outfeed");
|
||||
constexpr int consumer_id0 = 5;
|
||||
const Shape shape0 = ShapeUtil::MakeShape(U32, {16});
|
||||
XlaOp data = Iota(&builder, shape0, 0);
|
||||
XlaOp send = outfeed_receiver
|
||||
->AddOutfeedToBuilder(&builder, CreateToken(&builder),
|
||||
consumer_id0, {data})
|
||||
.ValueOrDie();
|
||||
EXPECT_TRUE(CompileAndExecute(&builder, send, 0, cpu_client.get()).ok());
|
||||
|
||||
// Shutdown the receiver, to force it to wait to deliver the callbacks.
|
||||
outfeed_receiver = nullptr;
|
||||
std::vector<Accumulator::Data> received = receiver->received();
|
||||
EXPECT_EQ(1, received.size());
|
||||
EXPECT_EQ(consumer_id0, received[0].consumer_id);
|
||||
EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user