diff --git a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h index ab1d40c1a51..98b3be654a2 100644 --- a/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h @@ -46,8 +46,8 @@ class NcclAllReduceThunk : public NcclCollectiveThunk { static bool IsDegenerate(mlir::lmhlo::AllReduceOp op, int64 replica_count, int64 partition_count) { - return GetNcclAllReduceConfig(op).config.IsDegenerate(replica_count, - partition_count); + return GetNcclCollectiveConfigForMlir(op, op.use_global_device_ids()) + .IsDegenerate(replica_count, partition_count); } static CollectiveOpGroupMode GetGroupMode(mlir::lmhlo::AllReduceOp op) { diff --git a/tensorflow/compiler/xla/service/gpu/tests/all_reduce.hlo b/tensorflow/compiler/xla/service/gpu/tests/all_reduce.hlo new file mode 100644 index 00000000000..dfadb9b83fb --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/all_reduce.hlo @@ -0,0 +1,24 @@ +// RUN: hlo_to_llvm_ir %s + +HloModule Test + +%fused_computation (param_0.5307: bf16[], param_1.5984: bf16[]) -> bf16[] { + %param_1.5984 = bf16[] parameter(1) + %convert.72239 = f32[] convert(bf16[] %param_1.5984) + %param_0.5307 = bf16[] parameter(0) + %convert.72238 = f32[] convert(bf16[] %param_0.5307) + %add.3846 = f32[] add(f32[] %convert.72239, f32[] %convert.72238), metadata={op_type="add" op_name="add"} + ROOT %convert.72237 = bf16[] convert(f32[] %add.3846) +} + +%all_reduce_computation (parameter.47449: bf16[], parameter.47450: bf16[]) -> bf16[] { + %parameter.47450 = bf16[] parameter(1), metadata={op_type="add" op_name="add"} + %parameter.47449 = bf16[] parameter(0), metadata={op_type="add" op_name="add"} + ROOT %fusion.1743 = bf16[] fusion(bf16[] %parameter.47450, bf16[] %parameter.47449), kind=kLoop, calls=%fused_computation +} + +ENTRY main { + input = bf16[8]{0} parameter(0) + ROOT crs = bf16[8]{0} all-reduce(input), replica_groups={{0}}, to_apply=%all_reduce_computation +} +