Internal change
PiperOrigin-RevId: 331353990 Change-Id: I66751ebf00239baa016b39d2a8e6d4b2c31d4b48
This commit is contained in:
parent
e49b2326ad
commit
d33c01b88a
@ -1007,6 +1007,102 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
|
|||||||
return true;
|
return true;
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
||||||
|
// Try allocate same buffer for dynamic update slice's operand and output.
|
||||||
|
|
||||||
|
// If memory_space_assignment is run and there is information about a color in
|
||||||
|
// preset assignments, don't merge those buffers. We expect
|
||||||
|
// memory_space_assignment to have merged these buffers. If
|
||||||
|
// memory_space_assignment didn't merge these buffers and have assigned
|
||||||
|
// different offsets to the operand and the output buffer, merging the buffers
|
||||||
|
// can cause memory corruption if memory_space_assignment assigned a different
|
||||||
|
// buffer at the same offset.
|
||||||
|
absl::flat_hash_set<int64> excluded_colors;
|
||||||
|
if (preset_assignments_) {
|
||||||
|
for (const auto& color_and_info :
|
||||||
|
preset_assignments_->assignment_informations()) {
|
||||||
|
excluded_colors.insert(color_and_info.first);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(yunxing): Moving this logic to alias analysis and add must-alias rule
|
||||||
|
// to operations that can be done in place.
|
||||||
|
for (HloComputation* computation : assignment->module().computations()) {
|
||||||
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
|
if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice ||
|
||||||
|
(instruction->opcode() == HloOpcode::kFusion &&
|
||||||
|
(instruction->fused_expression_root()->opcode() ==
|
||||||
|
HloOpcode::kDynamicUpdateSlice)))) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (instruction->parent()->IsFusionComputation()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (instruction->operand_count() == 0) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The operand can't share the same buffer with the user based on dataflow
|
||||||
|
// analysis.
|
||||||
|
if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser(
|
||||||
|
instruction->mutable_operand(0), {}, instruction, {})) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
HloBuffer& instruction_buffer =
|
||||||
|
assignment->alias_analysis().GetUniqueBufferAt(instruction, {});
|
||||||
|
|
||||||
|
HloBuffer& operand_buffer =
|
||||||
|
assignment->alias_analysis().GetUniqueBufferAt(
|
||||||
|
instruction->operand(0), {});
|
||||||
|
|
||||||
|
// The instruction or operand color is excluded because it was assigned by
|
||||||
|
// memory_space_assignment.
|
||||||
|
if (excluded_colors.contains(instruction_buffer.color()) ||
|
||||||
|
excluded_colors.contains(operand_buffer.color())) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Already have the same buffer. No need to merge those.
|
||||||
|
if (instruction_buffer.id() == operand_buffer.id()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Do not perform in-place dynamic update slice if the operand buffer is
|
||||||
|
// read-only.
|
||||||
|
if (HloBufferIsReadOnly(operand_buffer)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool interfere = false;
|
||||||
|
|
||||||
|
for (const HloValue* instruction_value : instruction_buffer.values()) {
|
||||||
|
for (const HloValue* operand_value : operand_buffer.values()) {
|
||||||
|
if (assignment->hlo_ordering().MayInterfere(
|
||||||
|
*instruction_value, *operand_value,
|
||||||
|
assignment->dataflow_analysis())) {
|
||||||
|
interfere = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (interfere) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (instruction_buffer.color() != operand_buffer.color()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
VLOG(3) << "Merging inplace " << instruction_buffer << " and "
|
||||||
|
<< operand_buffer;
|
||||||
|
assignment->alias_analysis().MergeBuffers(instruction_buffer,
|
||||||
|
operand_buffer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status BufferAssigner::AssignSingleHloBuffer(
|
Status BufferAssigner::AssignSingleHloBuffer(
|
||||||
const HloBuffer* hlo_buffer, bool is_thread_local,
|
const HloBuffer* hlo_buffer, bool is_thread_local,
|
||||||
absl::flat_hash_map<const HloComputation*,
|
absl::flat_hash_map<const HloComputation*,
|
||||||
@ -1558,6 +1654,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
|||||||
VLOG(3) << "After coloring:";
|
VLOG(3) << "After coloring:";
|
||||||
XLA_VLOG_LINES(3,
|
XLA_VLOG_LINES(3,
|
||||||
assignment->alias_analysis().dataflow_analysis().ToString());
|
assignment->alias_analysis().dataflow_analysis().ToString());
|
||||||
|
TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get()));
|
||||||
|
|
||||||
std::vector<const HloComputation*> thread_local_computations;
|
std::vector<const HloComputation*> thread_local_computations;
|
||||||
std::vector<const HloComputation*> global_computations;
|
std::vector<const HloComputation*> global_computations;
|
||||||
|
|||||||
@ -635,6 +635,10 @@ class BufferAssigner {
|
|||||||
absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
|
absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
|
||||||
BufferAssignment* assignment);
|
BufferAssignment* assignment);
|
||||||
|
|
||||||
|
// Promotes operations (DUS, scatter) to be done in place: If an operation can
|
||||||
|
// be done in place, merge its buffer with its operand buffer.
|
||||||
|
Status MergeInplaceOpBuffers(BufferAssignment* assignment);
|
||||||
|
|
||||||
// Assigns a single hlo buffer to an HLO allocation.
|
// Assigns a single hlo buffer to an HLO allocation.
|
||||||
Status AssignSingleHloBuffer(
|
Status AssignSingleHloBuffer(
|
||||||
const HloBuffer* hlo_buffer, bool is_thread_local,
|
const HloBuffer* hlo_buffer, bool is_thread_local,
|
||||||
|
|||||||
@ -1925,10 +1925,8 @@ ENTRY main {
|
|||||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
|
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
|
||||||
HloInstruction* parameter =
|
HloInstruction* parameter =
|
||||||
m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
|
m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
|
||||||
HloInstruction* dus1 =
|
HloInstruction* dus =
|
||||||
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
|
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
|
||||||
HloInstruction* dus2 =
|
|
||||||
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
|
|
||||||
|
|
||||||
auto buffers = RunBufferAssignment(m.get());
|
auto buffers = RunBufferAssignment(m.get());
|
||||||
|
|
||||||
@ -1936,10 +1934,8 @@ ENTRY main {
|
|||||||
const BufferAllocation& parameter_alloc =
|
const BufferAllocation& parameter_alloc =
|
||||||
GetTopLevelAllocation(*buffers, parameter);
|
GetTopLevelAllocation(*buffers, parameter);
|
||||||
|
|
||||||
const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
|
const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus);
|
||||||
EXPECT_EQ(parameter_alloc, dus1_alloc);
|
EXPECT_NE(parameter_alloc, dus_alloc);
|
||||||
const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
|
|
||||||
EXPECT_EQ(parameter_alloc, dus2_alloc);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -362,19 +362,6 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
|
|
||||||
// will remove the unnecessary copies.
|
|
||||||
Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
|
|
||||||
HloInstruction* in_place_op,
|
|
||||||
int64 operand_number) {
|
|
||||||
VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
|
|
||||||
HloInstruction* operand = in_place_op->mutable_operand(operand_number);
|
|
||||||
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
|
|
||||||
in_place_op->parent()->DeepCopyInstruction(operand));
|
|
||||||
TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy));
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Conservatively adds copies before root instruction of entry computation and
|
// Conservatively adds copies before root instruction of entry computation and
|
||||||
// each aliased parameter to resolve interference of aliased input and output
|
// each aliased parameter to resolve interference of aliased input and output
|
||||||
// buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
|
// buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
|
||||||
@ -522,12 +509,6 @@ class CopyRemover {
|
|||||||
// value. The map is used to construct the copy info map below.
|
// value. The map is used to construct the copy info map below.
|
||||||
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
|
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
|
||||||
for (const HloBuffer& buffer : alias_analysis.buffers()) {
|
for (const HloBuffer& buffer : alias_analysis.buffers()) {
|
||||||
// No copies should have been inserted within fused computations, so no
|
|
||||||
// need to remove them. HloOrdering isn't compatible with HloValues inside
|
|
||||||
// fusions, so skip copy removal for them.
|
|
||||||
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Verify values contained in the buffer are strictly ordered. This
|
// Verify values contained in the buffer are strictly ordered. This
|
||||||
// should always be the case after adding copies to eliminate
|
// should always be the case after adding copies to eliminate
|
||||||
// interference. Specifically, the addition of the control flow edges
|
// interference. Specifically, the addition of the control flow edges
|
||||||
@ -610,7 +591,7 @@ class CopyRemover {
|
|||||||
void CreateCopyMap(
|
void CreateCopyMap(
|
||||||
const HloModule& module,
|
const HloModule& module,
|
||||||
const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
|
const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
|
||||||
for (HloComputation* computation : module.MakeNonfusionComputations()) {
|
for (HloComputation* computation : module.computations()) {
|
||||||
for (HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
// Add copies with unambiguous source values to the map. Copies with
|
// Add copies with unambiguous source values to the map. Copies with
|
||||||
// ambiguous sources are not removable.
|
// ambiguous sources are not removable.
|
||||||
@ -1024,7 +1005,7 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
|||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||||
HloAliasAnalysis::Run(module, can_share_buffer_));
|
HloAliasAnalysis::Run(module, can_share_buffer_));
|
||||||
|
|
||||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
for (HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||||
for (HloInstruction* instruction :
|
for (HloInstruction* instruction :
|
||||||
computation->MakeInstructionPostOrder()) {
|
computation->MakeInstructionPostOrder()) {
|
||||||
if (instruction->opcode() == HloOpcode::kWhile) {
|
if (instruction->opcode() == HloOpcode::kWhile) {
|
||||||
@ -1032,13 +1013,6 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
|||||||
} else if (instruction->opcode() == HloOpcode::kConditional) {
|
} else if (instruction->opcode() == HloOpcode::kConditional) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
AddCopiesForConditional(*alias_analysis, instruction));
|
AddCopiesForConditional(*alias_analysis, instruction));
|
||||||
} else {
|
|
||||||
for (const auto& operand_number_and_output_index :
|
|
||||||
HloAliasAnalysis::GetInPlaceInputOutputPairs(instruction)) {
|
|
||||||
TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
|
|
||||||
*alias_analysis, instruction,
|
|
||||||
operand_number_and_output_index.first));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2530,250 +2530,5 @@ ENTRY Entry {
|
|||||||
EXPECT_EQ(CountCopies(*module), 1);
|
EXPECT_EQ(CountCopies(*module), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[1280,1,128] parameter(0)
|
|
||||||
negate = f32[1280,1,128] negate(param)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
InsertCopies(module.get());
|
|
||||||
EXPECT_EQ(CountCopies(*module), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
fused_computation {
|
|
||||||
param0 = f32[1280,1,128] parameter(0)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
}
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[1280,1,128] parameter(0)
|
|
||||||
negate = f32[1280,1,128] negate(param)
|
|
||||||
ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
InsertCopies(module.get());
|
|
||||||
EXPECT_EQ(CountCopies(*module), 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[1280,1,128] parameter(0)
|
|
||||||
negate = f32[1280,1,128] negate(param)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
add = f32[1280,1,128] add(negate, negate)
|
|
||||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
InsertCopies(module.get());
|
|
||||||
EXPECT_EQ(CountCopies(*module), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[1280,1,128] parameter(0)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
InsertCopies(module.get());
|
|
||||||
EXPECT_EQ(CountCopies(*module), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
fused_computation {
|
|
||||||
param0 = f32[1280,1,128] parameter(0)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
}
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[1280,1,128] parameter(0)
|
|
||||||
negate = f32[1280,1,128] negate(param)
|
|
||||||
add = f32[1280,1,128] add(negate, negate)
|
|
||||||
fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
|
|
||||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
InsertCopies(module.get());
|
|
||||||
EXPECT_EQ(CountCopies(*module), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
|
|
||||||
get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
|
|
||||||
get-tuple-element.3 = s32[] get-tuple-element(state), index=0
|
|
||||||
constant.2 = s32[] constant(128)
|
|
||||||
add.5 = s32[] add(get-tuple-element.3, constant.2)
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
InsertCopies(module.get());
|
|
||||||
EXPECT_EQ(CountCopies(*module), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
fused_computation.1 {
|
|
||||||
param0 = f32[1280,1,128] parameter(0)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
}
|
|
||||||
|
|
||||||
fused_computation.2 {
|
|
||||||
param0 = f32[1280,1,128] parameter(0)
|
|
||||||
param1 = f32[1280,1,128] parameter(1)
|
|
||||||
slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3)
|
|
||||||
}
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[1280,1,128] parameter(0)
|
|
||||||
negate = f32[1280,1,128] negate(param)
|
|
||||||
add = f32[1280,1,128] add(negate, negate)
|
|
||||||
fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1
|
|
||||||
ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
InsertCopies(module.get());
|
|
||||||
EXPECT_EQ(CountCopies(*module), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) {
|
|
||||||
// Tests multi-output fusion with two DUS outputs, requiring two copies.
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
fused_computation {
|
|
||||||
param0 = f32[1280,1,128] parameter(0)
|
|
||||||
param1 = f32[1280,1,128] parameter(1)
|
|
||||||
param2 = f32[1280,1,128] parameter(2)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
add.1 = f32[1280,1,128] add(param0, param0)
|
|
||||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
|
||||||
}
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[1280,1,128] parameter(0)
|
|
||||||
negate0 = f32[1280,1,128] negate(param)
|
|
||||||
negate1 = f32[1280,1,128] negate(param)
|
|
||||||
negate2 = f32[1280,1,128] negate(param)
|
|
||||||
fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
|
||||||
gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
|
|
||||||
gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
|
|
||||||
gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
|
|
||||||
add0 = f32[1280,1,128] add(negate0, gte0)
|
|
||||||
add1 = f32[1280,1,128] add(negate1, gte1)
|
|
||||||
add2 = f32[1280,1,128] add(negate2, gte2)
|
|
||||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
InsertCopies(module.get());
|
|
||||||
EXPECT_EQ(CountCopies(*module), 2);
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) {
|
|
||||||
// Same as above, but negate1 is not used beyond fusion, so it only needs one
|
|
||||||
// copy for negate0.
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
fused_computation {
|
|
||||||
param0 = f32[1280,1,128] parameter(0)
|
|
||||||
param1 = f32[1280,1,128] parameter(1)
|
|
||||||
param2 = f32[1280,1,128] parameter(2)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
add.1 = f32[1280,1,128] add(param0, param0)
|
|
||||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
|
||||||
}
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[1280,1,128] parameter(0)
|
|
||||||
negate0 = f32[1280,1,128] negate(param)
|
|
||||||
negate1 = f32[1280,1,128] negate(param)
|
|
||||||
negate2 = f32[1280,1,128] negate(param)
|
|
||||||
fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
|
||||||
gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
|
|
||||||
gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
|
|
||||||
gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
|
|
||||||
add0 = f32[1280,1,128] add(negate0, gte0)
|
|
||||||
add1 = f32[1280,1,128] add(gte1, gte1)
|
|
||||||
add2 = f32[1280,1,128] add(negate2, gte2)
|
|
||||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
InsertCopies(module.get());
|
|
||||||
EXPECT_EQ(CountCopies(*module), 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
// RUN: hlo_to_llvm_ir %s | FileCheck %s
|
// RUN: hlo_to_llvm_ir %s | FileCheck %s
|
||||||
|
|
||||||
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
|
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
|
||||||
// CHECK: entry:
|
// CHECK: entry:
|
||||||
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
|
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
|
||||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||||
@ -43,8 +43,8 @@
|
|||||||
// CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4
|
// CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4
|
||||||
// CHECK: br label %[[VAL_23]]
|
// CHECK: br label %[[VAL_23]]
|
||||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
|
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
|
||||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
|
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
|
||||||
// CHECK: !2 = !{i32 0, i32 1}
|
// CHECK: !2 = !{i32 0, i32 1}
|
||||||
// CHECK: !3 = !{i32 0, i32 6}
|
// CHECK: !3 = !{i32 0, i32 6}
|
||||||
// CHECK: !4 = !{}
|
// CHECK: !4 = !{}
|
||||||
@ -72,7 +72,7 @@ ENTRY main {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) {
|
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) {
|
||||||
// CHECK: entry:
|
// CHECK: entry:
|
||||||
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
|
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
|
||||||
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
|
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
|
||||||
@ -104,8 +104,8 @@ ENTRY main {
|
|||||||
// CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4
|
// CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4
|
||||||
// CHECK: br label %[[VAL_57]]
|
// CHECK: br label %[[VAL_57]]
|
||||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
|
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
|
||||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
|
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
|
||||||
// CHECK: !2 = !{i32 0, i32 1}
|
// CHECK: !2 = !{i32 0, i32 1}
|
||||||
// CHECK: !3 = !{}
|
// CHECK: !3 = !{}
|
||||||
|
|
||||||
@ -131,7 +131,7 @@ ENTRY main {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
|
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
|
||||||
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
|
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
|
||||||
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
|
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
|
||||||
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
|
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
|
||||||
@ -188,8 +188,8 @@ ENTRY main {
|
|||||||
// CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1
|
// CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1
|
||||||
// CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]]
|
// CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]]
|
||||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
|
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
|
||||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
|
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
|
||||||
// CHECK: !2 = !{i32 0, i32 1}
|
// CHECK: !2 = !{i32 0, i32 1}
|
||||||
// CHECK: !3 = !{i32 0, i32 6}
|
// CHECK: !3 = !{i32 0, i32 6}
|
||||||
// CHECK: !4 = !{}
|
// CHECK: !4 = !{}
|
||||||
@ -216,7 +216,7 @@ ENTRY main {
|
|||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) {
|
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) {
|
||||||
// CHECK: entry:
|
// CHECK: entry:
|
||||||
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
|
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
|
||||||
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
|
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
|
||||||
@ -253,8 +253,8 @@ ENTRY main {
|
|||||||
// CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4
|
// CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4
|
||||||
// CHECK: br label %[[VAL_138]]
|
// CHECK: br label %[[VAL_138]]
|
||||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
|
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
|
||||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
|
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
|
||||||
// CHECK: !2 = !{i32 0, i32 1}
|
// CHECK: !2 = !{i32 0, i32 1}
|
||||||
// CHECK: !3 = !{}
|
// CHECK: !3 = !{}
|
||||||
|
|
||||||
|
|||||||
@ -308,39 +308,6 @@ class BufferValueMap {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ComputeInPlaceOperationAliasedBuffers(
|
|
||||||
const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
|
|
||||||
VLOG(3) << "Compute aliases for in-place operations (e.g. "
|
|
||||||
"kDynamicUpdateSlice and kScatter)";
|
|
||||||
for (const HloPosition& position : value.positions()) {
|
|
||||||
HloInstruction* instruction = position.instruction;
|
|
||||||
for (const auto& operand_number_and_output_index :
|
|
||||||
HloAliasAnalysis::GetInPlaceInputOutputPairs(instruction)) {
|
|
||||||
if (position.index == operand_number_and_output_index.second) {
|
|
||||||
int64 operand_number = operand_number_and_output_index.first;
|
|
||||||
const HloValue& operand_value = dataflow_.GetUniqueValueAt(
|
|
||||||
instruction->operand(operand_number), {});
|
|
||||||
VLOG(3) << " operand value " << operand_value.ToShortString()
|
|
||||||
<< " aliases.";
|
|
||||||
aliased_buffers->push_back(GetBufferForValue(operand_value));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (const HloUse& use : value.uses()) {
|
|
||||||
for (const auto& operand_number_and_output_index :
|
|
||||||
HloAliasAnalysis::GetInPlaceInputOutputPairs(use.instruction)) {
|
|
||||||
int64 operand_number = operand_number_and_output_index.first;
|
|
||||||
if (use.operand_number == operand_number) {
|
|
||||||
const HloValue& use_value = dataflow_.GetUniqueValueAt(
|
|
||||||
use.instruction, operand_number_and_output_index.second);
|
|
||||||
VLOG(3) << " use value " << use_value.ToShortString() << " aliases.";
|
|
||||||
aliased_buffers->push_back(GetBufferForValue(use_value));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute and return a vector of buffers that the given value must be
|
// Compute and return a vector of buffers that the given value must be
|
||||||
// contained in due to HLO aliasing rules.
|
// contained in due to HLO aliasing rules.
|
||||||
std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
|
std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
|
||||||
@ -351,7 +318,6 @@ class BufferValueMap {
|
|||||||
ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
|
ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
|
||||||
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
||||||
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
||||||
ComputeInPlaceOperationAliasedBuffers(value, &aliased_buffers);
|
|
||||||
// Uniquify aliased buffers.
|
// Uniquify aliased buffers.
|
||||||
absl::c_sort(aliased_buffers);
|
absl::c_sort(aliased_buffers);
|
||||||
aliased_buffers.erase(
|
aliased_buffers.erase(
|
||||||
@ -376,42 +342,6 @@ class BufferValueMap {
|
|||||||
BufferNumber next_buffer_number_ = 0;
|
BufferNumber next_buffer_number_ = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
/*static*/ bool HloAliasAnalysis::IsInPlaceOperation(HloOpcode opcode) {
|
|
||||||
return opcode == HloOpcode::kDynamicUpdateSlice ||
|
|
||||||
opcode == HloOpcode::kScatter;
|
|
||||||
}
|
|
||||||
|
|
||||||
/*static*/ std::vector<std::pair<int64, ShapeIndex>>
|
|
||||||
HloAliasAnalysis::GetInPlaceInputOutputPairs(
|
|
||||||
const HloInstruction* instruction) {
|
|
||||||
if (IsInPlaceOperation(instruction->opcode())) {
|
|
||||||
return {{0, {}}};
|
|
||||||
} else if (instruction->opcode() != HloOpcode::kFusion) {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
std::vector<std::pair<int64, ShapeIndex>> input_output_pairs;
|
|
||||||
for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) {
|
|
||||||
const HloInstruction* hlo_generating_output =
|
|
||||||
instruction->fused_expression_root();
|
|
||||||
for (int64 i = 0; i < indexed_shape.index.size(); ++i) {
|
|
||||||
if (hlo_generating_output->opcode() == HloOpcode::kTuple) {
|
|
||||||
hlo_generating_output =
|
|
||||||
hlo_generating_output->operand(indexed_shape.index[i]);
|
|
||||||
} else {
|
|
||||||
CHECK_EQ(i, indexed_shape.index.size() - 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (IsInPlaceOperation(hlo_generating_output->opcode()) &&
|
|
||||||
hlo_generating_output->operand(0)->opcode() == HloOpcode::kParameter) {
|
|
||||||
input_output_pairs.emplace_back(
|
|
||||||
hlo_generating_output->operand(0)->parameter_number(),
|
|
||||||
indexed_shape.index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return input_output_pairs;
|
|
||||||
}
|
|
||||||
|
|
||||||
HloAliasAnalysis::HloAliasAnalysis(const HloModule* module) : module_(module) {}
|
HloAliasAnalysis::HloAliasAnalysis(const HloModule* module) : module_(module) {}
|
||||||
|
|
||||||
const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
|
const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
|
||||||
|
|||||||
@ -120,15 +120,6 @@ class HloAliasAnalysis {
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns true if the operation is an in-place operation and its operand 0
|
|
||||||
// must alias with the output.
|
|
||||||
static bool IsInPlaceOperation(HloOpcode opcode);
|
|
||||||
|
|
||||||
// Returns a vector consisting of operand number and output shape index of the
|
|
||||||
// in-place operations within this HLO.
|
|
||||||
static std::vector<std::pair<int64, ShapeIndex>> GetInPlaceInputOutputPairs(
|
|
||||||
const HloInstruction* instruction);
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
explicit HloAliasAnalysis(const HloModule* module);
|
explicit HloAliasAnalysis(const HloModule* module);
|
||||||
|
|
||||||
|
|||||||
@ -1062,118 +1062,6 @@ TEST_F(HloAliasAnalysisTest, MergeBuffersReverse) {
|
|||||||
analysis.BufferLivesOut(analysis.buffers()[0]);
|
analysis.BufferLivesOut(analysis.buffers()[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HloAliasAnalysisTest, DynamicUpdateSlice) {
|
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {8});
|
|
||||||
Shape update_shape = ShapeUtil::MakeShape(F32, {4});
|
|
||||||
Shape index_shape = ShapeUtil::MakeShape(S32, {});
|
|
||||||
auto builder = HloComputation::Builder(TestName());
|
|
||||||
auto param0 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(0, shape, "param0"));
|
|
||||||
auto param1 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(1, update_shape, "param1"));
|
|
||||||
auto param2 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateParameter(2, index_shape, "param2"));
|
|
||||||
auto copy0 = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0));
|
|
||||||
auto dynamic_update_slice = builder.AddInstruction(
|
|
||||||
HloInstruction::CreateDynamicUpdateSlice(shape, copy0, param1, {param2}));
|
|
||||||
|
|
||||||
module_->AddEntryComputation(builder.Build());
|
|
||||||
SCOPED_TRACE(module_->ToString());
|
|
||||||
|
|
||||||
HloAliasAnalysis& analysis = RunAnalysis();
|
|
||||||
|
|
||||||
EXPECT_EQ(analysis.GetUniqueBufferAt(copy0),
|
|
||||||
analysis.GetUniqueBufferAt(dynamic_update_slice));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(HloAliasAnalysisTest, DynamicUpdateSliceMultiOutputFusion) {
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
fused_computation {
|
|
||||||
param0 = f32[1280,1,128] parameter(0)
|
|
||||||
param1 = f32[1280,1,128] parameter(1)
|
|
||||||
param2 = f32[1280,1,128] parameter(2)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
add.1 = f32[1280,1,128] add(param0, param0)
|
|
||||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
|
||||||
}
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[1280,1,128] parameter(0)
|
|
||||||
negate0 = f32[1280,1,128] negate(param)
|
|
||||||
negate1 = f32[1280,1,128] negate(param)
|
|
||||||
negate2 = f32[1280,1,128] negate(param)
|
|
||||||
ROOT fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
|
|
||||||
SCOPED_TRACE(module_->ToString());
|
|
||||||
|
|
||||||
HloAliasAnalysis& analysis = RunAnalysis();
|
|
||||||
LOG(INFO) << analysis.ToString();
|
|
||||||
|
|
||||||
// Expect negate1 and negate2 to alias with fusion{1} and fusion{2}
|
|
||||||
// respectively (due to DUS), but not negate0 and fusion{0}.
|
|
||||||
const HloInstruction* fusion =
|
|
||||||
module_->entry_computation()->GetInstructionWithName("fusion");
|
|
||||||
const HloInstruction* negate0 =
|
|
||||||
module_->entry_computation()->GetInstructionWithName("negate0");
|
|
||||||
const HloInstruction* negate1 =
|
|
||||||
module_->entry_computation()->GetInstructionWithName("negate1");
|
|
||||||
const HloInstruction* negate2 =
|
|
||||||
module_->entry_computation()->GetInstructionWithName("negate2");
|
|
||||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate1),
|
|
||||||
analysis.GetUniqueBufferAt(fusion, {1}));
|
|
||||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate2),
|
|
||||||
analysis.GetUniqueBufferAt(fusion, {2}));
|
|
||||||
EXPECT_NE(analysis.GetUniqueBufferAt(negate0),
|
|
||||||
analysis.GetUniqueBufferAt(fusion, {0}));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(HloAliasAnalysisTest, ChainedDynamicUpdateSliceFusion) {
|
|
||||||
// CPU and GPU backends may generate fusions with dynamic update slices
|
|
||||||
// feeding each other. They expect the fusion to not be in-place if that is
|
|
||||||
// the case.
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module
|
|
||||||
|
|
||||||
fused_computation {
|
|
||||||
param0 = f32[1280,1,128] parameter(0)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
ROOT dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
|
|
||||||
}
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[1280,1,128] parameter(0)
|
|
||||||
negate0 = f32[1280,1,128] negate(param)
|
|
||||||
ROOT fusion = f32[1280,1,128] fusion(negate0), kind=kLoop, calls=fused_computation
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
|
|
||||||
SCOPED_TRACE(module_->ToString());
|
|
||||||
|
|
||||||
HloAliasAnalysis& analysis = RunAnalysis();
|
|
||||||
LOG(INFO) << analysis.ToString();
|
|
||||||
|
|
||||||
const HloInstruction* fusion =
|
|
||||||
module_->entry_computation()->GetInstructionWithName("fusion");
|
|
||||||
const HloInstruction* negate0 =
|
|
||||||
module_->entry_computation()->GetInstructionWithName("negate0");
|
|
||||||
EXPECT_NE(analysis.GetUniqueBufferAt(negate0),
|
|
||||||
analysis.GetUniqueBufferAt(fusion));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(HloAliasAnalysisTest, BitcastInterference) {
|
TEST_F(HloAliasAnalysisTest, BitcastInterference) {
|
||||||
// A bitcast value simultaneously live with its operand should not cause
|
// A bitcast value simultaneously live with its operand should not cause
|
||||||
// interference.
|
// interference.
|
||||||
|
|||||||
@ -232,24 +232,6 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
|||||||
return copies;
|
return copies;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 GetAlternateMemoryOffset(const PresetAssignments& preset_assignments,
|
|
||||||
const HloInstruction* instruction,
|
|
||||||
const ShapeIndex& index = {}) const {
|
|
||||||
// Returns the offset of the assignment, -1 if it's not in the alternate
|
|
||||||
// memory.
|
|
||||||
const HloModule* module = instruction->parent()->parent();
|
|
||||||
auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
|
|
||||||
HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index);
|
|
||||||
for (auto& pos_and_chunk : preset_assignments.chunks()) {
|
|
||||||
for (auto& value : buffer.values()) {
|
|
||||||
if (pos_and_chunk.first == value->defining_position()) {
|
|
||||||
return pos_and_chunk.second.offset;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
|
std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
|
||||||
HloComputation::Builder builder(TestName());
|
HloComputation::Builder builder(TestName());
|
||||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
@ -4433,47 +4415,6 @@ TEST_P(MemorySpaceAssignmentTest, Determinism) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_P(MemorySpaceAssignmentTest, InPlaceOp) {
|
|
||||||
// Tests that in-place ops like DynamicUpdateSlice get the same allocation as
|
|
||||||
// its input.
|
|
||||||
absl::string_view hlo_string = R"(
|
|
||||||
HloModule Module, is_scheduled=true
|
|
||||||
|
|
||||||
fused_computation {
|
|
||||||
param0 = f32[2,3] parameter(0)
|
|
||||||
constant.1 = f32[] constant(0)
|
|
||||||
broadcast = f32[2,1] broadcast(constant.1), dimensions={}
|
|
||||||
constant.3 = s32[] constant(0)
|
|
||||||
ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
|
|
||||||
}
|
|
||||||
|
|
||||||
ENTRY main {
|
|
||||||
param = f32[2,3] parameter(0)
|
|
||||||
negate = f32[2,3] negate(param)
|
|
||||||
fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation
|
|
||||||
ROOT add = f32[2,3] add(fusion, fusion)
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
|
|
||||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
|
||||||
ParseAndReturnVerifiedModule(hlo_string));
|
|
||||||
auto preset_assignments = AssignMemorySpace(module.get());
|
|
||||||
HloInstruction* negate_instruction =
|
|
||||||
module->entry_computation()->GetInstructionWithName("negate");
|
|
||||||
int64 negate_offset =
|
|
||||||
GetAlternateMemoryOffset(*preset_assignments, negate_instruction);
|
|
||||||
HloInstruction* fusion_instruction =
|
|
||||||
module->entry_computation()->GetInstructionWithName("fusion");
|
|
||||||
int64 fusion_offset =
|
|
||||||
GetAlternateMemoryOffset(*preset_assignments, fusion_instruction);
|
|
||||||
// We expect negate and fusion to get the same offsets.
|
|
||||||
EXPECT_EQ(negate_offset, fusion_offset);
|
|
||||||
const bool allocate_across_sequential_calls = GetParam();
|
|
||||||
if (allocate_across_sequential_calls) {
|
|
||||||
EXPECT_NE(negate_offset, -1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
|
INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
|
||||||
MemorySpaceAssignmentTest,
|
MemorySpaceAssignmentTest,
|
||||||
::testing::Values(false, true));
|
::testing::Values(false, true));
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user