Improve DotDecomposer to not add unnecessary non-contracting dimensions.
These would be removed by AlgebraicSimplifier, then DotDecomposer would add them again, which makes the HloPassFix iterate until it hits the maximum number of iterations. Also consider operands of dots without non-contracting dimension to be canonical. PiperOrigin-RevId: 313174496 Change-Id: I8e8ac404198a9df01378820ad16834c9893336a5
This commit is contained in:
parent
e3e0ba5781
commit
86ea222104
@ -31,7 +31,7 @@ namespace {
|
||||
|
||||
// Convert a dot into a canonical form where non-contracting and contracting
|
||||
// dimensions are reshaped together and batch dimensions are the most major
|
||||
// dimensions. The requires transposing and reshapes the lhs and rhs and
|
||||
// dimensions. This requires transposing and reshapes of the lhs and rhs and
|
||||
// reshaping the output batch to the original shape.
|
||||
Status CanonicalizeDot(HloInstruction* original_dot) {
|
||||
auto computation = original_dot->parent();
|
||||
@ -80,7 +80,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
|
||||
lhs_shape),
|
||||
original_dot->mutable_operand(0), lhs_transpose));
|
||||
std::vector<int64> lhs_reshape_dims = batch_dim_sizes;
|
||||
lhs_reshape_dims.push_back(lhs_non_contracting_size);
|
||||
if (lhs_non_contracting_size > 1) {
|
||||
lhs_reshape_dims.push_back(lhs_non_contracting_size);
|
||||
}
|
||||
lhs_reshape_dims.push_back(lhs_contracting_size);
|
||||
// Reshape the contracting and non-contracting dimensions together.
|
||||
HloInstruction* reshaped_lhs =
|
||||
@ -126,7 +128,9 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
|
||||
|
||||
std::vector<int64> rhs_reshape_dims = batch_dim_sizes;
|
||||
rhs_reshape_dims.push_back(rhs_contracting_size);
|
||||
rhs_reshape_dims.push_back(rhs_non_contracting_size);
|
||||
if (rhs_non_contracting_size > 1) {
|
||||
rhs_reshape_dims.push_back(rhs_non_contracting_size);
|
||||
}
|
||||
// Reshape the contracting and non-contracting dimensions together.
|
||||
HloInstruction* reshaped_rhs =
|
||||
computation->AddInstruction(HloInstruction::CreateReshape(
|
||||
@ -134,15 +138,20 @@ Status CanonicalizeDot(HloInstruction* original_dot) {
|
||||
transposed_rhs));
|
||||
|
||||
std::vector<int64> dot_dims = batch_dim_sizes;
|
||||
dot_dims.push_back(lhs_non_contracting_size);
|
||||
dot_dims.push_back(rhs_non_contracting_size);
|
||||
if (lhs_non_contracting_size > 1) {
|
||||
dot_dims.push_back(lhs_non_contracting_size);
|
||||
}
|
||||
if (rhs_non_contracting_size > 1) {
|
||||
dot_dims.push_back(rhs_non_contracting_size);
|
||||
}
|
||||
|
||||
DotDimensionNumbers dot_dnums;
|
||||
for (int64 i = 0; i < num_batch_dims; ++i) {
|
||||
dot_dnums.add_lhs_batch_dimensions(i);
|
||||
dot_dnums.add_rhs_batch_dimensions(i);
|
||||
}
|
||||
dot_dnums.add_lhs_contracting_dimensions(num_batch_dims + 1);
|
||||
dot_dnums.add_lhs_contracting_dimensions(
|
||||
num_batch_dims + (lhs_non_contracting_size > 1 ? 1 : 0));
|
||||
dot_dnums.add_rhs_contracting_dimensions(num_batch_dims);
|
||||
|
||||
HloInstruction* dot = computation->AddInstruction(HloInstruction::CreateDot(
|
||||
@ -174,9 +183,9 @@ StatusOr<bool> DotDecomposer::Run(HloModule* module) {
|
||||
}
|
||||
// A dot is not canonical if it has more than one non-contracting
|
||||
// dimension.
|
||||
if (dnums.lhs_batch_dimensions_size() + 2 !=
|
||||
if (dnums.lhs_batch_dimensions_size() + 2 <
|
||||
instruction->operand(0)->shape().rank() ||
|
||||
dnums.rhs_batch_dimensions_size() + 2 !=
|
||||
dnums.rhs_batch_dimensions_size() + 2 <
|
||||
instruction->operand(1)->shape().rank()) {
|
||||
non_canonical_dots.push_back(instruction);
|
||||
continue;
|
||||
|
||||
@ -50,5 +50,75 @@ TEST_F(DotDecomposerTest, CanonicalizeMultipleNonContractingDims) {
|
||||
op::Shape("f32[4032,512]"))));
|
||||
}
|
||||
|
||||
TEST_F(DotDecomposerTest, DontCanonicalizeIfNoNoncontractingDims) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = f32[64,4]{1,0} parameter(0)
|
||||
p1 = f32[64,4]{1,0} parameter(1)
|
||||
ROOT dot = f32[64]{0} dot(p0, p1), lhs_batch_dims={0},
|
||||
lhs_contracting_dims={1},
|
||||
rhs_batch_dims={0},
|
||||
rhs_contracting_dims={1}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool canonicalized,
|
||||
DotDecomposer().Run(module.get()));
|
||||
EXPECT_FALSE(canonicalized);
|
||||
}
|
||||
|
||||
TEST_F(DotDecomposerTest, DontAddLhsNonContractingDimIfOne) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = f32[64,4]{1,0} parameter(0)
|
||||
p1 = f32[64,4,2,1]{3,2,1,0} parameter(1)
|
||||
ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0},
|
||||
lhs_contracting_dims={1},
|
||||
rhs_batch_dims={0},
|
||||
rhs_contracting_dims={1}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool canonicalized,
|
||||
DotDecomposer().Run(module.get()));
|
||||
EXPECT_TRUE(canonicalized);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(),
|
||||
/*lhs_contracting_dim=*/1,
|
||||
/*rhs_contracting_dim=*/1),
|
||||
op::Shape("f32[64,2]"))));
|
||||
}
|
||||
|
||||
TEST_F(DotDecomposerTest, DontAddRhsNonContractingDimIfOne) {
|
||||
absl::string_view module_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY main {
|
||||
p0 = f32[64,4,2,1]{3,2,1,0} parameter(0)
|
||||
p1 = f32[64,4]{1,0} parameter(1)
|
||||
ROOT dot = f32[64,2,1]{2,1,0} dot(p0, p1), lhs_batch_dims={0},
|
||||
lhs_contracting_dims={1},
|
||||
rhs_batch_dims={0},
|
||||
rhs_contracting_dims={1}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(module_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool canonicalized,
|
||||
DotDecomposer().Run(module.get()));
|
||||
EXPECT_TRUE(canonicalized);
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Reshape(AllOf(op::Dot(op::Reshape(), op::Reshape(),
|
||||
/*lhs_contracting_dim=*/2,
|
||||
/*rhs_contracting_dim=*/1),
|
||||
op::Shape("f32[64,2]"))));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user