From 8a1c8335edbedb275aef564308577c80a9066ccf Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 12 Feb 2021 19:55:25 -0800 Subject: [PATCH] Minor fix of bidirectional looped einsum when sharded along more than one dim. PiperOrigin-RevId: 357311600 Change-Id: Ia449edeb63e16531cd2389f7d7e97764f74dedeb --- tensorflow/compiler/xla/service/spmd/dot_handler.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/spmd/dot_handler.cc b/tensorflow/compiler/xla/service/spmd/dot_handler.cc index 85ebc3118de..85c2719c222 100644 --- a/tensorflow/compiler/xla/service/spmd/dot_handler.cc +++ b/tensorflow/compiler/xla/service/spmd/dot_handler.cc @@ -418,10 +418,12 @@ StatusOr PartitionBaseCase( ? &*lhs_sharding_transposed_to_match_output : &*rhs_sharding_transposed_to_match_output; } + CHECK_EQ(Product(slice_sharding->tile_assignment().dimensions()), + num_partitions); int64 slice_sharding_dim = -1; for (int64 i = 0; i < slice_sharding->tile_assignment().num_dimensions(); ++i) { - if (slice_sharding->tile_assignment().dim(i) == num_partitions) { + if (slice_sharding->tile_assignment().dim(i) > 1) { slice_sharding_dim = i; break; }