diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 649ee116b4d..3283963a347 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1581,6 +1581,15 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( lhs_contracting_dims.Add(i); } } + // We require the "unsquished" lhs contracting dims to be consecutive. + auto is_iota = [](absl::Span dims) { + return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) { + return (b != a + 1); + }) == dims.end(); + }; + if (!is_iota(AsInt64Slice(lhs_contracting_dims))) { + return nullptr; + } lhs = lhs->mutable_operand(0); // Check that the transpose only permutes the contracting dims. @@ -1597,6 +1606,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( for (auto dim : lhs_contracting_dims) { permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]); } + CHECK(IsPermutation(permutation, permutation.size())); auto new_lhs_contracting_dims = ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation); lhs_contracting_dims.Clear(); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 70a6b0161cb..34e0eb6008a 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5372,10 +5372,8 @@ TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) { EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3)); } -// This test exposes a real bug: It tries to read an out-of-bounds array index -// from within ComposePermutations(). TODO(b/132330723): Fix this. TEST_F(AlgebraicSimplifierTest, - DISABLED_DotContractingReorder_NoChangeInContractingDimsOrder) { + DotContractingReorder_NoChangeInContractingDimsOrder) { // No optimization opportunity here because the transpose does not reorder the // contracting dims. const char* kModuleStr = R"(