From e43de872650f2e0bcc3be13dc63ab3eb2a45bbf0 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Wed, 3 Jul 2019 16:23:41 -0700 Subject: [PATCH] Supports backedge propagation of layout for while loops. - Supports backedge propagation of layout for while loops. - Use bfs when propagating a layout from reshape's output to operand. PiperOrigin-RevId: 256457555 --- .../compiler/xla/service/layout_assignment.cc | 49 ++++++++++++++++--- .../compiler/xla/service/layout_assignment.h | 4 ++ 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index c2372aa0c8a..72ffcd26a72 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -178,6 +178,11 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, return Status::OK(); } if (curr_constraint.mandatory()) { + if (!mandatory) { + VLOG(3) << "Buffer" << buffer + << " already has a mandatory layout constrain, skipping"; + return Status::OK(); + } return FailedPrecondition( "Buffer %s already has the layout constraint %s, cannot add " "incompatible constraint %s", @@ -1020,6 +1025,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // Don't assign a layout in case of R1 -> effective R1 reshape. return nullptr; } + const Shape& output_shape = instruction->shape(); Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), @@ -1225,7 +1231,8 @@ namespace { // A transpose or a reshape that only changes trivial dimensions have meaningful // layouts that are valuable to propagate in a depthfirst manner to avoid // unassigned layouts in the graph. -bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) { +bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo, + bool forward_propagation = true) { switch (hlo.opcode()) { case HloOpcode::kFusion: return hlo.IsCustomFusion(); @@ -1233,7 +1240,8 @@ bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) { return true; case HloOpcode::kReshape: return hlo.operand(0)->shape().rank() == 1 || - std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()); + (forward_propagation && + std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions())); case HloOpcode::kScatter: case HloOpcode::kTranspose: return true; @@ -1430,7 +1438,9 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands( if (operand_layout != nullptr) { TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout( *operand_layout, instruction, operand_no, /*mandatory=*/false, - /*dfs=*/InstructionShouldPropagateDepthFirst(*instruction))); + /*dfs=*/ + InstructionShouldPropagateDepthFirst( + *instruction, /*forward_propagation=*/false))); } } else { VLOG(6) << "Operand already has a constraint " @@ -1475,6 +1485,33 @@ Status LayoutAssignment::PropagateBufferConstraintToUses( } } + // Propagate to backedges of kWhile. + CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent()); + if (node.caller_callsites().size() != 1) { + return Status::OK(); + } + const HloInstruction* parent = node.caller_callsites()[0].instruction(); + if (parent->opcode() != HloOpcode::kWhile) { + return Status::OK(); + } + + for (HloInstruction* user : buffer.instruction()->users()) { + if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) { + continue; + } + if (user->parent()->root_instruction() == user) { + VLOG(3) << "Propagating layout through backedge" + << buffer_constraint.layout().ToString(); + int64 index = user->operand_index(buffer.instruction()); + TF_ASSIGN_OR_RETURN( + auto buffer, constraints->points_to_analysis().GetBufferDefinedAt( + user->parent()->parameter_instruction(0), {index})); + + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + buffer_constraint.layout(), *buffer, /*mandatory=*/false)); + } + } + return Status::OK(); } @@ -1933,11 +1970,11 @@ Status LayoutAssignment::PropagateComputationLayouts( StatusOr LayoutAssignment::Run(HloModule* module) { VLOG(2) << "Running layout assignment on module " << module->name(); TF_RETURN_IF_ERROR(Init()); - std::unique_ptr call_graph = CallGraph::Build(module); + call_graph_ = CallGraph::Build(module); auto computations = module->computations(); - // Clone Conditional computations wiht multiple callsites. + // Clone Conditional computations with multiple callsites. for (HloComputation* computation : computations) { - CallGraphNode& node = call_graph->GetNode(computation); + CallGraphNode& node = call_graph_->GetNode(computation); if (node.caller_callsites().size() == 1) { continue; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 6b6b3665317..6a202837e14 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -27,6 +27,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -532,6 +533,9 @@ class LayoutAssignment : public HloModulePass { std::function instruction_can_change_layout_func_; + + // CallGraph of the module, used to track callsites of each computation. + std::unique_ptr call_graph_; }; } // namespace xla