Merge pull request #28583 from BinFan:dot-trans-bug
PiperOrigin-RevId: 247822664
This commit is contained in:
commit
7d27d511c1
@ -1584,6 +1584,15 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
|||||||
lhs_contracting_dims.Add(i);
|
lhs_contracting_dims.Add(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// We require the "unsquished" lhs contracting dims to be consecutive.
|
||||||
|
auto is_iota = [](absl::Span<const int64> dims) {
|
||||||
|
return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) {
|
||||||
|
return (b != a + 1);
|
||||||
|
}) == dims.end();
|
||||||
|
};
|
||||||
|
if (!is_iota(AsInt64Slice(lhs_contracting_dims))) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
lhs = lhs->mutable_operand(0);
|
lhs = lhs->mutable_operand(0);
|
||||||
|
|
||||||
// Check that the transpose only permutes the contracting dims.
|
// Check that the transpose only permutes the contracting dims.
|
||||||
@ -1600,6 +1609,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
|
|||||||
for (auto dim : lhs_contracting_dims) {
|
for (auto dim : lhs_contracting_dims) {
|
||||||
permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
|
permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
|
||||||
}
|
}
|
||||||
|
CHECK(IsPermutation(permutation, permutation.size()));
|
||||||
auto new_lhs_contracting_dims =
|
auto new_lhs_contracting_dims =
|
||||||
ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
|
ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
|
||||||
lhs_contracting_dims.Clear();
|
lhs_contracting_dims.Clear();
|
||||||
|
@ -5372,10 +5372,8 @@ TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
|
|||||||
EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
|
EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
|
||||||
}
|
}
|
||||||
|
|
||||||
// This test exposes a real bug: It tries to read an out-of-bounds array index
|
|
||||||
// from within ComposePermutations(). TODO(b/132330723): Fix this.
|
|
||||||
TEST_F(AlgebraicSimplifierTest,
|
TEST_F(AlgebraicSimplifierTest,
|
||||||
DISABLED_DotContractingReorder_NoChangeInContractingDimsOrder) {
|
DotContractingReorder_NoChangeInContractingDimsOrder) {
|
||||||
// No optimization opportunity here because the transpose does not reorder the
|
// No optimization opportunity here because the transpose does not reorder the
|
||||||
// contracting dims.
|
// contracting dims.
|
||||||
const char* kModuleStr = R"(
|
const char* kModuleStr = R"(
|
||||||
|
Loading…
Reference in New Issue
Block a user