[NFC] Eliminate reference to HLO Inst from NcclAllReduceThunk
- Create NcclAllReduceConfig to stash properties from HLO inst used by NcclAllReducethunk. - Also move some other properties from the think to the config object - Remove destructor for NccAllReduceThunk now that the unique_ptr<> to the opaque object is move to the config object. - Eliminate RendezvousKey::FromInstruction as its not used any more. PiperOrigin-RevId: 335656689 Change-Id: I73d3c021c0e366d11736f0884a43984be2984e44
This commit is contained in:
parent
8b5e015c51
commit
64a324460a
@ -82,22 +82,6 @@ struct RendezvousKey {
|
|||||||
collective_op_kind(collective_op_kind),
|
collective_op_kind(collective_op_kind),
|
||||||
op_id(op_id) {}
|
op_id(op_id) {}
|
||||||
|
|
||||||
static RendezvousKey FromInstruction(
|
|
||||||
const RunId& run_id, std::vector<GlobalDeviceId> global_devices,
|
|
||||||
int num_local_participants, const HloInstruction* instr) {
|
|
||||||
CollectiveOpKind collective_op_kind;
|
|
||||||
int64 op_id;
|
|
||||||
|
|
||||||
std::tie(collective_op_kind, op_id) =
|
|
||||||
instr->channel_id().has_value()
|
|
||||||
? std::make_pair(kCrossModule, instr->channel_id().value())
|
|
||||||
: std::make_pair(
|
|
||||||
kCrossReplica,
|
|
||||||
static_cast<int64>(instr->GetModule()->unique_id()));
|
|
||||||
return RendezvousKey(run_id, std::move(global_devices),
|
|
||||||
num_local_participants, collective_op_kind, op_id);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename H>
|
template <typename H>
|
||||||
friend H AbslHashValue(H h, const RendezvousKey& k) {
|
friend H AbslHashValue(H h, const RendezvousKey& k) {
|
||||||
return H::combine(std::move(h), k.run_id, k.global_devices,
|
return H::combine(std::move(h), k.run_id, k.global_devices,
|
||||||
|
@ -462,6 +462,7 @@ tf_cuda_library(
|
|||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:stream_executor_no_cuda",
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
] + if_cuda([
|
] + if_cuda([
|
||||||
"//tensorflow/stream_executor/cuda:cuda_activation",
|
"//tensorflow/stream_executor/cuda:cuda_activation",
|
||||||
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
|
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",
|
||||||
|
@ -14,10 +14,21 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
|
struct NcclAllReduceConfig::AuxData {};
|
||||||
|
|
||||||
|
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default;
|
||||||
|
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
|
||||||
|
|
||||||
|
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr) {
|
||||||
|
NcclAllReduceConfig config = {};
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
/* static */ bool NcclAllReduceThunk::NcclIsEnabled() {
|
/* static */ bool NcclAllReduceThunk::NcclIsEnabled() {
|
||||||
return false; // Skylark selects this source file if NCCL is disabled.
|
return false; // Skylark selects this source file if NCCL is disabled.
|
||||||
}
|
}
|
||||||
@ -32,20 +43,16 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
"compiler, which is necessary to build the NCCL source library.");
|
"compiler, which is necessary to build the NCCL source library.");
|
||||||
}
|
}
|
||||||
|
|
||||||
NcclAllReduceThunk::~NcclAllReduceThunk() = default;
|
|
||||||
|
|
||||||
/*static*/ absl::flat_hash_set<GlobalDeviceId>
|
/*static*/ absl::flat_hash_set<GlobalDeviceId>
|
||||||
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NcclAllReduceThunk::AuxData {};
|
|
||||||
|
|
||||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||||
ThunkInfo thunk_info, int64 replica_count,
|
ThunkInfo thunk_info, NcclAllReduceConfig &&config,
|
||||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||||
: Thunk(Thunk::kNcclAllReduce, thunk_info),
|
: Thunk(Thunk::kNcclAllReduce, thunk_info),
|
||||||
replica_count_(replica_count),
|
config_(std::move(config)),
|
||||||
buffers_(std::move(buffers)) {}
|
buffers_(std::move(buffers)) {}
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
@ -1658,9 +1658,9 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
|||||||
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
|
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
|
||||||
tuple_element_buffers.push_back(buffers[i].destination_buffer);
|
tuple_element_buffers.push_back(buffers[i].destination_buffer);
|
||||||
}
|
}
|
||||||
|
NcclAllReduceConfig config = GetNcclAllReduceConfig(crs);
|
||||||
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
|
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
|
||||||
GetThunkInfo(crs),
|
GetThunkInfo(crs), std::move(config),
|
||||||
/*replica_count=*/hlo_module_config_.replica_count(),
|
|
||||||
/*buffers=*/std::move(buffers));
|
/*buffers=*/std::move(buffers));
|
||||||
if (crs->shape().IsTuple()) {
|
if (crs->shape().IsTuple()) {
|
||||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||||
|
@ -514,11 +514,38 @@ void RendezvousNcclAllReduce::CleanupImpl(std::shared_ptr<NcclClique> handle,
|
|||||||
// header. In particular, this stores the thunk's cache of all NcclCliques it's
|
// header. In particular, this stores the thunk's cache of all NcclCliques it's
|
||||||
// ever used. This causes those cliques to stay alive as long as the thunk
|
// ever used. This causes those cliques to stay alive as long as the thunk
|
||||||
// lives, which is how we avoid expensive reinitialization of NCCL cliques.
|
// lives, which is how we avoid expensive reinitialization of NCCL cliques.
|
||||||
struct NcclAllReduceThunk::AuxData {
|
struct NcclAllReduceConfig::AuxData {
|
||||||
tensorflow::mutex mu;
|
tensorflow::mutex mu;
|
||||||
absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques TF_GUARDED_BY(mu);
|
absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques TF_GUARDED_BY(mu);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default;
|
||||||
|
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
|
||||||
|
|
||||||
|
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) {
|
||||||
|
NcclAllReduceConfig config;
|
||||||
|
config.operand_count = instr->operands().size();
|
||||||
|
config.operand_element_type.reserve(config.operand_count);
|
||||||
|
for (int i = 0; i < config.operand_count; i++) {
|
||||||
|
config.operand_element_type.push_back(
|
||||||
|
instr->operand(i)->shape().element_type());
|
||||||
|
}
|
||||||
|
config.replica_groups = instr->replica_groups();
|
||||||
|
auto reduction_kind = MatchReductionComputation(instr->to_apply());
|
||||||
|
CHECK(reduction_kind.has_value());
|
||||||
|
config.reduction_kind = reduction_kind.value();
|
||||||
|
|
||||||
|
if (instr->channel_id().has_value()) {
|
||||||
|
config.collective_op_kind = RendezvousKey::kCrossModule;
|
||||||
|
config.op_id = instr->channel_id().value();
|
||||||
|
} else {
|
||||||
|
config.collective_op_kind = RendezvousKey::kCrossReplica;
|
||||||
|
config.op_id = static_cast<int64>(instr->GetModule()->unique_id());
|
||||||
|
}
|
||||||
|
config.aux_data = std::make_unique<NcclAllReduceConfig::AuxData>();
|
||||||
|
return config;
|
||||||
|
}
|
||||||
|
|
||||||
/*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) {
|
/*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) {
|
||||||
auto operands_are_supported = [crs]() {
|
auto operands_are_supported = [crs]() {
|
||||||
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
|
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
|
||||||
@ -541,14 +568,12 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
NcclAllReduceThunk::NcclAllReduceThunk(
|
NcclAllReduceThunk::NcclAllReduceThunk(
|
||||||
ThunkInfo thunk_info, int64 replica_count,
|
ThunkInfo thunk_info, NcclAllReduceConfig&& config,
|
||||||
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
std::vector<NcclAllReduceThunk::Buffer> buffers)
|
||||||
: Thunk(Thunk::kNcclAllReduce, thunk_info),
|
: Thunk(Thunk::kNcclAllReduce, thunk_info),
|
||||||
hlo_instruction_(thunk_info.hlo_instruction),
|
config_(std::move(config)),
|
||||||
replica_count_(replica_count),
|
buffers_(std::move(buffers)) {
|
||||||
buffers_(std::move(buffers)),
|
CHECK_EQ(config_.operand_count, buffers_.size());
|
||||||
aux_data_(absl::make_unique<AuxData>()) {
|
|
||||||
CHECK_EQ(hlo_instruction_->operand_count(), buffers_.size());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Figures out which devices (named by their replica-ids) are participating in
|
// Figures out which devices (named by their replica-ids) are participating in
|
||||||
@ -558,7 +583,6 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
auto op_profiler =
|
auto op_profiler =
|
||||||
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
params.profiler->MakeScopedInstructionProfiler(profile_index());
|
||||||
|
|
||||||
auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction_);
|
|
||||||
int64 local_device_ordinal = params.stream->parent()->device_ordinal();
|
int64 local_device_ordinal = params.stream->parent()->device_ordinal();
|
||||||
GlobalDeviceId global_device_id;
|
GlobalDeviceId global_device_id;
|
||||||
if (params.gpu_global_device_ids) {
|
if (params.gpu_global_device_ids) {
|
||||||
@ -574,10 +598,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
// the same collective group as the caller.
|
// the same collective group as the caller.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::vector<int64> global_participating_replicas,
|
std::vector<int64> global_participating_replicas,
|
||||||
GetParticipatingReplicas(global_device_id, instr->replica_groups(),
|
GetParticipatingReplicas(global_device_id, config_.replica_groups,
|
||||||
replica_count_, *params.device_assn));
|
config_.replica_count, *params.device_assn));
|
||||||
if (IsGlobalNcclConfig() &&
|
if (IsGlobalNcclConfig() &&
|
||||||
global_participating_replicas.size() != replica_count_) {
|
global_participating_replicas.size() != config_.replica_count) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"Partial replica groups are not allowed when using NCCL_COMM_ID "
|
"Partial replica groups are not allowed when using NCCL_COMM_ID "
|
||||||
"environment configuration.");
|
"environment configuration.");
|
||||||
@ -605,10 +629,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
}
|
}
|
||||||
absl::c_sort(global_devices);
|
absl::c_sort(global_devices);
|
||||||
|
|
||||||
// Find or create the rendezvous for this collective operation.
|
// Create the rendezvous for this collective operation.
|
||||||
RendezvousKey rendezvous_key = RendezvousKey::FromInstruction(
|
RendezvousKey rendezvous_key(params.run_id, global_devices,
|
||||||
params.run_id, global_devices, local_devices.size(), hlo_instruction_);
|
local_devices.size(), config_.collective_op_kind,
|
||||||
|
config_.op_id);
|
||||||
if (VLOG_IS_ON(2)) {
|
if (VLOG_IS_ON(2)) {
|
||||||
std::vector<std::string> local_participants;
|
std::vector<std::string> local_participants;
|
||||||
for (const auto& entry : local_devices) {
|
for (const auto& entry : local_devices) {
|
||||||
@ -633,15 +657,12 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer);
|
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer);
|
||||||
pbuffer.destination_data =
|
pbuffer.destination_data =
|
||||||
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer);
|
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer);
|
||||||
pbuffer.primitive_type =
|
pbuffer.primitive_type = config_.operand_element_type[i];
|
||||||
hlo_instruction_->operand(i)->shape().element_type();
|
|
||||||
participant.buffers.push_back(pbuffer);
|
participant.buffers.push_back(pbuffer);
|
||||||
}
|
}
|
||||||
participant.local_devices = std::move(local_devices);
|
participant.local_devices = std::move(local_devices);
|
||||||
participant.nccl_unique_id_callback = params.nccl_unique_id_callback;
|
participant.nccl_unique_id_callback = params.nccl_unique_id_callback;
|
||||||
auto reduction_kind = MatchReductionComputation(hlo_instruction_->to_apply());
|
participant.reduction_kind = config_.reduction_kind;
|
||||||
CHECK(reduction_kind.has_value());
|
|
||||||
participant.reduction_kind = *reduction_kind;
|
|
||||||
|
|
||||||
auto rendezvous_factory = [](const RendezvousKey& k) {
|
auto rendezvous_factory = [](const RendezvousKey& k) {
|
||||||
return absl::make_unique<RendezvousNcclAllReduce>(k);
|
return absl::make_unique<RendezvousNcclAllReduce>(k);
|
||||||
@ -658,13 +679,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
|
|||||||
// Keep the clique we used alive for as long as this Thunk lives. Creating
|
// Keep the clique we used alive for as long as this Thunk lives. Creating
|
||||||
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
|
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
|
||||||
{
|
{
|
||||||
tensorflow::mutex_lock lock(aux_data_->mu);
|
tensorflow::mutex_lock lock(config_.aux_data->mu);
|
||||||
aux_data_->cliques.insert(std::move(clique));
|
config_.aux_data->cliques.insert(std::move(clique));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
NcclAllReduceThunk::~NcclAllReduceThunk() {}
|
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -18,11 +18,13 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/mutex.h"
|
#include "tensorflow/core/platform/mutex.h"
|
||||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
@ -30,6 +32,29 @@ limitations under the License.
|
|||||||
namespace xla {
|
namespace xla {
|
||||||
namespace gpu {
|
namespace gpu {
|
||||||
|
|
||||||
|
struct NcclAllReduceConfig {
|
||||||
|
int64 operand_count;
|
||||||
|
std::vector<PrimitiveType> operand_element_type;
|
||||||
|
int64 replica_count;
|
||||||
|
std::vector<ReplicaGroup> replica_groups;
|
||||||
|
ReductionKind reduction_kind;
|
||||||
|
RendezvousKey::CollectiveOpKind collective_op_kind;
|
||||||
|
int64 op_id;
|
||||||
|
|
||||||
|
NcclAllReduceConfig() = default;
|
||||||
|
NcclAllReduceConfig(NcclAllReduceConfig &&);
|
||||||
|
~NcclAllReduceConfig();
|
||||||
|
|
||||||
|
// Extra data stored in NcclAllReduceThunk whose types we don't want exposed
|
||||||
|
// in the header file. (This is mainly because the implementation of
|
||||||
|
// NcclAllReduceThunk is different depending on whether CUDA is enabled in the
|
||||||
|
// build, and we don't want to expose *that* mess in the header.)
|
||||||
|
struct AuxData;
|
||||||
|
std::unique_ptr<AuxData> aux_data;
|
||||||
|
};
|
||||||
|
|
||||||
|
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr);
|
||||||
|
|
||||||
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
|
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
|
||||||
class NcclAllReduceThunk : public Thunk {
|
class NcclAllReduceThunk : public Thunk {
|
||||||
public:
|
public:
|
||||||
@ -56,9 +81,8 @@ class NcclAllReduceThunk : public Thunk {
|
|||||||
BufferAllocation::Slice source_buffer;
|
BufferAllocation::Slice source_buffer;
|
||||||
BufferAllocation::Slice destination_buffer;
|
BufferAllocation::Slice destination_buffer;
|
||||||
};
|
};
|
||||||
NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count,
|
NcclAllReduceThunk(ThunkInfo thunk_info, NcclAllReduceConfig &&config,
|
||||||
std::vector<Buffer> buffers);
|
std::vector<Buffer> buffers);
|
||||||
~NcclAllReduceThunk() override;
|
|
||||||
|
|
||||||
Status ExecuteOnStream(const ExecuteParams& params) override;
|
Status ExecuteOnStream(const ExecuteParams& params) override;
|
||||||
|
|
||||||
@ -67,16 +91,8 @@ class NcclAllReduceThunk : public Thunk {
|
|||||||
static bool CanImplement(const HloInstruction* crs);
|
static bool CanImplement(const HloInstruction* crs);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Extra data stored in NcclAllReduceThunk whose types we don't want exposed
|
NcclAllReduceConfig config_;
|
||||||
// in the header file. (This is mainly because the implementation of
|
|
||||||
// NcclAllReduceThunk is different depending on whether CUDA is enabled in the
|
|
||||||
// build, and we don't want to expose *that* mess in the header.)
|
|
||||||
struct AuxData;
|
|
||||||
|
|
||||||
const HloInstruction* hlo_instruction_;
|
|
||||||
const int64 replica_count_;
|
|
||||||
const std::vector<Buffer> buffers_;
|
const std::vector<Buffer> buffers_;
|
||||||
std::unique_ptr<AuxData> aux_data_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
|
Loading…
x
Reference in New Issue
Block a user