Minor fix of bidirectional looped einsum when sharded along more than one dim.
PiperOrigin-RevId: 357311600 Change-Id: Ia449edeb63e16531cd2389f7d7e97764f74dedeb
This commit is contained in:
parent
d38d28f3bf
commit
8a1c8335ed
@ -418,10 +418,12 @@ StatusOr<HloInstruction*> PartitionBaseCase(
|
||||
? &*lhs_sharding_transposed_to_match_output
|
||||
: &*rhs_sharding_transposed_to_match_output;
|
||||
}
|
||||
CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()),
|
||||
num_partitions);
|
||||
int64 slice_sharding_dim = -1;
|
||||
for (int64 i = 0; i < slice_sharding->tile_assignment().num_dimensions();
|
||||
++i) {
|
||||
if (slice_sharding->tile_assignment().dim(i) == num_partitions) {
|
||||
if (slice_sharding->tile_assignment().dim(i) > 1) {
|
||||
slice_sharding_dim = i;
|
||||
break;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user