diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 704f12b5e87..b0abce5eb2c 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -585,10 +585,10 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain the output and the operand of the while instruction to match // the computations. - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( - body_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( body_layout.result_shape(), instruction, 0)); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + body_layout.result_shape(), instruction)); } else if (instruction->opcode() == HloOpcode::kConditional) { // Find the conditional branch with the most instructions and force all // other computations to match that layout. A potentially better decison @@ -1227,7 +1227,8 @@ namespace { bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) { switch (hlo.opcode()) { case HloOpcode::kReshape: - return std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()); + return hlo.operand(0)->shape().rank() == 1 || + std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()); case HloOpcode::kTranspose: return true; default: @@ -1593,18 +1594,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { LayoutUtil::ClearLayout(instruction->mutable_shape()); - // Create a copy of an operand if the operand instruction's layout does not - // match the use constraint (OperandLayoutConstraint). - for (int64 operand_no = 0; operand_no < instruction->operand_count(); - ++operand_no) { - const ShapeLayout* operand_layout = - constraints.OperandLayout(instruction, operand_no); - if (operand_layout != nullptr) { - TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout, - instruction, operand_no)); - } - } - // Set the layouts of the array shapes this instruction defines as indicated // by the respective BufferLayoutConstraints. Any array shapes in the output // of the instruction which are not defined by the instruction (eg, array @@ -1647,6 +1636,18 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, return Status::OK(); })); + // Create a copy of an operand if the operand instruction's layout does not + // match the use constraint (OperandLayoutConstraint). + for (int64 operand_no = 0; operand_no < instruction->operand_count(); + ++operand_no) { + const ShapeLayout* operand_layout = + constraints.OperandLayout(instruction, operand_no); + if (operand_layout != nullptr) { + TF_RETURN_IF_ERROR(CopyOperandIfLayoutsDiffer(*operand_layout, + instruction, operand_no)); + } + } + // Fusion instructions require some layouts to be set on fused instructions // inside the fusion instruction. if (instruction->opcode() == HloOpcode::kFusion) { diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 97ace4a062c..fc6d43f1b61 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -465,9 +465,9 @@ class LayoutAssignment : public HloModulePass { // Creates a copy of the given operand if the operand's layout does not match // the given layout. This copy replaces the use in the given instruction. // Tuple operands will be deep-copied. - Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no); + virtual Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no); // Registers a copy instruction added by the layout assignment pass. void RegisterAddedCopy(HloInstruction* copy) {