[NFC] Eliminate reference to HLO Inst from NcclAllReduceThunk

- Create NcclAllReduceConfig to stash properties from HLO inst used by NcclAllReducethunk.
- Also move some other properties from the think to the config object
- Remove destructor for NccAllReduceThunk now that the unique_ptr<> to the opaque
  object is move to the config object.
- Eliminate RendezvousKey::FromInstruction as its not used any more.

PiperOrigin-RevId: 335656689
Change-Id: I73d3c021c0e366d11736f0884a43984be2984e44
This commit is contained in:
Rahul Joshi 2020-10-06 09:30:56 -07:00 committed by TensorFlower Gardener
parent 8b5e015c51
commit 64a324460a
6 changed files with 86 additions and 59 deletions

View File

@ -82,22 +82,6 @@ struct RendezvousKey {
collective_op_kind(collective_op_kind),
op_id(op_id) {}
static RendezvousKey FromInstruction(
const RunId& run_id, std::vector<GlobalDeviceId> global_devices,
int num_local_participants, const HloInstruction* instr) {
CollectiveOpKind collective_op_kind;
int64 op_id;
std::tie(collective_op_kind, op_id) =
instr->channel_id().has_value()
? std::make_pair(kCrossModule, instr->channel_id().value())
: std::make_pair(
kCrossReplica,
static_cast<int64>(instr->GetModule()->unique_id()));
return RendezvousKey(run_id, std::move(global_devices),
num_local_participants, collective_op_kind, op_id);
}
template <typename H>
friend H AbslHashValue(H h, const RendezvousKey& k) {
return H::combine(std::move(h), k.run_id, k.global_devices,

View File

@ -462,6 +462,7 @@ tf_cuda_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/compiler/xla:xla_data_proto_cc",
] + if_cuda([
"//tensorflow/stream_executor/cuda:cuda_activation",
"//tensorflow/stream_executor/cuda:cuda_gpu_executor",

View File

@ -14,10 +14,21 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
namespace xla {
namespace gpu {
struct NcclAllReduceConfig::AuxData {};
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default;
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr) {
NcclAllReduceConfig config = {};
return config;
}
/* static */ bool NcclAllReduceThunk::NcclIsEnabled() {
return false; // Skylark selects this source file if NCCL is disabled.
}
@ -32,20 +43,16 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
"compiler, which is necessary to build the NCCL source library.");
}
NcclAllReduceThunk::~NcclAllReduceThunk() = default;
/*static*/ absl::flat_hash_set<GlobalDeviceId>
NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
return {};
}
struct NcclAllReduceThunk::AuxData {};
NcclAllReduceThunk::NcclAllReduceThunk(
ThunkInfo thunk_info, int64 replica_count,
ThunkInfo thunk_info, NcclAllReduceConfig &&config,
std::vector<NcclAllReduceThunk::Buffer> buffers)
: Thunk(Thunk::kNcclAllReduce, thunk_info),
replica_count_(replica_count),
config_(std::move(config)),
buffers_(std::move(buffers)) {}
} // namespace gpu

View File

@ -1658,9 +1658,9 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
tuple_element_buffers.push_back(buffers[i].destination_buffer);
}
NcclAllReduceConfig config = GetNcclAllReduceConfig(crs);
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
GetThunkInfo(crs),
/*replica_count=*/hlo_module_config_.replica_count(),
GetThunkInfo(crs), std::move(config),
/*buffers=*/std::move(buffers));
if (crs->shape().IsTuple()) {
std::vector<std::unique_ptr<Thunk>> thunks;

View File

@ -514,11 +514,38 @@ void RendezvousNcclAllReduce::CleanupImpl(std::shared_ptr<NcclClique> handle,
// header. In particular, this stores the thunk's cache of all NcclCliques it's
// ever used. This causes those cliques to stay alive as long as the thunk
// lives, which is how we avoid expensive reinitialization of NCCL cliques.
struct NcclAllReduceThunk::AuxData {
struct NcclAllReduceConfig::AuxData {
tensorflow::mutex mu;
absl::flat_hash_set<std::shared_ptr<NcclClique>> cliques TF_GUARDED_BY(mu);
};
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default;
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) {
NcclAllReduceConfig config;
config.operand_count = instr->operands().size();
config.operand_element_type.reserve(config.operand_count);
for (int i = 0; i < config.operand_count; i++) {
config.operand_element_type.push_back(
instr->operand(i)->shape().element_type());
}
config.replica_groups = instr->replica_groups();
auto reduction_kind = MatchReductionComputation(instr->to_apply());
CHECK(reduction_kind.has_value());
config.reduction_kind = reduction_kind.value();
if (instr->channel_id().has_value()) {
config.collective_op_kind = RendezvousKey::kCrossModule;
config.op_id = instr->channel_id().value();
} else {
config.collective_op_kind = RendezvousKey::kCrossReplica;
config.op_id = static_cast<int64>(instr->GetModule()->unique_id());
}
config.aux_data = std::make_unique<NcclAllReduceConfig::AuxData>();
return config;
}
/*static*/ bool NcclAllReduceThunk::CanImplement(const HloInstruction* crs) {
auto operands_are_supported = [crs]() {
return absl::c_all_of(crs->operands(), [](HloInstruction* operand) {
@ -541,14 +568,12 @@ NcclAllReduceThunk::DevicesWithOpenNcclChannels() {
}
NcclAllReduceThunk::NcclAllReduceThunk(
ThunkInfo thunk_info, int64 replica_count,
ThunkInfo thunk_info, NcclAllReduceConfig&& config,
std::vector<NcclAllReduceThunk::Buffer> buffers)
: Thunk(Thunk::kNcclAllReduce, thunk_info),
hlo_instruction_(thunk_info.hlo_instruction),
replica_count_(replica_count),
buffers_(std::move(buffers)),
aux_data_(absl::make_unique<AuxData>()) {
CHECK_EQ(hlo_instruction_->operand_count(), buffers_.size());
config_(std::move(config)),
buffers_(std::move(buffers)) {
CHECK_EQ(config_.operand_count, buffers_.size());
}
// Figures out which devices (named by their replica-ids) are participating in
@ -558,7 +583,6 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
auto op_profiler =
params.profiler->MakeScopedInstructionProfiler(profile_index());
auto* instr = Cast<HloAllReduceInstruction>(hlo_instruction_);
int64 local_device_ordinal = params.stream->parent()->device_ordinal();
GlobalDeviceId global_device_id;
if (params.gpu_global_device_ids) {
@ -574,10 +598,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
// the same collective group as the caller.
TF_ASSIGN_OR_RETURN(
std::vector<int64> global_participating_replicas,
GetParticipatingReplicas(global_device_id, instr->replica_groups(),
replica_count_, *params.device_assn));
GetParticipatingReplicas(global_device_id, config_.replica_groups,
config_.replica_count, *params.device_assn));
if (IsGlobalNcclConfig() &&
global_participating_replicas.size() != replica_count_) {
global_participating_replicas.size() != config_.replica_count) {
return InvalidArgument(
"Partial replica groups are not allowed when using NCCL_COMM_ID "
"environment configuration.");
@ -605,10 +629,10 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
}
absl::c_sort(global_devices);
// Find or create the rendezvous for this collective operation.
RendezvousKey rendezvous_key = RendezvousKey::FromInstruction(
params.run_id, global_devices, local_devices.size(), hlo_instruction_);
// Create the rendezvous for this collective operation.
RendezvousKey rendezvous_key(params.run_id, global_devices,
local_devices.size(), config_.collective_op_kind,
config_.op_id);
if (VLOG_IS_ON(2)) {
std::vector<std::string> local_participants;
for (const auto& entry : local_devices) {
@ -633,15 +657,12 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
params.buffer_allocations->GetDeviceAddress(buffer.source_buffer);
pbuffer.destination_data =
params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer);
pbuffer.primitive_type =
hlo_instruction_->operand(i)->shape().element_type();
pbuffer.primitive_type = config_.operand_element_type[i];
participant.buffers.push_back(pbuffer);
}
participant.local_devices = std::move(local_devices);
participant.nccl_unique_id_callback = params.nccl_unique_id_callback;
auto reduction_kind = MatchReductionComputation(hlo_instruction_->to_apply());
CHECK(reduction_kind.has_value());
participant.reduction_kind = *reduction_kind;
participant.reduction_kind = config_.reduction_kind;
auto rendezvous_factory = [](const RendezvousKey& k) {
return absl::make_unique<RendezvousNcclAllReduce>(k);
@ -658,13 +679,11 @@ Status NcclAllReduceThunk::ExecuteOnStream(const ExecuteParams& params) {
// Keep the clique we used alive for as long as this Thunk lives. Creating
// new NCCL cliques is expensive, and this is how we avoid thrashing them.
{
tensorflow::mutex_lock lock(aux_data_->mu);
aux_data_->cliques.insert(std::move(clique));
tensorflow::mutex_lock lock(config_.aux_data->mu);
config_.aux_data->cliques.insert(std::move(clique));
}
return Status::OK();
}
NcclAllReduceThunk::~NcclAllReduceThunk() {}
} // namespace gpu
} // namespace xla

View File

@ -18,11 +18,13 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/collective_ops_utils.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_executable_run_options.h"
#include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/platform/types.h"
@ -30,6 +32,29 @@ limitations under the License.
namespace xla {
namespace gpu {
struct NcclAllReduceConfig {
int64 operand_count;
std::vector<PrimitiveType> operand_element_type;
int64 replica_count;
std::vector<ReplicaGroup> replica_groups;
ReductionKind reduction_kind;
RendezvousKey::CollectiveOpKind collective_op_kind;
int64 op_id;
NcclAllReduceConfig() = default;
NcclAllReduceConfig(NcclAllReduceConfig &&);
~NcclAllReduceConfig();
// Extra data stored in NcclAllReduceThunk whose types we don't want exposed
// in the header file. (This is mainly because the implementation of
// NcclAllReduceThunk is different depending on whether CUDA is enabled in the
// build, and we don't want to expose *that* mess in the header.)
struct AuxData;
std::unique_ptr<AuxData> aux_data;
};
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr);
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
class NcclAllReduceThunk : public Thunk {
public:
@ -56,9 +81,8 @@ class NcclAllReduceThunk : public Thunk {
BufferAllocation::Slice source_buffer;
BufferAllocation::Slice destination_buffer;
};
NcclAllReduceThunk(ThunkInfo thunk_info, int64 replica_count,
NcclAllReduceThunk(ThunkInfo thunk_info, NcclAllReduceConfig &&config,
std::vector<Buffer> buffers);
~NcclAllReduceThunk() override;
Status ExecuteOnStream(const ExecuteParams& params) override;
@ -67,16 +91,8 @@ class NcclAllReduceThunk : public Thunk {
static bool CanImplement(const HloInstruction* crs);
private:
// Extra data stored in NcclAllReduceThunk whose types we don't want exposed
// in the header file. (This is mainly because the implementation of
// NcclAllReduceThunk is different depending on whether CUDA is enabled in the
// build, and we don't want to expose *that* mess in the header.)
struct AuxData;
const HloInstruction* hlo_instruction_;
const int64 replica_count_;
NcclAllReduceConfig config_;
const std::vector<Buffer> buffers_;
std::unique_ptr<AuxData> aux_data_;
};
} // namespace gpu