[XLA:SPMD] Fix all-gather transpose when there are more than one sharded dimensions.

PiperOrigin-RevId: 313465338
Change-Id: I9c8a2763dea5dbbf1c40e114c8b0b2f25aa9c941
This commit is contained in:
Yuanzhong Xu 2020-05-27 14:33:20 -07:00 committed by TensorFlower Gardener
parent 4eae0941b7
commit 0ef1057c2d

View File

@ -4605,8 +4605,8 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b,
xpose_permutation[i] = i + tiled_dims.size() - split_dims_added;
} else {
xpose_permutation[i] = split_dims_added;
xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added;
split_dims_added++;
xpose_permutation[i + 1] = i + tiled_dims.size();
i++;
}
}