Fix a bug in algebraic simplification pass.
In the dot contracting dimensions reorder optimization, check if the reshape op squishes consecutive dimensions, and bail out if it is not the case.
This commit is contained in:
parent
8a8a109e56
commit
a4778b0138
@ -1581,6 +1581,15 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
|||||||
lhs_contracting_dims.Add(i);
|
lhs_contracting_dims.Add(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// We require the "unsquished" lhs contracting dims to be consecutive.
|
||||||
|
auto is_iota = [](absl::Span<const int64> dims) {
|
||||||
|
return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) {
|
||||||
|
return (b != a + 1);
|
||||||
|
}) == dims.end();
|
||||||
|
};
|
||||||
|
if (!is_iota(AsInt64Slice(lhs_contracting_dims))) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
lhs = lhs->mutable_operand(0);
|
lhs = lhs->mutable_operand(0);
|
||||||
|
|
||||||
// Check that the transpose only permutes the contracting dims.
|
// Check that the transpose only permutes the contracting dims.
|
||||||
@ -1597,6 +1606,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
|||||||
for (auto dim : lhs_contracting_dims) {
|
for (auto dim : lhs_contracting_dims) {
|
||||||
permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
|
permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
|
||||||
}
|
}
|
||||||
|
CHECK(IsPermutation(permutation, permutation.size()));
|
||||||
auto new_lhs_contracting_dims =
|
auto new_lhs_contracting_dims =
|
||||||
ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
|
ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
|
||||||
lhs_contracting_dims.Clear();
|
lhs_contracting_dims.Clear();
|
||||||
|
@ -5372,10 +5372,8 @@ TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
|
|||||||
EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
|
EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
|
||||||
}
|
}
|
||||||
|
|
||||||
// This test exposes a real bug: It tries to read an out-of-bounds array index
|
|
||||||
// from within ComposePermutations(). TODO(b/132330723): Fix this.
|
|
||||||
TEST_F(AlgebraicSimplifierTest,
|
TEST_F(AlgebraicSimplifierTest,
|
||||||
DISABLED_DotContractingReorder_NoChangeInContractingDimsOrder) {
|
DotContractingReorder_NoChangeInContractingDimsOrder) {
|
||||||
// No optimization opportunity here because the transpose does not reorder the
|
// No optimization opportunity here because the transpose does not reorder the
|
||||||
// contracting dims.
|
// contracting dims.
|
||||||
const char* kModuleStr = R"(
|
const char* kModuleStr = R"(
|
||||||
|
Loading…
Reference in New Issue
Block a user