Make transpose layout assignement more clear since reasoning about the

composition of permutations is far more confusing.

PiperOrigin-RevId: 159630894
This commit is contained in:
Blake Hechtman 2017-06-20 16:21:28 -07:00 committed by TensorFlower Gardener
parent 5856f9ea6d
commit 4be287671a
2 changed files with 53 additions and 12 deletions
tensorflow/compiler/xla/service

View File

@ -758,10 +758,14 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
if (instruction->opcode() == HloOpcode::kTranspose) {
// Pick the operand layout that makes the transpose a bitcast.
std::vector<int64> perm =
ComposePermutations(instruction->dimensions(),
AsInt64Slice(output_layout.minor_to_major()));
Layout operand_layout = LayoutUtil::MakeLayout(perm);
int64 rank = ShapeUtil::Rank(instruction->shape());
std::vector<int64> new_minor_to_major(rank);
for (int64 i = 0; i < rank; ++i) {
int64 output_dim = output_layout.minor_to_major(i);
int64 operand_dim = instruction->dimensions(output_dim);
new_minor_to_major[i] = operand_dim;
}
Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major);
TF_CHECK_OK(
LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape()));
return MakeUnique<Layout>(operand_layout);
@ -812,14 +816,16 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOutputLayoutFromOperandLayout(
}
if (user->opcode() == HloOpcode::kTranspose) {
// Pick the user layout that makes the reshape a bitcast.
// To become a bitcast, the layouts need to satisfy
// collapsing_order * output_layout = input_layout
// so output_layout = inverse(collapsing_order) * input_layout
std::vector<int64> perm =
Permute(InversePermutation(user->dimensions()),
AsInt64Slice(operand_layout.minor_to_major()));
Layout user_layout = LayoutUtil::MakeLayout(perm);
// Pick the user layout that makes the transpose a bitcast.
int64 rank = ShapeUtil::Rank(user->shape());
std::vector<int64> new_minor_to_major(rank);
auto inverse_dimensions = InversePermutation(user->dimensions());
for (int64 i = 0; i < rank; ++i) {
int64 operand_dim = operand_layout.minor_to_major(i);
int64 user_dim = inverse_dimensions[operand_dim];
new_minor_to_major[i] = user_dim;
}
Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major);
TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape()));
return MakeUnique<Layout>(user_layout);
}

View File

@ -552,6 +552,41 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) {
ElementsAre(1, 0));
}
// Test layout assignment of a transpose into a bitcast based on its operand.
TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) {
auto builder = HloComputation::Builder(TestName());
Shape input_shape_with_layout =
ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1});
auto param = builder.AddInstruction(
HloInstruction::CreateParameter(0, input_shape_with_layout, "param"));
auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1}));
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
AssignLayouts(module.get(), &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
// Test layout assignment of a transpose into a bitcast based on its user.
TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) {
auto builder = HloComputation::Builder(TestName());
Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7});
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0f)));
auto broadcast = builder.AddInstruction(
HloInstruction::CreateBroadcast(input_shape, constant, {}));
auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1}));
auto module = CreateNewModule();
HloComputation* computation =
module->AddEntryComputation(builder.Build(transpose));
ComputationLayout computation_layout(computation->ComputeProgramShape());
AssignLayouts(module.get(), &computation_layout);
EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(),
transpose->shape(), {2, 3, 0, 1}));
}
} // namespace
} // namespace xla