diff --git a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc index 1ec38960c15..4cc19a23201 100644 --- a/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/dummy_all_reduce_thunk.cc @@ -24,7 +24,8 @@ struct NcclAllReduceConfig::AuxData {}; NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig &&) = default; NcclAllReduceConfig::~NcclAllReduceConfig() = default; -NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr) { +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr, + int64 replica_count) { NcclAllReduceConfig config = {}; return config; } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 51bee21df4e..d33474f83c2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1652,7 +1652,8 @@ Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) { *crs, crs->shape().IsTuple() ? ShapeIndex({i}) : ShapeIndex({})); tuple_element_buffers.push_back(buffers[i].destination_buffer); } - NcclAllReduceConfig config = GetNcclAllReduceConfig(crs); + NcclAllReduceConfig config = + GetNcclAllReduceConfig(crs, hlo_module_config_.replica_count()); auto all_reduce_thunk = absl::make_unique( GetThunkInfo(crs), std::move(config), /*buffers=*/std::move(buffers)); diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc index 8cdc465f84e..b13f71c5a13 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.cc @@ -522,7 +522,8 @@ struct NcclAllReduceConfig::AuxData { NcclAllReduceConfig::NcclAllReduceConfig(NcclAllReduceConfig&&) = default; NcclAllReduceConfig::~NcclAllReduceConfig() = default; -NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) { +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr, + int64 replica_count) { NcclAllReduceConfig config; config.operand_count = instr->operands().size(); config.operand_element_type.reserve(config.operand_count); @@ -530,6 +531,7 @@ NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction* instr) { config.operand_element_type.push_back( instr->operand(i)->shape().element_type()); } + config.replica_count = replica_count; config.replica_groups = instr->replica_groups(); auto reduction_kind = MatchReductionComputation(instr->to_apply()); CHECK(reduction_kind.has_value()); diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index 42060e82428..20e4adef7b1 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -53,7 +53,8 @@ struct NcclAllReduceConfig { std::unique_ptr aux_data; }; -NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr); +NcclAllReduceConfig GetNcclAllReduceConfig(const HloInstruction *instr, + int64 replica_count); // Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas. class NcclAllReduceThunk : public Thunk {