[XLA:GPU] Fix all-reduce when it's degenerate and has arbitrary reduction computation.

- Do not call GetNcclAllReduceConfig() when checking if the all-reduce is degenerate,
  as that can fail if the reduction computation does not match. Use
  GetNcclCollectiveConfigForMlir instead.
- This allows otherwise degenerate all-reduce with an arbitrary reduction computation
  to be compiled without failures (and mapped to just a copy).

PiperOrigin-RevId: 359370850
Change-Id: I6d79c7f3dc75a438308676a22995c69105d748ad
This commit is contained in:
Rahul Joshi 2021-02-24 14:27:46 -08:00 committed by TensorFlower Gardener
parent e407b99940
commit e30cc0bfb0
2 changed files with 26 additions and 2 deletions
tensorflow/compiler/xla/service/gpu

View File

@ -46,8 +46,8 @@ class NcclAllReduceThunk : public NcclCollectiveThunk {
static bool IsDegenerate(mlir::lmhlo::AllReduceOp op, int64 replica_count,
int64 partition_count) {
return GetNcclAllReduceConfig(op).config.IsDegenerate(replica_count,
partition_count);
return GetNcclCollectiveConfigForMlir(op, op.use_global_device_ids())
.IsDegenerate(replica_count, partition_count);
}
static CollectiveOpGroupMode GetGroupMode(mlir::lmhlo::AllReduceOp op) {

View File

@ -0,0 +1,24 @@
// RUN: hlo_to_llvm_ir %s
HloModule Test
%fused_computation (param_0.5307: bf16[], param_1.5984: bf16[]) -> bf16[] {
%param_1.5984 = bf16[] parameter(1)
%convert.72239 = f32[] convert(bf16[] %param_1.5984)
%param_0.5307 = bf16[] parameter(0)
%convert.72238 = f32[] convert(bf16[] %param_0.5307)
%add.3846 = f32[] add(f32[] %convert.72239, f32[] %convert.72238), metadata={op_type="add" op_name="add"}
ROOT %convert.72237 = bf16[] convert(f32[] %add.3846)
}
%all_reduce_computation (parameter.47449: bf16[], parameter.47450: bf16[]) -> bf16[] {
%parameter.47450 = bf16[] parameter(1), metadata={op_type="add" op_name="add"}
%parameter.47449 = bf16[] parameter(0), metadata={op_type="add" op_name="add"}
ROOT %fusion.1743 = bf16[] fusion(bf16[] %parameter.47450, bf16[] %parameter.47449), kind=kLoop, calls=%fused_computation
}
ENTRY main {
input = bf16[8]{0} parameter(0)
ROOT crs = bf16[8]{0} all-reduce(input), replica_groups={{0}}, to_apply=%all_reduce_computation
}