Make transpose layout assignement more clear since reasoning about the
composition of permutations is far more confusing. PiperOrigin-RevId: 159630894
This commit is contained in:
parent
5856f9ea6d
commit
4be287671a
tensorflow/compiler/xla/service
@ -758,10 +758,14 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
|
||||
|
||||
if (instruction->opcode() == HloOpcode::kTranspose) {
|
||||
// Pick the operand layout that makes the transpose a bitcast.
|
||||
std::vector<int64> perm =
|
||||
ComposePermutations(instruction->dimensions(),
|
||||
AsInt64Slice(output_layout.minor_to_major()));
|
||||
Layout operand_layout = LayoutUtil::MakeLayout(perm);
|
||||
int64 rank = ShapeUtil::Rank(instruction->shape());
|
||||
std::vector<int64> new_minor_to_major(rank);
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
int64 output_dim = output_layout.minor_to_major(i);
|
||||
int64 operand_dim = instruction->dimensions(output_dim);
|
||||
new_minor_to_major[i] = operand_dim;
|
||||
}
|
||||
Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
|
||||
TF_CHECK_OK(
|
||||
LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
|
||||
return MakeUnique<Layout>(operand_layout);
|
||||
@ -812,14 +816,16 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
|
||||
}
|
||||
|
||||
if (user->opcode() == HloOpcode::kTranspose) {
|
||||
// Pick the user layout that makes the reshape a bitcast.
|
||||
// To become a bitcast, the layouts need to satisfy
|
||||
// collapsing_order * output_layout = input_layout
|
||||
// so output_layout = inverse(collapsing_order) * input_layout
|
||||
std::vector<int64> perm =
|
||||
Permute(InversePermutation(user->dimensions()),
|
||||
AsInt64Slice(operand_layout.minor_to_major()));
|
||||
Layout user_layout = LayoutUtil::MakeLayout(perm);
|
||||
// Pick the user layout that makes the transpose a bitcast.
|
||||
int64 rank = ShapeUtil::Rank(user->shape());
|
||||
std::vector<int64> new_minor_to_major(rank);
|
||||
auto inverse_dimensions = InversePermutation(user->dimensions());
|
||||
for (int64 i = 0; i < rank; ++i) {
|
||||
int64 operand_dim = operand_layout.minor_to_major(i);
|
||||
int64 user_dim = inverse_dimensions[operand_dim];
|
||||
new_minor_to_major[i] = user_dim;
|
||||
}
|
||||
Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
|
||||
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
|
||||
return MakeUnique<Layout>(user_layout);
|
||||
}
|
||||
|
@ -552,6 +552,41 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
|
||||
ElementsAre(1, 0));
|
||||
}
|
||||
|
||||
// Test layout assignment of a transpose into a bitcast based on its operand.
|
||||
TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape input_shape_with_layout =
|
||||
ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
|
||||
auto param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
|
||||
auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
|
||||
auto module = CreateNewModule();
|
||||
HloComputation* computation =
|
||||
module->AddEntryComputation(builder.Build(transpose));
|
||||
ComputationLayout computation_layout(computation->ComputeProgramShape());
|
||||
AssignLayouts(module.get(), &computation_layout);
|
||||
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
|
||||
transpose->shape(), {2, 3, 0, 1}));
|
||||
}
|
||||
// Test layout assignment of a transpose into a bitcast based on its user.
|
||||
TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
|
||||
auto broadcast = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(input_shape, constant, {}));
|
||||
auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
|
||||
ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
|
||||
auto module = CreateNewModule();
|
||||
HloComputation* computation =
|
||||
module->AddEntryComputation(builder.Build(transpose));
|
||||
ComputationLayout computation_layout(computation->ComputeProgramShape());
|
||||
AssignLayouts(module.get(), &computation_layout);
|
||||
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
|
||||
transpose->shape(), {2, 3, 0, 1}));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user