Supports backedge propagation of layout for while loops.

- Supports backedge propagation of layout for while loops.
- Use bfs when propagating a layout from reshape's output to operand.

PiperOrigin-RevId: 256457555
This commit is contained in:
Yunxing Dai 2019-07-03 16:23:41 -07:00 committed by TensorFlower Gardener
parent 52315e686c
commit e43de87265
2 changed files with 47 additions and 6 deletions

View File

@ -178,6 +178,11 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
return Status::OK();
}
if (curr_constraint.mandatory()) {
if (!mandatory) {
VLOG(3) << "Buffer" << buffer
<< " already has a mandatory layout constrain, skipping";
return Status::OK();
}
return FailedPrecondition(
"Buffer %s already has the layout constraint %s, cannot add "
"incompatible constraint %s",
@ -1020,6 +1025,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
// Don't assign a layout in case of R1 -> effective R1 reshape.
return nullptr;
}
const Shape& output_shape = instruction->shape();
Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
@ -1225,7 +1231,8 @@ namespace {
// A transpose or a reshape that only changes trivial dimensions have meaningful
// layouts that are valuable to propagate in a depthfirst manner to avoid
// unassigned layouts in the graph.
bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
bool forward_propagation = true) {
switch (hlo.opcode()) {
case HloOpcode::kFusion:
return hlo.IsCustomFusion();
@ -1233,7 +1240,8 @@ bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
return true;
case HloOpcode::kReshape:
return hlo.operand(0)->shape().rank() == 1 ||
std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
(forward_propagation &&
std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
case HloOpcode::kScatter:
case HloOpcode::kTranspose:
return true;
@ -1430,7 +1438,9 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands(
if (operand_layout != nullptr) {
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
*operand_layout, instruction, operand_no, /*mandatory=*/false,
/*dfs=*/InstructionShouldPropagateDepthFirst(*instruction)));
/*dfs=*/
InstructionShouldPropagateDepthFirst(
*instruction, /*forward_propagation=*/false)));
}
} else {
VLOG(6) << "Operand already has a constraint "
@ -1475,6 +1485,33 @@ Status LayoutAssignment::PropagateBufferConstraintToUses(
}
}
// Propagate to backedges of kWhile.
CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
if (node.caller_callsites().size() != 1) {
return Status::OK();
}
const HloInstruction* parent = node.caller_callsites()[0].instruction();
if (parent->opcode() != HloOpcode::kWhile) {
return Status::OK();
}
for (HloInstruction* user : buffer.instruction()->users()) {
if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
continue;
}
if (user->parent()->root_instruction() == user) {
VLOG(3) << "Propagating layout through backedge"
<< buffer_constraint.layout().ToString();
int64 index = user->operand_index(buffer.instruction());
TF_ASSIGN_OR_RETURN(
auto buffer, constraints->points_to_analysis().GetBufferDefinedAt(
user->parent()->parameter_instruction(0), {index}));
TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
buffer_constraint.layout(), *buffer, /*mandatory=*/false));
}
}
return Status::OK();
}
@ -1933,11 +1970,11 @@ Status LayoutAssignment::PropagateComputationLayouts(
StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
VLOG(2) << "Running layout assignment on module " << module->name();
TF_RETURN_IF_ERROR(Init());
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
call_graph_ = CallGraph::Build(module);
auto computations = module->computations();
// Clone Conditional computations wiht multiple callsites.
// Clone Conditional computations with multiple callsites.
for (HloComputation* computation : computations) {
CallGraphNode& node = call_graph->GetNode(computation);
CallGraphNode& node = call_graph_->GetNode(computation);
if (node.caller_callsites().size() == 1) {
continue;
}

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -532,6 +533,9 @@ class LayoutAssignment : public HloModulePass {
std::function<bool(const HloInstruction*)>
instruction_can_change_layout_func_;
// CallGraph of the module, used to track callsites of each computation.
std::unique_ptr<CallGraph> call_graph_;
};
} // namespace xla