Internal change
PiperOrigin-RevId: 331353990 Change-Id: I66751ebf00239baa016b39d2a8e6d4b2c31d4b48
This commit is contained in:
parent
e49b2326ad
commit
d33c01b88a
@ -1007,6 +1007,102 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
|
||||
return true;
|
||||
} // namespace xla
|
||||
|
||||
Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
|
||||
// Try allocate same buffer for dynamic update slice's operand and output.
|
||||
|
||||
// If memory_space_assignment is run and there is information about a color in
|
||||
// preset assignments, don't merge those buffers. We expect
|
||||
// memory_space_assignment to have merged these buffers. If
|
||||
// memory_space_assignment didn't merge these buffers and have assigned
|
||||
// different offsets to the operand and the output buffer, merging the buffers
|
||||
// can cause memory corruption if memory_space_assignment assigned a different
|
||||
// buffer at the same offset.
|
||||
absl::flat_hash_set<int64> excluded_colors;
|
||||
if (preset_assignments_) {
|
||||
for (const auto& color_and_info :
|
||||
preset_assignments_->assignment_informations()) {
|
||||
excluded_colors.insert(color_and_info.first);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(yunxing): Moving this logic to alias analysis and add must-alias rule
|
||||
// to operations that can be done in place.
|
||||
for (HloComputation* computation : assignment->module().computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice ||
|
||||
(instruction->opcode() == HloOpcode::kFusion &&
|
||||
(instruction->fused_expression_root()->opcode() ==
|
||||
HloOpcode::kDynamicUpdateSlice)))) {
|
||||
continue;
|
||||
}
|
||||
if (instruction->parent()->IsFusionComputation()) {
|
||||
continue;
|
||||
}
|
||||
if (instruction->operand_count() == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The operand can't share the same buffer with the user based on dataflow
|
||||
// analysis.
|
||||
if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser(
|
||||
instruction->mutable_operand(0), {}, instruction, {})) {
|
||||
continue;
|
||||
}
|
||||
HloBuffer& instruction_buffer =
|
||||
assignment->alias_analysis().GetUniqueBufferAt(instruction, {});
|
||||
|
||||
HloBuffer& operand_buffer =
|
||||
assignment->alias_analysis().GetUniqueBufferAt(
|
||||
instruction->operand(0), {});
|
||||
|
||||
// The instruction or operand color is excluded because it was assigned by
|
||||
// memory_space_assignment.
|
||||
if (excluded_colors.contains(instruction_buffer.color()) ||
|
||||
excluded_colors.contains(operand_buffer.color())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Already have the same buffer. No need to merge those.
|
||||
if (instruction_buffer.id() == operand_buffer.id()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Do not perform in-place dynamic update slice if the operand buffer is
|
||||
// read-only.
|
||||
if (HloBufferIsReadOnly(operand_buffer)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
bool interfere = false;
|
||||
|
||||
for (const HloValue* instruction_value : instruction_buffer.values()) {
|
||||
for (const HloValue* operand_value : operand_buffer.values()) {
|
||||
if (assignment->hlo_ordering().MayInterfere(
|
||||
*instruction_value, *operand_value,
|
||||
assignment->dataflow_analysis())) {
|
||||
interfere = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (interfere) {
|
||||
continue;
|
||||
}
|
||||
if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) {
|
||||
continue;
|
||||
}
|
||||
if (instruction_buffer.color() != operand_buffer.color()) {
|
||||
continue;
|
||||
}
|
||||
VLOG(3) << "Merging inplace " << instruction_buffer << " and "
|
||||
<< operand_buffer;
|
||||
assignment->alias_analysis().MergeBuffers(instruction_buffer,
|
||||
operand_buffer);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BufferAssigner::AssignSingleHloBuffer(
|
||||
const HloBuffer* hlo_buffer, bool is_thread_local,
|
||||
absl::flat_hash_map<const HloComputation*,
|
||||
@ -1558,6 +1654,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
|
||||
VLOG(3) << "After coloring:";
|
||||
XLA_VLOG_LINES(3,
|
||||
assignment->alias_analysis().dataflow_analysis().ToString());
|
||||
TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get()));
|
||||
|
||||
std::vector<const HloComputation*> thread_local_computations;
|
||||
std::vector<const HloComputation*> global_computations;
|
||||
|
@ -635,6 +635,10 @@ class BufferAssigner {
|
||||
absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
|
||||
BufferAssignment* assignment);
|
||||
|
||||
// Promotes operations (DUS, scatter) to be done in place: If an operation can
|
||||
// be done in place, merge its buffer with its operand buffer.
|
||||
Status MergeInplaceOpBuffers(BufferAssignment* assignment);
|
||||
|
||||
// Assigns a single hlo buffer to an HLO allocation.
|
||||
Status AssignSingleHloBuffer(
|
||||
const HloBuffer* hlo_buffer, bool is_thread_local,
|
||||
|
@ -1925,10 +1925,8 @@ ENTRY main {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
|
||||
HloInstruction* parameter =
|
||||
m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
|
||||
HloInstruction* dus1 =
|
||||
HloInstruction* dus =
|
||||
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
|
||||
HloInstruction* dus2 =
|
||||
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
|
||||
|
||||
auto buffers = RunBufferAssignment(m.get());
|
||||
|
||||
@ -1936,10 +1934,8 @@ ENTRY main {
|
||||
const BufferAllocation& parameter_alloc =
|
||||
GetTopLevelAllocation(*buffers, parameter);
|
||||
|
||||
const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
|
||||
EXPECT_EQ(parameter_alloc, dus1_alloc);
|
||||
const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
|
||||
EXPECT_EQ(parameter_alloc, dus2_alloc);
|
||||
const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus);
|
||||
EXPECT_NE(parameter_alloc, dus_alloc);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -362,19 +362,6 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
|
||||
// will remove the unnecessary copies.
|
||||
Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
|
||||
HloInstruction* in_place_op,
|
||||
int64 operand_number) {
|
||||
VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
|
||||
HloInstruction* operand = in_place_op->mutable_operand(operand_number);
|
||||
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
|
||||
in_place_op->parent()->DeepCopyInstruction(operand));
|
||||
TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Conservatively adds copies before root instruction of entry computation and
|
||||
// each aliased parameter to resolve interference of aliased input and output
|
||||
// buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
|
||||
@ -522,12 +509,6 @@ class CopyRemover {
|
||||
// value. The map is used to construct the copy info map below.
|
||||
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
|
||||
for (const HloBuffer& buffer : alias_analysis.buffers()) {
|
||||
// No copies should have been inserted within fused computations, so no
|
||||
// need to remove them. HloOrdering isn't compatible with HloValues inside
|
||||
// fusions, so skip copy removal for them.
|
||||
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
|
||||
continue;
|
||||
}
|
||||
// Verify values contained in the buffer are strictly ordered. This
|
||||
// should always be the case after adding copies to eliminate
|
||||
// interference. Specifically, the addition of the control flow edges
|
||||
@ -610,7 +591,7 @@ class CopyRemover {
|
||||
void CreateCopyMap(
|
||||
const HloModule& module,
|
||||
const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
|
||||
for (HloComputation* computation : module.MakeNonfusionComputations()) {
|
||||
for (HloComputation* computation : module.computations()) {
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
// Add copies with unambiguous source values to the map. Copies with
|
||||
// ambiguous sources are not removable.
|
||||
@ -1024,7 +1005,7 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, can_share_buffer_));
|
||||
|
||||
for (HloComputation* computation : module->MakeNonfusionComputations()) {
|
||||
for (HloComputation* computation : module->MakeComputationPostOrder()) {
|
||||
for (HloInstruction* instruction :
|
||||
computation->MakeInstructionPostOrder()) {
|
||||
if (instruction->opcode() == HloOpcode::kWhile) {
|
||||
@ -1032,13 +1013,6 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
|
||||
} else if (instruction->opcode() == HloOpcode::kConditional) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddCopiesForConditional(*alias_analysis, instruction));
|
||||
} else {
|
||||
for (const auto& operand_number_and_output_index :
|
||||
HloAliasAnalysis::GetInPlaceInputOutputPairs(instruction)) {
|
||||
TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
|
||||
*alias_analysis, instruction,
|
||||
operand_number_and_output_index.first));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2530,250 +2530,5 @@ ENTRY Entry {
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add = f32[1280,1,128] add(negate, negate)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
add = f32[1280,1,128] add(negate, negate)
|
||||
fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
ENTRY main {
|
||||
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
|
||||
get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
|
||||
get-tuple-element.3 = s32[] get-tuple-element(state), index=0
|
||||
constant.2 = s32[] constant(128)
|
||||
add.5 = s32[] add(get-tuple-element.3, constant.2)
|
||||
constant.3 = s32[] constant(0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation.1 {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
fused_computation.2 {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate = f32[1280,1,128] negate(param)
|
||||
add = f32[1280,1,128] add(negate, negate)
|
||||
fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1
|
||||
ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) {
|
||||
// Tests multi-output fusion with two DUS outputs, requiring two copies.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
param2 = f32[1280,1,128] parameter(2)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add.1 = f32[1280,1,128] add(param0, param0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
negate1 = f32[1280,1,128] negate(param)
|
||||
negate2 = f32[1280,1,128] negate(param)
|
||||
fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
||||
gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
|
||||
gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
|
||||
gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
|
||||
add0 = f32[1280,1,128] add(negate0, gte0)
|
||||
add1 = f32[1280,1,128] add(negate1, gte1)
|
||||
add2 = f32[1280,1,128] add(negate2, gte2)
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 2);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) {
|
||||
// Same as above, but negate1 is not used beyond fusion, so it only needs one
|
||||
// copy for negate0.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
param2 = f32[1280,1,128] parameter(2)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add.1 = f32[1280,1,128] add(param0, param0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
negate1 = f32[1280,1,128] negate(param)
|
||||
negate2 = f32[1280,1,128] negate(param)
|
||||
fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
||||
gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
|
||||
gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
|
||||
gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
|
||||
add0 = f32[1280,1,128] add(negate0, gte0)
|
||||
add1 = f32[1280,1,128] add(gte1, gte1)
|
||||
add2 = f32[1280,1,128] add(negate2, gte2)
|
||||
ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
InsertCopies(module.get());
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1,6 +1,6 @@
|
||||
// RUN: hlo_to_llvm_ir %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||
@ -43,8 +43,8 @@
|
||||
// CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4
|
||||
// CHECK: br label %[[VAL_23]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{i32 0, i32 6}
|
||||
// CHECK: !4 = !{}
|
||||
@ -72,7 +72,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) {
|
||||
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
|
||||
@ -104,8 +104,8 @@ ENTRY main {
|
||||
// CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4
|
||||
// CHECK: br label %[[VAL_57]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{}
|
||||
|
||||
@ -131,7 +131,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
|
||||
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
|
||||
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
|
||||
@ -188,8 +188,8 @@ ENTRY main {
|
||||
// CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1
|
||||
// CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{i32 0, i32 6}
|
||||
// CHECK: !4 = !{}
|
||||
@ -216,7 +216,7 @@ ENTRY main {
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) {
|
||||
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) {
|
||||
// CHECK: entry:
|
||||
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
|
||||
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
|
||||
@ -253,8 +253,8 @@ ENTRY main {
|
||||
// CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4
|
||||
// CHECK: br label %[[VAL_138]]
|
||||
// CHECK: !nvvm.annotations = !{!0, !1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
|
||||
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
|
||||
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
|
||||
// CHECK: !2 = !{i32 0, i32 1}
|
||||
// CHECK: !3 = !{}
|
||||
|
||||
|
@ -308,39 +308,6 @@ class BufferValueMap {
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeInPlaceOperationAliasedBuffers(
|
||||
const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
|
||||
VLOG(3) << "Compute aliases for in-place operations (e.g. "
|
||||
"kDynamicUpdateSlice and kScatter)";
|
||||
for (const HloPosition& position : value.positions()) {
|
||||
HloInstruction* instruction = position.instruction;
|
||||
for (const auto& operand_number_and_output_index :
|
||||
HloAliasAnalysis::GetInPlaceInputOutputPairs(instruction)) {
|
||||
if (position.index == operand_number_and_output_index.second) {
|
||||
int64 operand_number = operand_number_and_output_index.first;
|
||||
const HloValue& operand_value = dataflow_.GetUniqueValueAt(
|
||||
instruction->operand(operand_number), {});
|
||||
VLOG(3) << " operand value " << operand_value.ToShortString()
|
||||
<< " aliases.";
|
||||
aliased_buffers->push_back(GetBufferForValue(operand_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const HloUse& use : value.uses()) {
|
||||
for (const auto& operand_number_and_output_index :
|
||||
HloAliasAnalysis::GetInPlaceInputOutputPairs(use.instruction)) {
|
||||
int64 operand_number = operand_number_and_output_index.first;
|
||||
if (use.operand_number == operand_number) {
|
||||
const HloValue& use_value = dataflow_.GetUniqueValueAt(
|
||||
use.instruction, operand_number_and_output_index.second);
|
||||
VLOG(3) << " use value " << use_value.ToShortString() << " aliases.";
|
||||
aliased_buffers->push_back(GetBufferForValue(use_value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Compute and return a vector of buffers that the given value must be
|
||||
// contained in due to HLO aliasing rules.
|
||||
std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
|
||||
@ -351,7 +318,6 @@ class BufferValueMap {
|
||||
ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeWhileAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
|
||||
ComputeInPlaceOperationAliasedBuffers(value, &aliased_buffers);
|
||||
// Uniquify aliased buffers.
|
||||
absl::c_sort(aliased_buffers);
|
||||
aliased_buffers.erase(
|
||||
@ -376,42 +342,6 @@ class BufferValueMap {
|
||||
BufferNumber next_buffer_number_ = 0;
|
||||
};
|
||||
|
||||
/*static*/ bool HloAliasAnalysis::IsInPlaceOperation(HloOpcode opcode) {
|
||||
return opcode == HloOpcode::kDynamicUpdateSlice ||
|
||||
opcode == HloOpcode::kScatter;
|
||||
}
|
||||
|
||||
/*static*/ std::vector<std::pair<int64, ShapeIndex>>
|
||||
HloAliasAnalysis::GetInPlaceInputOutputPairs(
|
||||
const HloInstruction* instruction) {
|
||||
if (IsInPlaceOperation(instruction->opcode())) {
|
||||
return {{0, {}}};
|
||||
} else if (instruction->opcode() != HloOpcode::kFusion) {
|
||||
return {};
|
||||
}
|
||||
std::vector<std::pair<int64, ShapeIndex>> input_output_pairs;
|
||||
for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) {
|
||||
const HloInstruction* hlo_generating_output =
|
||||
instruction->fused_expression_root();
|
||||
for (int64 i = 0; i < indexed_shape.index.size(); ++i) {
|
||||
if (hlo_generating_output->opcode() == HloOpcode::kTuple) {
|
||||
hlo_generating_output =
|
||||
hlo_generating_output->operand(indexed_shape.index[i]);
|
||||
} else {
|
||||
CHECK_EQ(i, indexed_shape.index.size() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
if (IsInPlaceOperation(hlo_generating_output->opcode()) &&
|
||||
hlo_generating_output->operand(0)->opcode() == HloOpcode::kParameter) {
|
||||
input_output_pairs.emplace_back(
|
||||
hlo_generating_output->operand(0)->parameter_number(),
|
||||
indexed_shape.index);
|
||||
}
|
||||
}
|
||||
return input_output_pairs;
|
||||
}
|
||||
|
||||
HloAliasAnalysis::HloAliasAnalysis(const HloModule* module) : module_(module) {}
|
||||
|
||||
const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
|
||||
|
@ -120,15 +120,6 @@ class HloAliasAnalysis {
|
||||
return results;
|
||||
}
|
||||
|
||||
// Returns true if the operation is an in-place operation and its operand 0
|
||||
// must alias with the output.
|
||||
static bool IsInPlaceOperation(HloOpcode opcode);
|
||||
|
||||
// Returns a vector consisting of operand number and output shape index of the
|
||||
// in-place operations within this HLO.
|
||||
static std::vector<std::pair<int64, ShapeIndex>> GetInPlaceInputOutputPairs(
|
||||
const HloInstruction* instruction);
|
||||
|
||||
protected:
|
||||
explicit HloAliasAnalysis(const HloModule* module);
|
||||
|
||||
|
@ -1062,118 +1062,6 @@ TEST_F(HloAliasAnalysisTest, MergeBuffersReverse) {
|
||||
analysis.BufferLivesOut(analysis.buffers()[0]);
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, DynamicUpdateSlice) {
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {8});
|
||||
Shape update_shape = ShapeUtil::MakeShape(F32, {4});
|
||||
Shape index_shape = ShapeUtil::MakeShape(S32, {});
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape, "param0"));
|
||||
auto param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, update_shape, "param1"));
|
||||
auto param2 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, index_shape, "param2"));
|
||||
auto copy0 = builder.AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0));
|
||||
auto dynamic_update_slice = builder.AddInstruction(
|
||||
HloInstruction::CreateDynamicUpdateSlice(shape, copy0, param1, {param2}));
|
||||
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(copy0),
|
||||
analysis.GetUniqueBufferAt(dynamic_update_slice));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, DynamicUpdateSliceMultiOutputFusion) {
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
param1 = f32[1280,1,128] parameter(1)
|
||||
param2 = f32[1280,1,128] parameter(2)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
add.1 = f32[1280,1,128] add(param0, param0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
|
||||
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
negate1 = f32[1280,1,128] negate(param)
|
||||
negate2 = f32[1280,1,128] negate(param)
|
||||
ROOT fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
LOG(INFO) << analysis.ToString();
|
||||
|
||||
// Expect negate1 and negate2 to alias with fusion{1} and fusion{2}
|
||||
// respectively (due to DUS), but not negate0 and fusion{0}.
|
||||
const HloInstruction* fusion =
|
||||
module_->entry_computation()->GetInstructionWithName("fusion");
|
||||
const HloInstruction* negate0 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate0");
|
||||
const HloInstruction* negate1 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate1");
|
||||
const HloInstruction* negate2 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate2");
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate1),
|
||||
analysis.GetUniqueBufferAt(fusion, {1}));
|
||||
EXPECT_EQ(analysis.GetUniqueBufferAt(negate2),
|
||||
analysis.GetUniqueBufferAt(fusion, {2}));
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(negate0),
|
||||
analysis.GetUniqueBufferAt(fusion, {0}));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, ChainedDynamicUpdateSliceFusion) {
|
||||
// CPU and GPU backends may generate fusions with dynamic update slices
|
||||
// feeding each other. They expect the fusion to not be in-place if that is
|
||||
// the case.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[1280,1,128] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
|
||||
ROOT dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[1280,1,128] parameter(0)
|
||||
negate0 = f32[1280,1,128] negate(param)
|
||||
ROOT fusion = f32[1280,1,128] fusion(negate0), kind=kLoop, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string));
|
||||
|
||||
SCOPED_TRACE(module_->ToString());
|
||||
|
||||
HloAliasAnalysis& analysis = RunAnalysis();
|
||||
LOG(INFO) << analysis.ToString();
|
||||
|
||||
const HloInstruction* fusion =
|
||||
module_->entry_computation()->GetInstructionWithName("fusion");
|
||||
const HloInstruction* negate0 =
|
||||
module_->entry_computation()->GetInstructionWithName("negate0");
|
||||
EXPECT_NE(analysis.GetUniqueBufferAt(negate0),
|
||||
analysis.GetUniqueBufferAt(fusion));
|
||||
}
|
||||
|
||||
TEST_F(HloAliasAnalysisTest, BitcastInterference) {
|
||||
// A bitcast value simultaneously live with its operand should not cause
|
||||
// interference.
|
||||
|
@ -232,24 +232,6 @@ class MemorySpaceAssignmentTest : public HloTestBase,
|
||||
return copies;
|
||||
}
|
||||
|
||||
int64 GetAlternateMemoryOffset(const PresetAssignments& preset_assignments,
|
||||
const HloInstruction* instruction,
|
||||
const ShapeIndex& index = {}) const {
|
||||
// Returns the offset of the assignment, -1 if it's not in the alternate
|
||||
// memory.
|
||||
const HloModule* module = instruction->parent()->parent();
|
||||
auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
|
||||
HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index);
|
||||
for (auto& pos_and_chunk : preset_assignments.chunks()) {
|
||||
for (auto& value : buffer.values()) {
|
||||
if (pos_and_chunk.first == value->defining_position()) {
|
||||
return pos_and_chunk.second.offset;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
|
||||
HloComputation::Builder builder(TestName());
|
||||
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
|
||||
@ -4433,47 +4415,6 @@ TEST_P(MemorySpaceAssignmentTest, Determinism) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, InPlaceOp) {
|
||||
// Tests that in-place ops like DynamicUpdateSlice get the same allocation as
|
||||
// its input.
|
||||
absl::string_view hlo_string = R"(
|
||||
HloModule Module, is_scheduled=true
|
||||
|
||||
fused_computation {
|
||||
param0 = f32[2,3] parameter(0)
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast = f32[2,1] broadcast(constant.1), dimensions={}
|
||||
constant.3 = s32[] constant(0)
|
||||
ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param = f32[2,3] parameter(0)
|
||||
negate = f32[2,3] negate(param)
|
||||
fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation
|
||||
ROOT add = f32[2,3] add(fusion, fusion)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
auto preset_assignments = AssignMemorySpace(module.get());
|
||||
HloInstruction* negate_instruction =
|
||||
module->entry_computation()->GetInstructionWithName("negate");
|
||||
int64 negate_offset =
|
||||
GetAlternateMemoryOffset(*preset_assignments, negate_instruction);
|
||||
HloInstruction* fusion_instruction =
|
||||
module->entry_computation()->GetInstructionWithName("fusion");
|
||||
int64 fusion_offset =
|
||||
GetAlternateMemoryOffset(*preset_assignments, fusion_instruction);
|
||||
// We expect negate and fusion to get the same offsets.
|
||||
EXPECT_EQ(negate_offset, fusion_offset);
|
||||
const bool allocate_across_sequential_calls = GetParam();
|
||||
if (allocate_across_sequential_calls) {
|
||||
EXPECT_NE(negate_offset, -1);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
|
||||
MemorySpaceAssignmentTest,
|
||||
::testing::Values(false, true));
|
||||
|
Loading…
x
Reference in New Issue
Block a user