diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 2441e64f3d0..c6626201e3f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -1584,6 +1584,15 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( lhs_contracting_dims.Add(i); } } + // We require the "unsquished" lhs contracting dims to be consecutive. + auto is_iota = [](absl::Span dims) { + return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) { + return (b != a + 1); + }) == dims.end(); + }; + if (!is_iota(AsInt64Slice(lhs_contracting_dims))) { + return nullptr; + } lhs = lhs->mutable_operand(0); // Check that the transpose only permutes the contracting dims. @@ -1600,6 +1609,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims( for (auto dim : lhs_contracting_dims) { permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]); } + CHECK(IsPermutation(permutation, permutation.size())); auto new_lhs_contracting_dims = ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation); lhs_contracting_dims.Clear(); diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index fee95ae7e44..e37b69c5cba 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -5372,10 +5372,8 @@ TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) { EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3)); } -// This test exposes a real bug: It tries to read an out-of-bounds array index -// from within ComposePermutations(). TODO(b/132330723): Fix this. TEST_F(AlgebraicSimplifierTest, - DISABLED_DotContractingReorder_NoChangeInContractingDimsOrder) { + DotContractingReorder_NoChangeInContractingDimsOrder) { // No optimization opportunity here because the transpose does not reorder the // contracting dims. const char* kModuleStr = R"(