Fix GetNcclAllReduceConfig to initialize replica count

PiperOrigin-RevId: 336103268
Change-Id: Iea4f10da88fc728fae8b2298eeec19201915bb5f
This commit is contained in:
Rahul Joshi 2020-10-08 09:38:58 -07:00 committed by TensorFlower Gardener
parent f319210657
commit 3e0aee0d54
4 changed files with 9 additions and 4 deletions

View File

@ -24,7 +24,8 @@ struct NcclAllReduceConfig::AuxData {};
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default;
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr) {
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr,
int64 replica_count) {
NcclAllReduceConfig config = {};
return config;
}

View File

@ -1652,7 +1652,8 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
tuple_element_buffers.push_back(buffers[i].destination_buffer);
}
NcclAllReduceConfig config = GetNcclAllReduceConfig(crs);
NcclAllReduceConfig config =
GetNcclAllReduceConfig(crs, hlo_module_config_.replica_count());
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
GetThunkInfo(crs), std::move(config),
/*buffers=*/std::move(buffers));

View File

@ -522,7 +522,8 @@ struct NcclAllReduceConfig::AuxData {
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default;
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) {
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr,
int64 replica_count) {
NcclAllReduceConfig config;
config.operand_count = instr->operands().size();
config.operand_element_type.reserve(config.operand_count);
@ -530,6 +531,7 @@ NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) {
config.operand_element_type.push_back(
instr->operand(i)->shape().element_type());
}
config.replica_count = replica_count;
config.replica_groups = instr->replica_groups();
auto reduction_kind = MatchReductionComputation(instr->to_apply());
CHECK(reduction_kind.has_value());

View File

@ -53,7 +53,8 @@ struct NcclAllReduceConfig {
std::unique_ptr<AuxData> aux_data;
};
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr);
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr,
int64 replica_count);
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
class NcclAllReduceThunk : public Thunk {