From 8a72c4466a0f44de9bd1cfbf47a701727731036c Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 21 Feb 2020 13:56:28 -0800 Subject: [PATCH] [XLA:GPU] Add experimental, lightly tested support for multi-host and multi-process NCCL AllReduce. This change makes several API changes: * we allow the client to provide a mapping from the local device ordinals on the machine to global device IDs. If provided, we interpret the device IDs in the DeviceAssignment provided by the client as global IDs, not as local device ordinals. This allows us to describe computations that cross a host boundary. * we allow the client to provide a callback for manufacturing a ncclUniqueId for a particular subset of global devices. The idea is that the client should use some other distributed system of their own (e.g., MPI) to share ncclUniqueId values needed for a computation. NCCL allows for cross-host/process collectives iff the same ncclUniqueId value is used. Refactors the common collective logic and the NCCL collective logic in particular to support a local/global distinction. PiperOrigin-RevId: 296505571 Change-Id: I5ed42d65597b0960df78890745421f77e9789ba3 --- .../compiler/xla/executable_run_options.cc | 11 + .../compiler/xla/executable_run_options.h | 8 + .../compiler/xla/refcounting_hash_map.h | 17 +- .../compiler/xla/refcounting_hash_map_test.cc | 26 +- tensorflow/compiler/xla/service/BUILD | 1 + .../xla/service/collective_ops_utils.cc | 4 +- .../xla/service/collective_ops_utils.h | 51 ++-- .../compiler/xla/service/cpu/cpu_runtime.cc | 42 ++-- tensorflow/compiler/xla/service/gpu/BUILD | 19 ++ .../service/gpu/collective_permute_thunk.cc | 11 +- .../xla/service/gpu/dummy_all_reduce_thunk.cc | 2 +- .../xla/service/gpu/gpu_executable.cc | 11 +- .../service/gpu/gpu_executable_run_options.cc | 62 +++++ .../service/gpu/gpu_executable_run_options.h | 90 +++++++ .../xla/service/gpu/nccl_all_reduce_thunk.cc | 226 ++++++++++++------ .../xla/service/gpu/nccl_all_reduce_thunk.h | 3 +- tensorflow/compiler/xla/service/gpu/thunk.h | 3 + .../compiler/xla/tests/collective_ops_test.cc | 2 +- 18 files changed, 450 insertions(+), 139 deletions(-) create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc create mode 100644 tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 1cfb449ebd0..452c87b23b7 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -100,6 +100,17 @@ const DeviceAssignment* ExecutableRunOptions::device_assignment() const { return device_assignment_; } +ExecutableRunOptions& ExecutableRunOptions::set_gpu_executable_run_options( + const GpuExecutableRunOptions* gpu_executable_run_options) { + gpu_executable_run_options_ = gpu_executable_run_options; + return *this; +} + +const GpuExecutableRunOptions* +ExecutableRunOptions::gpu_executable_run_options() const { + return gpu_executable_run_options_; +} + ExecutableRunOptions& ExecutableRunOptions::set_rng_seed(int rng_seed) { rng_seed_ = rng_seed; return *this; diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index ed67bfbeb0d..b44d5f13b68 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -38,6 +38,7 @@ namespace xla { class DeviceAssignment; class ExecutionProfile; +class GpuExecutableRunOptions; // A unique identifier for a particular "logical execution" of an XLA model. // @@ -137,6 +138,12 @@ class ExecutableRunOptions { return then_execute_function_; } + // GPU-backend specific options. These are kept out-of-line to avoid bloating + // the size of this dependency for CPU-only AOT builds. + ExecutableRunOptions& set_gpu_executable_run_options( + const GpuExecutableRunOptions* gpu_executable_run_options); + const GpuExecutableRunOptions* gpu_executable_run_options() const; + private: stream_executor::DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; @@ -148,6 +155,7 @@ class ExecutableRunOptions { stream_executor::Stream* host_to_device_stream_ = nullptr; ThenExecuteFunction* then_execute_function_ = nullptr; RunId run_id_; + const GpuExecutableRunOptions* gpu_executable_run_options_ = nullptr; }; } // namespace xla diff --git a/tensorflow/compiler/xla/refcounting_hash_map.h b/tensorflow/compiler/xla/refcounting_hash_map.h index 3ff6a50d85f..efa1b9e3a50 100644 --- a/tensorflow/compiler/xla/refcounting_hash_map.h +++ b/tensorflow/compiler/xla/refcounting_hash_map.h @@ -42,13 +42,7 @@ template class RefcountingHashMap { public: // Default-constructs new values. - RefcountingHashMap() - : value_factory_([](const K&) { return absl::make_unique(); }) {} - - // Constructs new values according to the given factory function. - explicit RefcountingHashMap( - std::function(const K&)> value_factory) - : value_factory_(std::move(value_factory)) {} + RefcountingHashMap() = default; // Not copyable or movable because this contains internal pointers (namely, // instances of Deleter contain pointers to `this` and into `map_`). @@ -60,8 +54,10 @@ class RefcountingHashMap { // Gets the value for the given key. // // If the map doesn't contain a live value for the key, constructs one - // according to the factory passed to the map's constructor. - std::shared_ptr operator[](const K& key) { + // using `value_factory`. + std::shared_ptr GetOrCreateIfAbsent( + const K& key, + const std::function(const K&)>& value_factory) { absl::MutexLock lock(&mu_); auto it = map_.find(key); // We ensure that the entry has not expired in case deleter was running when @@ -76,7 +72,7 @@ class RefcountingHashMap { // Create entry in the map and then set its value, so the value can // contain a pointer back into the map. it = map_.emplace(key, std::weak_ptr()).first; - std::shared_ptr value(value_factory_(key).release(), + std::shared_ptr value(value_factory(key).release(), Deleter{&it->first, this}); it->second = value; // Set the weak ptr to the shared ptr. return value; @@ -112,7 +108,6 @@ class RefcountingHashMap { } }; - std::function(const K&)> value_factory_; absl::Mutex mu_; absl::node_hash_map> map_ ABSL_GUARDED_BY(mu_); }; diff --git a/tensorflow/compiler/xla/refcounting_hash_map_test.cc b/tensorflow/compiler/xla/refcounting_hash_map_test.cc index 753c30dafbe..acb7d7afb46 100644 --- a/tensorflow/compiler/xla/refcounting_hash_map_test.cc +++ b/tensorflow/compiler/xla/refcounting_hash_map_test.cc @@ -47,22 +47,25 @@ struct DeleteNotifier { TEST(RefcountingHashMapTest, PointerIdentity) { RefcountingHashMap m; - std::shared_ptr a = m[0]; - std::shared_ptr b = m[0]; - std::shared_ptr c = m[1]; + auto factory = [](const int&) { return absl::make_unique(); }; + std::shared_ptr a = m.GetOrCreateIfAbsent(0, factory); + std::shared_ptr b = m.GetOrCreateIfAbsent(0, factory); + std::shared_ptr c = m.GetOrCreateIfAbsent(1, factory); EXPECT_EQ(a.get(), b.get()); EXPECT_NE(a.get(), c.get()); } TEST(RefcountingHashMapTest, DefaultInitialized) { RefcountingHashMap m; - EXPECT_EQ(*m[42], 0); + auto factory = [](const int&) { return absl::make_unique(); }; + EXPECT_EQ(*m.GetOrCreateIfAbsent(42, factory), 0); } TEST(RefcountingHashMapTest, DeletesEagerly) { RefcountingHashMap m; bool deleted = false; - auto handle = m[0]; + auto factory = [](const int&) { return absl::make_unique(); }; + auto handle = m.GetOrCreateIfAbsent(0, factory); handle->fn = [&] { deleted = true; }; EXPECT_FALSE(deleted); handle = nullptr; @@ -70,10 +73,10 @@ TEST(RefcountingHashMapTest, DeletesEagerly) { } TEST(RefcountingHashMapTest, CustomFactory) { - RefcountingHashMap m( - [](const int& x) { return absl::make_unique(x + 1); }); - EXPECT_EQ(*m[0], 1); - EXPECT_EQ(*m[100], 101); + RefcountingHashMap m; + auto factory = [](const int& x) { return absl::make_unique(x + 1); }; + EXPECT_EQ(*m.GetOrCreateIfAbsent(0, factory), 1); + EXPECT_EQ(*m.GetOrCreateIfAbsent(100, factory), 101); } TEST(RefcountingHashMapTest, ForEachEmpty) { @@ -85,8 +88,9 @@ TEST(RefcountingHashMapTest, ForEachEmpty) { TEST(RefcountingHashMapTest, ForEachNonempty) { RefcountingHashMap m; - auto a = m[0]; - auto b = m[1]; + auto factory = [](const int&) { return absl::make_unique(); }; + auto a = m.GetOrCreateIfAbsent(0, factory); + auto b = m.GetOrCreateIfAbsent(1, factory); std::vector seen_keys; std::vector seen_values; diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index da50e92de32..12f7cd4654c 100755 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4585,6 +4585,7 @@ cc_library( "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:pattern_matcher", + "//tensorflow/compiler/xla/service/gpu:gpu_executable_run_options", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", # fixdeps: keep "//tensorflow/stream_executor/lib", diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.cc b/tensorflow/compiler/xla/service/collective_ops_utils.cc index cfe586c6c0b..a4eba334f31 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.cc +++ b/tensorflow/compiler/xla/service/collective_ops_utils.cc @@ -44,7 +44,7 @@ absl::optional MatchReductionComputation( } StatusOr> GetParticipatingReplicas( - int64 device_ordinal, absl::Span replica_groups, + GlobalDeviceId device_id, absl::Span replica_groups, int64 total_replica_count, const DeviceAssignment& device_assn) { std::vector participating_replicas; @@ -58,7 +58,7 @@ StatusOr> GetParticipatingReplicas( // Use the DeviceAssignment to figure out our replica-id. TF_ASSIGN_OR_RETURN(int replica_id, - device_assn.ReplicaIdForDeviceOrdinal(device_ordinal)); + device_assn.ReplicaIdForDeviceOrdinal(device_id.value())); // Figure out the other replicas that go together with this one. absl::optional replica_group; diff --git a/tensorflow/compiler/xla/service/collective_ops_utils.h b/tensorflow/compiler/xla/service/collective_ops_utils.h index 2524b4190e9..d9b6c48685b 100644 --- a/tensorflow/compiler/xla/service/collective_ops_utils.h +++ b/tensorflow/compiler/xla/service/collective_ops_utils.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/computation_placer.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/pattern_matcher.h" @@ -37,9 +38,9 @@ absl::optional MatchReductionComputation( const HloComputation* computation); // Figures out which devices (named by their replica-ids) are participating in -// the all-reduce subgroup that contains device_ordinal. +// the all-reduce subgroup that contains device_id. StatusOr> GetParticipatingReplicas( - int64 device_ordinal, absl::Span replica_groups, + GlobalDeviceId device_id, absl::Span replica_groups, int64 total_replica_count, const DeviceAssignment& device_assn); // Key that identifies a particular Rendezvous object in our global hashtable. @@ -72,16 +73,18 @@ struct RendezvousKey { }; explicit RendezvousKey(const RunId& run_id, - std::vector participating_replicas, + std::vector global_devices, + int num_local_participants, CollectiveOpKind collective_op_kind, int64 op_id) : run_id(run_id), - participating_replicas(participating_replicas), + global_devices(std::move(global_devices)), + num_local_participants(num_local_participants), collective_op_kind(collective_op_kind), op_id(op_id) {} static RendezvousKey FromInstruction( - const RunId& run_id, std::vector participating_replicas, - const HloInstruction* instr) { + const RunId& run_id, std::vector global_devices, + int num_local_participants, const HloInstruction* instr) { CollectiveOpKind collective_op_kind; int64 op_id; @@ -91,20 +94,19 @@ struct RendezvousKey { : std::make_pair( kCrossReplica, static_cast(instr->GetModule()->unique_id())); - return RendezvousKey(run_id, participating_replicas, collective_op_kind, - op_id); + return RendezvousKey(run_id, std::move(global_devices), + num_local_participants, collective_op_kind, op_id); } - int num_participants() const { return participating_replicas.size(); } - template friend H AbslHashValue(H h, const RendezvousKey& k) { - return H::combine(std::move(h), k.run_id, k.participating_replicas, + return H::combine(std::move(h), k.run_id, k.global_devices, + k.num_local_participants, static_cast(k.collective_op_kind), k.op_id); } friend bool operator==(const RendezvousKey& a, const RendezvousKey& b) { - return a.run_id == b.run_id && - a.participating_replicas == b.participating_replicas && + return a.run_id == b.run_id && a.global_devices == b.global_devices && + a.num_local_participants == b.num_local_participants && a.collective_op_kind == b.collective_op_kind && // a.op_id == b.op_id; } @@ -114,14 +116,15 @@ struct RendezvousKey { string ToString() const { return absl::StrFormat( - "RendezvousKey{run_id=%s, participating_replicas=[%s], " - "collective_op_kind=%d, op_id=%d}", - run_id.ToString(), absl::StrJoin(participating_replicas, ","), - static_cast(collective_op_kind), op_id); + "RendezvousKey{run_id=%s, global_devices=[%s], " + "num_local_participants=%d, collective_op_kind=%d, op_id=%d}", + run_id.ToString(), GlobalDeviceIdsToString(global_devices), + num_local_participants, static_cast(collective_op_kind), op_id); } RunId run_id; - std::vector participating_replicas; + std::vector global_devices; + int num_local_participants; CollectiveOpKind collective_op_kind; int64 op_id; }; @@ -164,10 +167,13 @@ struct AllReduceParticipantData { }; std::vector buffers; se::Stream* stream; + const NcclUniqueIdCallback* nccl_unique_id_callback = nullptr; ReductionKind reduction_kind; - int num_participants() const { return rendezvous_key.num_participants(); } + // For each local all-reduce participant a (global ID, local device ordinal) + // pair for the participant. Participants are in no particular order. + std::vector> local_devices; string ToString() const { std::vector buffer_strs; @@ -303,12 +309,13 @@ class Rendezvous { const RendezvousKey key_; tensorflow::BlockingCounter all_participants_present_{ - key_.num_participants()}; - tensorflow::BlockingCounter done_{key_.num_participants()}; + key_.num_local_participants}; + tensorflow::BlockingCounter done_{key_.num_local_participants}; // tensorflow::BlockingCounter returned by SubmitParticipant. std::shared_ptr returned_blocking_counter_{ - std::make_shared(key_.num_participants())}; + std::make_shared( + key_.num_local_participants)}; }; } // end namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 98c23b679fa..60e184411e9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -382,10 +382,7 @@ class CpuAllReduceRendezvous : public xla::Rendezvous { xla::RefcountingHashMap& GlobalRendezvousMap() { static auto& m = - *new xla::RefcountingHashMap( - [](const xla::RendezvousKey& k) { - return absl::make_unique(k); - }); + *new xla::RefcountingHashMap; return m; } @@ -411,18 +408,28 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( std::vector group = xla::ParseReplicaGroupsOnly(replica_groups_serialized).ValueOrDie(); - xla::int32 replica_count = run_options->device_assignment()->replica_count(); - std::vector participating_replicas_vec = - xla::GetParticipatingReplicas(device_ordinal, group, replica_count, + const xla::DeviceAssignment& device_assignment = + *run_options->device_assignment(); + xla::int32 replica_count = device_assignment.replica_count(); + CHECK_EQ(device_assignment.computation_count(), 1); + std::vector participating_replicas = + xla::GetParticipatingReplicas(xla::GlobalDeviceId(device_ordinal), group, + replica_count, *run_options->device_assignment()) .ValueOrDie(); xla::RendezvousKey::CollectiveOpKind op_kind = channel_id_present ? xla::RendezvousKey::kCrossModule : xla::RendezvousKey::kCrossReplica; - xla::RendezvousKey rendezvous_key(run_options->run_id(), - participating_replicas_vec, op_kind, op_id); - + std::vector participating_devices; + participating_devices.reserve(participating_replicas.size()); + for (xla::int64 replica : participating_replicas) { + participating_devices.push_back( + xla::GlobalDeviceId(device_assignment(replica, 0))); + } + xla::RendezvousKey rendezvous_key( + run_options->run_id(), std::move(participating_devices), + participating_replicas.size(), op_kind, op_id); auto shape_str = ShapeString(shape_ptr, shape_length); VLOG(2) << "All-reduce input/output shape : " << shape_str; @@ -444,10 +451,17 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_AllReduce( participant.buffers = {buffer}; participant.reduction_kind = static_cast(reduction_kind); - TF_CHECK_OK( - CpuAllReduceRendezvous::SubmitParticipant( - [&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant) - .status()); + auto make_cpu_rendezvous = [](const xla::RendezvousKey& k) { + return absl::make_unique(k); + }; + + TF_CHECK_OK(CpuAllReduceRendezvous::SubmitParticipant( + [&] { + return GlobalRendezvousMap().GetOrCreateIfAbsent( + rendezvous_key, make_cpu_rendezvous); + }, + participant) + .status()); } TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_ReplicaId( diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d13eca30cdc..86da500b1dd 100755 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -54,6 +54,22 @@ tf_proto_library_cc( protodeps = ["//tensorflow/compiler/xla:xla_data_proto"], ) +cc_library( + name = "gpu_executable_run_options", + srcs = ["gpu_executable_run_options.cc"], + hdrs = ["gpu_executable_run_options.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib_internal", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + ], +) + cc_library( name = "gpu_constants", srcs = ["gpu_constants.cc"], @@ -385,6 +401,7 @@ cc_library( hdrs = ["thunk.h"], deps = [ ":buffer_allocations", + ":gpu_executable_run_options", ":hlo_execution_profiler", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla/service:hlo", @@ -413,6 +430,7 @@ tf_cuda_library( ":buffer_allocations", ":hlo_execution_profiler", ":thunk", + ":gpu_executable_run_options", "@com_google_absl//absl/base:core_headers", "//tensorflow/compiler/xla/service:pattern_matcher", "//tensorflow/compiler/xla:refcounting_hash_map", @@ -522,6 +540,7 @@ cc_library( ":cudnn_batchnorm_runner", ":gpu_conv_runner", ":gpu_debug_info_manager", + ":gpu_executable_run_options", ":gpu_types", ":hlo_execution_profiler", ":infeed_manager", diff --git a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc index 2fe359861f8..2a071cd658d 100644 --- a/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/collective_permute_thunk.cc @@ -211,10 +211,7 @@ StatusOr> Rendezvous::SubmitParticipant( // Rendezvous objects are one-time use, so they're removed from this map once // we're through with them. RefcountingHashMap& GlobalRendezvousMap() { - static auto& m = *new RefcountingHashMap( - [](const RendezvousKey& key) { - return absl::make_unique(key); - }); + static auto& m = *new RefcountingHashMap(); return m; } @@ -233,7 +230,11 @@ Status CollectivePermuteThunk::ExecuteOnStream(const ExecuteParams& params) { // Rendezvous with the threads for all other devices that are participating in // this CollectivePermute. RendezvousKey key{params.run_id, params.device_assn->replica_count()}; - std::shared_ptr rendezvous = GlobalRendezvousMap()[key]; + auto rendezvous_factory = [](const RendezvousKey& key) { + return absl::make_unique(key); + }; + std::shared_ptr rendezvous = + GlobalRendezvousMap().GetOrCreateIfAbsent(key, rendezvous_factory); TF_ASSIGN_OR_RETURN(int64 replica_id, params.device_assn->ReplicaIdForDeviceOrdinal( diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc index 7c3d76c1c92..998a3ccb4ee 100644 --- a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc @@ -34,7 +34,7 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { NcclAllReduceThunk::~NcclAllReduceThunk() = default; -/*static*/ absl::flat_hash_set +/*static*/ absl::flat_hash_set NcclAllReduceThunk::DevicesWithOpenNcclChannels() { return {}; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index df44f379b99..d4797e094fd 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/gpu_debug_info_manager.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/gpu_types.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -196,13 +197,21 @@ Status GpuExecutable::ExecuteThunks( VLOG(2) << "Executing the thunk for " << thunk->hlo_instruction()->ToString() << " on stream " << stream_no; + const GpuExecutableRunOptions* gpu_options = + run_options->run_options().gpu_executable_run_options(); Thunk::ExecuteParams thunk_params{ &buffer_allocations, stream, run_options->run_options().run_id(), &profiler, run_options->run_options().device_assignment(), - &deferred_host_callbacks}; + &deferred_host_callbacks, + gpu_options && gpu_options->gpu_global_device_ids() + ? &*gpu_options->gpu_global_device_ids() + : nullptr, + gpu_options && gpu_options->nccl_unique_id_callback() + ? &gpu_options->nccl_unique_id_callback() + : nullptr}; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(thunk_params)); if (thunk_schedule_->Depended(thunk)) { auto finish_event = absl::make_unique(main_stream->parent()); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc new file mode 100644 index 00000000000..b152962eb99 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.cc @@ -0,0 +1,62 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" + +#include "absl/algorithm/container.h" +#include "absl/strings/str_join.h" + +namespace xla { + +std::string GlobalDeviceIdsToString(absl::Span ids) { + std::vector values; + values.reserve(ids.size()); + for (GlobalDeviceId id : ids) { + values.push_back(id.value()); + } + return absl::StrJoin(values, ","); +} + +NcclCliqueKey::NcclCliqueKey(std::vector devices) + : devices_(std::move(devices)) { + absl::c_sort(devices_); + CHECK(absl::c_adjacent_find(devices_) == devices_.end()) + << "Duplicate devices are not allowed: " + << GlobalDeviceIdsToString(devices_); +} + +GpuExecutableRunOptions& GpuExecutableRunOptions::set_gpu_global_device_ids( + absl::optional> gpu_global_device_ids) { + gpu_global_device_ids_ = std::move(gpu_global_device_ids); + return *this; +} + +const absl::optional>& +GpuExecutableRunOptions::gpu_global_device_ids() const { + return gpu_global_device_ids_; +} + +GpuExecutableRunOptions& GpuExecutableRunOptions::set_nccl_unique_id_callback( + NcclUniqueIdCallback nccl_unique_id_callback) { + nccl_unique_id_callback_ = std::move(nccl_unique_id_callback); + return *this; +} + +const NcclUniqueIdCallback& GpuExecutableRunOptions::nccl_unique_id_callback() + const { + return nccl_unique_id_callback_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h new file mode 100644 index 00000000000..7a43c80121b --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_ + +#include +#include +#include + +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/int_type.h" + +namespace xla { + +// Strongly-typed integer type for naming a device globally within a distributed +// system. XLA doesn't have a strong opinion about what global numbering scheme +// is applied to GPUs; the user must provide a local -> global mapping via +// GpuExecutableRunOptions for the local GPUs. +TF_LIB_GTL_DEFINE_INT_TYPE(GlobalDeviceId, int64); + +// Returns a comma-separated string of global device IDs. +std::string GlobalDeviceIdsToString(absl::Span ids); + +// Key for naming up a particular NCCL clique. This is just a set of unique +// device IDs (i.e. GPU IDs). The device IDs must be global within a cluster. +class NcclCliqueKey { + public: + explicit NcclCliqueKey(std::vector devices); + + template + friend H AbslHashValue(H h, const NcclCliqueKey& k) { + return H::combine(std::move(h), k.devices_); + } + friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { + return a.devices_ == b.devices_; + } + + const std::vector& devices() const { return devices_; } + + private: + std::vector devices_; +}; + +using NcclUniqueIdCallback = + std::function(const NcclCliqueKey&)>; + +// GPU-specific executable options. +// We keep these separate from ExecutableRunOptions to avoid adding +// dependencies to ExecutableRunOptions. +class GpuExecutableRunOptions { + public: + // Sets a mapping from local device ordinals to global device IDs. + // Used only on NVidia GPUs for cross-host NCCL collectives. If set, the + // elements of `device_assignment` are interpreted as global device IDs, not + // local device ordinals. + GpuExecutableRunOptions& set_gpu_global_device_ids( + absl::optional> gpu_global_device_ids); + const absl::optional>& gpu_global_device_ids() + const; + + // Callback that returns a ncclUniqueId encoded as a string for a group of + // communicating GPU devices. Used only on NVidia GPUs. + GpuExecutableRunOptions& set_nccl_unique_id_callback( + NcclUniqueIdCallback nccl_unique_id_callback); + const NcclUniqueIdCallback& nccl_unique_id_callback() const; + + private: + absl::optional> gpu_global_device_ids_; + NcclUniqueIdCallback nccl_unique_id_callback_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_EXECUTABLE_RUN_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 4498793113a..719429771af 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/refcounting_hash_map.h" #include "tensorflow/compiler/xla/service/collective_ops_utils.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/util.h" @@ -179,27 +180,18 @@ absl::optional DatatypeToNccl(PrimitiveType element_type) { } } -// Key for looking up a particular NCCL clique. This is just a set of unique -// device ordinals (i.e. GPU IDs). -struct NcclCliqueKey { - explicit NcclCliqueKey(absl::Span devices) - : devices(devices.begin(), devices.end()) { - absl::c_sort(this->devices); - CHECK(absl::c_adjacent_find(devices) == devices.end()) - << "Duplicate devices are not allowed: " - << absl::StrJoin(devices, ", "); +Status StringToNcclUniqueId(const std::string& str_id, ncclUniqueId* nccl_id) { + if (str_id.size() != NCCL_UNIQUE_ID_BYTES) { + return InvalidArgument( + "ncclUniqueId string must have %d bytes, got %d bytes", str_id.size(), + NCCL_UNIQUE_ID_BYTES); } - - template - friend H AbslHashValue(H h, const NcclCliqueKey& k) { - return H::combine(std::move(h), k.devices); - } - friend bool operator==(const NcclCliqueKey& a, const NcclCliqueKey& b) { - return a.devices == b.devices; - } - - std::vector devices; -}; + // NcclUniqueId is internally just a char[]. + static_assert(sizeof(ncclUniqueId) == NCCL_UNIQUE_ID_BYTES, + "NCCL_UNIQUE_ID_BYTES"); + std::memcpy(static_cast(nccl_id), str_id.data(), NCCL_UNIQUE_ID_BYTES); + return Status::OK(); +} // Owns a clique of NCCL comms which can be used for collective operations among // a particular set of GPUs. @@ -216,20 +208,29 @@ struct NcclCliqueKey { // GPUs, you'll need a different clique. class NcclClique { public: - explicit NcclClique(absl::Span devices) - : devices_(devices.begin(), devices.end()) { - absl::c_sort(devices_); - status_ = Init(); + explicit NcclClique( + int64 num_global_devices, std::vector local_device_ordinals, + std::vector local_device_ranks, + const StatusOr>& nccl_unique_id) + : num_global_devices_(num_global_devices), + local_device_ordinals_(std::move(local_device_ordinals)), + local_device_ranks_(std::move(local_device_ranks)) { + CHECK_EQ(local_device_ordinals_.size(), local_device_ranks_.size()); + // It's unusual to pass a StatusOr<> into a class, but since this class + // already has a erroneous state, it turns out to be a little easier to + // implement this way than to change RefcountingHashMap. + status_ = Init(nccl_unique_id); } Status status() { return status_; } - absl::Span devices() { - TF_CHECK_OK(status_); - return devices_; - } - ncclComm_t comm(int64 device) { - int64 idx = std::distance(devices_.begin(), absl::c_find(devices_, device)); + // A NCCL communicator is the NCCL state associated with a participant (rank) + // in a reduction. This method returns the state associated with a particular + // local device ordinal. + ncclComm_t comm(int64 device_ordinal) { + int64 idx = + std::distance(local_device_ordinals_.begin(), + absl::c_find(local_device_ordinals_, device_ordinal)); return comms_.at(idx).comm(); } @@ -249,10 +250,12 @@ class NcclClique { } private: - Status Init() { + Status Init( + const StatusOr>& maybe_nccl_unique_id) { VLOG(3) << absl::StreamFormat( - "Initializing nccl comms for participant devices {%s}", - absl::StrJoin(devices_, ", ")); + "Initializing nccl comms for participant device ordinals %s ranks {%s}", + absl::StrJoin(local_device_ordinals_, ", "), + absl::StrJoin(local_device_ranks_, ", ")); // Restore CUDA device after running this. XLA shouldn't care, but maybe // another consumer does. @@ -264,15 +267,23 @@ class NcclClique { // When using ncclGroupStart/End it seems that the ncclComm_t's are not // populated until the End() call. This unfortunately makes error handling // tricky. - std::vector raw_comms(devices_.size(), nullptr); + std::vector raw_comms(local_device_ordinals_.size(), nullptr); + TF_ASSIGN_OR_RETURN(const absl::optional& nccl_id_string, + maybe_nccl_unique_id); + ncclUniqueId nccl_id; - XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id)); + if (nccl_id_string) { + TF_RETURN_IF_ERROR(StringToNcclUniqueId(*nccl_id_string, &nccl_id)); + } else { + XLA_CUDA_RETURN_IF_ERROR(ncclGetUniqueId(&nccl_id)); + } XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart()); Status status = [&] { - for (int i = 0; i < devices_.size(); ++i) { - XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(devices_[i])); - XLA_CUDA_RETURN_IF_ERROR( - ncclCommInitRank(&raw_comms[i], devices_.size(), nccl_id, i)); + for (int i = 0; i < local_device_ordinals_.size(); ++i) { + XLA_CUDA_RETURN_IF_ERROR(cudaSetDevice(local_device_ordinals_[i])); + XLA_CUDA_RETURN_IF_ERROR(ncclCommInitRank(&raw_comms[i], + num_global_devices_, nccl_id, + local_device_ranks_.at(i))); } return Status::OK(); }(); @@ -282,9 +293,9 @@ class NcclClique { // Populate comms_ from the raw comms we created above. If we encountered // an error above we'll later clear comms_ thus destroying any raw comms // that were created before the error. - for (int i = 0; i < devices_.size(); ++i) { - VLOG(3) << absl::StreamFormat("Device %d assigned ncclComm %p", - devices_[i], raw_comms[i]); + for (int i = 0; i < local_device_ordinals_.size(); ++i) { + VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p", + local_device_ordinals_[i], raw_comms[i]); CHECK(raw_comms[i] != nullptr || !status.ok()); comms_.emplace_back(raw_comms[i]); } @@ -296,7 +307,11 @@ class NcclClique { } Status status_; - std::vector devices_; + int64 num_global_devices_; + std::vector local_device_ordinals_; + // NCCL communicator rank for each local device. The rank of a device is equal + // to the offset of the local device in the global device set. + std::vector local_device_ranks_; std::vector comms_; // This mutex is in a unique_ptr so NcclClique can be movable. @@ -312,10 +327,7 @@ class NcclClique { // have one clique alive for a given set of GPUs. This means that a process // will never do two collective operations concurrently on the same set of GPUs. RefcountingHashMap& GlobalNcclCliqueMap() { - static auto& m = *new RefcountingHashMap( - [](const NcclCliqueKey& key) { - return absl::make_unique(key.devices); - }); + static auto& m = *new RefcountingHashMap(); return m; } @@ -341,10 +353,7 @@ class RendezvousNcclAllReduce : public Rendezvous> { RefcountingHashMap& GlobalRendezvousMap() { static auto& m = - *new RefcountingHashMap( - [](const RendezvousKey& k) { - return absl::make_unique(k); - }); + *new RefcountingHashMap(); return m; } @@ -365,12 +374,46 @@ RendezvousNcclAllReduce::SubmitParticipantImpl( // ensuring that there's a NCCL clique available for us to use. primary = !initialized_; + TF_RET_CHECK(participant.local_devices.size() == + participant.rendezvous_key.num_local_participants); + // Look up or create the NCCL clique for this set of devices. - std::vector devices; - for (const auto& p : participants_) { - devices.push_back(p.device_ordinal); - } - clique = GlobalNcclCliqueMap()[NcclCliqueKey(devices)]; + NcclCliqueKey clique_key(participant.rendezvous_key.global_devices); + + auto clique_factory = + [&](const NcclCliqueKey& key) -> std::unique_ptr { + std::vector local_device_ranks; + std::vector local_device_ordinals; + local_device_ranks.reserve(participant.local_devices.size()); + local_device_ordinals.reserve(participant.local_devices.size()); + for (const auto& l : participant.local_devices) { + auto it = + absl::c_find(participant.rendezvous_key.global_devices, l.first); + CHECK(it != participant.rendezvous_key.global_devices.end()) << l.first; + local_device_ranks.push_back(std::distance( + participant.rendezvous_key.global_devices.begin(), it)); + local_device_ordinals.push_back(l.second); + } + StatusOr> nccl_unique_id; + if (participant.nccl_unique_id_callback) { + nccl_unique_id = (*participant.nccl_unique_id_callback)(clique_key); + } else { + if (participant.rendezvous_key.global_devices.size() != + participant.rendezvous_key.num_local_participants) { + nccl_unique_id = InvalidArgument( + "Multihost AllReduce on GPU requires a nccl_unique_id_callback " + "to be provided by the client."); + } else { + nccl_unique_id = absl::optional(); + } + } + return absl::make_unique( + participant.rendezvous_key.global_devices.size(), + std::move(local_device_ordinals), std::move(local_device_ranks), + nccl_unique_id); + }; + clique = + GlobalNcclCliqueMap().GetOrCreateIfAbsent(clique_key, clique_factory); if (primary) { VLOG(3) << "Primary initializing accounting data."; @@ -463,12 +506,12 @@ struct NcclAllReduceThunk::AuxData { crs->IsCrossReplicaAllReduce() && operands_are_supported(); } -/*static*/ absl::flat_hash_set +/*static*/ absl::flat_hash_set NcclAllReduceThunk::DevicesWithOpenNcclChannels() { - absl::flat_hash_set devices; + absl::flat_hash_set devices; GlobalNcclCliqueMap().ForEach( [&](const NcclCliqueKey& k, const std::shared_ptr&) { - devices.insert(k.devices.begin(), k.devices.end()); + devices.insert(k.devices().begin(), k.devices().end()); }); return devices; } @@ -491,23 +534,57 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { params.profiler->MakeScopedInstructionProfiler(hlo_instruction()); auto* instr = Cast(hlo_instruction()); - int64 device_ordinal = params.stream->parent()->device_ordinal(); + int64 local_device_ordinal = params.stream->parent()->device_ordinal(); + GlobalDeviceId global_device_id; + if (params.gpu_global_device_ids) { + TF_RET_CHECK(0 <= local_device_ordinal && + local_device_ordinal < params.gpu_global_device_ids->size()); + global_device_id = (*params.gpu_global_device_ids)[local_device_ordinal]; + } else { + // No local -> global mapping was provided; assume the identity mapping. + global_device_id = GlobalDeviceId(local_device_ordinal); + } + // Determines the set of global and local devices that are participating in + // the same collective group as the caller. TF_ASSIGN_OR_RETURN( - std::vector participating_replicas, - GetParticipatingReplicas(device_ordinal, instr->replica_groups(), + std::vector global_participating_replicas, + GetParticipatingReplicas(global_device_id, instr->replica_groups(), replica_count_, *params.device_assn)); + std::vector global_devices; + std::vector> local_devices; + local_devices.reserve(global_participating_replicas.size()); + global_devices.reserve(global_participating_replicas.size()); + TF_RET_CHECK(params.device_assn->computation_count() == 1) + << params.device_assn->ToString(); + for (int64 replica : global_participating_replicas) { + GlobalDeviceId global_device( + (*params.device_assn)(replica, /*computation=*/0)); + global_devices.push_back(global_device); + if (!params.gpu_global_device_ids) { + local_devices.emplace_back(global_device, global_device.value()); + } else { + auto it = absl::c_find(*params.gpu_global_device_ids, global_device); + if (it != params.gpu_global_device_ids->end()) { + local_devices.emplace_back( + *it, std::distance(params.gpu_global_device_ids->begin(), it)); + } + } + } + absl::c_sort(global_devices); // Find or create the rendezvous for this collective operation. RendezvousKey rendezvous_key = RendezvousKey::FromInstruction( - params.run_id, participating_replicas, hlo_instruction()); + params.run_id, global_devices, local_devices.size(), hlo_instruction()); VLOG(2) << "Rendezvous key: " << rendezvous_key.ToString() - << ", participating replicas: " - << absl::StrJoin(participating_replicas, ", "); + << ", global participating replicas: " + << absl::StrJoin(global_participating_replicas, ", ") + << ", global participating devices: " + << GlobalDeviceIdsToString(global_devices); AllReduceParticipantData participant(rendezvous_key); - participant.device_ordinal = device_ordinal; + participant.device_ordinal = local_device_ordinal; for (size_t i = 0; i < buffers_.size(); ++i) { const NcclAllReduceThunk::Buffer& buffer = buffers_[i]; AllReduceParticipantData::Buffer pbuffer; @@ -521,15 +598,24 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) { participant.buffers.push_back(pbuffer); } participant.stream = params.stream; + participant.local_devices = std::move(local_devices); + participant.nccl_unique_id_callback = params.nccl_unique_id_callback; auto reduction_kind = MatchReductionComputation(hlo_instruction()->to_apply()); CHECK(reduction_kind.has_value()); participant.reduction_kind = *reduction_kind; - TF_ASSIGN_OR_RETURN( - std::shared_ptr clique, - RendezvousNcclAllReduce::SubmitParticipant( - [&] { return GlobalRendezvousMap()[rendezvous_key]; }, participant)); + auto rendezvous_factory = [](const RendezvousKey& k) { + return absl::make_unique(k); + }; + + TF_ASSIGN_OR_RETURN(std::shared_ptr clique, + RendezvousNcclAllReduce::SubmitParticipant( + [&] { + return GlobalRendezvousMap().GetOrCreateIfAbsent( + rendezvous_key, rendezvous_factory); + }, + participant)); // Keep the clique we used alive for as long as this Thunk lives. Creating // new NCCL cliques is expensive, and this is how we avoid thrashing them. diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 7633a99794f..90091ed2c7b 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -46,7 +47,7 @@ class NcclAllReduceThunk : public Thunk { // (Indeed, because the NCCL channels are a global variable, in the real // world, the value returned here is stale as soon as you read it, so it's not // clear how you *could* use it for anything other than tests.) - static absl::flat_hash_set DevicesWithOpenNcclChannels(); + static absl::flat_hash_set DevicesWithOpenNcclChannels(); // TODO(b/125951860): Support all-reduces with replica groups, i.e. // all-reduces that compute multiple sums across subsets of all replicas. diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index abf829cee00..326c5a20716 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/core/lib/core/status.h" @@ -98,6 +99,8 @@ class Thunk { HloExecutionProfiler* profiler; // never null const DeviceAssignment* device_assn; // never null std::vector>* deferred_host_callbacks; // never null + const std::vector* gpu_global_device_ids; // may be null + const NcclUniqueIdCallback* nccl_unique_id_callback; // may be null }; // Execute the kernel for the thunk on the given stream. This method must be diff --git a/tensorflow/compiler/xla/tests/collective_ops_test.cc b/tensorflow/compiler/xla/tests/collective_ops_test.cc index 5cdf9633ca4..464865506f7 100644 --- a/tensorflow/compiler/xla/tests/collective_ops_test.cc +++ b/tensorflow/compiler/xla/tests/collective_ops_test.cc @@ -159,7 +159,7 @@ DeviceAssignment MakeDeviceAssn(std::vector devices) { } // Shorter alias for this function. -absl::flat_hash_set OpenNcclChannels() { +absl::flat_hash_set OpenNcclChannels() { return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels(); }