[XLA:GPU] Improve performance of AllReduce by caching nccl channels.
Previously we'd tear down nccl communication channels after every all-reduce operation. Now we keep the channel alive for what we hope will be a reasonable period of time. PiperOrigin-RevId: 246746547
This commit is contained in:
parent
274062297a
commit
acc20217e6
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "absl/synchronization/blocking_counter.h"
|
#include "absl/synchronization/blocking_counter.h"
|
||||||
#include "third_party/nccl/nccl.h"
|
#include "third_party/nccl/nccl.h"
|
||||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||||
@ -76,6 +77,42 @@ struct ParticipantData {
|
|||||||
// This manager is responsible for establishing communication channels and
|
// This manager is responsible for establishing communication channels and
|
||||||
// ultimately enqueueing the NCCL library operation onto the participating
|
// ultimately enqueueing the NCCL library operation onto the participating
|
||||||
// streams.
|
// streams.
|
||||||
|
//
|
||||||
|
// Implementation note: We make an effort to avoid initializing nccl
|
||||||
|
// communciation channels too often, as this is expensive.
|
||||||
|
//
|
||||||
|
// Ideally, we'd set up a nccl channel between each pair of devices that needs
|
||||||
|
// to communicate, and close each channel when the GPUs won't be communicating
|
||||||
|
// again "for a long time" (because channels hold memory on the GPU). As a
|
||||||
|
// simplification to this ideal, we adopt the following policy.
|
||||||
|
//
|
||||||
|
// - We maintain a set of GPUs that are "actively participating" in
|
||||||
|
// cross-device communications. That set of GPUs is always connected as a
|
||||||
|
// clique, using ncclCommInitAll.
|
||||||
|
//
|
||||||
|
// - When a NcclAllReduceThunk touches a new GPU, we tear down the old clique
|
||||||
|
// and build a new, bigger one.
|
||||||
|
//
|
||||||
|
// - All GPUs ever touched by a thunk are considered "actively in use" by that
|
||||||
|
// thunk until the thunk is destroyed. Destroying the thunk decrements the
|
||||||
|
// refcount of the GPUs it's touched, and if that refcount goes to 0
|
||||||
|
// (meaning, some GPUs are no longer in use by any thunk), we tear down the
|
||||||
|
// clique and build a new, smaller one.
|
||||||
|
//
|
||||||
|
// This approximation is justified because:
|
||||||
|
//
|
||||||
|
// - Currently the only collective operation we support is AllReduce, which
|
||||||
|
// requires a clique. When we support point-to-point operations, we may not
|
||||||
|
// want to build a communication clique.
|
||||||
|
//
|
||||||
|
// - Tearing down and creating a new thunk is tantamount to running the whole
|
||||||
|
// XLA:GPU compiler. This is expensive, so shouldn't happen "too often" to
|
||||||
|
// cause thrashing here.
|
||||||
|
//
|
||||||
|
// - XLA executables already keep resources on the GPU tied to the lifetime of
|
||||||
|
// the executable (e.g. constants stored in GPU memory), so tying the
|
||||||
|
// lifetime of the nccl communication channels to the lifetime of the
|
||||||
|
// executable is consistent.
|
||||||
class GlobalRendezvousManager {
|
class GlobalRendezvousManager {
|
||||||
public:
|
public:
|
||||||
// The GpuExecutable-executing threads call this in order to a) establish the
|
// The GpuExecutable-executing threads call this in order to a) establish the
|
||||||
@ -98,18 +135,38 @@ class GlobalRendezvousManager {
|
|||||||
return current_generation_;
|
return current_generation_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
// Increments the refcount of a GPU in our accounting of which devices are
|
||||||
// Called by the primary thread to set up the communication links.
|
// "actively participating" in cross-device operations.
|
||||||
//
|
//
|
||||||
// TODO(b/125951860): This performs lots of (presumably) unnecessary host-side
|
// This doesn't actually do anything other than increment the refcount. If
|
||||||
// synchronization so that we can be paranoid about semantics in the earliest
|
// the GPU added here is novel, we'll rebuild the nccl communication clique
|
||||||
// implementation. In the limit we should only need to synchronize host
|
// when we actually go do the communication.
|
||||||
// replica threads when the "number of replicas" or "participating device
|
void AddrefParticipatingDevice(int device_ordinal);
|
||||||
// ordinals" change, to set up a new NCCL "communication" context, at which
|
|
||||||
// point we can enqueue onto device streams without host synchronization in
|
// Decrements the refcount of a set of GPUs in our accounting of which devices
|
||||||
// our code -- this will likely be helpful for "lots of little AllReduce"
|
// are "actively participating" in cross-device operations.
|
||||||
// cases.
|
//
|
||||||
Status InitializeCommunicationChannels() EXCLUSIVE_LOCKS_REQUIRED(mutex_);
|
// If one or more GPUs' refcounts to go 0, we immediately destroy the whole
|
||||||
|
// nccl communication clique. We'll rebuild a new, smaller clique the next
|
||||||
|
// time it's used.
|
||||||
|
void DecrefParticipatingDevices(absl::Span<const int> device_ordinals);
|
||||||
|
|
||||||
|
// Gets the set of devices that have a NCCL channel currently open. This is
|
||||||
|
// primarily for testing.
|
||||||
|
absl::flat_hash_set<int> DevicesWithOpenNcclChannels() const {
|
||||||
|
absl::flat_hash_set<int> devices;
|
||||||
|
tensorflow::mutex_lock lock(mutex_);
|
||||||
|
for (const auto& kv : comms_) {
|
||||||
|
devices.insert(kv.first);
|
||||||
|
}
|
||||||
|
return devices;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Destroys the current nccl communication clique and builds a new one
|
||||||
|
// connecting the given devices.
|
||||||
|
Status ReinitializeNcclClique(const absl::flat_hash_set<int>& device_ordinals)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(mutex_);
|
||||||
|
|
||||||
// Called when all necessary participants are present, the functionality
|
// Called when all necessary participants are present, the functionality
|
||||||
// that's implemented by all executing threads lives in here.
|
// that's implemented by all executing threads lives in here.
|
||||||
@ -118,28 +175,51 @@ class GlobalRendezvousManager {
|
|||||||
// Puts all state back into a "reset" state for the next generation of
|
// Puts all state back into a "reset" state for the next generation of
|
||||||
// AllReduce requests.
|
// AllReduce requests.
|
||||||
void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
void DeinitializeGeneration() EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
|
||||||
for (ncclComm_t& comm : comms_) {
|
|
||||||
ncclCommDestroy(comm);
|
|
||||||
}
|
|
||||||
comms_.clear();
|
|
||||||
participants_.clear();
|
participants_.clear();
|
||||||
current_generation_++;
|
current_generation_++;
|
||||||
initialized_ = false;
|
initialized_ = false;
|
||||||
done_ = absl::nullopt;
|
done_ = absl::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
tensorflow::mutex mutex_;
|
mutable tensorflow::mutex mutex_;
|
||||||
tensorflow::condition_variable all_participants_present_;
|
tensorflow::condition_variable all_participants_present_;
|
||||||
tensorflow::condition_variable deinitialized_;
|
tensorflow::condition_variable deinitialized_;
|
||||||
|
|
||||||
// Communication handles that correspond to the participants below.
|
|
||||||
std::vector<ncclComm_t> comms_ GUARDED_BY(mutex_);
|
|
||||||
|
|
||||||
Status initialize_status_ GUARDED_BY(mutex_);
|
Status initialize_status_ GUARDED_BY(mutex_);
|
||||||
std::vector<ParticipantData> participants_ GUARDED_BY(mutex_);
|
std::vector<ParticipantData> participants_ GUARDED_BY(mutex_);
|
||||||
int64 current_generation_ GUARDED_BY(mutex_) = 0;
|
int64 current_generation_ GUARDED_BY(mutex_) = 0;
|
||||||
bool initialized_ GUARDED_BY(mutex_) = false;
|
bool initialized_ GUARDED_BY(mutex_) = false;
|
||||||
|
|
||||||
|
struct Comm {
|
||||||
|
explicit Comm(ncclComm_t nccl_comm) : nccl_comm(nccl_comm) {}
|
||||||
|
|
||||||
|
// Movable, but not copyable.
|
||||||
|
Comm(Comm&& c) : nccl_comm(c.nccl_comm) { c.nccl_comm.reset(); }
|
||||||
|
Comm& operator=(Comm&& c) {
|
||||||
|
nccl_comm = c.nccl_comm;
|
||||||
|
c.nccl_comm.reset();
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
Comm(const Comm&) = delete;
|
||||||
|
Comm& operator=(const Comm&) = delete;
|
||||||
|
|
||||||
|
absl::optional<ncclComm_t> nccl_comm;
|
||||||
|
|
||||||
|
~Comm() {
|
||||||
|
if (nccl_comm.has_value()) {
|
||||||
|
VLOG(3) << absl::StreamFormat("Destroying comm %p", *nccl_comm);
|
||||||
|
ncclCommDestroy(*nccl_comm);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
// Communication handles for our NCCL clique. Key is device ordinal.
|
||||||
|
absl::flat_hash_map<int, Comm> comms_ GUARDED_BY(mutex_);
|
||||||
|
|
||||||
|
// Refcounts of which devices are "actively participating" in all-reduces.
|
||||||
|
// These devices don't necessarily have an open comm, but the next time we run
|
||||||
|
// an operation, we'll create a NCCL clique between all of them.
|
||||||
|
absl::flat_hash_map<int, int64> device_refcounts_ GUARDED_BY(mutex_);
|
||||||
|
|
||||||
// The participating threads wait for this to count down in order to know we
|
// The participating threads wait for this to count down in order to know we
|
||||||
// can begin the teardown process.
|
// can begin the teardown process.
|
||||||
absl::optional<tensorflow::BlockingCounter> done_;
|
absl::optional<tensorflow::BlockingCounter> done_;
|
||||||
@ -151,11 +231,6 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
|
|||||||
return participants_.size() >= participant.replica_count;
|
return participants_.size() >= participant.replica_count;
|
||||||
};
|
};
|
||||||
|
|
||||||
// We remember the participant index at which we are inserted and use that
|
|
||||||
// same index for referring to auxiliary metadata (e.g. the ncclComm_t handle
|
|
||||||
// index) below.
|
|
||||||
int64 index;
|
|
||||||
|
|
||||||
{
|
{
|
||||||
tensorflow::mutex_lock lock(mutex_);
|
tensorflow::mutex_lock lock(mutex_);
|
||||||
|
|
||||||
@ -171,7 +246,6 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
|
|||||||
"participants; existing: %s; submitted: %s)",
|
"participants; existing: %s; submitted: %s)",
|
||||||
participants_.back().ToString(), participant.ToString());
|
participants_.back().ToString(), participant.ToString());
|
||||||
}
|
}
|
||||||
index = participants_.size();
|
|
||||||
participants_.push_back(participant);
|
participants_.push_back(participant);
|
||||||
|
|
||||||
if (all_participants_present()) {
|
if (all_participants_present()) {
|
||||||
@ -205,20 +279,44 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
|
|||||||
VLOG(3) << "Primary initializing accounting data.";
|
VLOG(3) << "Primary initializing accounting data.";
|
||||||
initialized_ = true;
|
initialized_ = true;
|
||||||
done_.emplace(participant.replica_count);
|
done_.emplace(participant.replica_count);
|
||||||
initialize_status_ = InitializeCommunicationChannels();
|
|
||||||
|
// Check if all participants_ are in comms_. If not, we will rebuild the
|
||||||
|
// clique to include them. (This can't be spelled using absl::c_any_of
|
||||||
|
// because it needs to touch comms_ and tensorflow::mutex lacks an
|
||||||
|
// AssertHeld() function that would let us assert that the lambda is run
|
||||||
|
// while holding the lock.)
|
||||||
|
bool new_devices_found = false;
|
||||||
|
for (const auto& p : participants_) {
|
||||||
|
if (!comms_.contains(p.device_ordinal)) {
|
||||||
|
new_devices_found = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (new_devices_found) {
|
||||||
|
absl::flat_hash_set<int> new_clique_device_ordinals;
|
||||||
|
for (const auto& kv : comms_) {
|
||||||
|
new_clique_device_ordinals.insert(kv.first);
|
||||||
|
}
|
||||||
|
for (const auto& p : participants_) {
|
||||||
|
new_clique_device_ordinals.insert(p.device_ordinal);
|
||||||
|
}
|
||||||
|
|
||||||
|
initialize_status_ = ReinitializeNcclClique(new_clique_device_ordinals);
|
||||||
VLOG(3) << "Done initializing communication channels; status: "
|
VLOG(3) << "Done initializing communication channels; status: "
|
||||||
<< initialize_status_;
|
<< initialize_status_;
|
||||||
if (!initialize_status_.ok()) {
|
if (!initialize_status_.ok()) {
|
||||||
DeinitializeGeneration();
|
DeinitializeGeneration();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!initialize_status_.ok()) {
|
if (!initialize_status_.ok()) {
|
||||||
// TODO(b/125951860): If this fails once, it will fail forever.
|
// TODO(b/125951860): If this fails once, it will fail forever.
|
||||||
return initialize_status_;
|
return initialize_status_;
|
||||||
}
|
}
|
||||||
|
|
||||||
comm = comms_[index];
|
comm = *comms_.at(participant.device_ordinal).nccl_comm;
|
||||||
|
|
||||||
// Drop the lock at the end of scope so other participants may enter.
|
// Drop the lock at the end of scope so other participants may enter.
|
||||||
}
|
}
|
||||||
@ -259,22 +357,30 @@ Status GlobalRendezvousManager::SubmitParticipant(ParticipantData participant) {
|
|||||||
return all_reduce_status;
|
return all_reduce_status;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GlobalRendezvousManager::InitializeCommunicationChannels() {
|
Status GlobalRendezvousManager::ReinitializeNcclClique(
|
||||||
std::vector<int> ordinals;
|
const absl::flat_hash_set<int>& device_ordinals) {
|
||||||
for (ParticipantData& data : participants_) {
|
|
||||||
ordinals.push_back(data.device_ordinal);
|
|
||||||
}
|
|
||||||
comms_.resize(ordinals.size());
|
|
||||||
VLOG(3) << "Participants: " << participants_.size()
|
|
||||||
<< "; initializing comms.";
|
|
||||||
ncclResult_t result = ncclCommInitAll(comms_.data(), comms_.size(),
|
|
||||||
/*devlist=*/ordinals.data());
|
|
||||||
if (result != ncclSuccess) {
|
|
||||||
comms_.clear();
|
comms_.clear();
|
||||||
|
|
||||||
|
std::vector<int> ordinals_vec(device_ordinals.begin(), device_ordinals.end());
|
||||||
|
std::vector<ncclComm_t> comm_vec;
|
||||||
|
comm_vec.resize(device_ordinals.size());
|
||||||
|
|
||||||
|
VLOG(3) << absl::StreamFormat(
|
||||||
|
"Initializing nccl comms for participant devices {%s}",
|
||||||
|
absl::StrJoin(ordinals_vec, ", "));
|
||||||
|
ncclResult_t result = ncclCommInitAll(comm_vec.data(), comm_vec.size(),
|
||||||
|
/*devlist=*/ordinals_vec.data());
|
||||||
|
if (result != ncclSuccess) {
|
||||||
return InternalError(
|
return InternalError(
|
||||||
"Failed to initialize NCCL communication channels for %d participants: "
|
"Failed to initialize NCCL communication channels for %d participants: "
|
||||||
"%s",
|
"%s",
|
||||||
participants_.size(), ncclGetErrorString(result));
|
ordinals_vec.size(), ncclGetErrorString(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64 i = 0; i < ordinals_vec.size(); ++i) {
|
||||||
|
VLOG(3) << absl::StreamFormat("Device ordinal %d assigned ncclComm %p",
|
||||||
|
ordinals_vec[i], comm_vec[i]);
|
||||||
|
CHECK(comms_.emplace(ordinals_vec[i], Comm{comm_vec[i]}).second);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -289,6 +395,11 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant,
|
|||||||
<< " on device: " << participant.device_ordinal;
|
<< " on device: " << participant.device_ordinal;
|
||||||
void* send_buffer = participant.source_data.opaque();
|
void* send_buffer = participant.source_data.opaque();
|
||||||
void* recv_buffer = participant.destination_data.opaque();
|
void* recv_buffer = participant.destination_data.opaque();
|
||||||
|
VLOG(3) << absl::StreamFormat(
|
||||||
|
"Calling ncclAllReduce(send_buffer=%p, recv_buffer=%p, count=%d, "
|
||||||
|
"datatype=ncclFloat, op=ncclSum, comm=%p, stream=%p)",
|
||||||
|
send_buffer, recv_buffer, participant.element_count,
|
||||||
|
static_cast<const void*>(comm), cu_stream);
|
||||||
ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer,
|
ncclResult_t result = ncclAllReduce(send_buffer, recv_buffer,
|
||||||
/*count=*/participant.element_count,
|
/*count=*/participant.element_count,
|
||||||
/*datatype=*/ncclFloat,
|
/*datatype=*/ncclFloat,
|
||||||
@ -304,6 +415,36 @@ Status GlobalRendezvousManager::DoAllReduce(ParticipantData participant,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GlobalRendezvousManager::AddrefParticipatingDevice(int device_ordinal) {
|
||||||
|
// Addref'ing a device doesn't do anything other than increment its refcount.
|
||||||
|
// We'll update our nccl clique if necessary during the next call to
|
||||||
|
// SubmitParticipant.
|
||||||
|
tensorflow::mutex_lock lock(mutex_);
|
||||||
|
device_refcounts_[device_ordinal]++;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GlobalRendezvousManager::DecrefParticipatingDevices(
|
||||||
|
absl::Span<const int> device_ordinals) {
|
||||||
|
// Decref'ing devices causes us to destroy the nccl clique if any devices were
|
||||||
|
// removed due to having refcount 0. We'll rebuild the new, smaller clique
|
||||||
|
// during the next call to SubmitParticipant.
|
||||||
|
tensorflow::mutex_lock lock(mutex_);
|
||||||
|
bool removed_device = false;
|
||||||
|
for (int device_ordinal : device_ordinals) {
|
||||||
|
auto it = device_refcounts_.find(device_ordinal);
|
||||||
|
CHECK(it != device_refcounts_.end());
|
||||||
|
it->second--;
|
||||||
|
if (it->second == 0) {
|
||||||
|
device_refcounts_.erase(it);
|
||||||
|
removed_device = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (removed_device) {
|
||||||
|
comms_.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static GlobalRendezvousManager* GetGlobalRendezvous() {
|
static GlobalRendezvousManager* GetGlobalRendezvous() {
|
||||||
static auto* manager = new GlobalRendezvousManager;
|
static auto* manager = new GlobalRendezvousManager;
|
||||||
return manager;
|
return manager;
|
||||||
@ -311,6 +452,11 @@ static GlobalRendezvousManager* GetGlobalRendezvous() {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
/*static*/ absl::flat_hash_set<int>
|
||||||
|
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||||
|
return GetGlobalRendezvous()->DevicesWithOpenNcclChannels();
|
||||||
|
}
|
||||||
|
|
||||||
Status NcclAllReduceThunk::ExecuteOnStream(
|
Status NcclAllReduceThunk::ExecuteOnStream(
|
||||||
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
const BufferAllocations& buffer_allocations, se::Stream* stream,
|
||||||
HloExecutionProfiler* profiler) {
|
HloExecutionProfiler* profiler) {
|
||||||
@ -327,8 +473,32 @@ Status NcclAllReduceThunk::ExecuteOnStream(
|
|||||||
participant.stream = stream;
|
participant.stream = stream;
|
||||||
participant.originator = this;
|
participant.originator = this;
|
||||||
|
|
||||||
|
// We currently say that that all GPUs this thunk has ever touched are
|
||||||
|
// "actively participating" in cross-device operations, until the thunk itself
|
||||||
|
// is destroyed.
|
||||||
|
//
|
||||||
|
// This policy is an attempt to avoid thrashing the GPU (ncclCommInitAll is
|
||||||
|
// very expensive) while also freeing resources on the GPUs when we can. The
|
||||||
|
// idea is, creating new thunks is tantamount to running the whole XLA:GPU
|
||||||
|
// compiler stack, so that shouldn't happen terribly often.
|
||||||
|
bool new_device;
|
||||||
|
{
|
||||||
|
tensorflow::mutex_lock lock(mu_);
|
||||||
|
new_device = devices_seen_.insert(participant.device_ordinal).second;
|
||||||
|
}
|
||||||
|
if (new_device) {
|
||||||
|
GetGlobalRendezvous()->AddrefParticipatingDevice(
|
||||||
|
participant.device_ordinal);
|
||||||
|
}
|
||||||
|
|
||||||
return GetGlobalRendezvous()->SubmitParticipant(std::move(participant));
|
return GetGlobalRendezvous()->SubmitParticipant(std::move(participant));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NcclAllReduceThunk::~NcclAllReduceThunk() {
|
||||||
|
GetGlobalRendezvous()->DecrefParticipatingDevices(
|
||||||
|
std::vector<int>(devices_seen_.begin(), devices_seen_.end()));
|
||||||
|
}
|
||||||
|
|
||||||
#else
|
#else
|
||||||
|
|
||||||
Status NcclAllReduceThunk::ExecuteOnStream(
|
Status NcclAllReduceThunk::ExecuteOnStream(
|
||||||
@ -339,6 +509,13 @@ Status NcclAllReduceThunk::ExecuteOnStream(
|
|||||||
"compiler, which is necessary to build the NCCL source library.");
|
"compiler, which is necessary to build the NCCL source library.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
NcclAllReduceThunk::~NcclAllReduceThunk() = default;
|
||||||
|
|
||||||
|
/*static*/ absl::flat_hash_set<int>
|
||||||
|
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||||
|
@ -16,11 +16,13 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -38,12 +40,21 @@ class NcclAllReduceThunk : public Thunk {
|
|||||||
// error.
|
// error.
|
||||||
static bool NcclIsEnabled();
|
static bool NcclIsEnabled();
|
||||||
|
|
||||||
|
// Gets the set of devices that have a NCCL channel open. This is primarily
|
||||||
|
// for testing.
|
||||||
|
//
|
||||||
|
// (Indeed, because the NCCL channels are a global variable, in the real
|
||||||
|
// world, the value returned here is stale as soon as you read it, so it's not
|
||||||
|
// clear how you *could* use it for anything other than tests.)
|
||||||
|
static absl::flat_hash_set<int> DevicesWithOpenNcclChannels();
|
||||||
|
|
||||||
// TODO(b/125951860): Plumb more datatypes / reduction operators. Initial
|
// TODO(b/125951860): Plumb more datatypes / reduction operators. Initial
|
||||||
// implementation is simply F32 summation.
|
// implementation is simply F32 summation.
|
||||||
NcclAllReduceThunk(int64 replica_count, int64 element_count,
|
NcclAllReduceThunk(int64 replica_count, int64 element_count,
|
||||||
const BufferAllocation::Slice& source_buffer,
|
const BufferAllocation::Slice& source_buffer,
|
||||||
const BufferAllocation::Slice& destination_buffer,
|
const BufferAllocation::Slice& destination_buffer,
|
||||||
const HloInstruction* all_reduce);
|
const HloInstruction* all_reduce);
|
||||||
|
~NcclAllReduceThunk() override;
|
||||||
|
|
||||||
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
|
||||||
se::Stream* stream,
|
se::Stream* stream,
|
||||||
@ -54,6 +65,10 @@ class NcclAllReduceThunk : public Thunk {
|
|||||||
const int64 element_count_;
|
const int64 element_count_;
|
||||||
const BufferAllocation::Slice source_buffer_;
|
const BufferAllocation::Slice source_buffer_;
|
||||||
const BufferAllocation::Slice destination_buffer_;
|
const BufferAllocation::Slice destination_buffer_;
|
||||||
|
|
||||||
|
tensorflow::mutex mu_;
|
||||||
|
// Set of GPUs that ExecuteOnStream has been called on.
|
||||||
|
absl::flat_hash_set<int> devices_seen_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -273,6 +273,12 @@ StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
|||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<Executable> executable,
|
std::unique_ptr<Executable> executable,
|
||||||
CreateExecutable(std::move(module), options.run_hlo_passes));
|
CreateExecutable(std::move(module), options.run_hlo_passes));
|
||||||
|
return ExecuteReplicated(executable.get(), options, device_assignment);
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<std::vector<Literal>> HloRunner::ExecuteReplicated(
|
||||||
|
Executable* executable, const ReplicatedExecuteOptions& options,
|
||||||
|
DeviceAssignment* device_assignment, ExecutionProfile* profile) {
|
||||||
std::vector<std::unique_ptr<se::Stream>> streams;
|
std::vector<std::unique_ptr<se::Stream>> streams;
|
||||||
std::vector<ServiceExecutableRunOptions> service_run_options;
|
std::vector<ServiceExecutableRunOptions> service_run_options;
|
||||||
|
|
||||||
|
@ -183,6 +183,15 @@ class HloRunner {
|
|||||||
const ReplicatedExecuteOptions& options,
|
const ReplicatedExecuteOptions& options,
|
||||||
DeviceAssignment* device_assignment);
|
DeviceAssignment* device_assignment);
|
||||||
|
|
||||||
|
// Same as above, but with a reusable Executable. This may update the profile
|
||||||
|
// information in *executable.
|
||||||
|
//
|
||||||
|
// Note that this call ignores ReplicatedExecutionOptions::run_hlo_passes,
|
||||||
|
// since we've already compiled the Executable.
|
||||||
|
StatusOr<std::vector<Literal>> ExecuteReplicated(
|
||||||
|
Executable* executable, const ReplicatedExecuteOptions& options,
|
||||||
|
DeviceAssignment* device_assignment, ExecutionProfile* profile = nullptr);
|
||||||
|
|
||||||
// If backend is not created in the constructor, creates and returns the
|
// If backend is not created in the constructor, creates and returns the
|
||||||
// default backend. If creation fails, crashes the program.
|
// default backend. If creation fails, crashes the program.
|
||||||
//
|
//
|
||||||
|
@ -14,20 +14,31 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/literal.h"
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
|
// Tests cross-GPU all-reduce operatons.
|
||||||
|
//
|
||||||
|
// This test requires multiple GPUs. For instructions on running this within
|
||||||
|
// Google, see go/multi-gpu-unit-test.
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class MultiDeviceAllReduceTest : public HloTestBase {};
|
using ::testing::IsEmpty;
|
||||||
|
using ::testing::UnorderedElementsAre;
|
||||||
|
|
||||||
XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
|
class MultiDeviceAllReduceTest : public HloTestBase {
|
||||||
const char* module_str = R"(
|
protected:
|
||||||
|
std::unique_ptr<HloModule> MakeCrsModule(int64 num_elems,
|
||||||
|
const HloModuleConfig& config) {
|
||||||
|
const char* kTemplate = R"(
|
||||||
HloModule test
|
HloModule test
|
||||||
|
|
||||||
add {
|
add {
|
||||||
@ -37,12 +48,52 @@ XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ENTRY test_computation {
|
ENTRY test_computation {
|
||||||
p = f32[3] parameter(0)
|
p = f32[NUM_ELEMS] parameter(0)
|
||||||
ROOT crs = f32[3] all-reduce(p), to_apply=add
|
ROOT crs = f32[NUM_ELEMS] all-reduce(p), to_apply=add
|
||||||
})";
|
}
|
||||||
|
)";
|
||||||
|
return ParseHloString(
|
||||||
|
absl::StrReplaceAll(kTemplate,
|
||||||
|
{{"NUM_ELEMS", absl::StrCat(num_elems)}}),
|
||||||
|
config)
|
||||||
|
.ValueOrDie();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns the non-empty subsets of {0, 1, ..., n}. For example,
|
||||||
|
// PowerSetOfIota(3) = {{0}, {1}, {2}, {0,1}, {0,2}, {1,2}, {0,1,2}}.
|
||||||
|
std::vector<std::vector<int64>> PowerSetOfIota(int64 n) {
|
||||||
|
std::vector<std::vector<int64>> power_set;
|
||||||
|
for (int64 i = 1; i < (1 << n); ++i) {
|
||||||
|
power_set.emplace_back();
|
||||||
|
for (int64 j = 0; j < n; ++j) {
|
||||||
|
if (i & (1 << j)) {
|
||||||
|
power_set.back().push_back(j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return power_set;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Makes a DeviceAssignment assigning replica-id i to devices[i].
|
||||||
|
DeviceAssignment MakeDeviceAssn(std::vector<int64> devices) {
|
||||||
|
DeviceAssignment assn(/*replica_count=*/devices.size(),
|
||||||
|
/*computation_count=*/1);
|
||||||
|
for (int64 i = 0; i < devices.size(); ++i) {
|
||||||
|
assn(i, 0) = devices[i];
|
||||||
|
}
|
||||||
|
return assn;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shorter alias for this function.
|
||||||
|
absl::flat_hash_set<int> OpenNcclChannels() {
|
||||||
|
return gpu::NcclAllReduceThunk::DevicesWithOpenNcclChannels();
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
|
||||||
auto config = GetModuleConfigForTest();
|
auto config = GetModuleConfigForTest();
|
||||||
config.set_replica_count(2);
|
config.set_replica_count(2);
|
||||||
auto module = ParseHloString(module_str, config).ValueOrDie();
|
auto module = MakeCrsModule(/*num_elems=*/3, config);
|
||||||
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
|
auto literal = LiteralUtil::CreateR1<float>({1, 2, 3});
|
||||||
auto expected = LiteralUtil::CreateR1<float>({2, 4, 6});
|
auto expected = LiteralUtil::CreateR1<float>({2, 4, 6});
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
TF_ASSERT_OK_AND_ASSIGN(std::vector<Literal> results,
|
||||||
@ -52,5 +103,112 @@ XLA_TEST_F(MultiDeviceAllReduceTest, TwoReplicasOneOperand) {
|
|||||||
EXPECT_EQ(expected, results[1]);
|
EXPECT_EQ(expected, results[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tries all-to-all operations across all 2^kNumDevices - 1 combinations of
|
||||||
|
// devices in sequence.
|
||||||
|
XLA_TEST_F(MultiDeviceAllReduceTest, AllCombinations) {
|
||||||
|
const int64 kNumDevices = 4;
|
||||||
|
const int64 kNumElems = 1024;
|
||||||
|
|
||||||
|
for (std::vector<int64> devices : PowerSetOfIota(kNumDevices)) {
|
||||||
|
SCOPED_TRACE(absl::StrFormat("Running on devices {%s}",
|
||||||
|
absl::StrJoin(devices, ", ")));
|
||||||
|
|
||||||
|
DeviceAssignment device_assn = MakeDeviceAssn(devices);
|
||||||
|
|
||||||
|
auto config = GetModuleConfigForTest();
|
||||||
|
config.set_replica_count(devices.size());
|
||||||
|
config.set_static_device_assignment(device_assn);
|
||||||
|
|
||||||
|
auto module = MakeCrsModule(kNumElems, config);
|
||||||
|
|
||||||
|
std::vector<float> input_vec(kNumElems);
|
||||||
|
absl::c_iota(input_vec, 0);
|
||||||
|
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::vector<Literal> results,
|
||||||
|
ExecuteReplicated(std::move(module), {&input_literal},
|
||||||
|
/*num_replicas=*/devices.size(), &device_assn,
|
||||||
|
/*run_hlo_passes=*/true, /*use_threads=*/true));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the NCCL data structures in our all-reduce implementation are
|
||||||
|
// cached as we expect.
|
||||||
|
XLA_TEST_F(MultiDeviceAllReduceTest, NcclChannelCaching) {
|
||||||
|
const int64 kNumElems = 1024;
|
||||||
|
|
||||||
|
std::vector<float> input_vec(kNumElems);
|
||||||
|
absl::c_iota(input_vec, 0);
|
||||||
|
auto input_literal = LiteralUtil::CreateR1<float>(input_vec);
|
||||||
|
|
||||||
|
// Initially no NCCL channels should be open.
|
||||||
|
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
|
||||||
|
|
||||||
|
// Create three Executables, touching devices {0,1}, {1,2}, and {0,1,2}.
|
||||||
|
struct ExecutableInfo {
|
||||||
|
std::unique_ptr<Executable> executable;
|
||||||
|
DeviceAssignment device_assn;
|
||||||
|
HloRunner::ReplicatedExecuteOptions opts;
|
||||||
|
};
|
||||||
|
std::vector<ExecutableInfo> executables;
|
||||||
|
for (const auto& devices :
|
||||||
|
std::vector<std::vector<int64>>{{0, 1}, {1, 2}, {0, 1, 2}}) {
|
||||||
|
executables.emplace_back();
|
||||||
|
auto& e = executables.back();
|
||||||
|
|
||||||
|
e.device_assn = MakeDeviceAssn(devices);
|
||||||
|
|
||||||
|
auto config = GetModuleConfigForTest();
|
||||||
|
config.set_replica_count(devices.size());
|
||||||
|
config.set_static_device_assignment(e.device_assn);
|
||||||
|
auto module = MakeCrsModule(kNumElems, config);
|
||||||
|
e.executable =
|
||||||
|
test_runner_
|
||||||
|
.CreateExecutable(std::move(module), /*run_hlo_passes=*/true)
|
||||||
|
.ValueOrDie();
|
||||||
|
|
||||||
|
e.opts.num_replicas = devices.size();
|
||||||
|
e.opts.use_threads = true;
|
||||||
|
e.opts.arguments.push_back(&input_literal);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto run_executable = [&](int64 i) {
|
||||||
|
auto& e = executables[i];
|
||||||
|
TF_ASSERT_OK(
|
||||||
|
test_runner_
|
||||||
|
.ExecuteReplicated(e.executable.get(), e.opts, &e.device_assn)
|
||||||
|
.status());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Compiling executables above shouldn't cause us to open any channels.
|
||||||
|
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
|
||||||
|
|
||||||
|
// Run the executables and check that channels are opened as we expect.
|
||||||
|
run_executable(0);
|
||||||
|
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1));
|
||||||
|
|
||||||
|
run_executable(2);
|
||||||
|
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
|
||||||
|
|
||||||
|
run_executable(1);
|
||||||
|
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
|
||||||
|
|
||||||
|
// Tear down the executables and check that channels are closed as we expect.
|
||||||
|
// Note that after we tear down an executable *all* the nccl channels may go
|
||||||
|
// away, so we rerun all of the executables that haven't been torn down.
|
||||||
|
executables[2].executable.reset();
|
||||||
|
run_executable(0);
|
||||||
|
run_executable(1);
|
||||||
|
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(0, 1, 2));
|
||||||
|
|
||||||
|
executables[0].executable.reset();
|
||||||
|
run_executable(1);
|
||||||
|
EXPECT_THAT(OpenNcclChannels(), UnorderedElementsAre(1, 2));
|
||||||
|
|
||||||
|
executables[1].executable.reset();
|
||||||
|
EXPECT_THAT(OpenNcclChannels(), IsEmpty());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user