[XLA:SPMD] Fix all-gather transpose when there are more than one sharded dimensions.
PiperOrigin-RevId: 313465338 Change-Id: I9c8a2763dea5dbbf1c40e114c8b0b2f25aa9c941
This commit is contained in:
parent
4eae0941b7
commit
0ef1057c2d
@ -4605,8 +4605,8 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
|
||||
xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
|
||||
} else {
|
||||
xpose_permutation[i] = split_dims_added;
|
||||
xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
|
||||
split_dims_added++;
|
||||
xpose_permutation[i + 1] = i + tiled_dims.size();
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user