Propagate the layout of reshapes from rank 1 depth first.
PiperOrigin-RevId: 241259236
This commit is contained in:
parent
7b0fc100ec
commit
9fd90ad4c7
@ -585,10 +585,10 @@ Status LayoutAssignment::AddMandatoryConstraints(
|
|||||||
|
|
||||||
// Constrain the output and the operand of the while instruction to match
|
// Constrain the output and the operand of the while instruction to match
|
||||||
// the computations.
|
// the computations.
|
||||||
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
|
|
||||||
body_layout.result_shape(), instruction));
|
|
||||||
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
|
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
|
||||||
body_layout.result_shape(), instruction, 0));
|
body_layout.result_shape(), instruction, 0));
|
||||||
|
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
|
||||||
|
body_layout.result_shape(), instruction));
|
||||||
} else if (instruction->opcode() == HloOpcode::kConditional) {
|
} else if (instruction->opcode() == HloOpcode::kConditional) {
|
||||||
// Find the conditional branch with the most instructions and force all
|
// Find the conditional branch with the most instructions and force all
|
||||||
// other computations to match that layout. A potentially better decison
|
// other computations to match that layout. A potentially better decison
|
||||||
@ -1227,7 +1227,8 @@ namespace {
|
|||||||
bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
|
bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
|
||||||
switch (hlo.opcode()) {
|
switch (hlo.opcode()) {
|
||||||
case HloOpcode::kReshape:
|
case HloOpcode::kReshape:
|
||||||
return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
|
return hlo.operand(0)->shape().rank() == 1 ||
|
||||||
|
std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
|
||||||
case HloOpcode::kTranspose:
|
case HloOpcode::kTranspose:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
@ -1593,18 +1594,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
|
|||||||
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
|
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
|
||||||
LayoutUtil::ClearLayout(instruction->mutable_shape());
|
LayoutUtil::ClearLayout(instruction->mutable_shape());
|
||||||
|
|
||||||
// Create a copy of an operand if the operand instruction's layout does not
|
|
||||||
// match the use constraint (OperandLayoutConstraint).
|
|
||||||
for (int64 operand_no = 0; operand_no < instruction->operand_count();
|
|
||||||
++operand_no) {
|
|
||||||
const ShapeLayout* operand_layout =
|
|
||||||
constraints.OperandLayout(instruction, operand_no);
|
|
||||||
if (operand_layout != nullptr) {
|
|
||||||
TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
|
|
||||||
instruction, operand_no));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set the layouts of the array shapes this instruction defines as indicated
|
// Set the layouts of the array shapes this instruction defines as indicated
|
||||||
// by the respective BufferLayoutConstraints. Any array shapes in the output
|
// by the respective BufferLayoutConstraints. Any array shapes in the output
|
||||||
// of the instruction which are not defined by the instruction (eg, array
|
// of the instruction which are not defined by the instruction (eg, array
|
||||||
@ -1647,6 +1636,18 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}));
|
}));
|
||||||
|
|
||||||
|
// Create a copy of an operand if the operand instruction's layout does not
|
||||||
|
// match the use constraint (OperandLayoutConstraint).
|
||||||
|
for (int64 operand_no = 0; operand_no < instruction->operand_count();
|
||||||
|
++operand_no) {
|
||||||
|
const ShapeLayout* operand_layout =
|
||||||
|
constraints.OperandLayout(instruction, operand_no);
|
||||||
|
if (operand_layout != nullptr) {
|
||||||
|
TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
|
||||||
|
instruction, operand_no));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Fusion instructions require some layouts to be set on fused instructions
|
// Fusion instructions require some layouts to be set on fused instructions
|
||||||
// inside the fusion instruction.
|
// inside the fusion instruction.
|
||||||
if (instruction->opcode() == HloOpcode::kFusion) {
|
if (instruction->opcode() == HloOpcode::kFusion) {
|
||||||
|
@ -465,9 +465,9 @@ class LayoutAssignment : public HloModulePass {
|
|||||||
// Creates a copy of the given operand if the operand's layout does not match
|
// Creates a copy of the given operand if the operand's layout does not match
|
||||||
// the given layout. This copy replaces the use in the given instruction.
|
// the given layout. This copy replaces the use in the given instruction.
|
||||||
// Tuple operands will be deep-copied.
|
// Tuple operands will be deep-copied.
|
||||||
Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
|
virtual Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
|
||||||
HloInstruction* instruction,
|
HloInstruction* instruction,
|
||||||
int64 operand_no);
|
int64 operand_no);
|
||||||
|
|
||||||
// Registers a copy instruction added by the layout assignment pass.
|
// Registers a copy instruction added by the layout assignment pass.
|
||||||
void RegisterAddedCopy(HloInstruction* copy) {
|
void RegisterAddedCopy(HloInstruction* copy) {
|
||||||
|
Loading…
Reference in New Issue
Block a user