[XLA] Roll forward in-place op buffer sharing.
The original CL (cl/331352657) was auto rolled back because it broke a GPU test in debug mode. The reason was that HloDataflowAnalysis has additional logic to determine which HLOs can share the buffers with their operands, which is used by HloOrdering. There was a mismatch in what was allowed as in-place buffers in new versus old mechanism. I now moved the in-place checks from alias analysis to dataflow analysis (for linkage reasons) so that there is one source of truth of what is an in-place operation. PiperOrigin-RevId: 331779239 Change-Id: I226c3647cbcfbd8ec3ed7896a845ecaeab6ca84d
This commit is contained in:
parent
79b52dbdc9
commit
84cb50ea17
@ -1007,102 +1007,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
|
||||
return true;
|
||||
} // namespace xla
|
||||
|
||||
Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
||||
// Try allocate same buffer for dynamic update slice's operand and output.
|
||||
|
||||
// If memory_space_assignment is run and there is information about a color in
|
||||
// preset assignments, don't merge those buffers. We expect
|
||||
// memory_space_assignment to have merged these buffers. If
|
||||
// memory_space_assignment didn't merge these buffers and have assigned
|
||||
// different offsets to the operand and the output buffer, merging the buffers
|
||||
// can cause memory corruption if memory_space_assignment assigned a different
|
||||
// buffer at the same offset.
|
||||
absl::flat_hash_set<int64> excluded_colors;
|
||||
if (preset_assignments_) {
|
||||
for (const auto& color_and_info :
|
||||
preset_assignments_->assignment_informations()) {
|
||||
excluded_colors.insert(color_and_info.first);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(yunxing): Moving this logic to alias analysis and add must-alias rule
|
||||
// to operations that can be done in place.
|
||||
for (HloComputation* computation : assignment->module().computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice ||
|
||||
(instruction->opcode() == HloOpcode::kFusion &&
|
||||
(instruction->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice)))) {
|
||||
continue;
|
||||
}
|
||||
if (instruction->parent()->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
if (instruction->operand_count() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The operand can't share the same buffer with the user based on dataflow
|
||||
// analysis.
|
||||
if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser(
|
||||
instruction->mutable_operand(0), {}, instruction, {})) {
|
||||
continue;
|
||||
}
|
||||
HloBuffer& instruction_buffer =
|
||||
assignment->alias_analysis().GetUniqueBufferAt(instruction, {});
|
||||
|
||||
HloBuffer& operand_buffer =
|
||||
assignment->alias_analysis().GetUniqueBufferAt(
|
||||
instruction->operand(0), {});
|
||||
|
||||
// The instruction or operand color is excluded because it was assigned by
|
||||
// memory_space_assignment.
|
||||
if (excluded_colors.contains(instruction_buffer.color()) ||
|
||||
excluded_colors.contains(operand_buffer.color())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Already have the same buffer. No need to merge those.
|
||||
if (instruction_buffer.id() == operand_buffer.id()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do not perform in-place dynamic update slice if the operand buffer is
|
||||
// read-only.
|
||||
if (HloBufferIsReadOnly(operand_buffer)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool interfere = false;
|
||||
|
||||
for (const HloValue* instruction_value : instruction_buffer.values()) {
|
||||
for (const HloValue* operand_value : operand_buffer.values()) {
|
||||
if (assignment->hlo_ordering().MayInterfere(
|
||||
*instruction_value, *operand_value,
|
||||
assignment->dataflow_analysis())) {
|
||||
interfere = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (interfere) {
|
||||
continue;
|
||||
}
|
||||
if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) {
|
||||
continue;
|
||||
}
|
||||
if (instruction_buffer.color() != operand_buffer.color()) {
|
||||
continue;
|
||||
}
|
||||
VLOG(3) << "Merging inplace " << instruction_buffer << " and "
|
||||
<< operand_buffer;
|
||||
assignment->alias_analysis().MergeBuffers(instruction_buffer,
|
||||
operand_buffer);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BufferAssigner::AssignSingleHloBuffer(
|
||||
const HloBuffer* hlo_buffer, bool is_thread_local,
|
||||
absl::flat_hash_map<const HloComputation*,
|
||||
@ -1654,7 +1558,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
||||
VLOG(3) << "After coloring:";
|
||||
XLA_VLOG_LINES(3,
|
||||
assignment->alias_analysis().dataflow_analysis().ToString());
|
||||
TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get()));
|
||||
|
||||
std::vector<const HloComputation*> thread_local_computations;
|
||||
std::vector<const HloComputation*> global_computations;
|
||||
|
@ -635,10 +635,6 @@ class BufferAssigner {
|
||||
absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
|
||||
BufferAssignment* assignment);
|
||||
|
||||
// Promotes operations (DUS, scatter) to be done in place: If an operation can
|
||||
// be done in place, merge its buffer with its operand buffer.
|
||||
Status MergeInplaceOpBuffers(BufferAssignment* assignment);
|
||||
|
||||
// Assigns a single hlo buffer to an HLO allocation.
|
||||
Status AssignSingleHloBuffer(
|
||||
const HloBuffer* hlo_buffer, bool is_thread_local,
|
||||
|
@ -1925,8 +1925,10 @@ ENTRY main {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
|
||||
HloInstruction* parameter =
|
||||
m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
|
||||
HloInstruction* dus =
|
||||
HloInstruction* dus1 =
|
||||
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
|
||||
HloInstruction* dus2 =
|
||||
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
|
||||
|
||||
auto buffers = RunBufferAssignment(m.get());
|
||||
|
||||
@ -1934,8 +1936,10 @@ ENTRY main {
|
||||
const BufferAllocation& parameter_alloc =
|
||||
GetTopLevelAllocation(*buffers, parameter);
|
||||
|
||||
const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus);
|
||||
EXPECT_NE(parameter_alloc, dus_alloc);
|
||||
const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
|
||||
EXPECT_EQ(parameter_alloc, dus1_alloc);
|
||||
const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
|
||||
EXPECT_EQ(parameter_alloc, dus2_alloc);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -362,6 +362,19 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
|
||||
// will remove the unnecessary copies.
|
||||
Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
|
||||
HloInstruction* in_place_op,
|
||||
int64 operand_number) {
|
||||
VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
|
||||
HloInstruction* operand = in_place_op->mutable_operand(operand_number);
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
|
||||
in_place_op->parent()->DeepCopyInstruction(operand));
|
||||
TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Conservatively adds copies before root instruction of entry computation and
|
||||
// each aliased parameter to resolve interference of aliased input and output
|
||||
// buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
|
||||
@ -509,6 +522,12 @@ class CopyRemover {
|
||||
// value. The map is used to construct the copy info map below.
|
||||
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
|
||||
for (const HloBuffer& buffer : alias_analysis.buffers()) {
|
||||
// No copies should have been inserted within fused computations, so no
|
||||
// need to remove them. HloOrdering isn't compatible with HloValues inside
|
||||
// fusions, so skip copy removal for them.
|
||||
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
|
||||
continue;
|
||||
}
|
||||
// Verify values contained in the buffer are strictly ordered. This
|
||||
// should always be the case after adding copies to eliminate
|
||||
// interference. Specifically, the addition of the control flow edges
|
||||
@ -591,7 +610,7 @@ class CopyRemover {
|
||||
void CreateCopyMap(
|
||||
const HloModule& module,
|
||||
const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
|
||||
for (HloComputation* computation : module.computations()) {
|
||||
for (HloComputation* computation : module.MakeNonfusionComputations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
// Add copies with unambiguous source values to the map. Copies with
|
||||
// ambiguous sources are not removable.
|
||||
@ -1005,7 +1024,7 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, can_share_buffer_));
|
||||
|
||||
for (HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
for (HloInstruction* instruction :
|
||||
computation->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() == HloOpcode::kWhile) {
|
||||
@ -1013,6 +1032,15 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
||||
} else if (instruction->opcode() == HloOpcode::kConditional) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddCopiesForConditional(*alias_analysis, instruction));
|
||||
} else {
|
||||
for (const auto& operand_and_output_index :
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
|
||||
const HloUse& operand = operand_and_output_index.first;
|
||||
CHECK_EQ(operand.operand_index, ShapeIndex{})
|
||||
<< "Support for non-{} shape operand not currently implemented.";
|
||||
TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
|
||||
*alias_analysis, instruction, operand.operand_number));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2530,5 +2530,250 @@ ENTRY Entry {
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add = f32[1280,1,128] add(negate, negate)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
add = f32[1280,1,128] add(negate, negate)
|
||||
fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
|
||||
get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
|
||||
get-tuple-element.3 = s32[] get-tuple-element(state), index=0
|
||||
constant.2 = s32[] constant(128)
|
||||
add.5 = s32[] add(get-tuple-element.3, constant.2)
|
||||
constant.3 = s32[] constant(0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation.1 {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
fused_computation.2 {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
add = f32[1280,1,128] add(negate, negate)
|
||||
fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1
|
||||
ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) {
|
||||
// Tests multi-output fusion with two DUS outputs, requiring two copies.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
param2 = f32[1280,1,128] parameter(2)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add.1 = f32[1280,1,128] add(param0, param0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
negate1 = f32[1280,1,128] negate(param)
|
||||
negate2 = f32[1280,1,128] negate(param)
|
||||
fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
||||
gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
|
||||
gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
|
||||
gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
|
||||
add0 = f32[1280,1,128] add(negate0, gte0)
|
||||
add1 = f32[1280,1,128] add(negate1, gte1)
|
||||
add2 = f32[1280,1,128] add(negate2, gte2)
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 2);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) {
|
||||
// Same as above, but negate1 is not used beyond fusion, so it only needs one
|
||||
// copy for negate0.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
param2 = f32[1280,1,128] parameter(2)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add.1 = f32[1280,1,128] add(param0, param0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
negate1 = f32[1280,1,128] negate(param)
|
||||
negate2 = f32[1280,1,128] negate(param)
|
||||
fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
||||
gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
|
||||
gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
|
||||
gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
|
||||
add0 = f32[1280,1,128] add(negate0, gte0)
|
||||
add1 = f32[1280,1,128] add(gte1, gte1)
|
||||
add2 = f32[1280,1,128] add(negate2, gte2)
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1,6 +1,6 @@
|
||||
// RUN: hlo_to_llvm_ir %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||
@ -43,8 +43,8 @@
|
||||
// CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4
|
||||
// CHECK: br label %[[VAL_23]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{i32 0, i32 6}
|
||||
// CHECK: !4 = !{}
|
||||
@ -72,7 +72,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
|
||||
@ -104,8 +104,8 @@ ENTRY main {
|
||||
// CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4
|
||||
// CHECK: br label %[[VAL_57]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{}
|
||||
|
||||
@ -131,7 +131,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
|
||||
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
|
||||
@ -188,8 +188,8 @@ ENTRY main {
|
||||
// CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1
|
||||
// CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{i32 0, i32 6}
|
||||
// CHECK: !4 = !{}
|
||||
@ -216,7 +216,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) {
|
||||
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
|
||||
@ -253,8 +253,8 @@ ENTRY main {
|
||||
// CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4
|
||||
// CHECK: br label %[[VAL_138]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{}
|
||||
|
||||
|
@ -308,6 +308,39 @@ class BufferValueMap {
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeInPlaceOperationAliasedBuffers(
|
||||
const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
|
||||
VLOG(3) << "Compute aliases for in-place operations (e.g. "
|
||||
"kDynamicUpdateSlice and kScatter)";
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
HloInstruction* instruction = position.instruction;
|
||||
for (const auto& operand_and_output_index :
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
|
||||
if (position.index == operand_and_output_index.second) {
|
||||
const HloUse& operand = operand_and_output_index.first;
|
||||
const HloValue& operand_value = dataflow_.GetUniqueValueAt(
|
||||
instruction->operand(operand.operand_number),
|
||||
operand.operand_index);
|
||||
VLOG(3) << " operand value " << operand_value.ToShortString()
|
||||
<< " aliases.";
|
||||
aliased_buffers->push_back(GetBufferForValue(operand_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const HloUse& use : value.uses()) {
|
||||
for (const auto& operand_and_output_index :
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(use.instruction)) {
|
||||
if (use == operand_and_output_index.first) {
|
||||
const HloValue& use_value = dataflow_.GetUniqueValueAt(
|
||||
use.instruction, operand_and_output_index.second);
|
||||
VLOG(3) << " use value " << use_value.ToShortString() << " aliases.";
|
||||
aliased_buffers->push_back(GetBufferForValue(use_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute and return a vector of buffers that the given value must be
|
||||
// contained in due to HLO aliasing rules.
|
||||
std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
|
||||
@ -318,6 +351,7 @@ class BufferValueMap {
|
||||
ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeInPlaceOperationAliasedBuffers(value, &aliased_buffers);
|
||||
// Uniquify aliased buffers.
|
||||
absl::c_sort(aliased_buffers);
|
||||
aliased_buffers.erase(
|
||||
|
@ -1062,6 +1062,118 @@ TEST_F(HloAliasAnalysisTest, MergeBuffersReverse) {
|
||||
analysis.BufferLivesOut(analysis.buffers()[0]);
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, DynamicUpdateSlice) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {8});
|
||||
Shape update_shape = ShapeUtil::MakeShape(F32, {4});
|
||||
Shape index_shape = ShapeUtil::MakeShape(S32, {});
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param0"));
|
||||
auto param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, update_shape, "param1"));
|
||||
auto param2 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, index_shape, "param2"));
|
||||
auto copy0 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0));
|
||||
auto dynamic_update_slice = builder.AddInstruction(
|
||||
HloInstruction::CreateDynamicUpdateSlice(shape, copy0, param1, {param2}));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(copy0),
|
||||
analysis.GetUniqueBufferAt(dynamic_update_slice));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, DynamicUpdateSliceMultiOutputFusion) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
param2 = f32[1280,1,128] parameter(2)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add.1 = f32[1280,1,128] add(param0, param0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
negate1 = f32[1280,1,128] negate(param)
|
||||
negate2 = f32[1280,1,128] negate(param)
|
||||
ROOT fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
LOG(INFO) << analysis.ToString();
|
||||
|
||||
// Expect negate1 and negate2 to alias with fusion{1} and fusion{2}
|
||||
// respectively (due to DUS), but not negate0 and fusion{0}.
|
||||
const HloInstruction* fusion =
|
||||
module_->entry_computation()->GetInstructionWithName("fusion");
|
||||
const HloInstruction* negate0 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate0");
|
||||
const HloInstruction* negate1 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate1");
|
||||
const HloInstruction* negate2 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate2");
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate1),
|
||||
analysis.GetUniqueBufferAt(fusion, {1}));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate2),
|
||||
analysis.GetUniqueBufferAt(fusion, {2}));
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(negate0),
|
||||
analysis.GetUniqueBufferAt(fusion, {0}));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, ChainedDynamicUpdateSliceFusion) {
|
||||
// CPU and GPU backends may generate fusions with dynamic update slices
|
||||
// feeding each other. They expect the fusion to not be in-place if that is
|
||||
// the case.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
ROOT fusion = f32[1280,1,128] fusion(negate0), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
LOG(INFO) << analysis.ToString();
|
||||
|
||||
const HloInstruction* fusion =
|
||||
module_->entry_computation()->GetInstructionWithName("fusion");
|
||||
const HloInstruction* negate0 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate0");
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(negate0),
|
||||
analysis.GetUniqueBufferAt(fusion));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, BitcastInterference) {
|
||||
// A bitcast value simultaneously live with its operand should not cause
|
||||
// interference.
|
||||
|
@ -1178,69 +1178,49 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
|
||||
return true;
|
||||
}
|
||||
|
||||
// Given a fusion whose root is a dynamic-update-slice op, determines whether
|
||||
// the fusion's output buffer can be shared with the buffer of fusion_param,
|
||||
// which must be a fused parameter of the fusion.
|
||||
//
|
||||
// Preconditions:
|
||||
//
|
||||
// - fusion's root is a dynamic-update-slice op.
|
||||
// - fusion_param is a parameter within the fusion.
|
||||
//
|
||||
// fusion_param may point to a subelement of the actual parameter instruction if
|
||||
// the param is a tuple; i.e. fusion_param->index() need not be the empty list.
|
||||
//
|
||||
// Returns true if:
|
||||
//
|
||||
// * fusion_param is used by the root of dynamic-update-slice as the "base" of
|
||||
// the update, i.e. the thing being updated, AND
|
||||
// * all other uses of fusion_param are dynamic-slices that slice the same
|
||||
// indices as are overwritten in the dynamic-update-slice.
|
||||
//
|
||||
// In the case that there are no other uses of fusion_param (last bullet point
|
||||
// is vacuously true) it's easy to see why an in-place DUS is safe; this is just
|
||||
// the "natural" implementation of DUS. If there are other users, in-place DUS
|
||||
// is safe on the assumption that the thread which writes element i of the
|
||||
// output will be the only one to read element i of fusion_param (via the
|
||||
// dynamic-slice ops).
|
||||
static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion,
|
||||
const HloValue& fusion_param_value) {
|
||||
auto* root =
|
||||
Cast<HloDynamicUpdateSliceInstruction>(fusion->fused_expression_root());
|
||||
auto* fusion_param = fusion_param_value.instruction();
|
||||
CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter);
|
||||
CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation());
|
||||
/*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) {
|
||||
return opcode == HloOpcode::kDynamicUpdateSlice ||
|
||||
opcode == HloOpcode::kScatter;
|
||||
}
|
||||
|
||||
// fusion_param must be used by the root as the "base" of the
|
||||
// dynamic-update-slice. The natural way to check this would be
|
||||
//
|
||||
// `if (root->operand(0) != fusion_param)`
|
||||
//
|
||||
// but we also have to handle the case where the fusion parameter is
|
||||
// tuple-shaped and we're considering just one element of that tuple, i.e.
|
||||
// fusion_param.index() != {}.
|
||||
if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) {
|
||||
return use.instruction == root;
|
||||
}) != 1) {
|
||||
return false;
|
||||
/*static*/ std::vector<std::pair<HloUse, ShapeIndex>>
|
||||
HloDataflowAnalysis::GetInPlaceInputOutputPairs(HloInstruction* instruction) {
|
||||
if (IsInPlaceOperation(instruction->opcode())) {
|
||||
return {{HloUse{instruction, 0, {}}, {}}};
|
||||
} else if (instruction->opcode() != HloOpcode::kFusion) {
|
||||
return {};
|
||||
}
|
||||
|
||||
// All other uses of fusion_param must be dynamic-slices that slice the same
|
||||
// indices as are overwritten by the dynamic-update-slice.
|
||||
for (const HloUse& use : fusion_param_value.uses()) {
|
||||
auto* user = use.instruction;
|
||||
if (user == root) {
|
||||
continue;
|
||||
std::vector<std::pair<HloUse, ShapeIndex>> input_output_pairs;
|
||||
for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) {
|
||||
const HloInstruction* hlo_generating_output =
|
||||
instruction->fused_expression_root();
|
||||
for (int64 i = 0; i < indexed_shape.index.size(); ++i) {
|
||||
if (hlo_generating_output->opcode() == HloOpcode::kTuple) {
|
||||
hlo_generating_output =
|
||||
hlo_generating_output->operand(indexed_shape.index[i]);
|
||||
} else {
|
||||
CHECK_EQ(i, indexed_shape.index.size() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
// Check that `user` is a dynamic-slice op and has the same slice indices as
|
||||
// `root`.
|
||||
auto* ds = DynCast<HloDynamicSliceInstruction>(user);
|
||||
if (!ds || ds->index_operands() != root->index_operands()) {
|
||||
return false;
|
||||
if (IsInPlaceOperation(hlo_generating_output->opcode())) {
|
||||
ShapeIndex operand_index;
|
||||
const HloInstruction* fusion_parameter =
|
||||
hlo_generating_output->operand(0);
|
||||
while (fusion_parameter->opcode() == HloOpcode::kGetTupleElement) {
|
||||
operand_index.push_front(fusion_parameter->tuple_index());
|
||||
fusion_parameter = fusion_parameter->operand(0);
|
||||
}
|
||||
|
||||
if (fusion_parameter->opcode() == HloOpcode::kParameter) {
|
||||
input_output_pairs.emplace_back(
|
||||
HloUse{instruction, fusion_parameter->parameter_number(),
|
||||
operand_index},
|
||||
indexed_shape.index);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
return input_output_pairs;
|
||||
}
|
||||
|
||||
bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
@ -1261,24 +1241,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
return false;
|
||||
}
|
||||
|
||||
if (user->opcode() == HloOpcode::kFusion) {
|
||||
// Get the parameter associated with 'operand';
|
||||
HloInstruction* fusion_param =
|
||||
user->fused_parameter(user->operand_index(operand));
|
||||
|
||||
const HloValue& fusion_param_value =
|
||||
GetValueDefinedAt(fusion_param, operand_index);
|
||||
|
||||
// TODO(b/80315712): This code is in a bit of a weird intermediate state
|
||||
// at the moment. The in-place DUS check really needs to be common to all
|
||||
// backends, so it runs first. Then we run the backend-specific check if
|
||||
// provided, or go through the target-independent check if not.
|
||||
// Unfortunately, the notionally "target-independent" path actually contains
|
||||
// some target-specific code, so we can't run all of it *in addition* to the
|
||||
// target-specific function, like the interface documentation says.
|
||||
if (user->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice) {
|
||||
return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value);
|
||||
// Must-alias relationship returns true for in-place operations (DUS and DUS
|
||||
// fusions), regardless of the backend.
|
||||
for (const auto& operand_and_output_index :
|
||||
GetInPlaceInputOutputPairs(user)) {
|
||||
if (operand_and_output_index.second != user_index) {
|
||||
continue;
|
||||
}
|
||||
for (const HloUse& use : GetUniqueValueAt(operand, operand_index).uses()) {
|
||||
if (use == operand_and_output_index.first) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,6 +49,9 @@ class HloDataflowAnalysis {
|
||||
// Infrastructure for passing may-alias hints: HLO passes can populate the
|
||||
// may-alias table. If an empty optional is returned, default rules are used.
|
||||
//
|
||||
// Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be
|
||||
// overriden using backend-specific overrides.
|
||||
//
|
||||
// The first parameter of the function should be the instruction, the
|
||||
// second parameter should be an operand of the instruction. The third
|
||||
// parameter should be the output index of the instruction.
|
||||
@ -160,6 +163,15 @@ class HloDataflowAnalysis {
|
||||
|
||||
const HloModule& module() const { return module_; }
|
||||
|
||||
// Returns true if the operation is an in-place operation and its operand 0
|
||||
// must alias with the output.
|
||||
static bool IsInPlaceOperation(HloOpcode opcode);
|
||||
|
||||
// Returns a vector consisting of the HloUse (operand number and shape index)
|
||||
// and output shape index of the in-place operations within this HLO.
|
||||
static std::vector<std::pair<HloUse, ShapeIndex>> GetInPlaceInputOutputPairs(
|
||||
HloInstruction* instruction);
|
||||
|
||||
protected:
|
||||
HloDataflowAnalysis(const HloModule& module, bool ssa_form,
|
||||
bool bitcast_defines_value = false,
|
||||
|
@ -2324,36 +2324,6 @@ TEST_F(CanShareOperandBufferWithUserTest,
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithDifferentIndices) {
|
||||
const char* kModule = R"(
|
||||
HloModule test
|
||||
|
||||
fused_computation {
|
||||
p0 = f32[10,20,30] parameter(0)
|
||||
p1 = s32[] parameter(1)
|
||||
p2 = s32[] parameter(2)
|
||||
p3 = s32[] parameter(3)
|
||||
slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
|
||||
ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p3, p2)
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10,20,30] parameter(0)
|
||||
p1 = s32[] parameter(1)
|
||||
p2 = s32[] parameter(2)
|
||||
p3 = s32[] parameter(3)
|
||||
ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
|
||||
auto* fusion = module_->entry_computation()->root_instruction();
|
||||
auto* param = module_->entry_computation()->parameter_instruction(0);
|
||||
|
||||
RunAnalysis();
|
||||
EXPECT_FALSE(
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) {
|
||||
const char* kModule = R"(
|
||||
HloModule test
|
||||
|
@ -232,6 +232,24 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
||||
return copies;
|
||||
}
|
||||
|
||||
int64 GetAlternateMemoryOffset(const PresetAssignments& preset_assignments,
|
||||
const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const {
|
||||
// Returns the offset of the assignment, -1 if it's not in the alternate
|
||||
// memory.
|
||||
const HloModule* module = instruction->parent()->parent();
|
||||
auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
|
||||
HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index);
|
||||
for (auto& pos_and_chunk : preset_assignments.chunks()) {
|
||||
for (auto& value : buffer.values()) {
|
||||
if (pos_and_chunk.first == value->defining_position()) {
|
||||
return pos_and_chunk.second.offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
|
||||
HloComputation::Builder builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||
@ -4415,6 +4433,47 @@ TEST_P(MemorySpaceAssignmentTest, Determinism) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, InPlaceOp) {
|
||||
// Tests that in-place ops like DynamicUpdateSlice get the same allocation as
|
||||
// its input.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module, is_scheduled=true
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[2,3] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast = f32[2,1] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[2,3] parameter(0)
|
||||
negate = f32[2,3] negate(param)
|
||||
fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation
|
||||
ROOT add = f32[2,3] add(fusion, fusion)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
auto preset_assignments = AssignMemorySpace(module.get());
|
||||
HloInstruction* negate_instruction =
|
||||
module->entry_computation()->GetInstructionWithName("negate");
|
||||
int64 negate_offset =
|
||||
GetAlternateMemoryOffset(*preset_assignments, negate_instruction);
|
||||
HloInstruction* fusion_instruction =
|
||||
module->entry_computation()->GetInstructionWithName("fusion");
|
||||
int64 fusion_offset =
|
||||
GetAlternateMemoryOffset(*preset_assignments, fusion_instruction);
|
||||
// We expect negate and fusion to get the same offsets.
|
||||
EXPECT_EQ(negate_offset, fusion_offset);
|
||||
const bool allocate_across_sequential_calls = GetParam();
|
||||
if (allocate_across_sequential_calls) {
|
||||
EXPECT_NE(negate_offset, -1);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
|
||||
MemorySpaceAssignmentTest,
|
||||
::testing::Values(false, true));
|
||||
|
Loading…
x
Reference in New Issue
Block a user