[XLA:GPU] Improve performance of AllReduce by caching nccl channels.

Previously we'd tear down nccl communication channels after every all-reduce
operation.  Now we keep the channel alive for what we hope will be a reasonable
period of time.

PiperOrigin-RevId: 246746547
This commit is contained in:
Justin Lebar 2019-05-05 16:00:32 -07:00 committed by TensorFlower Gardener
parent 274062297a
commit acc20217e6
5 changed files with 423 additions and 58 deletions

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#if GOOGLE_CUDA
#include "absl/container/flat_hash_set.h"
#include "absl/synchronization/blocking_counter.h"
#include "third_party/nccl/nccl.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
@ -76,6 +77,42 @@ struct ParticipantData {
// This manager is responsible for establishing communication channels and
// ultimately enqueueing the NCCL library operation onto the participating
// streams.
//
// Implementation note: We make an effort to avoid initializing nccl
// communciation channels too often, as this is expensive.
//
// Ideally, we'd set up a nccl channel between each pair of devices that needs
// to communicate, and close each channel when the GPUs won't be communicating
// again "for a long time" (because channels hold memory on the GPU). As a
// simplification to this ideal, we adopt the following policy.
//
// - We maintain a set of GPUs that are "actively participating" in
// cross-device communications. That set of GPUs is always connected as a
// clique, using ncclCommInitAll.
//
// - When a NcclAllReduceThunk touches a new GPU, we tear down the old clique
// and build a new, bigger one.
//
// - All GPUs ever touched by a thunk are considered "actively in use" by that
// thunk until the thunk is destroyed. Destroying the thunk decrements the
// refcount of the GPUs it's touched, and if that refcount goes to 0
// (meaning, some GPUs are no longer in use by any thunk), we tear down the
// clique and build a new, smaller one.
//
// This approximation is justified because:
//
// - Currently the only collective operation we support is AllReduce, which
// requires a clique. When we support point-to-point operations, we may not
// want to build a communication clique.
//
// - Tearing down and creating a new thunk is tantamount to running the whole
// XLA:GPU compiler. This is expensive, so shouldn't happen "too often" to
// cause thrashing here.
//
// - XLA executables already keep resources on the GPU tied to the lifetime of
// the executable (e.g. constants stored in GPU memory), so tying the
// lifetime of the nccl communication channels to the lifetime of the
// executable is consistent.
class GlobalRendezvousManager {
public:
// The GpuExecutable-executing threads call this in order to a) establish the
@ -98,18 +135,38 @@ class GlobalRendezvousManager {
return current_generation_;
}
private:
// Called by the primary thread to set up the communication links.
// Increments the refcount of a GPU in our accounting of which devices are
// "actively participating" in cross-device operations.
//
// TODO(b/125951860): This performs lots of (presumably) unnecessary host-side
// synchronization so that we can be paranoid about semantics in the earliest
// implementation. In the limit we should only need to synchronize host
// replica threads when the "number of replicas" or "participating device
// ordinals" change, to set up a new NCCL "communication" context, at which
// point we can enqueue onto device streams without host synchronization in
// our code -- this will likely be helpful for "lots of little AllReduce"
// cases.
Status InitializeCommunicationChannels() EXCLUSIVE_LOCKS_REQUIRED(mutex_);
// This doesn't actually do anything other than increment the refcount. If
// the GPU added here is novel, we'll rebuild the nccl communication clique
// when we actually go do the communication.
void AddrefParticipatingDevice(int device_ordinal);
// Decrements the refcount of a set of GPUs in our accounting of which devices
// are "actively participating" in cross-device operations.
//
// If one or more GPUs' refcounts to go 0, we immediately destroy the whole
// nccl communication clique. We'll rebuild a new, smaller clique the next
// time it's used.
void DecrefParticipatingDevices(absl::Span<const int> device_ordinals);
// Gets the set of devices that have a NCCL channel currently open. This is
// primarily for testing.
absl::flat_hash_set<int> DevicesWithOpenNcclChannels() const {
absl::flat_hash_set<int> devices;
tensorflow::mutex_lock lock(mutex_);
for (const auto& kv : comms_) {
devices.insert(kv.first);
}
return devices;
}
private:
// Destroys the current nccl communication clique and builds a new one
// connecting the given devices.
Status ReinitializeNcclClique(const absl::flat_hash_set<int>& device_ordinals)
EXCLUSIVE_LOCKS_REQUIRED(mutex_);
// Called when all necessary participants are present, the functionality
// that's implemented by all executing threads lives in here.
@ -118,28 +175,51 @@ class GlobalRendezvousManager {
// Puts all state back into a "reset" state for the next generation of
// AllReduce requests.
void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
for (ncclComm_t& comm : comms_) {
ncclCommDestroy(comm);
}
comms_.clear();
participants_.clear();
current_generation_++;
initialized_ = false;
done_ = absl::nullopt;
}
tensorflow::mutex mutex_;
mutable tensorflow::mutex mutex_;
tensorflow::condition_variable all_participants_present_;
tensorflow::condition_variable deinitialized_;
// Communication handles that correspond to the participants below.
std::vector<ncclComm_t> comms_ GUARDED_BY(mutex_);
Status initialize_status_ GUARDED_BY(mutex_);
std::vector<ParticipantData> participants_ GUARDED_BY(mutex_);
int64 current_generation_ GUARDED_BY(mutex_) = 0;
bool initialized_ GUARDED_BY(mutex_) = false;
struct Comm {
explicit Comm(ncclComm_t nccl_comm) : nccl_comm(nccl_comm) {}
// Movable, but not copyable.
Comm(Comm&& c) : nccl_comm(c.nccl_comm) { c.nccl_comm.reset(); }
Comm& operator=(Comm&& c) {
nccl_comm = c.nccl_comm;
c.nccl_comm.reset();
return *this;
}
Comm(const Comm&) = delete;
Comm& operator=(const Comm&) = delete;
absl::optional<ncclComm_t> nccl_comm;
~Comm() {
if (nccl_comm.has_value()) {
VLOG(3) << absl::StreamFormat("Destroying comm %p", *nccl_comm);
ncclCommDestroy(*nccl_comm);
}
}
};
// Communication handles for our NCCL clique. Key is device ordinal.
absl::flat_hash_map<int, Comm> comms_ GUARDED_BY(mutex_);
// Refcounts of which devices are "actively participating" in all-reduces.
// These devices don't necessarily have an open comm, but the next time we run
// an operation, we'll create a NCCL clique between all of them.
absl::flat_hash_map<int, int64> device_refcounts_ GUARDED_BY(mutex_);
// The participating threads wait for this to count down in order to know we
// can begin the teardown process.
absl::optional<tensorflow::BlockingCounter> done_;
@ -151,11 +231,6 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
return participants_.size() >= participant.replica_count;
};
// We remember the participant index at which we are inserted and use that
// same index for referring to auxiliary metadata (e.g. the ncclComm_t handle
// index) below.
int64 index;
{
tensorflow::mutex_lock lock(mutex_);
@ -171,7 +246,6 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
"participants; existing: %s; submitted: %s)",
participants_.back().ToString(), participant.ToString());
}
index = participants_.size();
participants_.push_back(participant);
if (all_participants_present()) {
@ -205,20 +279,44 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
VLOG(3) << "Primary initializing accounting data.";
initialized_ = true;
done_.emplace(participant.replica_count);
initialize_status_ = InitializeCommunicationChannels();
// Check if all participants_ are in comms_. If not, we will rebuild the
// clique to include them. (This can't be spelled using absl::c_any_of
// because it needs to touch comms_ and tensorflow::mutex lacks an
// AssertHeld() function that would let us assert that the lambda is run
// while holding the lock.)
bool new_devices_found = false;
for (const auto& p : participants_) {
if (!comms_.contains(p.device_ordinal)) {
new_devices_found = true;
break;
}
}
if (new_devices_found) {
absl::flat_hash_set<int> new_clique_device_ordinals;
for (const auto& kv : comms_) {
new_clique_device_ordinals.insert(kv.first);
}
for (const auto& p : participants_) {
new_clique_device_ordinals.insert(p.device_ordinal);
}
initialize_status_ = ReinitializeNcclClique(new_clique_device_ordinals);
VLOG(3) << "Done initializing communication channels; status: "
<< initialize_status_;
if (!initialize_status_.ok()) {
DeinitializeGeneration();
}
}
}
if (!initialize_status_.ok()) {
// TODO(b/125951860): If this fails once, it will fail forever.
return initialize_status_;
}
comm = comms_[index];
comm = *comms_.at(participant.device_ordinal).nccl_comm;
// Drop the lock at the end of scope so other participants may enter.
}
@ -259,22 +357,30 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
return all_reduce_status;
}
Status GlobalRendezvousManager::InitializeCommunicationChannels() {
std::vector<int> ordinals;
for (ParticipantData& data : participants_) {
ordinals.push_back(data.device_ordinal);
}
comms_.resize(ordinals.size());
VLOG(3) << "Participants: " << participants_.size()
<< "; initializing comms.";
ncclResult_t result = ncclCommInitAll(comms_.data(), comms_.size(),
/*devlist=*/ordinals.data());
if (result != ncclSuccess) {
Status GlobalRendezvousManager::ReinitializeNcclClique(
const absl::flat_hash_set<int>& device_ordinals) {
comms_.clear();
std::vector<int> ordinals_vec(device_ordinals.begin(), device_ordinals.end());
std::vector<ncclComm_t> comm_vec;
comm_vec.resize(device_ordinals.size());
VLOG(3) << absl::StreamFormat(
"Initializing nccl comms for participant devices {%s}",
absl::StrJoin(ordinals_vec, ", "));
ncclResult_t result = ncclCommInitAll(comm_vec.data(), comm_vec.size(),
/*devlist=*/ordinals_vec.data());
if (result != ncclSuccess) {
return InternalError(
"Failed to initialize NCCL communication channels for %d participants: "
"%s",
participants_.size(), ncclGetErrorString(result));
ordinals_vec.size(), ncclGetErrorString(result));
}
for (int64 i = 0; i < ordinals_vec.size(); ++i) {
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
ordinals_vec[i], comm_vec[i]);
CHECK(comms_.emplace(ordinals_vec[i], Comm{comm_vec[i]}).second);
}
return Status::OK();
}
@ -289,6 +395,11 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant,
<< " on device: " << participant.device_ordinal;
void* send_buffer = participant.source_data.opaque();
void* recv_buffer = participant.destination_data.opaque();
VLOG(3) << absl::StreamFormat(
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
"datatype=ncclFloat, op=ncclSum, comm=%p, stream=%p)",
send_buffer, recv_buffer, participant.element_count,
static_cast<const void*>(comm), cu_stream);
ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer,
/*count=*/participant.element_count,
/*datatype=*/ncclFloat,
@ -304,6 +415,36 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant,
return Status::OK();
}
void GlobalRendezvousManager::AddrefParticipatingDevice(int device_ordinal) {
// Addref'ing a device doesn't do anything other than increment its refcount.
// We'll update our nccl clique if necessary during the next call to
// SubmitParticipant.
tensorflow::mutex_lock lock(mutex_);
device_refcounts_[device_ordinal]++;
}
void GlobalRendezvousManager::DecrefParticipatingDevices(
absl::Span<const int> device_ordinals) {
// Decref'ing devices causes us to destroy the nccl clique if any devices were
// removed due to having refcount 0. We'll rebuild the new, smaller clique
// during the next call to SubmitParticipant.
tensorflow::mutex_lock lock(mutex_);
bool removed_device = false;
for (int device_ordinal : device_ordinals) {
auto it = device_refcounts_.find(device_ordinal);
CHECK(it != device_refcounts_.end());
it->second--;
if (it->second == 0) {
device_refcounts_.erase(it);
removed_device = true;
}
}
if (removed_device) {
comms_.clear();
}
}
static GlobalRendezvousManager* GetGlobalRendezvous() {
static auto* manager = new GlobalRendezvousManager;
return manager;
@ -311,6 +452,11 @@ static GlobalRendezvousManager* GetGlobalRendezvous() {
} // namespace
/*static*/ absl::flat_hash_set<int>
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
return GetGlobalRendezvous()->DevicesWithOpenNcclChannels();
}
Status NcclAllReduceThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream,
HloExecutionProfiler* profiler) {
@ -327,8 +473,32 @@ Status NcclAllReduceThunk::ExecuteOnStream(
participant.stream = stream;
participant.originator = this;
// We currently say that that all GPUs this thunk has ever touched are
// "actively participating" in cross-device operations, until the thunk itself
// is destroyed.
//
// This policy is an attempt to avoid thrashing the GPU (ncclCommInitAll is
// very expensive) while also freeing resources on the GPUs when we can. The
// idea is, creating new thunks is tantamount to running the whole XLA:GPU
// compiler stack, so that shouldn't happen terribly often.
bool new_device;
{
tensorflow::mutex_lock lock(mu_);
new_device = devices_seen_.insert(participant.device_ordinal).second;
}
if (new_device) {
GetGlobalRendezvous()->AddrefParticipatingDevice(
participant.device_ordinal);
}
return GetGlobalRendezvous()->SubmitParticipant(std::move(participant));
}
NcclAllReduceThunk::~NcclAllReduceThunk() {
GetGlobalRendezvous()->DecrefParticipatingDevices(
std::vector<int>(devices_seen_.begin(), devices_seen_.end()));
}
#else
Status NcclAllReduceThunk::ExecuteOnStream(
@ -339,6 +509,13 @@ Status NcclAllReduceThunk::ExecuteOnStream(
"compiler, which is necessary to build the NCCL source library.");
}
NcclAllReduceThunk::~NcclAllReduceThunk() = default;
/*static*/ absl::flat_hash_set<int>
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
return {};
}
#endif // GOOGLE_CUDA
NcclAllReduceThunk::NcclAllReduceThunk(

View File

@ -16,11 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@ -38,12 +40,21 @@ class NcclAllReduceThunk : public Thunk {
// error.
static bool NcclIsEnabled();
// Gets the set of devices that have a NCCL channel open. This is primarily
// for testing.
//
// (Indeed, because the NCCL channels are a global variable, in the real
// world, the value returned here is stale as soon as you read it, so it's not
// clear how you *could* use it for anything other than tests.)
static absl::flat_hash_set<int> DevicesWithOpenNcclChannels();
// TODO(b/125951860): Plumb more datatypes / reduction operators. Initial
// implementation is simply F32 summation.
NcclAllReduceThunk(int64 replica_count, int64 element_count,
const BufferAllocation::Slice& source_buffer,
const BufferAllocation::Slice& destination_buffer,
const HloInstruction* all_reduce);
~NcclAllReduceThunk() override;
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::Stream* stream,
@ -54,6 +65,10 @@ class NcclAllReduceThunk : public Thunk {
const int64 element_count_;
const BufferAllocation::Slice source_buffer_;
const BufferAllocation::Slice destination_buffer_;
tensorflow::mutex mu_;
// Set of GPUs that ExecuteOnStream has been called on.
absl::flat_hash_set<int> devices_seen_ GUARDED_BY(mu_);
};
} // namespace gpu

View File

@ -273,6 +273,12 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
CreateExecutable(std::move(module), options.run_hlo_passes));
return ExecuteReplicated(executable.get(), options, device_assignment);
}
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
Executable* executable, const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment, ExecutionProfile* profile) {
std::vector<std::unique_ptr<se::Stream>> streams;
std::vector<ServiceExecutableRunOptions> service_run_options;

View File

@ -183,6 +183,15 @@ class HloRunner {
const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment);
// Same as above, but with a reusable Executable. This may update the profile
// information in *executable.
//
// Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
// since we've already compiled the Executable.
StatusOr<std::vector<Literal>> ExecuteReplicated(
Executable* executable, const ReplicatedExecuteOptions& options,
DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr);
// If backend is not created in the constructor, creates and returns the
// default backend. If creation fails, crashes the program.
//

View File

@ -14,20 +14,31 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/core/status_test_util.h"
// Tests cross-GPU all-reduce operatons.
//
// This test requires multiple GPUs. For instructions on running this within
// Google, see go/multi-gpu-unit-test.
namespace xla {
namespace {
class MultiDeviceAllReduceTest : public HloTestBase {};
using ::testing::IsEmpty;
using ::testing::UnorderedElementsAre;
XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
const char* module_str = R"(
class MultiDeviceAllReduceTest : public HloTestBase {
protected:
std::unique_ptr<HloModule> MakeCrsModule(int64 num_elems,
const HloModuleConfig& config) {
const char* kTemplate = R"(
HloModule test
add {
@ -37,12 +48,52 @@ XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
}
ENTRY test_computation {
p = f32[3] parameter(0)
ROOT crs = f32[3] all-reduce(p), to_apply=add
})";
p = f32[NUM_ELEMS] parameter(0)
ROOT crs = f32[NUM_ELEMS] all-reduce(p), to_apply=add
}
)";
return ParseHloString(
absl::StrReplaceAll(kTemplate,
{{"NUM_ELEMS", absl::StrCat(num_elems)}}),
config)
.ValueOrDie();
}
};
// Returns the non-empty subsets of {0, 1, ..., n}. For example,
// PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}.
std::vector<std::vector<int64>> PowerSetOfIota(int64 n) {
std::vector<std::vector<int64>> power_set;
for (int64 i = 1; i < (1 << n); ++i) {
power_set.emplace_back();
for (int64 j = 0; j < n; ++j) {
if (i & (1 << j)) {
power_set.back().push_back(j);
}
}
}
return power_set;
}
// Makes a DeviceAssignment assigning replica-id i to devices[i].
DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
DeviceAssignment assn(/*replica_count=*/devices.size(),
/*computation_count=*/1);
for (int64 i = 0; i < devices.size(); ++i) {
assn(i, 0) = devices[i];
}
return assn;
}
// Shorter alias for this function.
absl::flat_hash_set<int> OpenNcclChannels() {
return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels();
}
XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
auto config = GetModuleConfigForTest();
config.set_replica_count(2);
auto module = ParseHloString(module_str, config).ValueOrDie();
auto module = MakeCrsModule(/*num_elems=*/3, config);
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
auto expected = LiteralUtil::CreateR1<float>({2, 4, 6});
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
@ -52,5 +103,112 @@ XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
EXPECT_EQ(expected, results[1]);
}
// Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
// devices in sequence.
XLA_TEST_F(MultiDeviceAllReduceTest, AllCombinations) {
const int64 kNumDevices = 4;
const int64 kNumElems = 1024;
for (std::vector<int64> devices : PowerSetOfIota(kNumDevices)) {
SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
absl::StrJoin(devices, ", ")));
DeviceAssignment device_assn = MakeDeviceAssn(devices);
auto config = GetModuleConfigForTest();
config.set_replica_count(devices.size());
config.set_static_device_assignment(device_assn);
auto module = MakeCrsModule(kNumElems, config);
std::vector<float> input_vec(kNumElems);
absl::c_iota(input_vec, 0);
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
ExecuteReplicated(std::move(module), {&input_literal},
/*num_replicas=*/devices.size(), &device_assn,
/*run_hlo_passes=*/true, /*use_threads=*/true));
}
}
// Check that the NCCL data structures in our all-reduce implementation are
// cached as we expect.
XLA_TEST_F(MultiDeviceAllReduceTest, NcclChannelCaching) {
const int64 kNumElems = 1024;
std::vector<float> input_vec(kNumElems);
absl::c_iota(input_vec, 0);
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
// Initially no NCCL channels should be open.
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
// Create three Executables, touching devices {0,1}, {1,2}, and {0,1,2}.
struct ExecutableInfo {
std::unique_ptr<Executable> executable;
DeviceAssignment device_assn;
HloRunner::ReplicatedExecuteOptions opts;
};
std::vector<ExecutableInfo> executables;
for (const auto& devices :
std::vector<std::vector<int64>>{{0, 1}, {1, 2}, {0, 1, 2}}) {
executables.emplace_back();
auto& e = executables.back();
e.device_assn = MakeDeviceAssn(devices);
auto config = GetModuleConfigForTest();
config.set_replica_count(devices.size());
config.set_static_device_assignment(e.device_assn);
auto module = MakeCrsModule(kNumElems, config);
e.executable =
test_runner_
.CreateExecutable(std::move(module), /*run_hlo_passes=*/true)
.ValueOrDie();
e.opts.num_replicas = devices.size();
e.opts.use_threads = true;
e.opts.arguments.push_back(&input_literal);
}
auto run_executable = [&](int64 i) {
auto& e = executables[i];
TF_ASSERT_OK(
test_runner_
.ExecuteReplicated(e.executable.get(), e.opts, &e.device_assn)
.status());
};
// Compiling executables above shouldn't cause us to open any channels.
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
// Run the executables and check that channels are opened as we expect.
run_executable(0);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1));
run_executable(2);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
run_executable(1);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
// Tear down the executables and check that channels are closed as we expect.
// Note that after we tear down an executable *all* the nccl channels may go
// away, so we rerun all of the executables that haven't been torn down.
executables[2].executable.reset();
run_executable(0);
run_executable(1);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
executables[0].executable.reset();
run_executable(1);
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(1, 2));
executables[1].executable.reset();
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
}
} // namespace
} // namespace xla