[XLA:GPU] Fix all-reduce when it's degenerate and has arbitrary reduction computation.
- Do not call GetNcclAllReduceConfig() when checking if the all-reduce is degenerate, as that can fail if the reduction computation does not match. Use GetNcclCollectiveConfigForMlir instead. - This allows otherwise degenerate all-reduce with an arbitrary reduction computation to be compiled without failures (and mapped to just a copy). PiperOrigin-RevId: 359370850 Change-Id: I6d79c7f3dc75a438308676a22995c69105d748ad
This commit is contained in:
parent
e407b99940
commit
e30cc0bfb0
tensorflow/compiler/xla/service/gpu
@ -46,8 +46,8 @@ class NcclAllReduceThunk : public NcclCollectiveThunk {
|
||||
|
||||
static bool IsDegenerate(mlir::lmhlo::AllReduceOp op, int64 replica_count,
|
||||
int64 partition_count) {
|
||||
return GetNcclAllReduceConfig(op).config.IsDegenerate(replica_count,
|
||||
partition_count);
|
||||
return GetNcclCollectiveConfigForMlir(op, op.use_global_device_ids())
|
||||
.IsDegenerate(replica_count, partition_count);
|
||||
}
|
||||
|
||||
static CollectiveOpGroupMode GetGroupMode(mlir::lmhlo::AllReduceOp op) {
|
||||
|
24
tensorflow/compiler/xla/service/gpu/tests/all_reduce.hlo
Normal file
24
tensorflow/compiler/xla/service/gpu/tests/all_reduce.hlo
Normal file
@ -0,0 +1,24 @@
|
||||
// RUN: hlo_to_llvm_ir %s
|
||||
|
||||
HloModule Test
|
||||
|
||||
%fused_computation (param_0.5307: bf16[], param_1.5984: bf16[]) -> bf16[] {
|
||||
%param_1.5984 = bf16[] parameter(1)
|
||||
%convert.72239 = f32[] convert(bf16[] %param_1.5984)
|
||||
%param_0.5307 = bf16[] parameter(0)
|
||||
%convert.72238 = f32[] convert(bf16[] %param_0.5307)
|
||||
%add.3846 = f32[] add(f32[] %convert.72239, f32[] %convert.72238), metadata={op_type="add" op_name="add"}
|
||||
ROOT %convert.72237 = bf16[] convert(f32[] %add.3846)
|
||||
}
|
||||
|
||||
%all_reduce_computation (parameter.47449: bf16[], parameter.47450: bf16[]) -> bf16[] {
|
||||
%parameter.47450 = bf16[] parameter(1), metadata={op_type="add" op_name="add"}
|
||||
%parameter.47449 = bf16[] parameter(0), metadata={op_type="add" op_name="add"}
|
||||
ROOT %fusion.1743 = bf16[] fusion(bf16[] %parameter.47450, bf16[] %parameter.47449), kind=kLoop, calls=%fused_computation
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
input = bf16[8]{0} parameter(0)
|
||||
ROOT crs = bf16[8]{0} all-reduce(input), replica_groups={{0}}, to_apply=%all_reduce_computation
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user