From 4f5c65d4494b4e4831d016176d506227c011f01b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 18 Jun 2020 06:50:27 -0700 Subject: [PATCH] Make linear layout more explicit. PiperOrigin-RevId: 317093123 Change-Id: I7af437c2c8afb31683bb659b1939eac2ce851da5 --- .../compiler/xla/service/layout_assignment.cc | 14 +++++++------ .../xla/service/layout_assignment_test.cc | 21 +++++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 307fd82069e..a35ba140e86 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -951,12 +951,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { if (!Shape::Equal() .IgnoreDynamicDimension() .MinorToMajorOnlyInLayout()(instruction_subshape, - buffer->shape()) && - // TODO(mingyao): Use explicit linear layout tiling to - // detect and allow special bitcast. - instruction->opcode() != HloOpcode::kBitcast && - instruction->opcode() != HloOpcode::kGetTupleElement && - instruction->opcode() != HloOpcode::kTuple) { + buffer->shape())) { return InternalError( "Layout of instruction %s at index {%s} does not match " "source LogicalBuffer %s: %s vs %s", @@ -1803,6 +1798,13 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // potential bugs in the layout assignment pass that may accidentally use the // existing layout. for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kBitcast) { + // bitcasts are inherently layout sensitive and so a bitcast instruction + // present in the IR before layout assignment is a bug. + return InternalError( + "Unexpected bitcast operation seen during layout assignment: %s.", + instruction->ToString()); + } // Some instructions carry mandatory layouts in their shape. if (instruction->opcode() != HloOpcode::kInfeed && !IsLayoutConstrainedCustomCall(instruction) && diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 6e575247e6b..304a80c7a52 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -814,6 +814,27 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); } +TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) { + auto builder = HloComputation::Builder(TestName()); + auto constant0 = builder.AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1})))); + builder.AddInstruction( + HloInstruction::CreateBitcast(constant0->shape(), constant0)); + auto m = CreateNewVerifiedModule(); + m->AddEntryComputation(builder.Build()); + + ComputationLayout computation_layout( + m->entry_computation()->ComputeProgramShape()); + LayoutAssignment layout_assignment(&computation_layout); + Status error_status = layout_assignment.Run(m.get()).status(); + EXPECT_FALSE(error_status.ok()); + EXPECT_THAT( + error_status.error_message(), + ::testing::HasSubstr( + "Unexpected bitcast operation seen during layout assignment")); +} + TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) { // Pin non matching layouts to parameter and root. const char* module_str = R"(