From 0ef1057c2d0850f3380b90d64f1daccce82f0a7c Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Wed, 27 May 2020 14:33:20 -0700 Subject: [PATCH] [XLA:SPMD] Fix all-gather transpose when there are more than one sharded dimensions. PiperOrigin-RevId: 313465338 Change-Id: I9c8a2763dea5dbbf1c40e114c8b0b2f25aa9c941 --- tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc index eb0a9c330c3..068442ad5c7 100644 --- a/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc +++ b/tensorflow/compiler/xla/service/spmd/spmd_partitioner.cc @@ -4605,8 +4605,8 @@ HloInstruction* SpmdPartitioner::AllGatherShards(SpmdBuilder* b, xpose_permutation[i] = i + tiled_dims.size() - split_dims_added; } else { xpose_permutation[i] = split_dims_added; + xpose_permutation[i + 1] = i + tiled_dims.size() - split_dims_added; split_dims_added++; - xpose_permutation[i + 1] = i + tiled_dims.size(); i++; } }