[XLA] Teach TransposeFolding to fold dots with batch dimensions
Check that the transpose doesn't touch any batch dims, then fold it away. PiperOrigin-RevId: 336096503 Change-Id: Id9811b787b48e0c665eef6396be16c7fd255043e
This commit is contained in:
parent
886229e977
commit
19db1548e8
@ -35,17 +35,46 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot(
|
||||
const HloInstruction& dot,
|
||||
const TransposeFolding::TransposableGemmOperandsFn&
|
||||
transposable_gemm_operands) {
|
||||
if (HloOpcode::kDot != dot.opcode() ||
|
||||
dot.dot_dimension_numbers().lhs_batch_dimensions_size() != 0) {
|
||||
if (HloOpcode::kDot != dot.opcode()) {
|
||||
return {};
|
||||
}
|
||||
|
||||
if (!absl::c_equal(dot.dot_dimension_numbers().lhs_batch_dimensions(),
|
||||
dot.dot_dimension_numbers().rhs_batch_dimensions())) {
|
||||
return {};
|
||||
}
|
||||
|
||||
int64 num_batch_dims =
|
||||
dot.dot_dimension_numbers().lhs_batch_dimensions_size();
|
||||
int64 expected_rank = 2 + num_batch_dims;
|
||||
auto is_r2_transpose = [&](const HloInstruction& transpose) {
|
||||
if (transpose.opcode() != HloOpcode::kTranspose) {
|
||||
return false;
|
||||
}
|
||||
const auto& transpose_dims = transpose.dimensions();
|
||||
if (transpose_dims.size() != expected_rank) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check that the transpose doesn't touch any batch dimensions, but does
|
||||
// transpose the non-batch ones.
|
||||
for (int64 i = 0; i != expected_rank; ++i) {
|
||||
bool is_batch = absl::c_linear_search(
|
||||
dot.dot_dimension_numbers().lhs_batch_dimensions(),
|
||||
transpose_dims[i]);
|
||||
if ((transpose_dims[i] == i) != is_batch) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
TransposeFolding::OperandIndices operand_set;
|
||||
for (int64 i = 0; i < dot.operand_count(); ++i) {
|
||||
auto& operand = *dot.operand(i);
|
||||
if (operand.IsRank2Transpose()) {
|
||||
if (is_r2_transpose(operand)) {
|
||||
operand_set.push_back(i);
|
||||
} else if (operand.shape().rank() != 2) {
|
||||
} else if (operand.shape().rank() != expected_rank) {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
@ -84,25 +113,25 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
|
||||
HloInstruction* new_lhs = dot->mutable_operand(0);
|
||||
HloInstruction* new_rhs = dot->mutable_operand(1);
|
||||
|
||||
CHECK_EQ(new_dim_numbers.lhs_batch_dimensions_size(), 0);
|
||||
CHECK_EQ(new_dim_numbers.rhs_batch_dimensions_size(), 0);
|
||||
CHECK_EQ(new_dim_numbers.lhs_contracting_dimensions_size(), 1);
|
||||
CHECK_EQ(new_dim_numbers.rhs_contracting_dimensions_size(), 1);
|
||||
|
||||
for (int64 operand_index : pair.second) {
|
||||
// We've checked that there aren't any batch dimensions and that the inputs
|
||||
// are rank 2, and shape inference guarantees that there is exactly one
|
||||
// contracting dimension.
|
||||
// We checked that the batch dimensions are not touched by the transpose,
|
||||
// and shape inference guarantees that there is exactly one contracting
|
||||
// dimension.
|
||||
if (operand_index == 0) {
|
||||
CHECK_EQ(new_lhs->opcode(), HloOpcode::kTranspose);
|
||||
new_dim_numbers.set_lhs_contracting_dimensions(
|
||||
0, 1 - new_dim_numbers.lhs_contracting_dimensions(0));
|
||||
0,
|
||||
new_lhs->dimensions(new_dim_numbers.lhs_contracting_dimensions(0)));
|
||||
new_lhs = new_lhs->mutable_operand(0);
|
||||
} else {
|
||||
CHECK_EQ(operand_index, 1);
|
||||
CHECK_EQ(new_rhs->opcode(), HloOpcode::kTranspose);
|
||||
new_dim_numbers.set_rhs_contracting_dimensions(
|
||||
0, 1 - new_dim_numbers.rhs_contracting_dimensions(0));
|
||||
0,
|
||||
new_rhs->dimensions(new_dim_numbers.rhs_contracting_dimensions(0)));
|
||||
new_rhs = new_rhs->mutable_operand(0);
|
||||
}
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ namespace {
|
||||
|
||||
class TransposeFoldingTest : public HloTestBase {
|
||||
protected:
|
||||
void FoldTranspose(HloModule* module) {
|
||||
bool FoldTranspose(HloModule* module) {
|
||||
TransposeFolding transpose_folding(
|
||||
[](const HloInstruction& dot,
|
||||
const TransposeFolding::OperandIndices& candidate_operands) {
|
||||
@ -52,7 +52,9 @@ class TransposeFoldingTest : public HloTestBase {
|
||||
const TransposeFolding::OperandIndices& candidate_operands) {
|
||||
return candidate_operands;
|
||||
});
|
||||
EXPECT_IS_OK(transpose_folding.Run(module).status());
|
||||
auto folded = transpose_folding.Run(module);
|
||||
EXPECT_IS_OK(folded.status());
|
||||
return *folded;
|
||||
}
|
||||
};
|
||||
|
||||
@ -465,5 +467,81 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
|
||||
new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
|
||||
}
|
||||
|
||||
TEST_F(TransposeFoldingTest, FoldBatchDotTranspose) {
|
||||
string hlo_string = R"(
|
||||
HloModule FoldBatchDotTranspose
|
||||
|
||||
ENTRY entry_computation {
|
||||
x = f32[7,7,2,3]{3,2,1,0} parameter(0)
|
||||
y = f32[7,7,2,3]{3,2,1,0} parameter(1)
|
||||
transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={0,1,3,2}
|
||||
ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3},
|
||||
rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
ASSERT_TRUE(FoldTranspose(module.get()));
|
||||
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Dot(op::Parameter(0), op::Parameter(1),
|
||||
/*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/3));
|
||||
}
|
||||
|
||||
TEST_F(TransposeFoldingTest, NoFoldBatchDotTransposeBatch) {
|
||||
string hlo_string = R"(
|
||||
HloModule NoFoldBatchDotTransposeBatch
|
||||
|
||||
ENTRY entry_computation {
|
||||
x = f32[7,7,2,3]{3,2,1,0} parameter(0)
|
||||
y = f32[7,7,2,3]{3,2,1,0} parameter(1)
|
||||
transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={1,0,3,2}
|
||||
ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3},
|
||||
rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
EXPECT_FALSE(FoldTranspose(module.get()));
|
||||
}
|
||||
|
||||
TEST_F(TransposeFoldingTest, FoldBatchDotTransposeNonContiguousBatch) {
|
||||
string hlo_string = R"(
|
||||
HloModule FoldBatchDotTransposeNonContiguousBatch
|
||||
|
||||
ENTRY entry_computation {
|
||||
x = f32[7,2,7,3]{3,2,1,0} parameter(0)
|
||||
y = f32[7,2,7,3]{3,2,1,0} parameter(1)
|
||||
transpose = f32[7,3,7,2]{3,2,1,0} transpose(y), dimensions={0,3,2,1}
|
||||
ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3},
|
||||
rhs_contracting_dims={1}, lhs_batch_dims={0,2}, rhs_batch_dims={0,2}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
ASSERT_TRUE(FoldTranspose(module.get()));
|
||||
|
||||
EXPECT_THAT(module->entry_computation()->root_instruction(),
|
||||
op::Dot(op::Parameter(0), op::Parameter(1),
|
||||
/*lhs_contracting_dim=*/3, /*rhs_contracting_dim=*/3));
|
||||
}
|
||||
|
||||
TEST_F(TransposeFoldingTest, NoFoldBatchDotTransposeIdentity) {
|
||||
string hlo_string = R"(
|
||||
HloModule NoFoldBatchDotTransposeIdentity
|
||||
|
||||
ENTRY entry_computation {
|
||||
x = f32[7,7,2,3]{3,2,1,0} parameter(0)
|
||||
y = f32[7,7,3,2]{3,2,1,0} parameter(1)
|
||||
transpose = f32[7,7,3,2]{3,2,1,0} transpose(y), dimensions={0,1,2,3}
|
||||
ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3},
|
||||
rhs_contracting_dims={2}, lhs_batch_dims={0,1}, rhs_batch_dims={0,1}
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
EXPECT_FALSE(FoldTranspose(module.get()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user