[XLA] Teach TransposeFolding to fold dots with batch dimensions

Check that the transpose doesn't touch any batch dims, then fold it away.

PiperOrigin-RevId: 336096503
Change-Id: Id9811b787b48e0c665eef6396be16c7fd255043e
This commit is contained in:
Benjamin Kramer 2020-10-08 09:05:05 -07:00 committed by TensorFlower Gardener
parent 886229e977
commit 19db1548e8
2 changed files with 120 additions and 13 deletions

View File

@ -35,17 +35,46 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
const HloInstruction& dot,
const TransposeFolding::TransposableGemmOperandsFn&
transposable_gemm_operands) {
if (HloOpcode::kDot != dot.opcode() ||
dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) {
if (HloOpcode::kDot != dot.opcode()) {
return {};
}
if (!absl::c_equal(dot.dot_dimension_numbers().lhs_batch_dimensions(),
dot.dot_dimension_numbers().rhs_batch_dimensions())) {
return {};
}
int64 num_batch_dims =
dot.dot_dimension_numbers().lhs_batch_dimensions_size();
int64 expected_rank = 2 + num_batch_dims;
auto is_r2_transpose = [&](const HloInstruction& transpose) {
if (transpose.opcode() != HloOpcode::kTranspose) {
return false;
}
const auto& transpose_dims = transpose.dimensions();
if (transpose_dims.size() != expected_rank) {
return false;
}
// Check that the transpose doesn't touch any batch dimensions, but does
// transpose the non-batch ones.
for (int64 i = 0; i != expected_rank; ++i) {
bool is_batch = absl::c_linear_search(
dot.dot_dimension_numbers().lhs_batch_dimensions(),
transpose_dims[i]);
if ((transpose_dims[i] == i) != is_batch) {
return false;
}
}
return true;
};
TransposeFolding::OperandIndices operand_set;
for (int64 i = 0; i < dot.operand_count(); ++i) {
auto& operand = *dot.operand(i);
if (operand.IsRank2Transpose()) {
if (is_r2_transpose(operand)) {
operand_set.push_back(i);
} else if (operand.shape().rank() != 2) {
} else if (operand.shape().rank() != expected_rank) {
return {};
}
}
@ -84,25 +113,25 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
HloInstruction* new_lhs = dot->mutable_operand(0);
HloInstruction* new_rhs = dot->mutable_operand(1);
CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0);
CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0);
CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1);
CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1);
for (int64 operand_index : pair.second) {
// We've checked that there aren't any batch dimensions and that the inputs
// are rank 2, and shape inference guarantees that there is exactly one
// contracting dimension.
// We checked that the batch dimensions are not touched by the transpose,
// and shape inference guarantees that there is exactly one contracting
// dimension.
if (operand_index == 0) {
CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose);
new_dim_numbers.set_lhs_contracting_dimensions(
0, 1 - new_dim_numbers.lhs_contracting_dimensions(0));
0,
new_lhs->dimensions(new_dim_numbers.lhs_contracting_dimensions(0)));
new_lhs = new_lhs->mutable_operand(0);
} else {
CHECK_EQ(operand_index, 1);
CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose);
new_dim_numbers.set_rhs_contracting_dimensions(
0, 1 - new_dim_numbers.rhs_contracting_dimensions(0));
0,
new_rhs->dimensions(new_dim_numbers.rhs_contracting_dimensions(0)));
new_rhs = new_rhs->mutable_operand(0);
}
}

View File

@ -42,7 +42,7 @@ namespace {
class TransposeFoldingTest : public HloTestBase {
protected:
void FoldTranspose(HloModule* module) {
bool FoldTranspose(HloModule* module) {
TransposeFolding transpose_folding(
[](const HloInstruction& dot,
const TransposeFolding::OperandIndices& candidate_operands) {
@ -52,7 +52,9 @@ class TransposeFoldingTest : public HloTestBase {
const TransposeFolding::OperandIndices& candidate_operands) {
return candidate_operands;
});
EXPECT_IS_OK(transpose_folding.Run(module).status());
auto folded = transpose_folding.Run(module);
EXPECT_IS_OK(folded.status());
return *folded;
}
};
@ -465,5 +467,81 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
}
TEST_F(TransposeFoldingTest, FoldBatchDotTranspose) {
string hlo_string = R"(
HloModule FoldBatchDotTranspose
ENTRY entry_computation {
x = f32[7,7,2,3]{3,2,1,0} parameter(0)
y = f32[7,7,2,3]{3,2,1,0} parameter(1)
transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={0,1,3,2}
ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3},
rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
ASSERT_TRUE(FoldTranspose(module.get()));
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Dot(op::Parameter(0), op::Parameter(1),
/*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/3));
}
TEST_F(TransposeFoldingTest, NoFoldBatchDotTransposeBatch) {
string hlo_string = R"(
HloModule NoFoldBatchDotTransposeBatch
ENTRY entry_computation {
x = f32[7,7,2,3]{3,2,1,0} parameter(0)
y = f32[7,7,2,3]{3,2,1,0} parameter(1)
transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={1,0,3,2}
ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3},
rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
EXPECT_FALSE(FoldTranspose(module.get()));
}
TEST_F(TransposeFoldingTest, FoldBatchDotTransposeNonContiguousBatch) {
string hlo_string = R"(
HloModule FoldBatchDotTransposeNonContiguousBatch
ENTRY entry_computation {
x = f32[7,2,7,3]{3,2,1,0} parameter(0)
y = f32[7,2,7,3]{3,2,1,0} parameter(1)
transpose = f32[7,3,7,2]{3,2,1,0} transpose(y), dimensions={0,3,2,1}
ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3},
rhs_contracting_dims={1}, lhs_batch_dims={0,2}, rhs_batch_dims={0,2}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
ASSERT_TRUE(FoldTranspose(module.get()));
EXPECT_THAT(module->entry_computation()->root_instruction(),
op::Dot(op::Parameter(0), op::Parameter(1),
/*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/3));
}
TEST_F(TransposeFoldingTest, NoFoldBatchDotTransposeIdentity) {
string hlo_string = R"(
HloModule NoFoldBatchDotTransposeIdentity
ENTRY entry_computation {
x = f32[7,7,2,3]{3,2,1,0} parameter(0)
y = f32[7,7,3,2]{3,2,1,0} parameter(1)
transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={0,1,2,3}
ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3},
rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
EXPECT_FALSE(FoldTranspose(module.get()));
}
} // namespace
} // namespace xla