[NFC] Eliminate reference to HLO Inst from NcclAllReduceThunk
- Create NcclAllReduceConfig to stash properties from HLO inst used by NcclAllReducethunk. - Also move some other properties from the think to the config object - Remove destructor for NccAllReduceThunk now that the unique_ptr<> to the opaque object is move to the config object. - Eliminate RendezvousKey::FromInstruction as its not used any more. PiperOrigin-RevId: 335656689 Change-Id: I73d3c021c0e366d11736f0884a43984be2984e44
This commit is contained in:
parent
8b5e015c51
commit
64a324460a
@ -82,22 +82,6 @@ struct RendezvousKey {
|
||||
collective_op_kind(collective_op_kind),
|
||||
op_id(op_id) {}
|
||||
|
||||
static RendezvousKey FromInstruction(
|
||||
const RunId& run_id, std::vector<GlobalDeviceId> global_devices,
|
||||
int num_local_participants, const HloInstruction* instr) {
|
||||
CollectiveOpKind collective_op_kind;
|
||||
int64 op_id;
|
||||
|
||||
std::tie(collective_op_kind, op_id) =
|
||||
instr->channel_id().has_value()
|
||||
? std::make_pair(kCrossModule, instr->channel_id().value())
|
||||
: std::make_pair(
|
||||
kCrossReplica,
|
||||
static_cast<int64>(instr->GetModule()->unique_id()));
|
||||
return RendezvousKey(run_id, std::move(global_devices),
|
||||
num_local_participants, collective_op_kind, op_id);
|
||||
}
|
||||
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, const RendezvousKey& k) {
|
||||
return H::combine(std::move(h), k.run_id, k.global_devices,
|
||||
|
@ -462,6 +462,7 @@ tf_cuda_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
] + if_cuda([
|
||||
"//tensorflow/stream_executor/cuda:cuda_activation",
|
||||
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
|
||||
|
@ -14,10 +14,21 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
struct NcclAllReduceConfig::AuxData {};
|
||||
|
||||
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default;
|
||||
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
|
||||
|
||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr) {
|
||||
NcclAllReduceConfig config = {};
|
||||
return config;
|
||||
}
|
||||
|
||||
/* static */ bool NcclAllReduceThunk::NcclIsEnabled() {
|
||||
return false; // Skylark selects this source file if NCCL is disabled.
|
||||
}
|
||||
@ -32,20 +43,16 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
"compiler, which is necessary to build the NCCL source library.");
|
||||
}
|
||||
|
||||
NcclAllReduceThunk::~NcclAllReduceThunk() = default;
|
||||
|
||||
/*static*/ absl::flat_hash_set<GlobalDeviceId>
|
||||
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
return {};
|
||||
}
|
||||
|
||||
struct NcclAllReduceThunk::AuxData {};
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
ThunkInfo thunk_info, int64 replica_count,
|
||||
ThunkInfo thunk_info, NcclAllReduceConfig &&config,
|
||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||
: Thunk(Thunk::kNcclAllReduce, thunk_info),
|
||||
replica_count_(replica_count),
|
||||
config_(std::move(config)),
|
||||
buffers_(std::move(buffers)) {}
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -1658,9 +1658,9 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
||||
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
|
||||
tuple_element_buffers.push_back(buffers[i].destination_buffer);
|
||||
}
|
||||
NcclAllReduceConfig config = GetNcclAllReduceConfig(crs);
|
||||
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
|
||||
GetThunkInfo(crs),
|
||||
/*replica_count=*/hlo_module_config_.replica_count(),
|
||||
GetThunkInfo(crs), std::move(config),
|
||||
/*buffers=*/std::move(buffers));
|
||||
if (crs->shape().IsTuple()) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
|
@ -514,11 +514,38 @@ void RendezvousNcclAllReduce::CleanupImpl(std::shared_ptr<NcclClique> handle,
|
||||
// header. In particular, this stores the thunk's cache of all NcclCliques it's
|
||||
// ever used. This causes those cliques to stay alive as long as the thunk
|
||||
// lives, which is how we avoid expensive reinitialization of NCCL cliques.
|
||||
struct NcclAllReduceThunk::AuxData {
|
||||
struct NcclAllReduceConfig::AuxData {
|
||||
tensorflow::mutex mu;
|
||||
absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques TF_GUARDED_BY(mu);
|
||||
};
|
||||
|
||||
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default;
|
||||
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
|
||||
|
||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) {
|
||||
NcclAllReduceConfig config;
|
||||
config.operand_count = instr->operands().size();
|
||||
config.operand_element_type.reserve(config.operand_count);
|
||||
for (int i = 0; i < config.operand_count; i++) {
|
||||
config.operand_element_type.push_back(
|
||||
instr->operand(i)->shape().element_type());
|
||||
}
|
||||
config.replica_groups = instr->replica_groups();
|
||||
auto reduction_kind = MatchReductionComputation(instr->to_apply());
|
||||
CHECK(reduction_kind.has_value());
|
||||
config.reduction_kind = reduction_kind.value();
|
||||
|
||||
if (instr->channel_id().has_value()) {
|
||||
config.collective_op_kind = RendezvousKey::kCrossModule;
|
||||
config.op_id = instr->channel_id().value();
|
||||
} else {
|
||||
config.collective_op_kind = RendezvousKey::kCrossReplica;
|
||||
config.op_id = static_cast<int64>(instr->GetModule()->unique_id());
|
||||
}
|
||||
config.aux_data = std::make_unique<NcclAllReduceConfig::AuxData>();
|
||||
return config;
|
||||
}
|
||||
|
||||
/*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) {
|
||||
auto operands_are_supported = [crs]() {
|
||||
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
|
||||
@ -541,14 +568,12 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||
}
|
||||
|
||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||
ThunkInfo thunk_info, int64 replica_count,
|
||||
ThunkInfo thunk_info, NcclAllReduceConfig&& config,
|
||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||
: Thunk(Thunk::kNcclAllReduce, thunk_info),
|
||||
hlo_instruction_(thunk_info.hlo_instruction),
|
||||
replica_count_(replica_count),
|
||||
buffers_(std::move(buffers)),
|
||||
aux_data_(absl::make_unique<AuxData>()) {
|
||||
CHECK_EQ(hlo_instruction_->operand_count(), buffers_.size());
|
||||
config_(std::move(config)),
|
||||
buffers_(std::move(buffers)) {
|
||||
CHECK_EQ(config_.operand_count, buffers_.size());
|
||||
}
|
||||
|
||||
// Figures out which devices (named by their replica-ids) are participating in
|
||||
@ -558,7 +583,6 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
auto op_profiler =
|
||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||
|
||||
auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction_);
|
||||
int64 local_device_ordinal = params.stream->parent()->device_ordinal();
|
||||
GlobalDeviceId global_device_id;
|
||||
if (params.gpu_global_device_ids) {
|
||||
@ -574,10 +598,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
// the same collective group as the caller.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<int64> global_participating_replicas,
|
||||
GetParticipatingReplicas(global_device_id, instr->replica_groups(),
|
||||
replica_count_, *params.device_assn));
|
||||
GetParticipatingReplicas(global_device_id, config_.replica_groups,
|
||||
config_.replica_count, *params.device_assn));
|
||||
if (IsGlobalNcclConfig() &&
|
||||
global_participating_replicas.size() != replica_count_) {
|
||||
global_participating_replicas.size() != config_.replica_count) {
|
||||
return InvalidArgument(
|
||||
"Partial replica groups are not allowed when using NCCL_COMM_ID "
|
||||
"environment configuration.");
|
||||
@ -605,10 +629,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
}
|
||||
absl::c_sort(global_devices);
|
||||
|
||||
// Find or create the rendezvous for this collective operation.
|
||||
RendezvousKey rendezvous_key = RendezvousKey::FromInstruction(
|
||||
params.run_id, global_devices, local_devices.size(), hlo_instruction_);
|
||||
|
||||
// Create the rendezvous for this collective operation.
|
||||
RendezvousKey rendezvous_key(params.run_id, global_devices,
|
||||
local_devices.size(), config_.collective_op_kind,
|
||||
config_.op_id);
|
||||
if (VLOG_IS_ON(2)) {
|
||||
std::vector<std::string> local_participants;
|
||||
for (const auto& entry : local_devices) {
|
||||
@ -633,15 +657,12 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer);
|
||||
pbuffer.destination_data =
|
||||
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer);
|
||||
pbuffer.primitive_type =
|
||||
hlo_instruction_->operand(i)->shape().element_type();
|
||||
pbuffer.primitive_type = config_.operand_element_type[i];
|
||||
participant.buffers.push_back(pbuffer);
|
||||
}
|
||||
participant.local_devices = std::move(local_devices);
|
||||
participant.nccl_unique_id_callback = params.nccl_unique_id_callback;
|
||||
auto reduction_kind = MatchReductionComputation(hlo_instruction_->to_apply());
|
||||
CHECK(reduction_kind.has_value());
|
||||
participant.reduction_kind = *reduction_kind;
|
||||
participant.reduction_kind = config_.reduction_kind;
|
||||
|
||||
auto rendezvous_factory = [](const RendezvousKey& k) {
|
||||
return absl::make_unique<RendezvousNcclAllReduce>(k);
|
||||
@ -658,13 +679,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
||||
// Keep the clique we used alive for as long as this Thunk lives. Creating
|
||||
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
|
||||
{
|
||||
tensorflow::mutex_lock lock(aux_data_->mu);
|
||||
aux_data_->cliques.insert(std::move(clique));
|
||||
tensorflow::mutex_lock lock(config_.aux_data->mu);
|
||||
config_.aux_data->cliques.insert(std::move(clique));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
NcclAllReduceThunk::~NcclAllReduceThunk() {}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -18,11 +18,13 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -30,6 +32,29 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
struct NcclAllReduceConfig {
|
||||
int64 operand_count;
|
||||
std::vector<PrimitiveType> operand_element_type;
|
||||
int64 replica_count;
|
||||
std::vector<ReplicaGroup> replica_groups;
|
||||
ReductionKind reduction_kind;
|
||||
RendezvousKey::CollectiveOpKind collective_op_kind;
|
||||
int64 op_id;
|
||||
|
||||
NcclAllReduceConfig() = default;
|
||||
NcclAllReduceConfig(NcclAllReduceConfig &&);
|
||||
~NcclAllReduceConfig();
|
||||
|
||||
// Extra data stored in NcclAllReduceThunk whose types we don't want exposed
|
||||
// in the header file. (This is mainly because the implementation of
|
||||
// NcclAllReduceThunk is different depending on whether CUDA is enabled in the
|
||||
// build, and we don't want to expose *that* mess in the header.)
|
||||
struct AuxData;
|
||||
std::unique_ptr<AuxData> aux_data;
|
||||
};
|
||||
|
||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr);
|
||||
|
||||
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
|
||||
class NcclAllReduceThunk : public Thunk {
|
||||
public:
|
||||
@ -56,9 +81,8 @@ class NcclAllReduceThunk : public Thunk {
|
||||
BufferAllocation::Slice source_buffer;
|
||||
BufferAllocation::Slice destination_buffer;
|
||||
};
|
||||
NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count,
|
||||
NcclAllReduceThunk(ThunkInfo thunk_info, NcclAllReduceConfig &&config,
|
||||
std::vector<Buffer> buffers);
|
||||
~NcclAllReduceThunk() override;
|
||||
|
||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||
|
||||
@ -67,16 +91,8 @@ class NcclAllReduceThunk : public Thunk {
|
||||
static bool CanImplement(const HloInstruction* crs);
|
||||
|
||||
private:
|
||||
// Extra data stored in NcclAllReduceThunk whose types we don't want exposed
|
||||
// in the header file. (This is mainly because the implementation of
|
||||
// NcclAllReduceThunk is different depending on whether CUDA is enabled in the
|
||||
// build, and we don't want to expose *that* mess in the header.)
|
||||
struct AuxData;
|
||||
|
||||
const HloInstruction* hlo_instruction_;
|
||||
const int64 replica_count_;
|
||||
NcclAllReduceConfig config_;
|
||||
const std::vector<Buffer> buffers_;
|
||||
std::unique_ptr<AuxData> aux_data_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
Loading…
x
Reference in New Issue
Block a user