Propagate the layout of reshapes from rank 1 depth first.

PiperOrigin-RevId: 241259236
This commit is contained in:
Blake Hechtman 2019-03-31 22:00:22 -07:00 committed by TensorFlower Gardener
parent 7b0fc100ec
commit 9fd90ad4c7
2 changed files with 19 additions and 18 deletions

View File

@ -585,10 +585,10 @@ Status LayoutAssignment::AddMandatoryConstraints(
// Constrain the output and the operand of the while instruction to match // Constrain the output and the operand of the while instruction to match
// the computations. // the computations.
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
body_layout.result_shape(), instruction));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout( TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
body_layout.result_shape(), instruction, 0)); body_layout.result_shape(), instruction, 0));
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
body_layout.result_shape(), instruction));
} else if (instruction->opcode() == HloOpcode::kConditional) { } else if (instruction->opcode() == HloOpcode::kConditional) {
// Find the conditional branch with the most instructions and force all // Find the conditional branch with the most instructions and force all
// other computations to match that layout. A potentially better decison // other computations to match that layout. A potentially better decison
@ -1227,7 +1227,8 @@ namespace {
bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) { bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
switch (hlo.opcode()) { switch (hlo.opcode()) {
case HloOpcode::kReshape: case HloOpcode::kReshape:
return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()); return hlo.operand(0)->shape().rank() == 1 ||
std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
case HloOpcode::kTranspose: case HloOpcode::kTranspose:
return true; return true;
default: default:
@ -1593,18 +1594,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
LayoutUtil::ClearLayout(instruction->mutable_shape()); LayoutUtil::ClearLayout(instruction->mutable_shape());
// Create a copy of an operand if the operand instruction's layout does not
// match the use constraint (OperandLayoutConstraint).
for (int64 operand_no = 0; operand_no < instruction->operand_count();
++operand_no) {
const ShapeLayout* operand_layout =
constraints.OperandLayout(instruction, operand_no);
if (operand_layout != nullptr) {
TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
instruction, operand_no));
}
}
// Set the layouts of the array shapes this instruction defines as indicated // Set the layouts of the array shapes this instruction defines as indicated
// by the respective BufferLayoutConstraints. Any array shapes in the output // by the respective BufferLayoutConstraints. Any array shapes in the output
// of the instruction which are not defined by the instruction (eg, array // of the instruction which are not defined by the instruction (eg, array
@ -1647,6 +1636,18 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
return Status::OK(); return Status::OK();
})); }));
// Create a copy of an operand if the operand instruction's layout does not
// match the use constraint (OperandLayoutConstraint).
for (int64 operand_no = 0; operand_no < instruction->operand_count();
++operand_no) {
const ShapeLayout* operand_layout =
constraints.OperandLayout(instruction, operand_no);
if (operand_layout != nullptr) {
TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout,
instruction, operand_no));
}
}
// Fusion instructions require some layouts to be set on fused instructions // Fusion instructions require some layouts to be set on fused instructions
// inside the fusion instruction. // inside the fusion instruction.
if (instruction->opcode() == HloOpcode::kFusion) { if (instruction->opcode() == HloOpcode::kFusion) {

View File

@ -465,7 +465,7 @@ class LayoutAssignment : public HloModulePass {
// Creates a copy of the given operand if the operand's layout does not match // Creates a copy of the given operand if the operand's layout does not match
// the given layout. This copy replaces the use in the given instruction. // the given layout. This copy replaces the use in the given instruction.
// Tuple operands will be deep-copied. // Tuple operands will be deep-copied.
Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, virtual Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
HloInstruction* instruction, HloInstruction* instruction,
int64 operand_no); int64 operand_no);