[XLA] Provide a more generic infrastructure to pass may-alias hints
Currently, may-alias hints can be passed from the compiler to the buffer assignment through the FusionCanShareFunction callback. This has a number of disadvantages: - Only aliasing inside fusion is supported. It's often desirable to alias inside custom calls, which have efficient inout parameter implementations. - FusionCanShareFunction returns a boolean, which requires an all-or-nothing approach: either the function returns whether aliasing is permitted, or the function is not passed at all. This change replaces FusionCanShareFunction with MayAliasHint callback, which solves these problems: - MayAliasHint returns absl::optional<bool>, which allows the callback to say "I don't know", delegating to the default behavior. - The callback is called outside of fusion, allowing aliasing inside non-fused instructions. PiperOrigin-RevId: 254418380
This commit is contained in:
parent
8f4e309f2d
commit
572db7bf76
@ -763,12 +763,12 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::Run(
|
||||
LogicalBuffer::AlignmentFunction color_alignment,
|
||||
bool allocate_buffers_for_constants, BufferAssigner::Colorer colorer,
|
||||
const absl::flat_hash_set<HloOpcode>& reuse_checker,
|
||||
HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer) {
|
||||
HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
|
||||
BufferAssigner assigner(allocate_buffers_for_constants, std::move(colorer),
|
||||
reuse_checker);
|
||||
return assigner.CreateAssignment(
|
||||
module, std::move(hlo_ordering), std::move(buffer_size),
|
||||
std::move(color_alignment), std::move(fusion_can_share_buffer));
|
||||
std::move(color_alignment), std::move(can_share_buffer));
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -1493,12 +1493,12 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
||||
const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
|
||||
BufferValue::SizeFunction buffer_size,
|
||||
LogicalBuffer::AlignmentFunction color_alignment,
|
||||
HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer) {
|
||||
HloDataflowAnalysis::CanShareBuffer can_share_buffer) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<BufferLiveness> liveness,
|
||||
BufferLiveness::Run(module, std::move(hlo_ordering)));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, fusion_can_share_buffer));
|
||||
HloAliasAnalysis::Run(module, can_share_buffer));
|
||||
|
||||
VLOG(1) << "Assigning buffers to module " << module->name();
|
||||
XLA_VLOG_LINES(3, module->ToString());
|
||||
|
||||
@ -586,8 +586,7 @@ class BufferAssigner {
|
||||
bool allocate_buffers_for_constants = false,
|
||||
Colorer colorer = DefaultColorer(),
|
||||
const absl::flat_hash_set<HloOpcode>& must_not_live_out = {},
|
||||
HloDataflowAnalysis::FusionCanShareBufferFunction
|
||||
fusion_can_share_buffer = nullptr);
|
||||
HloDataflowAnalysis::CanShareBuffer can_share_buffer = nullptr);
|
||||
|
||||
private:
|
||||
BufferAssigner(bool allocate_buffers_for_constants, Colorer colorer,
|
||||
@ -602,8 +601,7 @@ class BufferAssigner {
|
||||
const HloModule* module, std::unique_ptr<HloOrdering> hlo_ordering,
|
||||
BufferValue::SizeFunction buffer_size,
|
||||
LogicalBuffer::AlignmentFunction color_alignment,
|
||||
HloDataflowAnalysis::FusionCanShareBufferFunction
|
||||
fusion_can_share_buffer);
|
||||
HloDataflowAnalysis::CanShareBuffer can_share_buffer);
|
||||
|
||||
// Assigns buffers to the instructions in the given computations. "assignment"
|
||||
// is modified to reflect the new buffer assignments. If is_thread_local is
|
||||
|
||||
@ -964,7 +964,7 @@ class CopyRemover {
|
||||
// instructions which have update-in-place semantics.
|
||||
Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
|
||||
HloAliasAnalysis::Run(module, can_share_buffer_));
|
||||
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
@ -989,7 +989,7 @@ Status CopyInsertion::AddSpecialCaseCopies(HloModule* module) {
|
||||
Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
|
||||
HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
|
||||
HloAliasAnalysis::Run(module, can_share_buffer_));
|
||||
|
||||
// Identify which shape indices of which instructions need to be copied. Store
|
||||
// these results in 'instructions_to_copy'.
|
||||
@ -1091,7 +1091,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
|
||||
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, fusion_can_share_buffer_));
|
||||
HloAliasAnalysis::Run(module, can_share_buffer_));
|
||||
|
||||
CopyRemover copy_remover(*module, *alias_analysis, ordering);
|
||||
if (VLOG_IS_ON(3)) {
|
||||
|
||||
@ -52,9 +52,9 @@ class CopyInsertion : public HloModulePass {
|
||||
//
|
||||
// TODO(b/80315712): Find a better way to tell whether a fusion can share
|
||||
// buffer.
|
||||
CopyInsertion(const HloDataflowAnalysis::FusionCanShareBufferFunction&
|
||||
fusion_can_share_buffer = nullptr)
|
||||
: fusion_can_share_buffer_(fusion_can_share_buffer) {}
|
||||
explicit CopyInsertion(
|
||||
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr)
|
||||
: can_share_buffer_(can_share_buffer) {}
|
||||
|
||||
// Run the pass on the given module. Returns whether the module was changed
|
||||
// (copies were inserted).
|
||||
@ -85,7 +85,7 @@ class CopyInsertion : public HloModulePass {
|
||||
|
||||
// Backend specific function that decides whether a fusion can share buffer
|
||||
// with its operand.
|
||||
HloDataflowAnalysis::FusionCanShareBufferFunction fusion_can_share_buffer_;
|
||||
HloDataflowAnalysis::CanShareBuffer can_share_buffer_;
|
||||
|
||||
private:
|
||||
Status AddCopiesToResolveInterference(HloModule* module);
|
||||
|
||||
@ -490,8 +490,7 @@ string HloAliasAnalysis::ToString() const {
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
const HloModule* module,
|
||||
const HloDataflowAnalysis::FusionCanShareBufferFunction&
|
||||
fusion_can_share_buffer) {
|
||||
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer) {
|
||||
VLOG(2) << "HloAliasAnalysis::Run on module " << module->name();
|
||||
XLA_VLOG_LINES(2, module->ToString());
|
||||
|
||||
@ -499,7 +498,7 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
|
||||
TF_ASSIGN_OR_RETURN(alias_analysis->dataflow_analysis_,
|
||||
HloDataflowAnalysis::Run(*module, /*ssa_form=*/true,
|
||||
/*bitcast_defines_value=*/false,
|
||||
fusion_can_share_buffer));
|
||||
can_share_buffer));
|
||||
|
||||
BufferValueMap buffer_map(module, alias_analysis->dataflow_analysis());
|
||||
buffer_map.MergeAliasedBuffers();
|
||||
|
||||
@ -42,8 +42,7 @@ class HloAliasAnalysis {
|
||||
// (xla::FlattenCallGraph) prior to running the analysis.
|
||||
static StatusOr<std::unique_ptr<HloAliasAnalysis>> Run(
|
||||
const HloModule* module,
|
||||
const HloDataflowAnalysis::FusionCanShareBufferFunction&
|
||||
fusion_can_share_buffer = nullptr);
|
||||
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr);
|
||||
|
||||
string ToString() const;
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ class HloAliasAnalysisTest : public HloTestBase {
|
||||
// reference to the generated analysis stored in analysis_.
|
||||
HloAliasAnalysis& RunAnalysis() {
|
||||
analysis_ = HloAliasAnalysis::Run(module_.get(),
|
||||
/*fusion_can_share_buffer=*/nullptr)
|
||||
/*can_share_buffer=*/nullptr)
|
||||
.ConsumeValueOrDie();
|
||||
return *analysis_;
|
||||
}
|
||||
|
||||
@ -41,14 +41,14 @@ namespace xla {
|
||||
using absl::StrAppend;
|
||||
using absl::StrCat;
|
||||
|
||||
HloDataflowAnalysis::HloDataflowAnalysis(
|
||||
const HloModule& module, bool ssa_form, bool bitcast_defines_value,
|
||||
const FusionCanShareBufferFunction& fusion_can_share_buffer)
|
||||
HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form,
|
||||
bool bitcast_defines_value,
|
||||
const CanShareBuffer& can_share_buffer)
|
||||
: module_(module),
|
||||
ssa_form_(ssa_form),
|
||||
bitcast_defines_value_(bitcast_defines_value),
|
||||
call_graph_(CallGraph::Build(&module)),
|
||||
fusion_can_share_buffer_(fusion_can_share_buffer) {}
|
||||
can_share_buffer_(can_share_buffer) {}
|
||||
|
||||
bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
|
||||
const HloInstruction* inst) {
|
||||
@ -849,12 +849,12 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
|
||||
const HloModule& module, bool ssa_form, bool bitcast_defines_value,
|
||||
const FusionCanShareBufferFunction& fusion_can_share_buffer) {
|
||||
const CanShareBuffer& can_share_buffer) {
|
||||
VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name();
|
||||
XLA_VLOG_LINES(2, module.ToString());
|
||||
|
||||
auto dataflow_analysis = absl::WrapUnique(new HloDataflowAnalysis(
|
||||
module, ssa_form, bitcast_defines_value, fusion_can_share_buffer));
|
||||
module, ssa_form, bitcast_defines_value, can_share_buffer));
|
||||
|
||||
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
|
||||
dataflow_analysis->Propagate();
|
||||
@ -1055,10 +1055,20 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
HloOpcode::kDynamicUpdateSlice) {
|
||||
return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value);
|
||||
}
|
||||
}
|
||||
|
||||
if (fusion_can_share_buffer_ != nullptr) {
|
||||
return fusion_can_share_buffer_(user, operand, user_index);
|
||||
if (can_share_buffer_ != nullptr) {
|
||||
if (absl::optional<bool> hint =
|
||||
can_share_buffer_(user, operand, user_index)) {
|
||||
return *hint;
|
||||
}
|
||||
}
|
||||
|
||||
if (user->opcode() == HloOpcode::kFusion) {
|
||||
HloInstruction* fusion_param =
|
||||
user->fused_parameter(user->operand_index(operand));
|
||||
const HloValue& fusion_param_value =
|
||||
GetValueDefinedAt(fusion_param, operand_index);
|
||||
|
||||
if (user->IsLoopFusion() || user->IsInputFusion()) {
|
||||
return AreTransitiveUsesElementwiseOrTuple(fusion_param);
|
||||
|
||||
@ -42,21 +42,15 @@ namespace xla {
|
||||
// Analysis which identifies all HLO values and their uses in an HLO module.
|
||||
class HloDataflowAnalysis {
|
||||
public:
|
||||
// Different backends can have very different ways to do fusion, so we give
|
||||
// backends the flexibility to decide whether an fusion instruction can share
|
||||
// buffer with it's operands. If this is not specified, a default strategy
|
||||
// will be used; if this is specified, it will be applied *in addition* to the
|
||||
// default strategy.
|
||||
// Infrastructure for passing may-alias hints: HLO passes can populate the
|
||||
// may-alias table. If an empty optional is returned, default rules are used.
|
||||
//
|
||||
// The first parameter of the function should be the fusion instruction, the
|
||||
// second parameter should be an operand of the fusion instruction. The third
|
||||
// parameter should be the output index of the fusion.
|
||||
//
|
||||
// TODO(b/80315712): Find a better way to tell whether a fusion can share
|
||||
// buffer.
|
||||
using FusionCanShareBufferFunction = std::function<bool(
|
||||
const HloInstruction* fusion, const HloInstruction* operand,
|
||||
const ShapeIndex& fusion_index)>;
|
||||
// The first parameter of the function should be the instruction, the
|
||||
// second parameter should be an operand of the instruction. The third
|
||||
// parameter should be the output index of the instruction.
|
||||
using CanShareBuffer = std::function<absl::optional<bool>(
|
||||
const HloInstruction* instr, const HloInstruction* operand,
|
||||
const ShapeIndex& user_index)>;
|
||||
|
||||
// Run dataflow analysis on the given module. Parameters:
|
||||
//
|
||||
@ -78,7 +72,7 @@ class HloDataflowAnalysis {
|
||||
static StatusOr<std::unique_ptr<HloDataflowAnalysis>> Run(
|
||||
const HloModule& module, bool ssa_form = false,
|
||||
bool bitcast_defines_value = false,
|
||||
const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr);
|
||||
const CanShareBuffer& can_share_buffer = nullptr);
|
||||
|
||||
static bool AreTransitiveUsesElementwiseOrTuple(const HloInstruction* inst);
|
||||
|
||||
@ -160,10 +154,9 @@ class HloDataflowAnalysis {
|
||||
const HloModule& module() const { return module_; }
|
||||
|
||||
protected:
|
||||
HloDataflowAnalysis(
|
||||
const HloModule& module, bool ssa_form,
|
||||
bool bitcast_defines_value = false,
|
||||
const FusionCanShareBufferFunction& fusion_can_share_buffer = nullptr);
|
||||
HloDataflowAnalysis(const HloModule& module, bool ssa_form,
|
||||
bool bitcast_defines_value = false,
|
||||
const CanShareBuffer& can_share_buffer = nullptr);
|
||||
|
||||
// Returns a new HloValue defined at the given instruction and shape index.
|
||||
HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
|
||||
@ -249,9 +242,9 @@ class HloDataflowAnalysis {
|
||||
// The Id to use for the next HloValue.
|
||||
HloValue::Id next_value_id_ = 0;
|
||||
|
||||
// Backend specific function that decides whether a fusion can share buffer
|
||||
// with its operand.
|
||||
FusionCanShareBufferFunction fusion_can_share_buffer_ = nullptr;
|
||||
// Backend specific function that decides whether an instruction can share
|
||||
// a buffer with its operand.
|
||||
CanShareBuffer can_share_buffer_ = nullptr;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -1941,14 +1941,13 @@ class HloDataflowAnalysisTestBase : public HloTestBase {
|
||||
computation_ = module_->AddEntryComputation(std::move(computation));
|
||||
}
|
||||
|
||||
void RunAnalysis(const HloDataflowAnalysis::FusionCanShareBufferFunction&
|
||||
fusion_can_share_buffer = nullptr) {
|
||||
void RunAnalysis(
|
||||
const HloDataflowAnalysis::CanShareBuffer& can_share_buffer = nullptr) {
|
||||
CHECK_NOTNULL(module_.get());
|
||||
dataflow_analysis_ =
|
||||
HloDataflowAnalysis::Run(*module_, /*ssa_form=*/false,
|
||||
/*bitcast_defines_value=*/false,
|
||||
fusion_can_share_buffer)
|
||||
.ConsumeValueOrDie();
|
||||
dataflow_analysis_ = HloDataflowAnalysis::Run(
|
||||
*module_, /*ssa_form=*/false,
|
||||
/*bitcast_defines_value=*/false, can_share_buffer)
|
||||
.ConsumeValueOrDie();
|
||||
}
|
||||
|
||||
void BuildModuleAndRunAnalysis(std::unique_ptr<HloComputation> computation) {
|
||||
@ -2575,9 +2574,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusionCanShareBufferCustomized) {
|
||||
BuildModule(builder.Build());
|
||||
auto fusion = computation_->CreateFusionInstruction(
|
||||
{add, two, mul}, HloInstruction::FusionKind::kInput);
|
||||
RunAnalysis(/*fusion_can_share_buffer=*/[](const HloInstruction* fusion,
|
||||
const HloInstruction*,
|
||||
const ShapeIndex& output_index) {
|
||||
RunAnalysis(/*can_share_buffer=*/[](const HloInstruction* fusion,
|
||||
const HloInstruction*,
|
||||
const ShapeIndex&) {
|
||||
return fusion->IsLoopFusion();
|
||||
});
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user