[XLA:GPU] Improve performance of AllReduce by caching nccl channels.
Previously we'd tear down nccl communication channels after every all-reduce operation. Now we keep the channel alive for what we hope will be a reasonable period of time. PiperOrigin-RevId: 246746547
This commit is contained in:
parent
274062297a
commit
acc20217e6
@ -18,6 +18,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/synchronization/blocking_counter.h"
|
||||
#include "third_party/nccl/nccl.h"
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
@ -76,6 +77,42 @@ struct ParticipantData {
|
||||
// This manager is responsible for establishing communication channels and
|
||||
// ultimately enqueueing the NCCL library operation onto the participating
|
||||
// streams.
|
||||
//
|
||||
// Implementation note: We make an effort to avoid initializing nccl
|
||||
// communciation channels too often, as this is expensive.
|
||||
//
|
||||
// Ideally, we'd set up a nccl channel between each pair of devices that needs
|
||||
// to communicate, and close each channel when the GPUs won't be communicating
|
||||
// again "for a long time" (because channels hold memory on the GPU). As a
|
||||
// simplification to this ideal, we adopt the following policy.
|
||||
//
|
||||
// - We maintain a set of GPUs that are "actively participating" in
|
||||
// cross-device communications. That set of GPUs is always connected as a
|
||||
// clique, using ncclCommInitAll.
|
||||
//
|
||||
// - When a NcclAllReduceThunk touches a new GPU, we tear down the old clique
|
||||
// and build a new, bigger one.
|
||||
//
|
||||
// - All GPUs ever touched by a thunk are considered "actively in use" by that
|
||||
// thunk until the thunk is destroyed. Destroying the thunk decrements the
|
||||
// refcount of the GPUs it's touched, and if that refcount goes to 0
|
||||
// (meaning, some GPUs are no longer in use by any thunk), we tear down the
|
||||
// clique and build a new, smaller one.
|
||||
//
|
||||
// This approximation is justified because:
|
||||
//
|
||||
// - Currently the only collective operation we support is AllReduce, which
|
||||
// requires a clique. When we support point-to-point operations, we may not
|
||||
// want to build a communication clique.
|
||||
//
|
||||
// - Tearing down and creating a new thunk is tantamount to running the whole
|
||||
// XLA:GPU compiler. This is expensive, so shouldn't happen "too often" to
|
||||
// cause thrashing here.
|
||||
//
|
||||
// - XLA executables already keep resources on the GPU tied to the lifetime of
|
||||
// the executable (e.g. constants stored in GPU memory), so tying the
|
||||
// lifetime of the nccl communication channels to the lifetime of the
|
||||
// executable is consistent.
|
||||
class GlobalRendezvousManager {
|
||||
public:
|
||||
// The GpuExecutable-executing threads call this in order to a) establish the
|
||||
@ -98,18 +135,38 @@ class GlobalRendezvousManager {
|
||||
return current_generation_;
|
||||
}
|
||||
|
||||
private:
|
||||
// Called by the primary thread to set up the communication links.
|
||||
// Increments the refcount of a GPU in our accounting of which devices are
|
||||
// "actively participating" in cross-device operations.
|
||||
//
|
||||
// TODO(b/125951860): This performs lots of (presumably) unnecessary host-side
|
||||
// synchronization so that we can be paranoid about semantics in the earliest
|
||||
// implementation. In the limit we should only need to synchronize host
|
||||
// replica threads when the "number of replicas" or "participating device
|
||||
// ordinals" change, to set up a new NCCL "communication" context, at which
|
||||
// point we can enqueue onto device streams without host synchronization in
|
||||
// our code -- this will likely be helpful for "lots of little AllReduce"
|
||||
// cases.
|
||||
Status InitializeCommunicationChannels() EXCLUSIVE_LOCKS_REQUIRED(mutex_);
|
||||
// This doesn't actually do anything other than increment the refcount. If
|
||||
// the GPU added here is novel, we'll rebuild the nccl communication clique
|
||||
// when we actually go do the communication.
|
||||
void AddrefParticipatingDevice(int device_ordinal);
|
||||
|
||||
// Decrements the refcount of a set of GPUs in our accounting of which devices
|
||||
// are "actively participating" in cross-device operations.
|
||||
//
|
||||
// If one or more GPUs' refcounts to go 0, we immediately destroy the whole
|
||||
// nccl communication clique. We'll rebuild a new, smaller clique the next
|
||||
// time it's used.
|
||||
void DecrefParticipatingDevices(absl::Span<const int> device_ordinals);
|
||||
|
||||
// Gets the set of devices that have a NCCL channel currently open. This is
|
||||
// primarily for testing.
|
||||
absl::flat_hash_set<int> DevicesWithOpenNcclChannels() const {
|
||||
absl::flat_hash_set<int> devices;
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
for (const auto& kv : comms_) {
|
||||
devices.insert(kv.first);
|
||||
}
|
||||
return devices;
|
||||
}
|
||||
|
||||
private:
|
||||
// Destroys the current nccl communication clique and builds a new one
|
||||
// connecting the given devices.
|
||||
Status ReinitializeNcclClique(const absl::flat_hash_set<int>& device_ordinals)
|
||||
EXCLUSIVE_LOCKS_REQUIRED(mutex_);
|
||||
|
||||
// Called when all necessary participants are present, the functionality
|
||||
// that's implemented by all executing threads lives in here.
|
||||
@ -118,28 +175,51 @@ class GlobalRendezvousManager {
|
||||
// Puts all state back into a "reset" state for the next generation of
|
||||
// AllReduce requests.
|
||||
void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||
for (ncclComm_t& comm : comms_) {
|
||||
ncclCommDestroy(comm);
|
||||
}
|
||||
comms_.clear();
|
||||
participants_.clear();
|
||||
current_generation_++;
|
||||
initialized_ = false;
|
||||
done_ = absl::nullopt;
|
||||
}
|
||||
|
||||
tensorflow::mutex mutex_;
|
||||
mutable tensorflow::mutex mutex_;
|
||||
tensorflow::condition_variable all_participants_present_;
|
||||
tensorflow::condition_variable deinitialized_;
|
||||
|
||||
// Communication handles that correspond to the participants below.
|
||||
std::vector<ncclComm_t> comms_ GUARDED_BY(mutex_);
|
||||
|
||||
Status initialize_status_ GUARDED_BY(mutex_);
|
||||
std::vector<ParticipantData> participants_ GUARDED_BY(mutex_);
|
||||
int64 current_generation_ GUARDED_BY(mutex_) = 0;
|
||||
bool initialized_ GUARDED_BY(mutex_) = false;
|
||||
|
||||
struct Comm {
|
||||
explicit Comm(ncclComm_t nccl_comm) : nccl_comm(nccl_comm) {}
|
||||
|
||||
// Movable, but not copyable.
|
||||
Comm(Comm&& c) : nccl_comm(c.nccl_comm) { c.nccl_comm.reset(); }
|
||||
Comm& operator=(Comm&& c) {
|
||||
nccl_comm = c.nccl_comm;
|
||||
c.nccl_comm.reset();
|
||||
return *this;
|
||||
}
|
||||
Comm(const Comm&) = delete;
|
||||
Comm& operator=(const Comm&) = delete;
|
||||
|
||||
absl::optional<ncclComm_t> nccl_comm;
|
||||
|
||||
~Comm() {
|
||||
if (nccl_comm.has_value()) {
|
||||
VLOG(3) << absl::StreamFormat("Destroying comm %p", *nccl_comm);
|
||||
ncclCommDestroy(*nccl_comm);
|
||||
}
|
||||
}
|
||||
};
|
||||
// Communication handles for our NCCL clique. Key is device ordinal.
|
||||
absl::flat_hash_map<int, Comm> comms_ GUARDED_BY(mutex_);
|
||||
|
||||
// Refcounts of which devices are "actively participating" in all-reduces.
|
||||
// These devices don't necessarily have an open comm, but the next time we run
|
||||
// an operation, we'll create a NCCL clique between all of them.
|
||||
absl::flat_hash_map<int, int64> device_refcounts_ GUARDED_BY(mutex_);
|
||||
|
||||
// The participating threads wait for this to count down in order to know we
|
||||
// can begin the teardown process.
|
||||
absl::optional<tensorflow::BlockingCounter> done_;
|
||||
@ -151,11 +231,6 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
|
||||
return participants_.size() >= participant.replica_count;
|
||||
};
|
||||
|
||||
// We remember the participant index at which we are inserted and use that
|
||||
// same index for referring to auxiliary metadata (e.g. the ncclComm_t handle
|
||||
// index) below.
|
||||
int64 index;
|
||||
|
||||
{
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
|
||||
@ -171,7 +246,6 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
|
||||
"participants; existing: %s; submitted: %s)",
|
||||
participants_.back().ToString(), participant.ToString());
|
||||
}
|
||||
index = participants_.size();
|
||||
participants_.push_back(participant);
|
||||
|
||||
if (all_participants_present()) {
|
||||
@ -205,20 +279,44 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
|
||||
VLOG(3) << "Primary initializing accounting data.";
|
||||
initialized_ = true;
|
||||
done_.emplace(participant.replica_count);
|
||||
initialize_status_ = InitializeCommunicationChannels();
|
||||
|
||||
// Check if all participants_ are in comms_. If not, we will rebuild the
|
||||
// clique to include them. (This can't be spelled using absl::c_any_of
|
||||
// because it needs to touch comms_ and tensorflow::mutex lacks an
|
||||
// AssertHeld() function that would let us assert that the lambda is run
|
||||
// while holding the lock.)
|
||||
bool new_devices_found = false;
|
||||
for (const auto& p : participants_) {
|
||||
if (!comms_.contains(p.device_ordinal)) {
|
||||
new_devices_found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (new_devices_found) {
|
||||
absl::flat_hash_set<int> new_clique_device_ordinals;
|
||||
for (const auto& kv : comms_) {
|
||||
new_clique_device_ordinals.insert(kv.first);
|
||||
}
|
||||
for (const auto& p : participants_) {
|
||||
new_clique_device_ordinals.insert(p.device_ordinal);
|
||||
}
|
||||
|
||||
initialize_status_ = ReinitializeNcclClique(new_clique_device_ordinals);
|
||||
VLOG(3) << "Done initializing communication channels; status: "
|
||||
<< initialize_status_;
|
||||
if (!initialize_status_.ok()) {
|
||||
DeinitializeGeneration();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!initialize_status_.ok()) {
|
||||
// TODO(b/125951860): If this fails once, it will fail forever.
|
||||
return initialize_status_;
|
||||
}
|
||||
|
||||
comm = comms_[index];
|
||||
comm = *comms_.at(participant.device_ordinal).nccl_comm;
|
||||
|
||||
// Drop the lock at the end of scope so other participants may enter.
|
||||
}
|
||||
@ -259,22 +357,30 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
|
||||
return all_reduce_status;
|
||||
}
|
||||
|
||||
Status GlobalRendezvousManager::InitializeCommunicationChannels() {
|
||||
std::vector<int> ordinals;
|
||||
for (ParticipantData& data : participants_) {
|
||||
ordinals.push_back(data.device_ordinal);
|
||||
}
|
||||
comms_.resize(ordinals.size());
|
||||
VLOG(3) << "Participants: " << participants_.size()
|
||||
<< "; initializing comms.";
|
||||
ncclResult_t result = ncclCommInitAll(comms_.data(), comms_.size(),
|
||||
/*devlist=*/ordinals.data());
|
||||
if (result != ncclSuccess) {
|
||||
Status GlobalRendezvousManager::ReinitializeNcclClique(
|
||||
const absl::flat_hash_set<int>& device_ordinals) {
|
||||
comms_.clear();
|
||||
|
||||
std::vector<int> ordinals_vec(device_ordinals.begin(), device_ordinals.end());
|
||||
std::vector<ncclComm_t> comm_vec;
|
||||
comm_vec.resize(device_ordinals.size());
|
||||
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
"Initializing nccl comms for participant devices {%s}",
|
||||
absl::StrJoin(ordinals_vec, ", "));
|
||||
ncclResult_t result = ncclCommInitAll(comm_vec.data(), comm_vec.size(),
|
||||
/*devlist=*/ordinals_vec.data());
|
||||
if (result != ncclSuccess) {
|
||||
return InternalError(
|
||||
"Failed to initialize NCCL communication channels for %d participants: "
|
||||
"%s",
|
||||
participants_.size(), ncclGetErrorString(result));
|
||||
ordinals_vec.size(), ncclGetErrorString(result));
|
||||
}
|
||||
|
||||
for (int64 i = 0; i < ordinals_vec.size(); ++i) {
|
||||
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
|
||||
ordinals_vec[i], comm_vec[i]);
|
||||
CHECK(comms_.emplace(ordinals_vec[i], Comm{comm_vec[i]}).second);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -289,6 +395,11 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant,
|
||||
<< " on device: " << participant.device_ordinal;
|
||||
void* send_buffer = participant.source_data.opaque();
|
||||
void* recv_buffer = participant.destination_data.opaque();
|
||||
VLOG(3) << absl::StreamFormat(
|
||||
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
|
||||
"datatype=ncclFloat, op=ncclSum, comm=%p, stream=%p)",
|
||||
send_buffer, recv_buffer, participant.element_count,
|
||||
static_cast<const void*>(comm), cu_stream);
|
||||
ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer,
|
||||
/*count=*/participant.element_count,
|
||||
/*datatype=*/ncclFloat,
|
||||
@ -304,6 +415,36 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void GlobalRendezvousManager::AddrefParticipatingDevice(int device_ordinal) {
|
||||
// Addref'ing a device doesn't do anything other than increment its refcount.
|
||||
// We'll update our nccl clique if necessary during the next call to
|
||||
// SubmitParticipant.
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
device_refcounts_[device_ordinal]++;
|
||||
}
|
||||
|
||||
void GlobalRendezvousManager::DecrefParticipatingDevices(
|
||||
absl::Span<const int> device_ordinals) {
|
||||
// Decref'ing devices causes us to destroy the nccl clique if any devices were
|
||||
// removed due to having refcount 0. We'll rebuild the new, smaller clique
|
||||
// during the next call to SubmitParticipant.
|
||||
tensorflow::mutex_lock lock(mutex_);
|
||||
bool removed_device = false;
|
||||
for (int device_ordinal : device_ordinals) {
|
||||
auto it = device_refcounts_.find(device_ordinal);
|
||||
CHECK(it != device_refcounts_.end());
|
||||
it->second--;
|
||||
if (it->second == 0) {
|
||||
device_refcounts_.erase(it);
|
||||
removed_device = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (removed_device) {
|
||||
comms_.clear();
|
||||
}
|
||||
}
|
||||
|
||||
static GlobalRendezvousManager* GetGlobalRendezvous() {
|
||||
static auto* manager = new GlobalRendezvousManager;
|
||||
return manager;
|
||||
@ -311,6 +452,11 @@ static GlobalRendezvousManager* GetGlobalRendezvous() {
|
||||
|
||||
} // namespace
|
||||
|
||||
/*static*/ absl::flat_hash_set<int>
|
||||
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
return GetGlobalRendezvous()->DevicesWithOpenNcclChannels();
|
||||
}
|
||||
|
||||
Status NcclAllReduceThunk::ExecuteOnStream(
|
||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||
HloExecutionProfiler* profiler) {
|
||||
@ -327,8 +473,32 @@ Status NcclAllReduceThunk::ExecuteOnStream(
|
||||
participant.stream = stream;
|
||||
participant.originator = this;
|
||||
|
||||
// We currently say that that all GPUs this thunk has ever touched are
|
||||
// "actively participating" in cross-device operations, until the thunk itself
|
||||
// is destroyed.
|
||||
//
|
||||
// This policy is an attempt to avoid thrashing the GPU (ncclCommInitAll is
|
||||
// very expensive) while also freeing resources on the GPUs when we can. The
|
||||
// idea is, creating new thunks is tantamount to running the whole XLA:GPU
|
||||
// compiler stack, so that shouldn't happen terribly often.
|
||||
bool new_device;
|
||||
{
|
||||
tensorflow::mutex_lock lock(mu_);
|
||||
new_device = devices_seen_.insert(participant.device_ordinal).second;
|
||||
}
|
||||
if (new_device) {
|
||||
GetGlobalRendezvous()->AddrefParticipatingDevice(
|
||||
participant.device_ordinal);
|
||||
}
|
||||
|
||||
return GetGlobalRendezvous()->SubmitParticipant(std::move(participant));
|
||||
}
|
||||
|
||||
NcclAllReduceThunk::~NcclAllReduceThunk() {
|
||||
GetGlobalRendezvous()->DecrefParticipatingDevices(
|
||||
std::vector<int>(devices_seen_.begin(), devices_seen_.end()));
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
Status NcclAllReduceThunk::ExecuteOnStream(
|
||||
@ -339,6 +509,13 @@ Status NcclAllReduceThunk::ExecuteOnStream(
|
||||
"compiler, which is necessary to build the NCCL source library.");
|
||||
}
|
||||
|
||||
NcclAllReduceThunk::~NcclAllReduceThunk() = default;
|
||||
|
||||
/*static*/ absl::flat_hash_set<int>
|
||||
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
return {};
|
||||
}
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -38,12 +40,21 @@ class NcclAllReduceThunk : public Thunk {
|
||||
// error.
|
||||
static bool NcclIsEnabled();
|
||||
|
||||
// Gets the set of devices that have a NCCL channel open. This is primarily
|
||||
// for testing.
|
||||
//
|
||||
// (Indeed, because the NCCL channels are a global variable, in the real
|
||||
// world, the value returned here is stale as soon as you read it, so it's not
|
||||
// clear how you *could* use it for anything other than tests.)
|
||||
static absl::flat_hash_set<int> DevicesWithOpenNcclChannels();
|
||||
|
||||
// TODO(b/125951860): Plumb more datatypes / reduction operators. Initial
|
||||
// implementation is simply F32 summation.
|
||||
NcclAllReduceThunk(int64 replica_count, int64 element_count,
|
||||
const BufferAllocation::Slice& source_buffer,
|
||||
const BufferAllocation::Slice& destination_buffer,
|
||||
const HloInstruction* all_reduce);
|
||||
~NcclAllReduceThunk() override;
|
||||
|
||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||
se::Stream* stream,
|
||||
@ -54,6 +65,10 @@ class NcclAllReduceThunk : public Thunk {
|
||||
const int64 element_count_;
|
||||
const BufferAllocation::Slice source_buffer_;
|
||||
const BufferAllocation::Slice destination_buffer_;
|
||||
|
||||
tensorflow::mutex mu_;
|
||||
// Set of GPUs that ExecuteOnStream has been called on.
|
||||
absl::flat_hash_set<int> devices_seen_ GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -273,6 +273,12 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
CreateExecutable(std::move(module), options.run_hlo_passes));
|
||||
return ExecuteReplicated(executable.get(), options, device_assignment);
|
||||
}
|
||||
|
||||
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
||||
Executable* executable, const ReplicatedExecuteOptions& options,
|
||||
DeviceAssignment* device_assignment, ExecutionProfile* profile) {
|
||||
std::vector<std::unique_ptr<se::Stream>> streams;
|
||||
std::vector<ServiceExecutableRunOptions> service_run_options;
|
||||
|
||||
|
@ -183,6 +183,15 @@ class HloRunner {
|
||||
const ReplicatedExecuteOptions& options,
|
||||
DeviceAssignment* device_assignment);
|
||||
|
||||
// Same as above, but with a reusable Executable. This may update the profile
|
||||
// information in *executable.
|
||||
//
|
||||
// Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
|
||||
// since we've already compiled the Executable.
|
||||
StatusOr<std::vector<Literal>> ExecuteReplicated(
|
||||
Executable* executable, const ReplicatedExecuteOptions& options,
|
||||
DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr);
|
||||
|
||||
// If backend is not created in the constructor, creates and returns the
|
||||
// default backend. If creation fails, crashes the program.
|
||||
//
|
||||
|
@ -14,20 +14,31 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
// Tests cross-GPU all-reduce operatons.
|
||||
//
|
||||
// This test requires multiple GPUs. For instructions on running this within
|
||||
// Google, see go/multi-gpu-unit-test.
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
class MultiDeviceAllReduceTest : public HloTestBase {};
|
||||
using ::testing::IsEmpty;
|
||||
using ::testing::UnorderedElementsAre;
|
||||
|
||||
XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
|
||||
const char* module_str = R"(
|
||||
class MultiDeviceAllReduceTest : public HloTestBase {
|
||||
protected:
|
||||
std::unique_ptr<HloModule> MakeCrsModule(int64 num_elems,
|
||||
const HloModuleConfig& config) {
|
||||
const char* kTemplate = R"(
|
||||
HloModule test
|
||||
|
||||
add {
|
||||
@ -37,12 +48,52 @@ XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
|
||||
}
|
||||
|
||||
ENTRY test_computation {
|
||||
p = f32[3] parameter(0)
|
||||
ROOT crs = f32[3] all-reduce(p), to_apply=add
|
||||
})";
|
||||
p = f32[NUM_ELEMS] parameter(0)
|
||||
ROOT crs = f32[NUM_ELEMS] all-reduce(p), to_apply=add
|
||||
}
|
||||
)";
|
||||
return ParseHloString(
|
||||
absl::StrReplaceAll(kTemplate,
|
||||
{{"NUM_ELEMS", absl::StrCat(num_elems)}}),
|
||||
config)
|
||||
.ValueOrDie();
|
||||
}
|
||||
};
|
||||
|
||||
// Returns the non-empty subsets of {0, 1, ..., n}. For example,
|
||||
// PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}.
|
||||
std::vector<std::vector<int64>> PowerSetOfIota(int64 n) {
|
||||
std::vector<std::vector<int64>> power_set;
|
||||
for (int64 i = 1; i < (1 << n); ++i) {
|
||||
power_set.emplace_back();
|
||||
for (int64 j = 0; j < n; ++j) {
|
||||
if (i & (1 << j)) {
|
||||
power_set.back().push_back(j);
|
||||
}
|
||||
}
|
||||
}
|
||||
return power_set;
|
||||
}
|
||||
|
||||
// Makes a DeviceAssignment assigning replica-id i to devices[i].
|
||||
DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
|
||||
DeviceAssignment assn(/*replica_count=*/devices.size(),
|
||||
/*computation_count=*/1);
|
||||
for (int64 i = 0; i < devices.size(); ++i) {
|
||||
assn(i, 0) = devices[i];
|
||||
}
|
||||
return assn;
|
||||
}
|
||||
|
||||
// Shorter alias for this function.
|
||||
absl::flat_hash_set<int> OpenNcclChannels() {
|
||||
return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels();
|
||||
}
|
||||
|
||||
XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
|
||||
auto config = GetModuleConfigForTest();
|
||||
config.set_replica_count(2);
|
||||
auto module = ParseHloString(module_str, config).ValueOrDie();
|
||||
auto module = MakeCrsModule(/*num_elems=*/3, config);
|
||||
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||
auto expected = LiteralUtil::CreateR1<float>({2, 4, 6});
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
||||
@ -52,5 +103,112 @@ XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
|
||||
EXPECT_EQ(expected, results[1]);
|
||||
}
|
||||
|
||||
// Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
|
||||
// devices in sequence.
|
||||
XLA_TEST_F(MultiDeviceAllReduceTest, AllCombinations) {
|
||||
const int64 kNumDevices = 4;
|
||||
const int64 kNumElems = 1024;
|
||||
|
||||
for (std::vector<int64> devices : PowerSetOfIota(kNumDevices)) {
|
||||
SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
|
||||
absl::StrJoin(devices, ", ")));
|
||||
|
||||
DeviceAssignment device_assn = MakeDeviceAssn(devices);
|
||||
|
||||
auto config = GetModuleConfigForTest();
|
||||
config.set_replica_count(devices.size());
|
||||
config.set_static_device_assignment(device_assn);
|
||||
|
||||
auto module = MakeCrsModule(kNumElems, config);
|
||||
|
||||
std::vector<float> input_vec(kNumElems);
|
||||
absl::c_iota(input_vec, 0);
|
||||
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::vector<Literal> results,
|
||||
ExecuteReplicated(std::move(module), {&input_literal},
|
||||
/*num_replicas=*/devices.size(), &device_assn,
|
||||
/*run_hlo_passes=*/true, /*use_threads=*/true));
|
||||
}
|
||||
}
|
||||
|
||||
// Check that the NCCL data structures in our all-reduce implementation are
|
||||
// cached as we expect.
|
||||
XLA_TEST_F(MultiDeviceAllReduceTest, NcclChannelCaching) {
|
||||
const int64 kNumElems = 1024;
|
||||
|
||||
std::vector<float> input_vec(kNumElems);
|
||||
absl::c_iota(input_vec, 0);
|
||||
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
|
||||
|
||||
// Initially no NCCL channels should be open.
|
||||
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
|
||||
|
||||
// Create three Executables, touching devices {0,1}, {1,2}, and {0,1,2}.
|
||||
struct ExecutableInfo {
|
||||
std::unique_ptr<Executable> executable;
|
||||
DeviceAssignment device_assn;
|
||||
HloRunner::ReplicatedExecuteOptions opts;
|
||||
};
|
||||
std::vector<ExecutableInfo> executables;
|
||||
for (const auto& devices :
|
||||
std::vector<std::vector<int64>>{{0, 1}, {1, 2}, {0, 1, 2}}) {
|
||||
executables.emplace_back();
|
||||
auto& e = executables.back();
|
||||
|
||||
e.device_assn = MakeDeviceAssn(devices);
|
||||
|
||||
auto config = GetModuleConfigForTest();
|
||||
config.set_replica_count(devices.size());
|
||||
config.set_static_device_assignment(e.device_assn);
|
||||
auto module = MakeCrsModule(kNumElems, config);
|
||||
e.executable =
|
||||
test_runner_
|
||||
.CreateExecutable(std::move(module), /*run_hlo_passes=*/true)
|
||||
.ValueOrDie();
|
||||
|
||||
e.opts.num_replicas = devices.size();
|
||||
e.opts.use_threads = true;
|
||||
e.opts.arguments.push_back(&input_literal);
|
||||
}
|
||||
|
||||
auto run_executable = [&](int64 i) {
|
||||
auto& e = executables[i];
|
||||
TF_ASSERT_OK(
|
||||
test_runner_
|
||||
.ExecuteReplicated(e.executable.get(), e.opts, &e.device_assn)
|
||||
.status());
|
||||
};
|
||||
|
||||
// Compiling executables above shouldn't cause us to open any channels.
|
||||
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
|
||||
|
||||
// Run the executables and check that channels are opened as we expect.
|
||||
run_executable(0);
|
||||
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1));
|
||||
|
||||
run_executable(2);
|
||||
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
|
||||
|
||||
run_executable(1);
|
||||
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
|
||||
|
||||
// Tear down the executables and check that channels are closed as we expect.
|
||||
// Note that after we tear down an executable *all* the nccl channels may go
|
||||
// away, so we rerun all of the executables that haven't been torn down.
|
||||
executables[2].executable.reset();
|
||||
run_executable(0);
|
||||
run_executable(1);
|
||||
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
|
||||
|
||||
executables[0].executable.reset();
|
||||
run_executable(1);
|
||||
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(1, 2));
|
||||
|
||||
executables[1].executable.reset();
|
||||
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user