diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 3051db3af4a..c00edae9540 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #if GOOGLE_CUDA +#include "absl/container/flat_hash_set.h" #include "absl/synchronization/blocking_counter.h" #include "third_party/nccl/nccl.h" #include "tensorflow/core/lib/core/blocking_counter.h" @@ -76,6 +77,42 @@ struct ParticipantData { // This manager is responsible for establishing communication channels and // ultimately enqueueing the NCCL library operation onto the participating // streams. +// +// Implementation note: We make an effort to avoid initializing nccl +// communciation channels too often, as this is expensive. +// +// Ideally, we'd set up a nccl channel between each pair of devices that needs +// to communicate, and close each channel when the GPUs won't be communicating +// again "for a long time" (because channels hold memory on the GPU). As a +// simplification to this ideal, we adopt the following policy. +// +// - We maintain a set of GPUs that are "actively participating" in +// cross-device communications. That set of GPUs is always connected as a +// clique, using ncclCommInitAll. +// +// - When a NcclAllReduceThunk touches a new GPU, we tear down the old clique +// and build a new, bigger one. +// +// - All GPUs ever touched by a thunk are considered "actively in use" by that +// thunk until the thunk is destroyed. Destroying the thunk decrements the +// refcount of the GPUs it's touched, and if that refcount goes to 0 +// (meaning, some GPUs are no longer in use by any thunk), we tear down the +// clique and build a new, smaller one. +// +// This approximation is justified because: +// +// - Currently the only collective operation we support is AllReduce, which +// requires a clique. When we support point-to-point operations, we may not +// want to build a communication clique. +// +// - Tearing down and creating a new thunk is tantamount to running the whole +// XLA:GPU compiler. This is expensive, so shouldn't happen "too often" to +// cause thrashing here. +// +// - XLA executables already keep resources on the GPU tied to the lifetime of +// the executable (e.g. constants stored in GPU memory), so tying the +// lifetime of the nccl communication channels to the lifetime of the +// executable is consistent. class GlobalRendezvousManager { public: // The GpuExecutable-executing threads call this in order to a) establish the @@ -98,18 +135,38 @@ class GlobalRendezvousManager { return current_generation_; } - private: - // Called by the primary thread to set up the communication links. + // Increments the refcount of a GPU in our accounting of which devices are + // "actively participating" in cross-device operations. // - // TODO(b/125951860): This performs lots of (presumably) unnecessary host-side - // synchronization so that we can be paranoid about semantics in the earliest - // implementation. In the limit we should only need to synchronize host - // replica threads when the "number of replicas" or "participating device - // ordinals" change, to set up a new NCCL "communication" context, at which - // point we can enqueue onto device streams without host synchronization in - // our code -- this will likely be helpful for "lots of little AllReduce" - // cases. - Status InitializeCommunicationChannels() EXCLUSIVE_LOCKS_REQUIRED(mutex_); + // This doesn't actually do anything other than increment the refcount. If + // the GPU added here is novel, we'll rebuild the nccl communication clique + // when we actually go do the communication. + void AddrefParticipatingDevice(int device_ordinal); + + // Decrements the refcount of a set of GPUs in our accounting of which devices + // are "actively participating" in cross-device operations. + // + // If one or more GPUs' refcounts to go 0, we immediately destroy the whole + // nccl communication clique. We'll rebuild a new, smaller clique the next + // time it's used. + void DecrefParticipatingDevices(absl::Span device_ordinals); + + // Gets the set of devices that have a NCCL channel currently open. This is + // primarily for testing. + absl::flat_hash_set DevicesWithOpenNcclChannels() const { + absl::flat_hash_set devices; + tensorflow::mutex_lock lock(mutex_); + for (const auto& kv : comms_) { + devices.insert(kv.first); + } + return devices; + } + + private: + // Destroys the current nccl communication clique and builds a new one + // connecting the given devices. + Status ReinitializeNcclClique(const absl::flat_hash_set& device_ordinals) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); // Called when all necessary participants are present, the functionality // that's implemented by all executing threads lives in here. @@ -118,28 +175,51 @@ class GlobalRendezvousManager { // Puts all state back into a "reset" state for the next generation of // AllReduce requests. void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) { - for (ncclComm_t& comm : comms_) { - ncclCommDestroy(comm); - } - comms_.clear(); participants_.clear(); current_generation_++; initialized_ = false; done_ = absl::nullopt; } - tensorflow::mutex mutex_; + mutable tensorflow::mutex mutex_; tensorflow::condition_variable all_participants_present_; tensorflow::condition_variable deinitialized_; - // Communication handles that correspond to the participants below. - std::vector comms_ GUARDED_BY(mutex_); - Status initialize_status_ GUARDED_BY(mutex_); std::vector participants_ GUARDED_BY(mutex_); int64 current_generation_ GUARDED_BY(mutex_) = 0; bool initialized_ GUARDED_BY(mutex_) = false; + struct Comm { + explicit Comm(ncclComm_t nccl_comm) : nccl_comm(nccl_comm) {} + + // Movable, but not copyable. + Comm(Comm&& c) : nccl_comm(c.nccl_comm) { c.nccl_comm.reset(); } + Comm& operator=(Comm&& c) { + nccl_comm = c.nccl_comm; + c.nccl_comm.reset(); + return *this; + } + Comm(const Comm&) = delete; + Comm& operator=(const Comm&) = delete; + + absl::optional nccl_comm; + + ~Comm() { + if (nccl_comm.has_value()) { + VLOG(3) << absl::StreamFormat("Destroying comm %p", *nccl_comm); + ncclCommDestroy(*nccl_comm); + } + } + }; + // Communication handles for our NCCL clique. Key is device ordinal. + absl::flat_hash_map comms_ GUARDED_BY(mutex_); + + // Refcounts of which devices are "actively participating" in all-reduces. + // These devices don't necessarily have an open comm, but the next time we run + // an operation, we'll create a NCCL clique between all of them. + absl::flat_hash_map device_refcounts_ GUARDED_BY(mutex_); + // The participating threads wait for this to count down in order to know we // can begin the teardown process. absl::optional done_; @@ -151,11 +231,6 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { return participants_.size() >= participant.replica_count; }; - // We remember the participant index at which we are inserted and use that - // same index for referring to auxiliary metadata (e.g. the ncclComm_t handle - // index) below. - int64 index; - { tensorflow::mutex_lock lock(mutex_); @@ -171,7 +246,6 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { "participants; existing: %s; submitted: %s)", participants_.back().ToString(), participant.ToString()); } - index = participants_.size(); participants_.push_back(participant); if (all_participants_present()) { @@ -205,11 +279,35 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { VLOG(3) << "Primary initializing accounting data."; initialized_ = true; done_.emplace(participant.replica_count); - initialize_status_ = InitializeCommunicationChannels(); - VLOG(3) << "Done initializing communication channels; status: " - << initialize_status_; - if (!initialize_status_.ok()) { - DeinitializeGeneration(); + + // Check if all participants_ are in comms_. If not, we will rebuild the + // clique to include them. (This can't be spelled using absl::c_any_of + // because it needs to touch comms_ and tensorflow::mutex lacks an + // AssertHeld() function that would let us assert that the lambda is run + // while holding the lock.) + bool new_devices_found = false; + for (const auto& p : participants_) { + if (!comms_.contains(p.device_ordinal)) { + new_devices_found = true; + break; + } + } + + if (new_devices_found) { + absl::flat_hash_set new_clique_device_ordinals; + for (const auto& kv : comms_) { + new_clique_device_ordinals.insert(kv.first); + } + for (const auto& p : participants_) { + new_clique_device_ordinals.insert(p.device_ordinal); + } + + initialize_status_ = ReinitializeNcclClique(new_clique_device_ordinals); + VLOG(3) << "Done initializing communication channels; status: " + << initialize_status_; + if (!initialize_status_.ok()) { + DeinitializeGeneration(); + } } } @@ -218,7 +316,7 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { return initialize_status_; } - comm = comms_[index]; + comm = *comms_.at(participant.device_ordinal).nccl_comm; // Drop the lock at the end of scope so other participants may enter. } @@ -259,22 +357,30 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) { return all_reduce_status; } -Status GlobalRendezvousManager::InitializeCommunicationChannels() { - std::vector ordinals; - for (ParticipantData& data : participants_) { - ordinals.push_back(data.device_ordinal); - } - comms_.resize(ordinals.size()); - VLOG(3) << "Participants: " << participants_.size() - << "; initializing comms."; - ncclResult_t result = ncclCommInitAll(comms_.data(), comms_.size(), - /*devlist=*/ordinals.data()); +Status GlobalRendezvousManager::ReinitializeNcclClique( + const absl::flat_hash_set& device_ordinals) { + comms_.clear(); + + std::vector ordinals_vec(device_ordinals.begin(), device_ordinals.end()); + std::vector comm_vec; + comm_vec.resize(device_ordinals.size()); + + VLOG(3) << absl::StreamFormat( + "Initializing nccl comms for participant devices {%s}", + absl::StrJoin(ordinals_vec, ", ")); + ncclResult_t result = ncclCommInitAll(comm_vec.data(), comm_vec.size(), + /*devlist=*/ordinals_vec.data()); if (result != ncclSuccess) { - comms_.clear(); return InternalError( "Failed to initialize NCCL communication channels for %d participants: " "%s", - participants_.size(), ncclGetErrorString(result)); + ordinals_vec.size(), ncclGetErrorString(result)); + } + + for (int64 i = 0; i < ordinals_vec.size(); ++i) { + VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p", + ordinals_vec[i], comm_vec[i]); + CHECK(comms_.emplace(ordinals_vec[i], Comm{comm_vec[i]}).second); } return Status::OK(); } @@ -289,6 +395,11 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant, << " on device: " << participant.device_ordinal; void* send_buffer = participant.source_data.opaque(); void* recv_buffer = participant.destination_data.opaque(); + VLOG(3) << absl::StreamFormat( + "Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, " + "datatype=ncclFloat, op=ncclSum, comm=%p, stream=%p)", + send_buffer, recv_buffer, participant.element_count, + static_cast(comm), cu_stream); ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer, /*count=*/participant.element_count, /*datatype=*/ncclFloat, @@ -304,6 +415,36 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant, return Status::OK(); } +void GlobalRendezvousManager::AddrefParticipatingDevice(int device_ordinal) { + // Addref'ing a device doesn't do anything other than increment its refcount. + // We'll update our nccl clique if necessary during the next call to + // SubmitParticipant. + tensorflow::mutex_lock lock(mutex_); + device_refcounts_[device_ordinal]++; +} + +void GlobalRendezvousManager::DecrefParticipatingDevices( + absl::Span device_ordinals) { + // Decref'ing devices causes us to destroy the nccl clique if any devices were + // removed due to having refcount 0. We'll rebuild the new, smaller clique + // during the next call to SubmitParticipant. + tensorflow::mutex_lock lock(mutex_); + bool removed_device = false; + for (int device_ordinal : device_ordinals) { + auto it = device_refcounts_.find(device_ordinal); + CHECK(it != device_refcounts_.end()); + it->second--; + if (it->second == 0) { + device_refcounts_.erase(it); + removed_device = true; + } + } + + if (removed_device) { + comms_.clear(); + } +} + static GlobalRendezvousManager* GetGlobalRendezvous() { static auto* manager = new GlobalRendezvousManager; return manager; @@ -311,6 +452,11 @@ static GlobalRendezvousManager* GetGlobalRendezvous() { } // namespace +/*static*/ absl::flat_hash_set +NcclAllReduceThunk::DevicesWithOpenNcclChannels() { + return GetGlobalRendezvous()->DevicesWithOpenNcclChannels(); +} + Status NcclAllReduceThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream, HloExecutionProfiler* profiler) { @@ -327,8 +473,32 @@ Status NcclAllReduceThunk::ExecuteOnStream( participant.stream = stream; participant.originator = this; + // We currently say that that all GPUs this thunk has ever touched are + // "actively participating" in cross-device operations, until the thunk itself + // is destroyed. + // + // This policy is an attempt to avoid thrashing the GPU (ncclCommInitAll is + // very expensive) while also freeing resources on the GPUs when we can. The + // idea is, creating new thunks is tantamount to running the whole XLA:GPU + // compiler stack, so that shouldn't happen terribly often. + bool new_device; + { + tensorflow::mutex_lock lock(mu_); + new_device = devices_seen_.insert(participant.device_ordinal).second; + } + if (new_device) { + GetGlobalRendezvous()->AddrefParticipatingDevice( + participant.device_ordinal); + } + return GetGlobalRendezvous()->SubmitParticipant(std::move(participant)); } + +NcclAllReduceThunk::~NcclAllReduceThunk() { + GetGlobalRendezvous()->DecrefParticipatingDevices( + std::vector(devices_seen_.begin(), devices_seen_.end())); +} + #else Status NcclAllReduceThunk::ExecuteOnStream( @@ -339,6 +509,13 @@ Status NcclAllReduceThunk::ExecuteOnStream( "compiler, which is necessary to build the NCCL source library."); } +NcclAllReduceThunk::~NcclAllReduceThunk() = default; + +/*static*/ absl::flat_hash_set +NcclAllReduceThunk::DevicesWithOpenNcclChannels() { + return {}; +} + #endif // GOOGLE_CUDA NcclAllReduceThunk::NcclAllReduceThunk( diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 1a8d1356c00..9ff4fb187af 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -16,11 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ +#include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/types.h" @@ -38,12 +40,21 @@ class NcclAllReduceThunk : public Thunk { // error. static bool NcclIsEnabled(); + // Gets the set of devices that have a NCCL channel open. This is primarily + // for testing. + // + // (Indeed, because the NCCL channels are a global variable, in the real + // world, the value returned here is stale as soon as you read it, so it's not + // clear how you *could* use it for anything other than tests.) + static absl::flat_hash_set DevicesWithOpenNcclChannels(); + // TODO(b/125951860): Plumb more datatypes / reduction operators. Initial // implementation is simply F32 summation. NcclAllReduceThunk(int64 replica_count, int64 element_count, const BufferAllocation::Slice& source_buffer, const BufferAllocation::Slice& destination_buffer, const HloInstruction* all_reduce); + ~NcclAllReduceThunk() override; Status ExecuteOnStream(const BufferAllocations& buffer_allocations, se::Stream* stream, @@ -54,6 +65,10 @@ class NcclAllReduceThunk : public Thunk { const int64 element_count_; const BufferAllocation::Slice source_buffer_; const BufferAllocation::Slice destination_buffer_; + + tensorflow::mutex mu_; + // Set of GPUs that ExecuteOnStream has been called on. + absl::flat_hash_set devices_seen_ GUARDED_BY(mu_); }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index b195dd7e4f4..5ba390acfd4 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -273,6 +273,12 @@ StatusOr> HloRunner::ExecuteReplicated( TF_ASSIGN_OR_RETURN( std::unique_ptr executable, CreateExecutable(std::move(module), options.run_hlo_passes)); + return ExecuteReplicated(executable.get(), options, device_assignment); +} + +StatusOr> HloRunner::ExecuteReplicated( + Executable* executable, const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment, ExecutionProfile* profile) { std::vector> streams; std::vector service_run_options; diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 5c5a82fc0fd..7e666a8186e 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -183,6 +183,15 @@ class HloRunner { const ReplicatedExecuteOptions& options, DeviceAssignment* device_assignment); + // Same as above, but with a reusable Executable. This may update the profile + // information in *executable. + // + // Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes, + // since we've already compiled the Executable. + StatusOr> ExecuteReplicated( + Executable* executable, const ReplicatedExecuteOptions& options, + DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr); + // If backend is not created in the constructor, creates and returns the // default backend. If creation fails, crashes the program. // diff --git a/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc b/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc index 1513d89ba9c..7895895e3e7 100644 --- a/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/multi_device_all_reduce_test.cc @@ -14,35 +14,86 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h" #include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +// Tests cross-GPU all-reduce operatons. +// +// This test requires multiple GPUs. For instructions on running this within +// Google, see go/multi-gpu-unit-test. namespace xla { namespace { -class MultiDeviceAllReduceTest : public HloTestBase {}; +using ::testing::IsEmpty; +using ::testing::UnorderedElementsAre; + +class MultiDeviceAllReduceTest : public HloTestBase { + protected: + std::unique_ptr MakeCrsModule(int64 num_elems, + const HloModuleConfig& config) { + const char* kTemplate = R"( + HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + + ENTRY test_computation { + p = f32[NUM_ELEMS] parameter(0) + ROOT crs = f32[NUM_ELEMS] all-reduce(p), to_apply=add + } + )"; + return ParseHloString( + absl::StrReplaceAll(kTemplate, + {{"NUM_ELEMS", absl::StrCat(num_elems)}}), + config) + .ValueOrDie(); + } +}; + +// Returns the non-empty subsets of {0, 1, ..., n}. For example, +// PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}. +std::vector> PowerSetOfIota(int64 n) { + std::vector> power_set; + for (int64 i = 1; i < (1 << n); ++i) { + power_set.emplace_back(); + for (int64 j = 0; j < n; ++j) { + if (i & (1 << j)) { + power_set.back().push_back(j); + } + } + } + return power_set; +} + +// Makes a DeviceAssignment assigning replica-id i to devices[i]. +DeviceAssignment MakeDeviceAssn(std::vector devices) { + DeviceAssignment assn(/*replica_count=*/devices.size(), + /*computation_count=*/1); + for (int64 i = 0; i < devices.size(); ++i) { + assn(i, 0) = devices[i]; + } + return assn; +} + +// Shorter alias for this function. +absl::flat_hash_set OpenNcclChannels() { + return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels(); +} XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) { - const char* module_str = R"( - HloModule test - - add { - x = f32[] parameter(0) - y = f32[] parameter(1) - add = f32[] add(x, y) - } - - ENTRY test_computation { - p = f32[3] parameter(0) - ROOT crs = f32[3] all-reduce(p), to_apply=add - })"; auto config = GetModuleConfigForTest(); config.set_replica_count(2); - auto module = ParseHloString(module_str, config).ValueOrDie(); + auto module = MakeCrsModule(/*num_elems=*/3, config); auto literal = LiteralUtil::CreateR1({1, 2, 3}); auto expected = LiteralUtil::CreateR1({2, 4, 6}); TF_ASSERT_OK_AND_ASSIGN(std::vector results, @@ -52,5 +103,112 @@ XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) { EXPECT_EQ(expected, results[1]); } +// Tries all-to-all operations across all 2^kNumDevices - 1 combinations of +// devices in sequence. +XLA_TEST_F(MultiDeviceAllReduceTest, AllCombinations) { + const int64 kNumDevices = 4; + const int64 kNumElems = 1024; + + for (std::vector devices : PowerSetOfIota(kNumDevices)) { + SCOPED_TRACE(absl::StrFormat("Running on devices {%s}", + absl::StrJoin(devices, ", "))); + + DeviceAssignment device_assn = MakeDeviceAssn(devices); + + auto config = GetModuleConfigForTest(); + config.set_replica_count(devices.size()); + config.set_static_device_assignment(device_assn); + + auto module = MakeCrsModule(kNumElems, config); + + std::vector input_vec(kNumElems); + absl::c_iota(input_vec, 0); + auto input_literal = LiteralUtil::CreateR1(input_vec); + + TF_ASSERT_OK_AND_ASSIGN( + std::vector results, + ExecuteReplicated(std::move(module), {&input_literal}, + /*num_replicas=*/devices.size(), &device_assn, + /*run_hlo_passes=*/true, /*use_threads=*/true)); + } +} + +// Check that the NCCL data structures in our all-reduce implementation are +// cached as we expect. +XLA_TEST_F(MultiDeviceAllReduceTest, NcclChannelCaching) { + const int64 kNumElems = 1024; + + std::vector input_vec(kNumElems); + absl::c_iota(input_vec, 0); + auto input_literal = LiteralUtil::CreateR1(input_vec); + + // Initially no NCCL channels should be open. + EXPECT_THAT(OpenNcclChannels(), IsEmpty()); + + // Create three Executables, touching devices {0,1}, {1,2}, and {0,1,2}. + struct ExecutableInfo { + std::unique_ptr executable; + DeviceAssignment device_assn; + HloRunner::ReplicatedExecuteOptions opts; + }; + std::vector executables; + for (const auto& devices : + std::vector>{{0, 1}, {1, 2}, {0, 1, 2}}) { + executables.emplace_back(); + auto& e = executables.back(); + + e.device_assn = MakeDeviceAssn(devices); + + auto config = GetModuleConfigForTest(); + config.set_replica_count(devices.size()); + config.set_static_device_assignment(e.device_assn); + auto module = MakeCrsModule(kNumElems, config); + e.executable = + test_runner_ + .CreateExecutable(std::move(module), /*run_hlo_passes=*/true) + .ValueOrDie(); + + e.opts.num_replicas = devices.size(); + e.opts.use_threads = true; + e.opts.arguments.push_back(&input_literal); + } + + auto run_executable = [&](int64 i) { + auto& e = executables[i]; + TF_ASSERT_OK( + test_runner_ + .ExecuteReplicated(e.executable.get(), e.opts, &e.device_assn) + .status()); + }; + + // Compiling executables above shouldn't cause us to open any channels. + EXPECT_THAT(OpenNcclChannels(), IsEmpty()); + + // Run the executables and check that channels are opened as we expect. + run_executable(0); + EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1)); + + run_executable(2); + EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2)); + + run_executable(1); + EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2)); + + // Tear down the executables and check that channels are closed as we expect. + // Note that after we tear down an executable *all* the nccl channels may go + // away, so we rerun all of the executables that haven't been torn down. + executables[2].executable.reset(); + run_executable(0); + run_executable(1); + EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2)); + + executables[0].executable.reset(); + run_executable(1); + EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(1, 2)); + + executables[1].executable.reset(); + EXPECT_THAT(OpenNcclChannels(), IsEmpty()); +} + } // namespace } // namespace xla