Merge pull request #28583 from BinFan:dot-trans-bug
PiperOrigin-RevId: 247822664
This commit is contained in:
commit
7d27d511c1
@ -1584,6 +1584,15 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
||||
lhs_contracting_dims.Add(i);
|
||||
}
|
||||
}
|
||||
// We require the "unsquished" lhs contracting dims to be consecutive.
|
||||
auto is_iota = [](absl::Span<const int64> dims) {
|
||||
return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) {
|
||||
return (b != a + 1);
|
||||
}) == dims.end();
|
||||
};
|
||||
if (!is_iota(AsInt64Slice(lhs_contracting_dims))) {
|
||||
return nullptr;
|
||||
}
|
||||
lhs = lhs->mutable_operand(0);
|
||||
|
||||
// Check that the transpose only permutes the contracting dims.
|
||||
@ -1600,6 +1609,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
||||
for (auto dim : lhs_contracting_dims) {
|
||||
permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
|
||||
}
|
||||
CHECK(IsPermutation(permutation, permutation.size()));
|
||||
auto new_lhs_contracting_dims =
|
||||
ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
|
||||
lhs_contracting_dims.Clear();
|
||||
|
@ -5372,10 +5372,8 @@ TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
|
||||
EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
|
||||
}
|
||||
|
||||
// This test exposes a real bug: It tries to read an out-of-bounds array index
|
||||
// from within ComposePermutations(). TODO(b/132330723): Fix this.
|
||||
TEST_F(AlgebraicSimplifierTest,
|
||||
DISABLED_DotContractingReorder_NoChangeInContractingDimsOrder) {
|
||||
DotContractingReorder_NoChangeInContractingDimsOrder) {
|
||||
// No optimization opportunity here because the transpose does not reorder the
|
||||
// contracting dims.
|
||||
const char* kModuleStr = R"(
|
||||
|
Loading…
Reference in New Issue
Block a user