[XLA] Provide a more generic infrastructure to pass may-alias hints

Currently, may-alias hints can be passed from the compiler to the buffer assignment
through the FusionCanShareFunction callback.
This has a number of disadvantages:

 - Only aliasing inside fusion is supported. It's often desirable to alias
   inside custom calls, which have efficient inout parameter implementations.

 - FusionCanShareFunction returns a boolean, which requires an all-or-nothing
   approach: either the function returns whether aliasing is permitted,
   or the function is not passed at all.

This change replaces FusionCanShareFunction with MayAliasHint callback,
which solves these problems:

 - MayAliasHint returns absl::optional<bool>, which allows the callback to say
   "I don't know", delegating to the default behavior.

 - The callback is called outside of fusion, allowing aliasing inside non-fused
   instructions.

PiperOrigin-RevId: 254418380
This commit is contained in:
George Karpenkov 2019-06-21 10:10:18 -07:00 committed by TensorFlower Gardener
parent 8f4e309f2d
commit 572db7bf76
10 changed files with 59 additions and 61 deletions

View File

@ -763,12 +763,12 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
LogicalBuffer::AlignmentFunction color_alignment,
bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
const absl::flat_hash_set<HloOpcode>& reuse_checker,
HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer) {
HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
reuse_checker);
return assigner.CreateAssignment(
module, std::move(hlo_ordering), std::move(buffer_size),
std::move(color_alignment), std::move(fusion_can_share_buffer));
std::move(color_alignment), std::move(can_share_buffer));
}
namespace {
@ -1493,12 +1493,12 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
BufferValue::SizeFunction buffer_size,
LogicalBuffer::AlignmentFunction color_alignment,
HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer) {
HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
BufferLiveness::Run(module, std::move(hlo_ordering)));
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, fusion_can_share_buffer));
HloAliasAnalysis::Run(module, can_share_buffer));
VLOG(1) << "Assigning buffers to module " << module->name();
XLA_VLOG_LINES(3, module->ToString());

View File

@ -586,8 +586,7 @@ class BufferAssigner {
bool allocate_buffers_for_constants = false,
Colorer colorer = DefaultColorer(),
const absl::flat_hash_set<HloOpcode>& must_not_live_out = {},
HloDataflowAnalysis::FusionCanShareBufferFunction
fusion_can_share_buffer = nullptr);
HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr);
private:
BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer,
@ -602,8 +601,7 @@ class BufferAssigner {
const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
BufferValue::SizeFunction buffer_size,
LogicalBuffer::AlignmentFunction color_alignment,
HloDataflowAnalysis::FusionCanShareBufferFunction
fusion_can_share_buffer);
HloDataflowAnalysis::CanShareBuffer can_share_buffer);
// Assigns buffers to the instructions in the given computations. "assignment"
// is modified to reflect the new buffer assignments. If is_thread_local is

View File

@ -964,7 +964,7 @@ class CopyRemover {
// instructions which have update-in-place semantics.
Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
HloAliasAnalysis::Run(module, can_share_buffer_));
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
@ -989,7 +989,7 @@ Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
HloAliasAnalysis::Run(module, can_share_buffer_));
// Identify which shape indices of which instructions need to be copied. Store
// these results in 'instructions_to_copy'.
@ -1091,7 +1091,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
HloAliasAnalysis::Run(module, can_share_buffer_));
CopyRemover copy_remover(*module, *alias_analysis, ordering);
if (VLOG_IS_ON(3)) {

View File

@ -52,9 +52,9 @@ class CopyInsertion : public HloModulePass {
//
// TODO(b/80315712): Find a better way to tell whether a fusion can share
// buffer.
CopyInsertion(const HloDataflowAnalysis::FusionCanShareBufferFunction&
fusion_can_share_buffer = nullptr)
: fusion_can_share_buffer_(fusion_can_share_buffer) {}
explicit CopyInsertion(
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr)
: can_share_buffer_(can_share_buffer) {}
// Run the pass on the given module. Returns whether the module was changed
// (copies were inserted).
@ -85,7 +85,7 @@ class CopyInsertion : public HloModulePass {
// Backend specific function that decides whether a fusion can share buffer
// with its operand.
HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_;
HloDataflowAnalysis::CanShareBuffer can_share_buffer_;
private:
Status AddCopiesToResolveInterference(HloModule* module);

View File

@ -490,8 +490,7 @@ string HloAliasAnalysis::ToString() const {
/* static */
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
const HloModule* module,
const HloDataflowAnalysis::FusionCanShareBufferFunction&
fusion_can_share_buffer) {
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) {
VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
XLA_VLOG_LINES(2, module->ToString());
@ -499,7 +498,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
/*bitcast_defines_value=*/false,
fusion_can_share_buffer));
can_share_buffer));
BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
buffer_map.MergeAliasedBuffers();

View File

@ -42,8 +42,7 @@ class HloAliasAnalysis {
// (xla::FlattenCallGraph) prior to running the analysis.
static StatusOr<std::unique_ptr<HloAliasAnalysis>> Run(
const HloModule* module,
const HloDataflowAnalysis::FusionCanShareBufferFunction&
fusion_can_share_buffer = nullptr);
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr);
string ToString() const;

View File

@ -48,7 +48,7 @@ class HloAliasAnalysisTest : public HloTestBase {
// reference to the generated analysis stored in analysis_.
HloAliasAnalysis& RunAnalysis() {
analysis_ = HloAliasAnalysis::Run(module_.get(),
/*fusion_can_share_buffer=*/nullptr)
/*can_share_buffer=*/nullptr)
.ConsumeValueOrDie();
return *analysis_;
}

View File

@ -41,14 +41,14 @@ namespace xla {
using absl::StrAppend;
using absl::StrCat;
HloDataflowAnalysis::HloDataflowAnalysis(
const HloModule& module, bool ssa_form, bool bitcast_defines_value,
const FusionCanShareBufferFunction& fusion_can_share_buffer)
HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
bool bitcast_defines_value,
const CanShareBuffer& can_share_buffer)
: module_(module),
ssa_form_(ssa_form),
bitcast_defines_value_(bitcast_defines_value),
call_graph_(CallGraph::Build(&module)),
fusion_can_share_buffer_(fusion_can_share_buffer) {}
can_share_buffer_(can_share_buffer) {}
bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
const HloInstruction* inst) {
@ -849,12 +849,12 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
/* static */
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
const HloModule& module, bool ssa_form, bool bitcast_defines_value,
const FusionCanShareBufferFunction& fusion_can_share_buffer) {
const CanShareBuffer& can_share_buffer) {
VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
XLA_VLOG_LINES(2, module.ToString());
auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
module, ssa_form, bitcast_defines_value, fusion_can_share_buffer));
module, ssa_form, bitcast_defines_value, can_share_buffer));
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
dataflow_analysis->Propagate();
@ -1055,10 +1055,20 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
HloOpcode::kDynamicUpdateSlice) {
return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value);
}
}
if (fusion_can_share_buffer_ != nullptr) {
return fusion_can_share_buffer_(user, operand, user_index);
if (can_share_buffer_ != nullptr) {
if (absl::optional<bool> hint =
can_share_buffer_(user, operand, user_index)) {
return *hint;
}
}
if (user->opcode() == HloOpcode::kFusion) {
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
const HloValue& fusion_param_value =
GetValueDefinedAt(fusion_param, operand_index);
if (user->IsLoopFusion() || user->IsInputFusion()) {
return AreTransitiveUsesElementwiseOrTuple(fusion_param);

View File

@ -42,21 +42,15 @@ namespace xla {
// Analysis which identifies all HLO values and their uses in an HLO module.
class HloDataflowAnalysis {
public:
// Different backends can have very different ways to do fusion, so we give
// backends the flexibility to decide whether an fusion instruction can share
// buffer with it's operands. If this is not specified, a default strategy
// will be used; if this is specified, it will be applied *in addition* to the
// default strategy.
// Infrastructure for passing may-alias hints: HLO passes can populate the
// may-alias table. If an empty optional is returned, default rules are used.
//
// The first parameter of the function should be the fusion instruction, the
// second parameter should be an operand of the fusion instruction. The third
// parameter should be the output index of the fusion.
//
// TODO(b/80315712): Find a better way to tell whether a fusion can share
// buffer.
using FusionCanShareBufferFunction = std::function<bool(
const HloInstruction* fusion, const HloInstruction* operand,
const ShapeIndex& fusion_index)>;
// The first parameter of the function should be the instruction, the
// second parameter should be an operand of the instruction. The third
// parameter should be the output index of the instruction.
using CanShareBuffer = std::function<absl::optional<bool>(
const HloInstruction* instr, const HloInstruction* operand,
const ShapeIndex& user_index)>;
// Run dataflow analysis on the given module. Parameters:
//
@ -78,7 +72,7 @@ class HloDataflowAnalysis {
static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
const HloModule& module, bool ssa_form = false,
bool bitcast_defines_value = false,
const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr);
const CanShareBuffer& can_share_buffer = nullptr);
static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst);
@ -160,10 +154,9 @@ class HloDataflowAnalysis {
const HloModule& module() const { return module_; }
protected:
HloDataflowAnalysis(
const HloModule& module, bool ssa_form,
bool bitcast_defines_value = false,
const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr);
HloDataflowAnalysis(const HloModule& module, bool ssa_form,
bool bitcast_defines_value = false,
const CanShareBuffer& can_share_buffer = nullptr);
// Returns a new HloValue defined at the given instruction and shape index.
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
@ -249,9 +242,9 @@ class HloDataflowAnalysis {
// The Id to use for the next HloValue.
HloValue::Id next_value_id_ = 0;
// Backend specific function that decides whether a fusion can share buffer
// with its operand.
FusionCanShareBufferFunction fusion_can_share_buffer_ = nullptr;
// Backend specific function that decides whether an instruction can share
// a buffer with its operand.
CanShareBuffer can_share_buffer_ = nullptr;
};
} // namespace xla

View File

@ -1941,14 +1941,13 @@ class HloDataflowAnalysisTestBase : public HloTestBase {
computation_ = module_->AddEntryComputation(std::move(computation));
}
void RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction&
fusion_can_share_buffer = nullptr) {
void RunAnalysis(
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
CHECK_NOTNULL(module_.get());
dataflow_analysis_ =
HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false,
/*bitcast_defines_value=*/false,
fusion_can_share_buffer)
.ConsumeValueOrDie();
dataflow_analysis_ = HloDataflowAnalysis::Run(
*module_, /*ssa_form=*/false,
/*bitcast_defines_value=*/false, can_share_buffer)
.ConsumeValueOrDie();
}
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
@ -2575,9 +2574,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
BuildModule(builder.Build());
auto fusion = computation_->CreateFusionInstruction(
{add, two, mul}, HloInstruction::FusionKind::kInput);
RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion,
const HloInstruction*,
const ShapeIndex& output_index) {
RunAnalysis(/*can_share_buffer=*/[](const HloInstruction* fusion,
const HloInstruction*,
const ShapeIndex&) {
return fusion->IsLoopFusion();
});