Supports backedge propagation of layout for while loops.
- Supports backedge propagation of layout for while loops. - Use bfs when propagating a layout from reshape's output to operand. PiperOrigin-RevId: 256457555
This commit is contained in:
parent
52315e686c
commit
e43de87265
@ -178,6 +178,11 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout,
|
||||
return Status::OK();
|
||||
}
|
||||
if (curr_constraint.mandatory()) {
|
||||
if (!mandatory) {
|
||||
VLOG(3) << "Buffer" << buffer
|
||||
<< " already has a mandatory layout constrain, skipping";
|
||||
return Status::OK();
|
||||
}
|
||||
return FailedPrecondition(
|
||||
"Buffer %s already has the layout constraint %s, cannot add "
|
||||
"incompatible constraint %s",
|
||||
@ -1020,6 +1025,7 @@ std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
|
||||
// Don't assign a layout in case of R1 -> effective R1 reshape.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const Shape& output_shape = instruction->shape();
|
||||
Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout(
|
||||
output_shape.element_type(), AsInt64Slice(output_shape.dimensions()),
|
||||
@ -1225,7 +1231,8 @@ namespace {
|
||||
// A transpose or a reshape that only changes trivial dimensions have meaningful
|
||||
// layouts that are valuable to propagate in a depthfirst manner to avoid
|
||||
// unassigned layouts in the graph.
|
||||
bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
|
||||
bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo,
|
||||
bool forward_propagation = true) {
|
||||
switch (hlo.opcode()) {
|
||||
case HloOpcode::kFusion:
|
||||
return hlo.IsCustomFusion();
|
||||
@ -1233,7 +1240,8 @@ bool InstructionShouldPropagateDepthFirst(const HloInstruction& hlo) {
|
||||
return true;
|
||||
case HloOpcode::kReshape:
|
||||
return hlo.operand(0)->shape().rank() == 1 ||
|
||||
std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions());
|
||||
(forward_propagation &&
|
||||
std::get<0>(hlo.ReshapeMerelyInsertsOrDeletes1SizedDimensions()));
|
||||
case HloOpcode::kScatter:
|
||||
case HloOpcode::kTranspose:
|
||||
return true;
|
||||
@ -1430,7 +1438,9 @@ Status LayoutAssignment::PropagateBufferConstraintToOperands(
|
||||
if (operand_layout != nullptr) {
|
||||
TF_RETURN_IF_ERROR(constraints->SetArrayOperandLayout(
|
||||
*operand_layout, instruction, operand_no, /*mandatory=*/false,
|
||||
/*dfs=*/InstructionShouldPropagateDepthFirst(*instruction)));
|
||||
/*dfs=*/
|
||||
InstructionShouldPropagateDepthFirst(
|
||||
*instruction, /*forward_propagation=*/false)));
|
||||
}
|
||||
} else {
|
||||
VLOG(6) << "Operand already has a constraint "
|
||||
@ -1475,6 +1485,33 @@ Status LayoutAssignment::PropagateBufferConstraintToUses(
|
||||
}
|
||||
}
|
||||
|
||||
// Propagate to backedges of kWhile.
|
||||
CallGraphNode& node = call_graph_->GetNode(buffer.instruction()->parent());
|
||||
if (node.caller_callsites().size() != 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
const HloInstruction* parent = node.caller_callsites()[0].instruction();
|
||||
if (parent->opcode() != HloOpcode::kWhile) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
for (HloInstruction* user : buffer.instruction()->users()) {
|
||||
if (user->parent()->root_instruction()->opcode() != HloOpcode::kTuple) {
|
||||
continue;
|
||||
}
|
||||
if (user->parent()->root_instruction() == user) {
|
||||
VLOG(3) << "Propagating layout through backedge"
|
||||
<< buffer_constraint.layout().ToString();
|
||||
int64 index = user->operand_index(buffer.instruction());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto buffer, constraints->points_to_analysis().GetBufferDefinedAt(
|
||||
user->parent()->parameter_instruction(0), {index}));
|
||||
|
||||
TF_RETURN_IF_ERROR(constraints->SetBufferLayout(
|
||||
buffer_constraint.layout(), *buffer, /*mandatory=*/false));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1933,11 +1970,11 @@ Status LayoutAssignment::PropagateComputationLayouts(
|
||||
StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
|
||||
VLOG(2) << "Running layout assignment on module " << module->name();
|
||||
TF_RETURN_IF_ERROR(Init());
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
call_graph_ = CallGraph::Build(module);
|
||||
auto computations = module->computations();
|
||||
// Clone Conditional computations wiht multiple callsites.
|
||||
// Clone Conditional computations with multiple callsites.
|
||||
for (HloComputation* computation : computations) {
|
||||
CallGraphNode& node = call_graph->GetNode(computation);
|
||||
CallGraphNode& node = call_graph_->GetNode(computation);
|
||||
if (node.caller_callsites().size() == 1) {
|
||||
continue;
|
||||
}
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
@ -532,6 +533,9 @@ class LayoutAssignment : public HloModulePass {
|
||||
|
||||
std::function<bool(const HloInstruction*)>
|
||||
instruction_can_change_layout_func_;
|
||||
|
||||
// CallGraph of the module, used to track callsites of each computation.
|
||||
std::unique_ptr<CallGraph> call_graph_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user