[XLA] Try committing DUS buffer sharing again with fixes.

There were two sources of failures with the previous CL (cl/331779239):

1- Some models triggered dynamic update slice input and output to be different
   precisions. We don't want to disable mixed precision since the update can be
   a different precision than the input/output and there is a performance cost
   to disabling this mode. To fix, in bfloat16 propagation, we now force the DUS
   input and output to be the same precision.
2- Two models triggered patterns where two dynamic update slices that shared the
   same input were fused to the same multi-output fusion. This is wrong because
   the DUS happens in-place, so we can't have DUS multi-output fusions that use
   the same operand. We now explicitly check this pattern and disallow this in
   multi-output fusions.

PiperOrigin-RevId: 332247949
Change-Id: If76f15a0d26e45e256269b2ec51f9c648d2f203f
This commit is contained in:
Berkin Ilbeyi 2020-09-17 09:15:16 -07:00 committed by TensorFlower Gardener
parent 105f48a75e
commit 593bf310d5
17 changed files with 643 additions and 236 deletions

View File

@ -83,6 +83,7 @@ cc_library(
deps = [
":bfloat16_support",
":hlo",
":hlo_dataflow_analysis",
":hlo_pass",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto_cc",
@ -1684,6 +1685,7 @@ cc_library(
hdrs = ["multi_output_fusion.h"],
deps = [
":hlo",
":hlo_dataflow_analysis",
":hlo_dce",
":hlo_pass",
":hlo_reachability",

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status_macros.h"
@ -159,19 +160,20 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions(
Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) {
// Do not fold BF16 conversions for instructions related to tuples, entry and
// exit of a computation, fusion, convert, side-effecting instructions and
// control flow.
if (hlo->opcode() == HloOpcode::kTuple || //
hlo->opcode() == HloOpcode::kGetTupleElement || //
hlo->opcode() == HloOpcode::kConstant || //
hlo->opcode() == HloOpcode::kParameter || //
hlo->opcode() == HloOpcode::kFusion || //
hlo->opcode() == HloOpcode::kBitcastConvert || //
hlo->opcode() == HloOpcode::kConvert || //
hlo->opcode() == HloOpcode::kCall || //
hlo->opcode() == HloOpcode::kCustomCall || //
hlo->opcode() == HloOpcode::kWhile || //
hlo->opcode() == HloOpcode::kConditional || //
// exit of a computation, fusion, convert, side-effecting instructions,
// in-place operations and control flow.
if (hlo->opcode() == HloOpcode::kTuple || //
hlo->opcode() == HloOpcode::kGetTupleElement || //
hlo->opcode() == HloOpcode::kConstant || //
hlo->opcode() == HloOpcode::kParameter || //
hlo->opcode() == HloOpcode::kFusion || //
hlo->opcode() == HloOpcode::kBitcastConvert || //
hlo->opcode() == HloOpcode::kConvert || //
hlo->opcode() == HloOpcode::kCall || //
hlo->opcode() == HloOpcode::kCustomCall || //
hlo->opcode() == HloOpcode::kWhile || //
hlo->opcode() == HloOpcode::kConditional || //
HloDataflowAnalysis::IsInPlaceOperation(hlo->opcode()) || //
hlo->HasSideEffectNoRecurse()) {
return Status::OK();
}

View File

@ -598,6 +598,31 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
type = F32;
break;
}
// In order to find aliases due to in-place operations, use
// GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here,
// but this code works with HloModules that aren't ready yet to use
// HloAliasAnalysis (e.g., their computation graphs may not have been
// flattened yet).
for (const auto& operand_and_output_index :
HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) {
if (operand_and_output_index.second == index) {
const HloUse& operand = operand_and_output_index.first;
for (const auto* value :
dataflow_
->GetValueSet(hlo->operand(operand.operand_number),
operand.operand_index)
.values()) {
auto value_type = ValueTypeAfterChange(value);
if (value_type == BF16) {
continue;
}
CHECK_EQ(value_type, F32);
type = F32;
break;
}
}
}
// It's possible that a user has been changed from BF16 to F32
// during this final adjustment pass, so we need to check
// AllUsersConsumeBF16() again.

View File

@ -1156,4 +1156,30 @@ ENTRY entry {
EXPECT_FALSE(PropagatePrecision(module.get()));
}
TEST_F(BFloat16PropagationTest, DynamicUpdateSlice) {
// This test is crafted so that the DUS has an f32 input (due to parameter)
// and bf16 output (due to dot). But we should enforce DUS operand 0 and
// output to get the same precision since it's an in-place operation.
const string module_str = R"(
HloModule Module
ENTRY main {
param = f32[128,128] parameter(0)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
dynamic-update-slice = f32[128,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3)
ROOT dot = f32[128,128] dot(dynamic-update-slice, dynamic-update-slice), lhs_contracting_dims={1}, rhs_contracting_dims={0}
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(module_str));
EXPECT_FALSE(PropagatePrecision(module.get()));
HloInstruction* dus = module->entry_computation()->GetInstructionWithName(
"dynamic-update-slice");
EXPECT_FALSE(OutputsBF16(dus));
}
} // namespace xla

View File

@ -1007,102 +1007,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
return true;
} // namespace xla
Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) {
// Try allocate same buffer for dynamic update slice's operand and output.
// If memory_space_assignment is run and there is information about a color in
// preset assignments, don't merge those buffers. We expect
// memory_space_assignment to have merged these buffers. If
// memory_space_assignment didn't merge these buffers and have assigned
// different offsets to the operand and the output buffer, merging the buffers
// can cause memory corruption if memory_space_assignment assigned a different
// buffer at the same offset.
absl::flat_hash_set<int64> excluded_colors;
if (preset_assignments_) {
for (const auto& color_and_info :
preset_assignments_->assignment_informations()) {
excluded_colors.insert(color_and_info.first);
}
}
// TODO(yunxing): Moving this logic to alias analysis and add must-alias rule
// to operations that can be done in place.
for (HloComputation* computation : assignment->module().computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice ||
(instruction->opcode() == HloOpcode::kFusion &&
(instruction->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice)))) {
continue;
}
if (instruction->parent()->IsFusionComputation()) {
continue;
}
if (instruction->operand_count() == 0) {
continue;
}
// The operand can't share the same buffer with the user based on dataflow
// analysis.
if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser(
instruction->mutable_operand(0), {}, instruction, {})) {
continue;
}
HloBuffer& instruction_buffer =
assignment->alias_analysis().GetUniqueBufferAt(instruction, {});
HloBuffer& operand_buffer =
assignment->alias_analysis().GetUniqueBufferAt(
instruction->operand(0), {});
// The instruction or operand color is excluded because it was assigned by
// memory_space_assignment.
if (excluded_colors.contains(instruction_buffer.color()) ||
excluded_colors.contains(operand_buffer.color())) {
continue;
}
// Already have the same buffer. No need to merge those.
if (instruction_buffer.id() == operand_buffer.id()) {
continue;
}
// Do not perform in-place dynamic update slice if the operand buffer is
// read-only.
if (HloBufferIsReadOnly(operand_buffer)) {
continue;
}
bool interfere = false;
for (const HloValue* instruction_value : instruction_buffer.values()) {
for (const HloValue* operand_value : operand_buffer.values()) {
if (assignment->hlo_ordering().MayInterfere(
*instruction_value, *operand_value,
assignment->dataflow_analysis())) {
interfere = true;
break;
}
}
}
if (interfere) {
continue;
}
if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) {
continue;
}
if (instruction_buffer.color() != operand_buffer.color()) {
continue;
}
VLOG(3) << "Merging inplace " << instruction_buffer << " and "
<< operand_buffer;
assignment->alias_analysis().MergeBuffers(instruction_buffer,
operand_buffer);
}
}
return Status::OK();
}
Status BufferAssigner::AssignSingleHloBuffer(
const HloBuffer* hlo_buffer, bool is_thread_local,
absl::flat_hash_map<const HloComputation*,
@ -1654,7 +1558,6 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
VLOG(3) << "After coloring:";
XLA_VLOG_LINES(3,
assignment->alias_analysis().dataflow_analysis().ToString());
TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get()));
std::vector<const HloComputation*> thread_local_computations;
std::vector<const HloComputation*> global_computations;

View File

@ -635,10 +635,6 @@ class BufferAssigner {
absl::flat_hash_set<const HloBuffer*>* assigned_buffers,
BufferAssignment* assignment);
// Promotes operations (DUS, scatter) to be done in place: If an operation can
// be done in place, merge its buffer with its operand buffer.
Status MergeInplaceOpBuffers(BufferAssignment* assignment);
// Assigns a single hlo buffer to an HLO allocation.
Status AssignSingleHloBuffer(
const HloBuffer* hlo_buffer, bool is_thread_local,

View File

@ -1925,8 +1925,10 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text));
HloInstruction* parameter =
m->entry_computation()->GetInstructionWithName("get-tuple-element.4");
HloInstruction* dus =
HloInstruction* dus1 =
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5");
HloInstruction* dus2 =
m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9");
auto buffers = RunBufferAssignment(m.get());
@ -1934,8 +1936,10 @@ ENTRY main {
const BufferAllocation& parameter_alloc =
GetTopLevelAllocation(*buffers, parameter);
const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus);
EXPECT_NE(parameter_alloc, dus_alloc);
const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1);
EXPECT_EQ(parameter_alloc, dus1_alloc);
const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2);
EXPECT_EQ(parameter_alloc, dus2_alloc);
}
}

View File

@ -362,6 +362,19 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
return Status::OK();
}
// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies
// will remove the unnecessary copies.
Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis,
HloInstruction* in_place_op,
int64 operand_number) {
VLOG(2) << "Adding copies for in-place operation " << in_place_op->name();
HloInstruction* operand = in_place_op->mutable_operand(operand_number);
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
in_place_op->parent()->DeepCopyInstruction(operand));
TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy));
return Status::OK();
}
// Conservatively adds copies before root instruction of entry computation and
// each aliased parameter to resolve interference of aliased input and output
// buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary
@ -509,6 +522,12 @@ class CopyRemover {
// value. The map is used to construct the copy info map below.
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
for (const HloBuffer& buffer : alias_analysis.buffers()) {
// No copies should have been inserted within fused computations, so no
// need to remove them. HloOrdering isn't compatible with HloValues inside
// fusions, so skip copy removal for them.
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
continue;
}
// Verify values contained in the buffer are strictly ordered. This
// should always be the case after adding copies to eliminate
// interference. Specifically, the addition of the control flow edges
@ -591,7 +610,7 @@ class CopyRemover {
void CreateCopyMap(
const HloModule& module,
const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
for (HloComputation* computation : module.computations()) {
for (HloComputation* computation : module.MakeNonfusionComputations()) {
for (HloInstruction* instruction : computation->instructions()) {
// Add copies with unambiguous source values to the map. Copies with
// ambiguous sources are not removable.
@ -1005,7 +1024,7 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, can_share_buffer_));
for (HloComputation* computation : module->MakeComputationPostOrder()) {
for (HloComputation* computation : module->MakeNonfusionComputations()) {
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kWhile) {
@ -1013,6 +1032,15 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
} else if (instruction->opcode() == HloOpcode::kConditional) {
TF_RETURN_IF_ERROR(
AddCopiesForConditional(*alias_analysis, instruction));
} else {
for (const auto& operand_and_output_index :
HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
const HloUse& operand = operand_and_output_index.first;
CHECK_EQ(operand.operand_index, ShapeIndex{})
<< "Support for non-{} shape operand not currently implemented.";
TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation(
*alias_analysis, instruction, operand.operand_number));
}
}
}
}

View File

@ -2530,5 +2530,250 @@ ENTRY Entry {
EXPECT_EQ(CountCopies(*module), 1);
}
TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) {
absl::string_view hlo_string = R"(
HloModule Module
ENTRY main {
param = f32[1280,1,128] parameter(0)
negate = f32[1280,1,128] negate(param)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 0);
}
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) {
absl::string_view hlo_string = R"(
HloModule Module
fused_computation {
param0 = f32[1280,1,128] parameter(0)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
}
ENTRY main {
param = f32[1280,1,128] parameter(0)
negate = f32[1280,1,128] negate(param)
ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 0);
}
TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) {
absl::string_view hlo_string = R"(
HloModule Module
ENTRY main {
param = f32[1280,1,128] parameter(0)
negate = f32[1280,1,128] negate(param)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
add = f32[1280,1,128] add(negate, negate)
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3)
ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 1);
}
TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) {
absl::string_view hlo_string = R"(
HloModule Module
ENTRY main {
param = f32[1280,1,128] parameter(0)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 1);
}
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) {
absl::string_view hlo_string = R"(
HloModule Module
fused_computation {
param0 = f32[1280,1,128] parameter(0)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
}
ENTRY main {
param = f32[1280,1,128] parameter(0)
negate = f32[1280,1,128] negate(param)
add = f32[1280,1,128] add(negate, negate)
fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation
ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 1);
}
TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) {
absl::string_view hlo_string = R"(
HloModule Module
ENTRY main {
state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={}
get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1
get-tuple-element.3 = s32[] get-tuple-element(state), index=0
constant.2 = s32[] constant(128)
add.5 = s32[] add(get-tuple-element.3, constant.2)
constant.3 = s32[] constant(0)
dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3)
dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 1);
}
TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) {
absl::string_view hlo_string = R"(
HloModule Module
fused_computation.1 {
param0 = f32[1280,1,128] parameter(0)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
}
fused_computation.2 {
param0 = f32[1280,1,128] parameter(0)
param1 = f32[1280,1,128] parameter(1)
slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]}
constant.3 = s32[] constant(0)
ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3)
}
ENTRY main {
param = f32[1280,1,128] parameter(0)
negate = f32[1280,1,128] negate(param)
add = f32[1280,1,128] add(negate, negate)
fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1
ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 1);
}
TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) {
// Tests multi-output fusion with two DUS outputs, requiring two copies.
absl::string_view hlo_string = R"(
HloModule Module
fused_computation {
param0 = f32[1280,1,128] parameter(0)
param1 = f32[1280,1,128] parameter(1)
param2 = f32[1280,1,128] parameter(2)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
add.1 = f32[1280,1,128] add(param0, param0)
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
}
ENTRY main {
param = f32[1280,1,128] parameter(0)
negate0 = f32[1280,1,128] negate(param)
negate1 = f32[1280,1,128] negate(param)
negate2 = f32[1280,1,128] negate(param)
fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
add0 = f32[1280,1,128] add(negate0, gte0)
add1 = f32[1280,1,128] add(negate1, gte1)
add2 = f32[1280,1,128] add(negate2, gte2)
ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 2);
}
TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) {
// Same as above, but negate1 is not used beyond fusion, so it only needs one
// copy for negate0.
absl::string_view hlo_string = R"(
HloModule Module
fused_computation {
param0 = f32[1280,1,128] parameter(0)
param1 = f32[1280,1,128] parameter(1)
param2 = f32[1280,1,128] parameter(2)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
add.1 = f32[1280,1,128] add(param0, param0)
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
}
ENTRY main {
param = f32[1280,1,128] parameter(0)
negate0 = f32[1280,1,128] negate(param)
negate1 = f32[1280,1,128] negate(param)
negate2 = f32[1280,1,128] negate(param)
fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0
gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1
gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2
add0 = f32[1280,1,128] add(negate0, gte0)
add1 = f32[1280,1,128] add(gte1, gte1)
add2 = f32[1280,1,128] add(negate2, gte2)
ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
InsertCopies(module.get());
EXPECT_EQ(CountCopies(*module), 1);
}
} // namespace
} // namespace xla

View File

@ -1,6 +1,6 @@
// RUN: hlo_to_llvm_ir %s | FileCheck %s
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
// CHECK: entry:
// CHECK: %[[VAL_32:.*]] = alloca i32, align 4
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
@ -43,8 +43,8 @@
// CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4
// CHECK: br label %[[VAL_23]]
// CHECK: !nvvm.annotations = !{!0, !1}
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6}
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{i32 0, i32 6}
// CHECK: !4 = !{}
@ -72,7 +72,7 @@ ENTRY main {
// -----
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) {
// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) {
// CHECK: entry:
// CHECK: %[[VAL_60:.*]] = alloca i32, align 4
// CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0
@ -104,8 +104,8 @@ ENTRY main {
// CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4
// CHECK: br label %[[VAL_57]]
// CHECK: !nvvm.annotations = !{!0, !1}
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1}
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{}
@ -131,7 +131,7 @@ ENTRY main {
// -----
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) {
// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) {
// CHECK: %[[VAL_63:.*]] = alloca i32, align 4
// CHECK: %[[VAL_64:.*]] = alloca i32, align 4
// CHECK: %[[VAL_98:.*]] = alloca i32, align 4
@ -188,8 +188,8 @@ ENTRY main {
// CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1
// CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]]
// CHECK: !nvvm.annotations = !{!0, !1}
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6}
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{i32 0, i32 6}
// CHECK: !4 = !{}
@ -216,7 +216,7 @@ ENTRY main {
// -----
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) {
// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) {
// CHECK: entry:
// CHECK: %[[VAL_146:.*]] = alloca i32, align 4
// CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0
@ -253,8 +253,8 @@ ENTRY main {
// CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4
// CHECK: br label %[[VAL_138]]
// CHECK: !nvvm.annotations = !{!0, !1}
// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1}
// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1}
// CHECK: !2 = !{i32 0, i32 1}
// CHECK: !3 = !{}

View File

@ -308,6 +308,39 @@ class BufferValueMap {
}
}
void ComputeInPlaceOperationAliasedBuffers(
const HloValue& value, std::vector<BufferNumber>* aliased_buffers) {
VLOG(3) << "Compute aliases for in-place operations (e.g. "
"kDynamicUpdateSlice and kScatter)";
for (const HloPosition& position : value.positions()) {
HloInstruction* instruction = position.instruction;
for (const auto& operand_and_output_index :
HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) {
if (position.index == operand_and_output_index.second) {
const HloUse& operand = operand_and_output_index.first;
const HloValue& operand_value = dataflow_.GetUniqueValueAt(
instruction->operand(operand.operand_number),
operand.operand_index);
VLOG(3) << " operand value " << operand_value.ToShortString()
<< " aliases.";
aliased_buffers->push_back(GetBufferForValue(operand_value));
}
}
}
for (const HloUse& use : value.uses()) {
for (const auto& operand_and_output_index :
HloDataflowAnalysis::GetInPlaceInputOutputPairs(use.instruction)) {
if (use == operand_and_output_index.first) {
const HloValue& use_value = dataflow_.GetUniqueValueAt(
use.instruction, operand_and_output_index.second);
VLOG(3) << " use value " << use_value.ToShortString() << " aliases.";
aliased_buffers->push_back(GetBufferForValue(use_value));
}
}
}
}
// Compute and return a vector of buffers that the given value must be
// contained in due to HLO aliasing rules.
std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
@ -318,6 +351,7 @@ class BufferValueMap {
ComputeInputOutputAliasedBuffers(value, &aliased_buffers);
ComputeWhileAliasedBuffers(value, &aliased_buffers);
ComputeConditionalAliasedBuffers(value, &aliased_buffers);
ComputeInPlaceOperationAliasedBuffers(value, &aliased_buffers);
// Uniquify aliased buffers.
absl::c_sort(aliased_buffers);
aliased_buffers.erase(

View File

@ -1062,6 +1062,118 @@ TEST_F(HloAliasAnalysisTest, MergeBuffersReverse) {
analysis.BufferLivesOut(analysis.buffers()[0]);
}
TEST_F(HloAliasAnalysisTest, DynamicUpdateSlice) {
Shape shape = ShapeUtil::MakeShape(F32, {8});
Shape update_shape = ShapeUtil::MakeShape(F32, {4});
Shape index_shape = ShapeUtil::MakeShape(S32, {});
auto builder = HloComputation::Builder(TestName());
auto param0 = builder.AddInstruction(
HloInstruction::CreateParameter(0, shape, "param0"));
auto param1 = builder.AddInstruction(
HloInstruction::CreateParameter(1, update_shape, "param1"));
auto param2 = builder.AddInstruction(
HloInstruction::CreateParameter(2, index_shape, "param2"));
auto copy0 = builder.AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0));
auto dynamic_update_slice = builder.AddInstruction(
HloInstruction::CreateDynamicUpdateSlice(shape, copy0, param1, {param2}));
module_->AddEntryComputation(builder.Build());
SCOPED_TRACE(module_->ToString());
HloAliasAnalysis& analysis = RunAnalysis();
EXPECT_EQ(analysis.GetUniqueBufferAt(copy0),
analysis.GetUniqueBufferAt(dynamic_update_slice));
}
TEST_F(HloAliasAnalysisTest, DynamicUpdateSliceMultiOutputFusion) {
absl::string_view hlo_string = R"(
HloModule Module
fused_computation {
param0 = f32[1280,1,128] parameter(0)
param1 = f32[1280,1,128] parameter(1)
param2 = f32[1280,1,128] parameter(2)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
add.1 = f32[1280,1,128] add(param0, param0)
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3)
dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3)
ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6)
}
ENTRY main {
param = f32[1280,1,128] parameter(0)
negate0 = f32[1280,1,128] negate(param)
negate1 = f32[1280,1,128] negate(param)
negate2 = f32[1280,1,128] negate(param)
ROOT fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string));
SCOPED_TRACE(module_->ToString());
HloAliasAnalysis& analysis = RunAnalysis();
LOG(INFO) << analysis.ToString();
// Expect negate1 and negate2 to alias with fusion{1} and fusion{2}
// respectively (due to DUS), but not negate0 and fusion{0}.
const HloInstruction* fusion =
module_->entry_computation()->GetInstructionWithName("fusion");
const HloInstruction* negate0 =
module_->entry_computation()->GetInstructionWithName("negate0");
const HloInstruction* negate1 =
module_->entry_computation()->GetInstructionWithName("negate1");
const HloInstruction* negate2 =
module_->entry_computation()->GetInstructionWithName("negate2");
EXPECT_EQ(analysis.GetUniqueBufferAt(negate1),
analysis.GetUniqueBufferAt(fusion, {1}));
EXPECT_EQ(analysis.GetUniqueBufferAt(negate2),
analysis.GetUniqueBufferAt(fusion, {2}));
EXPECT_NE(analysis.GetUniqueBufferAt(negate0),
analysis.GetUniqueBufferAt(fusion, {0}));
}
TEST_F(HloAliasAnalysisTest, ChainedDynamicUpdateSliceFusion) {
// CPU and GPU backends may generate fusions with dynamic update slices
// feeding each other. They expect the fusion to not be in-place if that is
// the case.
absl::string_view hlo_string = R"(
HloModule Module
fused_computation {
param0 = f32[1280,1,128] parameter(0)
constant.1 = f32[] constant(0)
broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3)
ROOT dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3)
}
ENTRY main {
param = f32[1280,1,128] parameter(0)
negate0 = f32[1280,1,128] negate(param)
ROOT fusion = f32[1280,1,128] fusion(negate0), kind=kLoop, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string));
SCOPED_TRACE(module_->ToString());
HloAliasAnalysis& analysis = RunAnalysis();
LOG(INFO) << analysis.ToString();
const HloInstruction* fusion =
module_->entry_computation()->GetInstructionWithName("fusion");
const HloInstruction* negate0 =
module_->entry_computation()->GetInstructionWithName("negate0");
EXPECT_NE(analysis.GetUniqueBufferAt(negate0),
analysis.GetUniqueBufferAt(fusion));
}
TEST_F(HloAliasAnalysisTest, BitcastInterference) {
// A bitcast value simultaneously live with its operand should not cause
// interference.

View File

@ -1178,69 +1178,49 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer(
return true;
}
// Given a fusion whose root is a dynamic-update-slice op, determines whether
// the fusion's output buffer can be shared with the buffer of fusion_param,
// which must be a fused parameter of the fusion.
//
// Preconditions:
//
// - fusion's root is a dynamic-update-slice op.
// - fusion_param is a parameter within the fusion.
//
// fusion_param may point to a subelement of the actual parameter instruction if
// the param is a tuple; i.e. fusion_param->index() need not be the empty list.
//
// Returns true if:
//
// * fusion_param is used by the root of dynamic-update-slice as the "base" of
// the update, i.e. the thing being updated, AND
// * all other uses of fusion_param are dynamic-slices that slice the same
// indices as are overwritten in the dynamic-update-slice.
//
// In the case that there are no other uses of fusion_param (last bullet point
// is vacuously true) it's easy to see why an in-place DUS is safe; this is just
// the "natural" implementation of DUS. If there are other users, in-place DUS
// is safe on the assumption that the thread which writes element i of the
// output will be the only one to read element i of fusion_param (via the
// dynamic-slice ops).
static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion,
const HloValue& fusion_param_value) {
auto* root =
Cast<HloDynamicUpdateSliceInstruction>(fusion->fused_expression_root());
auto* fusion_param = fusion_param_value.instruction();
CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter);
CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation());
/*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) {
return opcode == HloOpcode::kDynamicUpdateSlice ||
opcode == HloOpcode::kScatter;
}
// fusion_param must be used by the root as the "base" of the
// dynamic-update-slice. The natural way to check this would be
//
// `if (root->operand(0) != fusion_param)`
//
// but we also have to handle the case where the fusion parameter is
// tuple-shaped and we're considering just one element of that tuple, i.e.
// fusion_param.index() != {}.
if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) {
return use.instruction == root;
}) != 1) {
return false;
/*static*/ std::vector<std::pair<HloUse, ShapeIndex>>
HloDataflowAnalysis::GetInPlaceInputOutputPairs(HloInstruction* instruction) {
if (IsInPlaceOperation(instruction->opcode())) {
return {{HloUse{instruction, 0, {}}, {}}};
} else if (instruction->opcode() != HloOpcode::kFusion) {
return {};
}
// All other uses of fusion_param must be dynamic-slices that slice the same
// indices as are overwritten by the dynamic-update-slice.
for (const HloUse& use : fusion_param_value.uses()) {
auto* user = use.instruction;
if (user == root) {
continue;
std::vector<std::pair<HloUse, ShapeIndex>> input_output_pairs;
for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) {
const HloInstruction* hlo_generating_output =
instruction->fused_expression_root();
for (int64 i = 0; i < indexed_shape.index.size(); ++i) {
if (hlo_generating_output->opcode() == HloOpcode::kTuple) {
hlo_generating_output =
hlo_generating_output->operand(indexed_shape.index[i]);
} else {
CHECK_EQ(i, indexed_shape.index.size() - 1);
}
}
// Check that `user` is a dynamic-slice op and has the same slice indices as
// `root`.
auto* ds = DynCast<HloDynamicSliceInstruction>(user);
if (!ds || ds->index_operands() != root->index_operands()) {
return false;
if (IsInPlaceOperation(hlo_generating_output->opcode())) {
ShapeIndex operand_index;
const HloInstruction* fusion_parameter =
hlo_generating_output->operand(0);
while (fusion_parameter->opcode() == HloOpcode::kGetTupleElement) {
operand_index.push_front(fusion_parameter->tuple_index());
fusion_parameter = fusion_parameter->operand(0);
}
if (fusion_parameter->opcode() == HloOpcode::kParameter) {
input_output_pairs.emplace_back(
HloUse{instruction, fusion_parameter->parameter_number(),
operand_index},
indexed_shape.index);
}
}
}
return true;
return input_output_pairs;
}
bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
@ -1261,24 +1241,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
return false;
}
if (user->opcode() == HloOpcode::kFusion) {
// Get the parameter associated with 'operand';
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
const HloValue& fusion_param_value =
GetValueDefinedAt(fusion_param, operand_index);
// TODO(b/80315712): This code is in a bit of a weird intermediate state
// at the moment. The in-place DUS check really needs to be common to all
// backends, so it runs first. Then we run the backend-specific check if
// provided, or go through the target-independent check if not.
// Unfortunately, the notionally "target-independent" path actually contains
// some target-specific code, so we can't run all of it *in addition* to the
// target-specific function, like the interface documentation says.
if (user->fused_expression_root()->opcode() ==
HloOpcode::kDynamicUpdateSlice) {
return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value);
// Must-alias relationship returns true for in-place operations (DUS and DUS
// fusions), regardless of the backend.
for (const auto& operand_and_output_index :
GetInPlaceInputOutputPairs(user)) {
if (operand_and_output_index.second != user_index) {
continue;
}
for (const HloUse& use : GetUniqueValueAt(operand, operand_index).uses()) {
if (use == operand_and_output_index.first) {
return true;
}
}
}

View File

@ -49,6 +49,9 @@ class HloDataflowAnalysis {
// Infrastructure for passing may-alias hints: HLO passes can populate the
// may-alias table. If an empty optional is returned, default rules are used.
//
// Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be
// overriden using backend-specific overrides.
//
// The first parameter of the function should be the instruction, the
// second parameter should be an operand of the instruction. The third
// parameter should be the output index of the instruction.
@ -160,6 +163,15 @@ class HloDataflowAnalysis {
const HloModule& module() const { return module_; }
// Returns true if the operation is an in-place operation and its operand 0
// must alias with the output.
static bool IsInPlaceOperation(HloOpcode opcode);
// Returns a vector consisting of the HloUse (operand number and shape index)
// and output shape index of the in-place operations within this HLO.
static std::vector<std::pair<HloUse, ShapeIndex>> GetInPlaceInputOutputPairs(
HloInstruction* instruction);
protected:
HloDataflowAnalysis(const HloModule& module, bool ssa_form,
bool bitcast_defines_value = false,

View File

@ -2324,36 +2324,6 @@ TEST_F(CanShareOperandBufferWithUserTest,
dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
}
TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithDifferentIndices) {
const char* kModule = R"(
HloModule test
fused_computation {
p0 = f32[10,20,30] parameter(0)
p1 = s32[] parameter(1)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30}
ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p3, p2)
}
ENTRY test {
p0 = f32[10,20,30] parameter(0)
p1 = s32[] parameter(1)
p2 = s32[] parameter(2)
p3 = s32[] parameter(3)
ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
auto* fusion = module_->entry_computation()->root_instruction();
auto* param = module_->entry_computation()->parameter_instruction(0);
RunAnalysis();
EXPECT_FALSE(
dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {}));
}
TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) {
const char* kModule = R"(
HloModule test

View File

@ -232,6 +232,24 @@ class MemorySpaceAssignmentTest : public HloTestBase,
return copies;
}
int64 GetAlternateMemoryOffset(const PresetAssignments& preset_assignments,
const HloInstruction* instruction,
const ShapeIndex& index = {}) const {
// Returns the offset of the assignment, -1 if it's not in the alternate
// memory.
const HloModule* module = instruction->parent()->parent();
auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie();
HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index);
for (auto& pos_and_chunk : preset_assignments.chunks()) {
for (auto& value : buffer.values()) {
if (pos_and_chunk.first == value->defining_position()) {
return pos_and_chunk.second.offset;
}
}
}
return -1;
}
std::unique_ptr<HloModule> CreateEvictAndPrefetchModule() {
HloComputation::Builder builder(TestName());
Shape shape = ShapeUtil::MakeShape(F32, {2, 3});
@ -4415,6 +4433,47 @@ TEST_P(MemorySpaceAssignmentTest, Determinism) {
}
}
TEST_P(MemorySpaceAssignmentTest, InPlaceOp) {
// Tests that in-place ops like DynamicUpdateSlice get the same allocation as
// its input.
absl::string_view hlo_string = R"(
HloModule Module, is_scheduled=true
fused_computation {
param0 = f32[2,3] parameter(0)
constant.1 = f32[] constant(0)
broadcast = f32[2,1] broadcast(constant.1), dimensions={}
constant.3 = s32[] constant(0)
ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3)
}
ENTRY main {
param = f32[2,3] parameter(0)
negate = f32[2,3] negate(param)
fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation
ROOT add = f32[2,3] add(fusion, fusion)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
auto preset_assignments = AssignMemorySpace(module.get());
HloInstruction* negate_instruction =
module->entry_computation()->GetInstructionWithName("negate");
int64 negate_offset =
GetAlternateMemoryOffset(*preset_assignments, negate_instruction);
HloInstruction* fusion_instruction =
module->entry_computation()->GetInstructionWithName("fusion");
int64 fusion_offset =
GetAlternateMemoryOffset(*preset_assignments, fusion_instruction);
// We expect negate and fusion to get the same offsets.
EXPECT_EQ(negate_offset, fusion_offset);
const bool allocate_across_sequential_calls = GetParam();
if (allocate_across_sequential_calls) {
EXPECT_NE(negate_offset, -1);
}
}
INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation,
MemorySpaceAssignmentTest,
::testing::Values(false, true));

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@ -338,6 +339,21 @@ bool MultiOutputFusion::LegalToFuseMainConstraints(HloInstruction* instr1,
if (!ShapesCompatibleForFusion(instr1, instr2)) {
return false;
}
// If both nodes are in-place operations and they use a common in-place
// operand, we can't fuse these two.
for (const auto& operand_and_output_index1 :
HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr1)) {
const HloInstruction* operand =
instr1->operand(operand_and_output_index1.first.operand_number);
for (const auto& operand_and_output_index2 :
HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr2)) {
if (operand ==
instr2->operand(operand_and_output_index2.first.operand_number)) {
return false;
}
}
}
return true;
}