Fix GetNcclAllReduceConfig to initialize replica count

PiperOrigin-RevId: 336103268
Change-Id: Iea4f10da88fc728fae8b2298eeec19201915bb5f
This commit is contained in:
Rahul Joshi 2020-10-08 09:38:58 -07:00 committed by TensorFlower Gardener
parent f319210657
commit 3e0aee0d54
4 changed files with 9 additions and 4 deletions

View File

@ -24,7 +24,8 @@ struct NcclAllReduceConfig::AuxData {};
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default; NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default;
NcclAllReduceConfig::~NcclAllReduceConfig() = default; NcclAllReduceConfig::~NcclAllReduceConfig() = default;
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr) { NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr,
int64 replica_count) {
NcclAllReduceConfig config = {}; NcclAllReduceConfig config = {};
return config; return config;
} }

View File

@ -1652,7 +1652,8 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({})); *crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
tuple_element_buffers.push_back(buffers[i].destination_buffer); tuple_element_buffers.push_back(buffers[i].destination_buffer);
} }
NcclAllReduceConfig config = GetNcclAllReduceConfig(crs); NcclAllReduceConfig config =
GetNcclAllReduceConfig(crs, hlo_module_config_.replica_count());
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>( auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
GetThunkInfo(crs), std::move(config), GetThunkInfo(crs), std::move(config),
/*buffers=*/std::move(buffers)); /*buffers=*/std::move(buffers));

View File

@ -522,7 +522,8 @@ struct NcclAllReduceConfig::AuxData {
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default; NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default;
NcclAllReduceConfig::~NcclAllReduceConfig() = default; NcclAllReduceConfig::~NcclAllReduceConfig() = default;
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) { NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr,
int64 replica_count) {
NcclAllReduceConfig config; NcclAllReduceConfig config;
config.operand_count = instr->operands().size(); config.operand_count = instr->operands().size();
config.operand_element_type.reserve(config.operand_count); config.operand_element_type.reserve(config.operand_count);
@ -530,6 +531,7 @@ NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) {
config.operand_element_type.push_back( config.operand_element_type.push_back(
instr->operand(i)->shape().element_type()); instr->operand(i)->shape().element_type());
} }
config.replica_count = replica_count;
config.replica_groups = instr->replica_groups(); config.replica_groups = instr->replica_groups();
auto reduction_kind = MatchReductionComputation(instr->to_apply()); auto reduction_kind = MatchReductionComputation(instr->to_apply());
CHECK(reduction_kind.has_value()); CHECK(reduction_kind.has_value());

View File

@ -53,7 +53,8 @@ struct NcclAllReduceConfig {
std::unique_ptr<AuxData> aux_data; std::unique_ptr<AuxData> aux_data;
}; };
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr); NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr,
int64 replica_count);
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas. // Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
class NcclAllReduceThunk : public Thunk { class NcclAllReduceThunk : public Thunk {