Fix GetNcclAllReduceConfig to initialize replica count
PiperOrigin-RevId: 336103268 Change-Id: Iea4f10da88fc728fae8b2298eeec19201915bb5f
This commit is contained in:
parent
f319210657
commit
3e0aee0d54
@ -24,7 +24,8 @@ struct NcclAllReduceConfig::AuxData {};
|
||||
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default;
|
||||
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
|
||||
|
||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr) {
|
||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr,
|
||||
int64 replica_count) {
|
||||
NcclAllReduceConfig config = {};
|
||||
return config;
|
||||
}
|
||||
|
||||
@ -1652,7 +1652,8 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
||||
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
|
||||
tuple_element_buffers.push_back(buffers[i].destination_buffer);
|
||||
}
|
||||
NcclAllReduceConfig config = GetNcclAllReduceConfig(crs);
|
||||
NcclAllReduceConfig config =
|
||||
GetNcclAllReduceConfig(crs, hlo_module_config_.replica_count());
|
||||
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
|
||||
GetThunkInfo(crs), std::move(config),
|
||||
/*buffers=*/std::move(buffers));
|
||||
|
||||
@ -522,7 +522,8 @@ struct NcclAllReduceConfig::AuxData {
|
||||
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default;
|
||||
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
|
||||
|
||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) {
|
||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr,
|
||||
int64 replica_count) {
|
||||
NcclAllReduceConfig config;
|
||||
config.operand_count = instr->operands().size();
|
||||
config.operand_element_type.reserve(config.operand_count);
|
||||
@ -530,6 +531,7 @@ NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) {
|
||||
config.operand_element_type.push_back(
|
||||
instr->operand(i)->shape().element_type());
|
||||
}
|
||||
config.replica_count = replica_count;
|
||||
config.replica_groups = instr->replica_groups();
|
||||
auto reduction_kind = MatchReductionComputation(instr->to_apply());
|
||||
CHECK(reduction_kind.has_value());
|
||||
|
||||
@ -53,7 +53,8 @@ struct NcclAllReduceConfig {
|
||||
std::unique_ptr<AuxData> aux_data;
|
||||
};
|
||||
|
||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr);
|
||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr,
|
||||
int64 replica_count);
|
||||
|
||||
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
|
||||
class NcclAllReduceThunk : public Thunk {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user