Fix a bug in algebraic simplification pass.

In the dot contracting dimensions reorder optimization, check if
the reshape op squishes consecutive dimensions, and bail out if it
is not the case.
This commit is contained in:
Bin Fan 2019-05-09 21:56:00 -07:00
parent 8a8a109e56
commit a4778b0138
2 changed files with 11 additions and 3 deletions

View File

@ -1581,6 +1581,15 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
lhs_contracting_dims.Add(i);
}
}
// We require the "unsquished" lhs contracting dims to be consecutive.
auto is_iota = [](absl::Span<const int64> dims) {
return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) {
return (b != a + 1);
}) == dims.end();
};
if (!is_iota(AsInt64Slice(lhs_contracting_dims))) {
return nullptr;
}
lhs = lhs->mutable_operand(0);
// Check that the transpose only permutes the contracting dims.
@ -1597,6 +1606,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
for (auto dim : lhs_contracting_dims) {
permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
}
CHECK(IsPermutation(permutation, permutation.size()));
auto new_lhs_contracting_dims =
ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
lhs_contracting_dims.Clear();

View File

@ -5372,10 +5372,8 @@ TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
}
// This test exposes a real bug: It tries to read an out-of-bounds array index
// from within ComposePermutations(). TODO(b/132330723): Fix this.
TEST_F(AlgebraicSimplifierTest,
DISABLED_DotContractingReorder_NoChangeInContractingDimsOrder) {
DotContractingReorder_NoChangeInContractingDimsOrder) {
// No optimization opportunity here because the transpose does not reorder the
// contracting dims.
const char* kModuleStr = R"(