Minor fix of bidirectional looped einsum when sharded along more than one dim.

PiperOrigin-RevId: 357311600
Change-Id: Ia449edeb63e16531cd2389f7d7e97764f74dedeb
This commit is contained in:
A. Unique TensorFlower 2021-02-12 19:55:25 -08:00 committed by TensorFlower Gardener
parent d38d28f3bf
commit 8a1c8335ed

View File

@ -418,10 +418,12 @@ StatusOr<HloInstruction*> PartitionBaseCase(
? &*lhs_sharding_transposed_to_match_output
: &*rhs_sharding_transposed_to_match_output;
}
CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()),
num_partitions);
int64 slice_sharding_dim = -1;
for (int64 i = 0; i < slice_sharding->tile_assignment().num_dimensions();
++i) {
if (slice_sharding->tile_assignment().dim(i) == num_partitions) {
if (slice_sharding->tile_assignment().dim(i) > 1) {
slice_sharding_dim = i;
break;
}