Propagate the layout of reshapes from rank 1 depth first.
PiperOrigin-RevId: 241259236
This commit is contained in:
parent
7b0fc100ec
commit
9fd90ad4c7
@ -585,10 +585,10 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
||||
|
||||
// Constrain the output and the operand of the while instruction to match
|
||||
// the computations.
|
||||
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
|
||||
body_layout.result_shape(), instruction));
|
||||
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
|
||||
body_layout.result_shape(), instruction, 0));
|
||||
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
|
||||
body_layout.result_shape(), instruction));
|
||||
} else if (instruction->opcode() == HloOpcode::kConditional) {
|
||||
// Find the conditional branch with the most instructions and force all
|
||||
// other computations to match that layout. A potentially better decison
|
||||
@ -1227,7 +1227,8 @@ namespace {
|
||||
bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
|
||||
switch (hlo.opcode()) {
|
||||
case HloOpcode::kReshape:
|
||||
return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
|
||||
return hlo.operand(0)->shape().rank() == 1 ||
|
||||
std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
|
||||
case HloOpcode::kTranspose:
|
||||
return true;
|
||||
default:
|
||||
@ -1593,18 +1594,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
|
||||
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
|
||||
LayoutUtil::ClearLayout(instruction->mutable_shape());
|
||||
|
||||
// Create a copy of an operand if the operand instruction's layout does not
|
||||
// match the use constraint (OperandLayoutConstraint).
|
||||
for (int64 operand_no = 0; operand_no < instruction->operand_count();
|
||||
++operand_no) {
|
||||
const ShapeLayout* operand_layout =
|
||||
constraints.OperandLayout(instruction, operand_no);
|
||||
if (operand_layout != nullptr) {
|
||||
TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
|
||||
instruction, operand_no));
|
||||
}
|
||||
}
|
||||
|
||||
// Set the layouts of the array shapes this instruction defines as indicated
|
||||
// by the respective BufferLayoutConstraints. Any array shapes in the output
|
||||
// of the instruction which are not defined by the instruction (eg, array
|
||||
@ -1647,6 +1636,18 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
// Create a copy of an operand if the operand instruction's layout does not
|
||||
// match the use constraint (OperandLayoutConstraint).
|
||||
for (int64 operand_no = 0; operand_no < instruction->operand_count();
|
||||
++operand_no) {
|
||||
const ShapeLayout* operand_layout =
|
||||
constraints.OperandLayout(instruction, operand_no);
|
||||
if (operand_layout != nullptr) {
|
||||
TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
|
||||
instruction, operand_no));
|
||||
}
|
||||
}
|
||||
|
||||
// Fusion instructions require some layouts to be set on fused instructions
|
||||
// inside the fusion instruction.
|
||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||
|
@ -465,9 +465,9 @@ class LayoutAssignment : public HloModulePass {
|
||||
// Creates a copy of the given operand if the operand's layout does not match
|
||||
// the given layout. This copy replaces the use in the given instruction.
|
||||
// Tuple operands will be deep-copied.
|
||||
Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
|
||||
HloInstruction* instruction,
|
||||
int64 operand_no);
|
||||
virtual Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
|
||||
HloInstruction* instruction,
|
||||
int64 operand_no);
|
||||
|
||||
// Registers a copy instruction added by the layout assignment pass.
|
||||
void RegisterAddedCopy(HloInstruction* copy) {
|
||||
|
Loading…
Reference in New Issue
Block a user