Improve DotDecomposer to not add unnecessary non-contracting dimensions.

These would be removed by AlgebraicSimplifier, then DotDecomposer would add
them again, which makes the HloPassFix iterate until it hits the maximum number
of iterations.
Also consider operands of dots without non-contracting dimension to be
canonical.

PiperOrigin-RevId: 313174496
Change-Id: I8e8ac404198a9df01378820ad16834c9893336a5
This commit is contained in:
Adrian Kuegel 2020-05-26 05:46:06 -07:00 committed by TensorFlower Gardener
parent e3e0ba5781
commit 86ea222104
2 changed files with 87 additions and 8 deletions

View File

@ -31,7 +31,7 @@ namespace {
// Convert a dot into a canonical form where non-contracting and contracting
// dimensions are reshaped together and batch dimensions are the most major
// dimensions. The requires transposing and reshapes the lhs and rhs and
// dimensions. This requires transposing and reshapes of the lhs and rhs and
// reshaping the output batch to the original shape.
Status CanonicalizeDot(HloInstruction* original_dot) {
auto computation = original_dot->parent();
@ -80,7 +80,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
lhs_shape),
original_dot->mutable_operand(0), lhs_transpose));
std::vector<int64> lhs_reshape_dims = batch_dim_sizes;
lhs_reshape_dims.push_back(lhs_non_contracting_size);
if (lhs_non_contracting_size > 1) {
lhs_reshape_dims.push_back(lhs_non_contracting_size);
}
lhs_reshape_dims.push_back(lhs_contracting_size);
// Reshape the contracting and non-contracting dimensions together.
HloInstruction* reshaped_lhs =
@ -126,7 +128,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
std::vector<int64> rhs_reshape_dims = batch_dim_sizes;
rhs_reshape_dims.push_back(rhs_contracting_size);
rhs_reshape_dims.push_back(rhs_non_contracting_size);
if (rhs_non_contracting_size > 1) {
rhs_reshape_dims.push_back(rhs_non_contracting_size);
}
// Reshape the contracting and non-contracting dimensions together.
HloInstruction* reshaped_rhs =
computation->AddInstruction(HloInstruction::CreateReshape(
@ -134,15 +138,20 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
transposed_rhs));
std::vector<int64> dot_dims = batch_dim_sizes;
dot_dims.push_back(lhs_non_contracting_size);
dot_dims.push_back(rhs_non_contracting_size);
if (lhs_non_contracting_size > 1) {
dot_dims.push_back(lhs_non_contracting_size);
}
if (rhs_non_contracting_size > 1) {
dot_dims.push_back(rhs_non_contracting_size);
}
DotDimensionNumbers dot_dnums;
for (int64 i = 0; i < num_batch_dims; ++i) {
dot_dnums.add_lhs_batch_dimensions(i);
dot_dnums.add_rhs_batch_dimensions(i);
}
dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1);
dot_dnums.add_lhs_contracting_dimensions(
num_batch_dims + (lhs_non_contracting_size > 1 ? 1 : 0));
dot_dnums.add_rhs_contracting_dimensions(num_batch_dims);
HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot(
@ -174,9 +183,9 @@ StatusOr<bool> DotDecomposer::Run(HloModule* module) {
}
// A dot is not canonical if it has more than one non-contracting
// dimension.
if (dnums.lhs_batch_dimensions_size() + 2 !=
if (dnums.lhs_batch_dimensions_size() + 2 <
instruction->operand(0)->shape().rank() ||
dnums.rhs_batch_dimensions_size() + 2 !=
dnums.rhs_batch_dimensions_size() + 2 <
instruction->operand(1)->shape().rank()) {
non_canonical_dots.push_back(instruction);
continue;

View File

@ -50,5 +50,75 @@ TEST_F(DotDecomposerTest, CanonicalizeMultipleNonContractingDims) {
op::Shape("f32[4032,512]"))));
}
TEST_F(DotDecomposerTest, DontCanonicalizeIfNoNoncontractingDims) {
absl::string_view module_string = R"(
HloModule module
ENTRY main {
p0 = f32[64,4]{1,0} parameter(0)
p1 = f32[64,4]{1,0} parameter(1)
ROOT dot = f32[64]{0} dot(p0, p1), lhs_batch_dims={0},
lhs_contracting_dims={1},
rhs_batch_dims={0},
rhs_contracting_dims={1}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool canonicalized,
DotDecomposer().Run(module.get()));
EXPECT_FALSE(canonicalized);
}
TEST_F(DotDecomposerTest, DontAddLhsNonContractingDimIfOne) {
absl::string_view module_string = R"(
HloModule module
ENTRY main {
p0 = f32[64,4]{1,0} parameter(0)
p1 = f32[64,4,2,1]{3,2,1,0} parameter(1)
ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0},
lhs_contracting_dims={1},
rhs_batch_dims={0},
rhs_contracting_dims={1}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool canonicalized,
DotDecomposer().Run(module.get()));
EXPECT_TRUE(canonicalized);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(),
/*lhs_contracting_dim=*/1,
/*rhs_contracting_dim=*/1),
op::Shape("f32[64,2]"))));
}
TEST_F(DotDecomposerTest, DontAddRhsNonContractingDimIfOne) {
absl::string_view module_string = R"(
HloModule module
ENTRY main {
p0 = f32[64,4,2,1]{3,2,1,0} parameter(0)
p1 = f32[64,4]{1,0} parameter(1)
ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0},
lhs_contracting_dims={1},
rhs_batch_dims={0},
rhs_contracting_dims={1}
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool canonicalized,
DotDecomposer().Run(module.get()));
EXPECT_TRUE(canonicalized);
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(),
/*lhs_contracting_dim=*/2,
/*rhs_contracting_dim=*/1),
op::Shape("f32[64,2]"))));
}
} // namespace
} // namespace xla