Make linear layout more explicit.

PiperOrigin-RevId: 317093123
Change-Id: I7af437c2c8afb31683bb659b1939eac2ce851da5
This commit is contained in:
A. Unique TensorFlower 2020-06-18 06:50:27 -07:00 committed by TensorFlower Gardener
parent 2ad5792853
commit 4f5c65d449
2 changed files with 29 additions and 6 deletions

View File

@ -951,12 +951,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
if (!Shape::Equal()
.IgnoreDynamicDimension()
.MinorToMajorOnlyInLayout()(instruction_subshape,
buffer->shape()) &&
// TODO(mingyao): Use explicit linear layout tiling to
// detect and allow special bitcast.
instruction->opcode() != HloOpcode::kBitcast &&
instruction->opcode() != HloOpcode::kGetTupleElement &&
instruction->opcode() != HloOpcode::kTuple) {
buffer->shape())) {
return InternalError(
"Layout of instruction %s at index {%s} does not match "
"source LogicalBuffer %s: %s vs %s",
@ -1803,6 +1798,13 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
// potential bugs in the layout assignment pass that may accidentally use the
// existing layout.
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kBitcast) {
// bitcasts are inherently layout sensitive and so a bitcast instruction
// present in the IR before layout assignment is a bug.
return InternalError(
"Unexpected bitcast operation seen during layout assignment: %s.",
instruction->ToString());
}
// Some instructions carry mandatory layouts in their shape.
if (instruction->opcode() != HloOpcode::kInfeed &&
!IsLayoutConstrainedCustomCall(instruction) &&

View File

@ -814,6 +814,27 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
}
TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
auto builder = HloComputation::Builder(TestName());
auto constant0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
builder.AddInstruction(
HloInstruction::CreateBitcast(constant0->shape(), constant0));
auto m = CreateNewVerifiedModule();
m->AddEntryComputation(builder.Build());
ComputationLayout computation_layout(
m->entry_computation()->ComputeProgramShape());
LayoutAssignment layout_assignment(&computation_layout);
Status error_status = layout_assignment.Run(m.get()).status();
EXPECT_FALSE(error_status.ok());
EXPECT_THAT(
error_status.error_message(),
::testing::HasSubstr(
"Unexpected bitcast operation seen during layout assignment"));
}
TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
// Pin non matching layouts to parameter and root.
const char* module_str = R"(