Make linear layout more explicit.
PiperOrigin-RevId: 317093123 Change-Id: I7af437c2c8afb31683bb659b1939eac2ce851da5
This commit is contained in:
parent
2ad5792853
commit
4f5c65d449
@ -951,12 +951,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
|
||||
if (!Shape::Equal()
|
||||
.IgnoreDynamicDimension()
|
||||
.MinorToMajorOnlyInLayout()(instruction_subshape,
|
||||
buffer->shape()) &&
|
||||
// TODO(mingyao): Use explicit linear layout tiling to
|
||||
// detect and allow special bitcast.
|
||||
instruction->opcode() != HloOpcode::kBitcast &&
|
||||
instruction->opcode() != HloOpcode::kGetTupleElement &&
|
||||
instruction->opcode() != HloOpcode::kTuple) {
|
||||
buffer->shape())) {
|
||||
return InternalError(
|
||||
"Layout of instruction %s at index {%s} does not match "
|
||||
"source LogicalBuffer %s: %s vs %s",
|
||||
@ -1803,6 +1798,13 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
|
||||
// potential bugs in the layout assignment pass that may accidentally use the
|
||||
// existing layout.
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kBitcast) {
|
||||
// bitcasts are inherently layout sensitive and so a bitcast instruction
|
||||
// present in the IR before layout assignment is a bug.
|
||||
return InternalError(
|
||||
"Unexpected bitcast operation seen during layout assignment: %s.",
|
||||
instruction->ToString());
|
||||
}
|
||||
// Some instructions carry mandatory layouts in their shape.
|
||||
if (instruction->opcode() != HloOpcode::kInfeed &&
|
||||
!IsLayoutConstrainedCustomCall(instruction) &&
|
||||
|
@ -814,6 +814,27 @@ TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) {
|
||||
EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy);
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, InternalErrorOnBitcast) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR2WithLayout<float>(
|
||||
{{1.0, 2.0}, {3.0, 4.0}}, LayoutUtil::MakeLayout({0, 1}))));
|
||||
builder.AddInstruction(
|
||||
HloInstruction::CreateBitcast(constant0->shape(), constant0));
|
||||
auto m = CreateNewVerifiedModule();
|
||||
m->AddEntryComputation(builder.Build());
|
||||
|
||||
ComputationLayout computation_layout(
|
||||
m->entry_computation()->ComputeProgramShape());
|
||||
LayoutAssignment layout_assignment(&computation_layout);
|
||||
Status error_status = layout_assignment.Run(m.get()).status();
|
||||
EXPECT_FALSE(error_status.ok());
|
||||
EXPECT_THAT(
|
||||
error_status.error_message(),
|
||||
::testing::HasSubstr(
|
||||
"Unexpected bitcast operation seen during layout assignment"));
|
||||
}
|
||||
|
||||
TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
|
||||
// Pin non matching layouts to parameter and root.
|
||||
const char* module_str = R"(
|
||||
|
Loading…
x
Reference in New Issue
Block a user