Merge pull request #28583 from BinFan:dot-trans-bug

PiperOrigin-RevId: 247822664
This commit is contained in:
TensorFlower Gardener 2019-05-12 05:15:53 -07:00
commit 7d27d511c1
2 changed files with 11 additions and 3 deletions

View File

@ -1584,6 +1584,15 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
lhs_contracting_dims.Add(i); lhs_contracting_dims.Add(i);
} }
} }
// We require the "unsquished" lhs contracting dims to be consecutive.
auto is_iota = [](absl::Span<const int64> dims) {
return absl::c_adjacent_find(dims, [](const int64 a, const int64 b) {
return (b != a + 1);
}) == dims.end();
};
if (!is_iota(AsInt64Slice(lhs_contracting_dims))) {
return nullptr;
}
lhs = lhs->mutable_operand(0); lhs = lhs->mutable_operand(0);
// Check that the transpose only permutes the contracting dims. // Check that the transpose only permutes the contracting dims.
@ -1600,6 +1609,7 @@ AlgebraicSimplifierVisitor::OptimizeDotOfReorderContractingDims(
for (auto dim : lhs_contracting_dims) { for (auto dim : lhs_contracting_dims) {
permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]); permutation.push_back(transpose_dims[dim] - lhs_contracting_dims[0]);
} }
CHECK(IsPermutation(permutation, permutation.size()));
auto new_lhs_contracting_dims = auto new_lhs_contracting_dims =
ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation); ComposePermutations(AsInt64Slice(lhs_contracting_dims), permutation);
lhs_contracting_dims.Clear(); lhs_contracting_dims.Clear();

View File

@ -5372,10 +5372,8 @@ TEST_F(AlgebraicSimplifierTest, DotContractingReorder_SizeOneDims) {
EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3)); EXPECT_THAT(transpose->dimensions(), ElementsAre(0, 2, 1, 3));
} }
// This test exposes a real bug: It tries to read an out-of-bounds array index
// from within ComposePermutations(). TODO(b/132330723): Fix this.
TEST_F(AlgebraicSimplifierTest, TEST_F(AlgebraicSimplifierTest,
DISABLED_DotContractingReorder_NoChangeInContractingDimsOrder) { DotContractingReorder_NoChangeInContractingDimsOrder) {
// No optimization opportunity here because the transpose does not reorder the // No optimization opportunity here because the transpose does not reorder the
// contracting dims. // contracting dims.
const char* kModuleStr = R"( const char* kModuleStr = R"(