[XLA] Adjust outfeed_receiver to only create threads for local devices.
Otherwise, we get an error when using this with non-local devices, for example "Invalid argument: Device TPU_10(host=0,(1,1,0,0)) is not a local device." when using id_tap/host_callback in a multihost TPU setup - since we can only actually obtain the outfeeds for local devices. PiperOrigin-RevId: 351336978 Change-Id: I68f8b4345c9d01ef197dd400386a8bc04a59b0d2
This commit is contained in:
parent
9f14841393
commit
ed2134ae43
@ -329,13 +329,18 @@ tf_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":outfeed_receiver",
|
":outfeed_receiver",
|
||||||
"//tensorflow/compiler/jit:xla_cpu_jit",
|
"//tensorflow/compiler/jit:xla_cpu_jit",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
"//tensorflow/compiler/xla/client:executable_build_options",
|
"//tensorflow/compiler/xla/client:executable_build_options",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
"//tensorflow/compiler/xla/pjrt:cpu_device",
|
"//tensorflow/compiler/xla/pjrt:cpu_device",
|
||||||
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
"//tensorflow/compiler/xla/pjrt:pjrt_client",
|
||||||
|
"//tensorflow/compiler/xla/pjrt:pjrt_stream_executor_client",
|
||||||
|
"//tensorflow/compiler/xla/service:platform_util",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/synchronization",
|
"@com_google_absl//absl/synchronization",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -231,7 +231,7 @@ OutfeedReceiverImpl::OutfeedReceiverImpl(
|
|||||||
callback_ = callback;
|
callback_ = callback;
|
||||||
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
|
max_callback_queue_size_bytes_ = max_callback_queue_size_bytes;
|
||||||
for (const auto& client : clients) {
|
for (const auto& client : clients) {
|
||||||
for (auto device : client->devices()) {
|
for (auto device : client->local_devices()) {
|
||||||
devices_.push_back(device);
|
devices_.push_back(device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -18,10 +18,13 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "absl/synchronization/mutex.h"
|
#include "absl/synchronization/mutex.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
#include "tensorflow/compiler/xla/client/executable_build_options.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
|
#include "tensorflow/compiler/xla/pjrt/cpu_device.h"
|
||||||
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
|
||||||
|
#include "tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -72,6 +75,35 @@ class Accumulator {
|
|||||||
std::vector<Data> received_ TF_GUARDED_BY(mutex_);
|
std::vector<Data> received_ TF_GUARDED_BY(mutex_);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
StatusOr<std::unique_ptr<PjRtClient>> GetCpuClientWithNonLocalDevice() {
|
||||||
|
TF_ASSIGN_OR_RETURN(se::Platform * platform,
|
||||||
|
PlatformUtil::GetPlatform("Host"));
|
||||||
|
if (platform->VisibleDeviceCount() <= 0) {
|
||||||
|
return FailedPrecondition("CPU platform has no visible devices.");
|
||||||
|
}
|
||||||
|
LocalClientOptions options;
|
||||||
|
options.set_platform(platform);
|
||||||
|
TF_ASSIGN_OR_RETURN(LocalClient * client,
|
||||||
|
ClientLibrary::GetOrCreateLocalClient(options));
|
||||||
|
|
||||||
|
se::StreamExecutorConfig config(0);
|
||||||
|
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
|
||||||
|
platform->GetExecutor(config));
|
||||||
|
auto device_state = absl::make_unique<LocalDeviceState>(
|
||||||
|
executor, client, LocalDeviceState::kSynchronous, /*asynchronous=*/true,
|
||||||
|
/*allow_event_reuse=*/false);
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
|
||||||
|
devices.push_back(absl::make_unique<CpuDevice>(0, std::move(device_state)));
|
||||||
|
devices.push_back(absl::make_unique<CpuDevice>(1, nullptr));
|
||||||
|
|
||||||
|
return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>(
|
||||||
|
kCpuName, client, std::move(devices), /*host_id=*/0,
|
||||||
|
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
|
||||||
|
/*should_stage_host_to_device_transfers=*/false,
|
||||||
|
/*gpu_run_options=*/nullptr));
|
||||||
|
}
|
||||||
|
|
||||||
TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
|
TEST(OutfeedReceiverTest, ReceiveOutfeedSimple) {
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
|
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
|
||||||
GetCpuClient(true));
|
GetCpuClient(true));
|
||||||
@ -253,6 +285,39 @@ TEST(OutfeedReceiverTest, InvalidConsumerIdError) {
|
|||||||
testing::HasSubstr("Consumer ID cannot be a reserved value"));
|
testing::HasSubstr("Consumer ID cannot be a reserved value"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(OutfeedReceiverTest, NonLocalDevicesIgnored) {
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client,
|
||||||
|
GetCpuClientWithNonLocalDevice());
|
||||||
|
std::vector<PjRtClient*> clients{cpu_client.get()};
|
||||||
|
|
||||||
|
auto receiver = absl::make_unique<Accumulator>();
|
||||||
|
OutfeedReceiver::Callback callback =
|
||||||
|
[&receiver](PjRtDevice* device, uint32_t consumer_id,
|
||||||
|
std::shared_ptr<Literal> data) {
|
||||||
|
receiver->Receive(consumer_id, data);
|
||||||
|
};
|
||||||
|
auto outfeed_receiver =
|
||||||
|
std::make_shared<OutfeedReceiver>(callback, clients, 128);
|
||||||
|
outfeed_receiver->Start();
|
||||||
|
|
||||||
|
XlaBuilder builder("execute_test_outfeed");
|
||||||
|
constexpr int consumer_id0 = 5;
|
||||||
|
const Shape shape0 = ShapeUtil::MakeShape(U32, {16});
|
||||||
|
XlaOp data = Iota(&builder, shape0, 0);
|
||||||
|
XlaOp send = outfeed_receiver
|
||||||
|
->AddOutfeedToBuilder(&builder, CreateToken(&builder),
|
||||||
|
consumer_id0, {data})
|
||||||
|
.ValueOrDie();
|
||||||
|
EXPECT_TRUE(CompileAndExecute(&builder, send, 0, cpu_client.get()).ok());
|
||||||
|
|
||||||
|
// Shutdown the receiver, to force it to wait to deliver the callbacks.
|
||||||
|
outfeed_receiver = nullptr;
|
||||||
|
std::vector<Accumulator::Data> received = receiver->received();
|
||||||
|
EXPECT_EQ(1, received.size());
|
||||||
|
EXPECT_EQ(consumer_id0, received[0].consumer_id);
|
||||||
|
EXPECT_EQ(ShapeUtil::MakeTupleShape({shape0}), received[0].data->shape());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user