Fix GetNcclAllReduceConfig to initialize replica count
PiperOrigin-RevId: 336103268 Change-Id: Iea4f10da88fc728fae8b2298eeec19201915bb5f
This commit is contained in:
parent
f319210657
commit
3e0aee0d54
@ -24,7 +24,8 @@ struct NcclAllReduceConfig::AuxData {};
|
|||||||
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default;
|
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default;
|
||||||
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
|
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
|
||||||
|
|
||||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr) {
|
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr,
|
||||||
|
int64 replica_count) {
|
||||||
NcclAllReduceConfig config = {};
|
NcclAllReduceConfig config = {};
|
||||||
return config;
|
return config;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1652,7 +1652,8 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
|
|||||||
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
|
*crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({}));
|
||||||
tuple_element_buffers.push_back(buffers[i].destination_buffer);
|
tuple_element_buffers.push_back(buffers[i].destination_buffer);
|
||||||
}
|
}
|
||||||
NcclAllReduceConfig config = GetNcclAllReduceConfig(crs);
|
NcclAllReduceConfig config =
|
||||||
|
GetNcclAllReduceConfig(crs, hlo_module_config_.replica_count());
|
||||||
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
|
auto all_reduce_thunk = absl::make_unique<NcclAllReduceThunk>(
|
||||||
GetThunkInfo(crs), std::move(config),
|
GetThunkInfo(crs), std::move(config),
|
||||||
/*buffers=*/std::move(buffers));
|
/*buffers=*/std::move(buffers));
|
||||||
|
|||||||
@ -522,7 +522,8 @@ struct NcclAllReduceConfig::AuxData {
|
|||||||
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default;
|
NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default;
|
||||||
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
|
NcclAllReduceConfig::~NcclAllReduceConfig() = default;
|
||||||
|
|
||||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) {
|
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr,
|
||||||
|
int64 replica_count) {
|
||||||
NcclAllReduceConfig config;
|
NcclAllReduceConfig config;
|
||||||
config.operand_count = instr->operands().size();
|
config.operand_count = instr->operands().size();
|
||||||
config.operand_element_type.reserve(config.operand_count);
|
config.operand_element_type.reserve(config.operand_count);
|
||||||
@ -530,6 +531,7 @@ NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) {
|
|||||||
config.operand_element_type.push_back(
|
config.operand_element_type.push_back(
|
||||||
instr->operand(i)->shape().element_type());
|
instr->operand(i)->shape().element_type());
|
||||||
}
|
}
|
||||||
|
config.replica_count = replica_count;
|
||||||
config.replica_groups = instr->replica_groups();
|
config.replica_groups = instr->replica_groups();
|
||||||
auto reduction_kind = MatchReductionComputation(instr->to_apply());
|
auto reduction_kind = MatchReductionComputation(instr->to_apply());
|
||||||
CHECK(reduction_kind.has_value());
|
CHECK(reduction_kind.has_value());
|
||||||
|
|||||||
@ -53,7 +53,8 @@ struct NcclAllReduceConfig {
|
|||||||
std::unique_ptr<AuxData> aux_data;
|
std::unique_ptr<AuxData> aux_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr);
|
NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr,
|
||||||
|
int64 replica_count);
|
||||||
|
|
||||||
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
|
// Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
|
||||||
class NcclAllReduceThunk : public Thunk {
|
class NcclAllReduceThunk : public Thunk {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user