diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index bfb3fca6a9b..6874d00445c 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -83,6 +83,7 @@ cc_library( deps = [ ":bfloat16_support", ":hlo", + ":hlo_dataflow_analysis", ":hlo_pass", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -1684,6 +1685,7 @@ cc_library( hdrs = ["multi_output_fusion.h"], deps = [ ":hlo", + ":hlo_dataflow_analysis", ":hlo_dce", ":hlo_pass", ":hlo_reachability", diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc index 23d2a9225a8..73210e6b3dc 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -18,6 +18,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -159,19 +160,20 @@ Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { // Do not fold BF16 conversions for instructions related to tuples, entry and - // exit of a computation, fusion, convert, side-effecting instructions and - // control flow. - if (hlo->opcode() == HloOpcode::kTuple || // - hlo->opcode() == HloOpcode::kGetTupleElement || // - hlo->opcode() == HloOpcode::kConstant || // - hlo->opcode() == HloOpcode::kParameter || // - hlo->opcode() == HloOpcode::kFusion || // - hlo->opcode() == HloOpcode::kBitcastConvert || // - hlo->opcode() == HloOpcode::kConvert || // - hlo->opcode() == HloOpcode::kCall || // - hlo->opcode() == HloOpcode::kCustomCall || // - hlo->opcode() == HloOpcode::kWhile || // - hlo->opcode() == HloOpcode::kConditional || // + // exit of a computation, fusion, convert, side-effecting instructions, + // in-place operations and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kBitcastConvert || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional || // + HloDataflowAnalysis::IsInPlaceOperation(hlo->opcode()) || // hlo->HasSideEffectNoRecurse()) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc index a0fe0eaa1d9..f9e19493a86 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -598,6 +598,31 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper( type = F32; break; } + // In order to find aliases due to in-place operations, use + // GetInPlaceInputOutputPairs. Ideally, we'd use HloAliasAnalysis here, + // but this code works with HloModules that aren't ready yet to use + // HloAliasAnalysis (e.g., their computation graphs may not have been + // flattened yet). + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(hlo)) { + if (operand_and_output_index.second == index) { + const HloUse& operand = operand_and_output_index.first; + for (const auto* value : + dataflow_ + ->GetValueSet(hlo->operand(operand.operand_number), + operand.operand_index) + .values()) { + auto value_type = ValueTypeAfterChange(value); + if (value_type == BF16) { + continue; + } + CHECK_EQ(value_type, F32); + type = F32; + break; + } + } + } + // It's possible that a user has been changed from BF16 to F32 // during this final adjustment pass, so we need to check // AllUsersConsumeBF16() again. diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index 02d79025f1b..9a898833373 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -1156,4 +1156,30 @@ ENTRY entry { EXPECT_FALSE(PropagatePrecision(module.get())); } +TEST_F(BFloat16PropagationTest, DynamicUpdateSlice) { + // This test is crafted so that the DUS has an f32 input (due to parameter) + // and bf16 output (due to dot). But we should enforce DUS operand 0 and + // output to get the same precision since it's an in-place operation. + const string module_str = R"( +HloModule Module + +ENTRY main { + param = f32[128,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + dynamic-update-slice = f32[128,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3) + ROOT dot = f32[128,128] dot(dynamic-update-slice, dynamic-update-slice), lhs_contracting_dims={1}, rhs_contracting_dims={0} +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str)); + EXPECT_FALSE(PropagatePrecision(module.get())); + + HloInstruction* dus = module->entry_computation()->GetInstructionWithName( + "dynamic-update-slice"); + EXPECT_FALSE(OutputsBF16(dus)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index a0989d5765e..db34f054f35 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1007,102 +1007,6 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation, return true; } // namespace xla -Status BufferAssigner::MergeInplaceOpBuffers(BufferAssignment* assignment) { - // Try allocate same buffer for dynamic update slice's operand and output. - - // If memory_space_assignment is run and there is information about a color in - // preset assignments, don't merge those buffers. We expect - // memory_space_assignment to have merged these buffers. If - // memory_space_assignment didn't merge these buffers and have assigned - // different offsets to the operand and the output buffer, merging the buffers - // can cause memory corruption if memory_space_assignment assigned a different - // buffer at the same offset. - absl::flat_hash_set excluded_colors; - if (preset_assignments_) { - for (const auto& color_and_info : - preset_assignments_->assignment_informations()) { - excluded_colors.insert(color_and_info.first); - } - } - - // TODO(yunxing): Moving this logic to alias analysis and add must-alias rule - // to operations that can be done in place. - for (HloComputation* computation : assignment->module().computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (!(instruction->opcode() == HloOpcode::kDynamicUpdateSlice || - (instruction->opcode() == HloOpcode::kFusion && - (instruction->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice)))) { - continue; - } - if (instruction->parent()->IsFusionComputation()) { - continue; - } - if (instruction->operand_count() == 0) { - continue; - } - - // The operand can't share the same buffer with the user based on dataflow - // analysis. - if (!assignment->dataflow_analysis().CanShareOperandBufferWithUser( - instruction->mutable_operand(0), {}, instruction, {})) { - continue; - } - HloBuffer& instruction_buffer = - assignment->alias_analysis().GetUniqueBufferAt(instruction, {}); - - HloBuffer& operand_buffer = - assignment->alias_analysis().GetUniqueBufferAt( - instruction->operand(0), {}); - - // The instruction or operand color is excluded because it was assigned by - // memory_space_assignment. - if (excluded_colors.contains(instruction_buffer.color()) || - excluded_colors.contains(operand_buffer.color())) { - continue; - } - - // Already have the same buffer. No need to merge those. - if (instruction_buffer.id() == operand_buffer.id()) { - continue; - } - - // Do not perform in-place dynamic update slice if the operand buffer is - // read-only. - if (HloBufferIsReadOnly(operand_buffer)) { - continue; - } - - bool interfere = false; - - for (const HloValue* instruction_value : instruction_buffer.values()) { - for (const HloValue* operand_value : operand_buffer.values()) { - if (assignment->hlo_ordering().MayInterfere( - *instruction_value, *operand_value, - assignment->dataflow_analysis())) { - interfere = true; - break; - } - } - } - if (interfere) { - continue; - } - if (assignment->alias_analysis().BufferLivesOut(instruction_buffer)) { - continue; - } - if (instruction_buffer.color() != operand_buffer.color()) { - continue; - } - VLOG(3) << "Merging inplace " << instruction_buffer << " and " - << operand_buffer; - assignment->alias_analysis().MergeBuffers(instruction_buffer, - operand_buffer); - } - } - return Status::OK(); -} - Status BufferAssigner::AssignSingleHloBuffer( const HloBuffer* hlo_buffer, bool is_thread_local, absl::flat_hash_map> BufferAssigner::CreateAssignment( VLOG(3) << "After coloring:"; XLA_VLOG_LINES(3, assignment->alias_analysis().dataflow_analysis().ToString()); - TF_RETURN_IF_ERROR(MergeInplaceOpBuffers(assignment.get())); std::vector thread_local_computations; std::vector global_computations; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 60422965832..dfde46ca4b1 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -635,10 +635,6 @@ class BufferAssigner { absl::flat_hash_set* assigned_buffers, BufferAssignment* assignment); - // Promotes operations (DUS, scatter) to be done in place: If an operation can - // be done in place, merge its buffer with its operand buffer. - Status MergeInplaceOpBuffers(BufferAssignment* assignment); - // Assigns a single hlo buffer to an HLO allocation. Status AssignSingleHloBuffer( const HloBuffer* hlo_buffer, bool is_thread_local, diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index bc024f7144b..b49ca649f9a 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1925,8 +1925,10 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(hlo_text)); HloInstruction* parameter = m->entry_computation()->GetInstructionWithName("get-tuple-element.4"); - HloInstruction* dus = + HloInstruction* dus1 = m->entry_computation()->GetInstructionWithName("dynamic-update-slice.5"); + HloInstruction* dus2 = + m->entry_computation()->GetInstructionWithName("dynamic-update-slice.9"); auto buffers = RunBufferAssignment(m.get()); @@ -1934,8 +1936,10 @@ ENTRY main { const BufferAllocation& parameter_alloc = GetTopLevelAllocation(*buffers, parameter); - const BufferAllocation& dus_alloc = GetTopLevelAllocation(*buffers, dus); - EXPECT_NE(parameter_alloc, dus_alloc); + const BufferAllocation& dus1_alloc = GetTopLevelAllocation(*buffers, dus1); + EXPECT_EQ(parameter_alloc, dus1_alloc); + const BufferAllocation& dus2_alloc = GetTopLevelAllocation(*buffers, dus2); + EXPECT_EQ(parameter_alloc, dus2_alloc); } } diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index b88120d8128..f2e37ca23b6 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -362,6 +362,19 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, return Status::OK(); } +// Add copies for the operands of in-place operations. RemoveUnnecessaryCopies +// will remove the unnecessary copies. +Status AddCopiesForInPlaceOperation(const HloAliasAnalysis& alias_analysis, + HloInstruction* in_place_op, + int64 operand_number) { + VLOG(2) << "Adding copies for in-place operation " << in_place_op->name(); + HloInstruction* operand = in_place_op->mutable_operand(operand_number); + TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, + in_place_op->parent()->DeepCopyInstruction(operand)); + TF_RETURN_IF_ERROR(operand->ReplaceUseWith(in_place_op, deep_copy)); + return Status::OK(); +} + // Conservatively adds copies before root instruction of entry computation and // each aliased parameter to resolve interference of aliased input and output // buffer. We later rely on RemoveUnnecessaryCopies to drop the unnecessary @@ -509,6 +522,12 @@ class CopyRemover { // value. The map is used to construct the copy info map below. absl::flat_hash_map value_to_node; for (const HloBuffer& buffer : alias_analysis.buffers()) { + // No copies should have been inserted within fused computations, so no + // need to remove them. HloOrdering isn't compatible with HloValues inside + // fusions, so skip copy removal for them. + if (buffer.values().at(0)->defining_instruction()->IsFused()) { + continue; + } // Verify values contained in the buffer are strictly ordered. This // should always be the case after adding copies to eliminate // interference. Specifically, the addition of the control flow edges @@ -591,7 +610,7 @@ class CopyRemover { void CreateCopyMap( const HloModule& module, const absl::flat_hash_map& value_to_node) { - for (HloComputation* computation : module.computations()) { + for (HloComputation* computation : module.MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->instructions()) { // Add copies with unambiguous source values to the map. Copies with // ambiguous sources are not removable. @@ -1005,7 +1024,7 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { TF_ASSIGN_OR_RETURN(std::unique_ptr alias_analysis, HloAliasAnalysis::Run(module, can_share_buffer_)); - for (HloComputation* computation : module->MakeComputationPostOrder()) { + for (HloComputation* computation : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) { if (instruction->opcode() == HloOpcode::kWhile) { @@ -1013,6 +1032,15 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { } else if (instruction->opcode() == HloOpcode::kConditional) { TF_RETURN_IF_ERROR( AddCopiesForConditional(*alias_analysis, instruction)); + } else { + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) { + const HloUse& operand = operand_and_output_index.first; + CHECK_EQ(operand.operand_index, ShapeIndex{}) + << "Support for non-{} shape operand not currently implemented."; + TF_RETURN_IF_ERROR(AddCopiesForInPlaceOperation( + *alias_analysis, instruction, operand.operand_number)); + } } } } diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 3ee6b200da5..78730cbdcb8 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -2530,5 +2530,250 @@ ENTRY Entry { EXPECT_EQ(CountCopies(*module), 1); } +TEST_F(CopyInsertionTest, DynamicUpdateSliceNoCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceNoCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + ROOT fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 0); +} + +TEST_F(CopyInsertionTest, DynamicUpdateSliceCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add = f32[1280,1,128] add(negate, negate) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(negate, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(add, dynamic-update-slice.5) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, DynamicUpdateSliceParameterShareCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + param = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param, broadcast.6, constant.3, constant.3, constant.3) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + add = f32[1280,1,128] add(negate, negate) + fusion = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation + ROOT tuple = (f32[1280,1,128], f32[1280,1,128]) tuple(negate, fusion) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, ChainDynamicUpdateSliceCopy) { + absl::string_view hlo_string = R"( +HloModule Module + +ENTRY main { + state = (s32[], f32[1280,1,128]{2,1,0}) parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128]{2,1,0} broadcast(constant.1), dimensions={} + get-tuple-element.4 = f32[1280,1,128]{2,1,0} get-tuple-element(state), index=1 + get-tuple-element.3 = s32[] get-tuple-element(state), index=0 + constant.2 = s32[] constant(128) + add.5 = s32[] add(get-tuple-element.3, constant.2) + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128]{2,1,0} dynamic-update-slice(get-tuple-element.4, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.9 = f32[1280,1,128]{2,1,0} dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.85 = (s32[], s32[], s32[2]{0}, f32[1280,1,128]{2,1,0}) tuple(add.5, dynamic-update-slice.9) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, FusedDynamicUpdateSliceCopy2) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation.1 { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) +} + +fused_computation.2 { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + slice = f32[128,1,128] slice(param1), slice={[0:128], [0:1], [0:128]} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, slice, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate = f32[1280,1,128] negate(param) + add = f32[1280,1,128] add(negate, negate) + fusion1 = f32[1280,1,128] fusion(negate), kind=kLoop, calls=fused_computation.1 + ROOT fusion2 = f32[1280,1,128] fusion(fusion1, negate), kind=kLoop, calls=fused_computation.2 +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + +TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceCopy) { + // Tests multi-output fusion with two DUS outputs, requiring two copies. + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + param2 = f32[1280,1,128] parameter(2) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add.1 = f32[1280,1,128] add(param0, param0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + negate1 = f32[1280,1,128] negate(param) + negate2 = f32[1280,1,128] negate(param) + fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation + gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0 + gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1 + gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2 + add0 = f32[1280,1,128] add(negate0, gte0) + add1 = f32[1280,1,128] add(negate1, gte1) + add2 = f32[1280,1,128] add(negate2, gte2) + ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 2); +} + +TEST_F(CopyInsertionTest, MultiOutputFusedDynamicUpdateSliceNoCopy) { + // Same as above, but negate1 is not used beyond fusion, so it only needs one + // copy for negate0. + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + param2 = f32[1280,1,128] parameter(2) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add.1 = f32[1280,1,128] add(param0, param0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + negate1 = f32[1280,1,128] negate(param) + negate2 = f32[1280,1,128] negate(param) + fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation + gte0 = f32[1280,1,128] get-tuple-element(fusion), index=0 + gte1 = f32[1280,1,128] get-tuple-element(fusion), index=1 + gte2 = f32[1280,1,128] get-tuple-element(fusion), index=2 + add0 = f32[1280,1,128] add(negate0, gte0) + add1 = f32[1280,1,128] add(gte1, gte1) + add2 = f32[1280,1,128] add(negate2, gte2) + ROOT tuple = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add0, add1, add2) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + InsertCopies(module.get()); + EXPECT_EQ(CountCopies(*module), 1); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo index c9e7daeb3bc..f625abe6612 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/scatter.hlo @@ -1,6 +1,6 @@ // RUN: hlo_to_llvm_ir %s | FileCheck %s -// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) { +// CHECK-LABEL: define void @scatter_TensorFlowScatterV1(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) { // CHECK: entry: // CHECK: %[[VAL_32:.*]] = alloca i32, align 4 // CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 @@ -43,8 +43,8 @@ // CHECK: store atomic i32 %[[VAL_36]], i32* %[[VAL_31]] unordered, align 4 // CHECK: br label %[[VAL_23]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatterV1, !"reqntidx", i32 6} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{i32 0, i32 6} // CHECK: !4 = !{} @@ -72,7 +72,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 64 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 %alloc3) { +// CHECK-LABEL: define void @scatter_ScatterIntoScalar(i8* noalias align 16 dereferenceable(4) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 %alloc2) { // CHECK: entry: // CHECK: %[[VAL_60:.*]] = alloca i32, align 4 // CHECK: %[[VAL_37:.*]] = getelementptr inbounds i8, i8* %[[VAL_38:.*]], i64 0 @@ -104,8 +104,8 @@ ENTRY main { // CHECK: store atomic i32 %[[VAL_62]], i32* %[[VAL_39]] unordered, align 4 // CHECK: br label %[[VAL_57]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScatterIntoScalar, !"reqntidx", i32 1} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{} @@ -131,7 +131,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 64 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(36) %alloc1, i8* noalias align 16 dereferenceable(24) %alloc2, i8* noalias align 16 dereferenceable(8) %alloc3) { +// CHECK-LABEL: define void @scatter_TensorFlowScatter_Mul(i8* noalias align 16 dereferenceable(36) %alloc0, i8* noalias align 16 dereferenceable(24) %alloc1, i8* noalias align 16 dereferenceable(8) %alloc2) { // CHECK: %[[VAL_63:.*]] = alloca i32, align 4 // CHECK: %[[VAL_64:.*]] = alloca i32, align 4 // CHECK: %[[VAL_98:.*]] = alloca i32, align 4 @@ -188,8 +188,8 @@ ENTRY main { // CHECK: %[[VAL_109:.*]] = extractvalue { i32, i1 } %[[VAL_107]], 1 // CHECK: br i1 %[[VAL_109]], label %[[VAL_96]], label %[[VAL_104]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_TensorFlowScatter_Mul, !"reqntidx", i32 6} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{i32 0, i32 6} // CHECK: !4 = !{} @@ -216,7 +216,7 @@ ENTRY main { // ----- -// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 64 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(16) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2, i8* noalias align 16 dereferenceable(4) %alloc3) { +// CHECK-LABEL: define void @scatter_ScalarUpdate(i8* noalias align 16 dereferenceable(16) %alloc0, i8* noalias align 16 dereferenceable(4) %alloc1, i8* noalias align 16 dereferenceable(4) %alloc2) { // CHECK: entry: // CHECK: %[[VAL_146:.*]] = alloca i32, align 4 // CHECK: %[[VAL_118:.*]] = getelementptr inbounds i8, i8* %[[VAL_119:.*]], i64 0 @@ -253,8 +253,8 @@ ENTRY main { // CHECK: store atomic i32 %[[VAL_148]], i32* %[[VAL_145]] unordered, align 4 // CHECK: br label %[[VAL_138]] // CHECK: !nvvm.annotations = !{!0, !1} -// CHECK: !0 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1} -// CHECK: !1 = !{void (i8*, i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1} +// CHECK: !0 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"kernel", i32 1} +// CHECK: !1 = !{void (i8*, i8*, i8*)* @scatter_ScalarUpdate, !"reqntidx", i32 1} // CHECK: !2 = !{i32 0, i32 1} // CHECK: !3 = !{} diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 384ae272dc1..cf09ddeec27 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -308,6 +308,39 @@ class BufferValueMap { } } + void ComputeInPlaceOperationAliasedBuffers( + const HloValue& value, std::vector* aliased_buffers) { + VLOG(3) << "Compute aliases for in-place operations (e.g. " + "kDynamicUpdateSlice and kScatter)"; + for (const HloPosition& position : value.positions()) { + HloInstruction* instruction = position.instruction; + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instruction)) { + if (position.index == operand_and_output_index.second) { + const HloUse& operand = operand_and_output_index.first; + const HloValue& operand_value = dataflow_.GetUniqueValueAt( + instruction->operand(operand.operand_number), + operand.operand_index); + VLOG(3) << " operand value " << operand_value.ToShortString() + << " aliases."; + aliased_buffers->push_back(GetBufferForValue(operand_value)); + } + } + } + + for (const HloUse& use : value.uses()) { + for (const auto& operand_and_output_index : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(use.instruction)) { + if (use == operand_and_output_index.first) { + const HloValue& use_value = dataflow_.GetUniqueValueAt( + use.instruction, operand_and_output_index.second); + VLOG(3) << " use value " << use_value.ToShortString() << " aliases."; + aliased_buffers->push_back(GetBufferForValue(use_value)); + } + } + } + } + // Compute and return a vector of buffers that the given value must be // contained in due to HLO aliasing rules. std::vector ComputeAliasedBuffers(const HloValue& value) { @@ -318,6 +351,7 @@ class BufferValueMap { ComputeInputOutputAliasedBuffers(value, &aliased_buffers); ComputeWhileAliasedBuffers(value, &aliased_buffers); ComputeConditionalAliasedBuffers(value, &aliased_buffers); + ComputeInPlaceOperationAliasedBuffers(value, &aliased_buffers); // Uniquify aliased buffers. absl::c_sort(aliased_buffers); aliased_buffers.erase( diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 2666cb0872d..5e94f1d173e 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -1062,6 +1062,118 @@ TEST_F(HloAliasAnalysisTest, MergeBuffersReverse) { analysis.BufferLivesOut(analysis.buffers()[0]); } +TEST_F(HloAliasAnalysisTest, DynamicUpdateSlice) { + Shape shape = ShapeUtil::MakeShape(F32, {8}); + Shape update_shape = ShapeUtil::MakeShape(F32, {4}); + Shape index_shape = ShapeUtil::MakeShape(S32, {}); + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, update_shape, "param1")); + auto param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, index_shape, "param2")); + auto copy0 = builder.AddInstruction( + HloInstruction::CreateUnary(shape, HloOpcode::kCopy, param0)); + auto dynamic_update_slice = builder.AddInstruction( + HloInstruction::CreateDynamicUpdateSlice(shape, copy0, param1, {param2})); + + module_->AddEntryComputation(builder.Build()); + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + + EXPECT_EQ(analysis.GetUniqueBufferAt(copy0), + analysis.GetUniqueBufferAt(dynamic_update_slice)); +} + +TEST_F(HloAliasAnalysisTest, DynamicUpdateSliceMultiOutputFusion) { + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + param1 = f32[1280,1,128] parameter(1) + param2 = f32[1280,1,128] parameter(2) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + add.1 = f32[1280,1,128] add(param0, param0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param1, broadcast.6, constant.3, constant.3, constant.3) + dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(param2, broadcast.6, constant.3, constant.3, constant.3) + ROOT tuple.1 = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) tuple(add.1, dynamic-update-slice.5, dynamic-update-slice.6) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + negate1 = f32[1280,1,128] negate(param) + negate2 = f32[1280,1,128] negate(param) + ROOT fusion = (f32[1280,1,128], f32[1280,1,128], f32[1280,1,128]) fusion(negate0, negate1, negate2), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string)); + + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + LOG(INFO) << analysis.ToString(); + + // Expect negate1 and negate2 to alias with fusion{1} and fusion{2} + // respectively (due to DUS), but not negate0 and fusion{0}. + const HloInstruction* fusion = + module_->entry_computation()->GetInstructionWithName("fusion"); + const HloInstruction* negate0 = + module_->entry_computation()->GetInstructionWithName("negate0"); + const HloInstruction* negate1 = + module_->entry_computation()->GetInstructionWithName("negate1"); + const HloInstruction* negate2 = + module_->entry_computation()->GetInstructionWithName("negate2"); + EXPECT_EQ(analysis.GetUniqueBufferAt(negate1), + analysis.GetUniqueBufferAt(fusion, {1})); + EXPECT_EQ(analysis.GetUniqueBufferAt(negate2), + analysis.GetUniqueBufferAt(fusion, {2})); + EXPECT_NE(analysis.GetUniqueBufferAt(negate0), + analysis.GetUniqueBufferAt(fusion, {0})); +} + +TEST_F(HloAliasAnalysisTest, ChainedDynamicUpdateSliceFusion) { + // CPU and GPU backends may generate fusions with dynamic update slices + // feeding each other. They expect the fusion to not be in-place if that is + // the case. + absl::string_view hlo_string = R"( +HloModule Module + +fused_computation { + param0 = f32[1280,1,128] parameter(0) + constant.1 = f32[] constant(0) + broadcast.6 = f32[128,1,128] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + dynamic-update-slice.5 = f32[1280,1,128] dynamic-update-slice(param0, broadcast.6, constant.3, constant.3, constant.3) + ROOT dynamic-update-slice.6 = f32[1280,1,128] dynamic-update-slice(dynamic-update-slice.5, broadcast.6, constant.3, constant.3, constant.3) +} + +ENTRY main { + param = f32[1280,1,128] parameter(0) + negate0 = f32[1280,1,128] negate(param) + ROOT fusion = f32[1280,1,128] fusion(negate0), kind=kLoop, calls=fused_computation +} +)"; + TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(hlo_string)); + + SCOPED_TRACE(module_->ToString()); + + HloAliasAnalysis& analysis = RunAnalysis(); + LOG(INFO) << analysis.ToString(); + + const HloInstruction* fusion = + module_->entry_computation()->GetInstructionWithName("fusion"); + const HloInstruction* negate0 = + module_->entry_computation()->GetInstructionWithName("negate0"); + EXPECT_NE(analysis.GetUniqueBufferAt(negate0), + analysis.GetUniqueBufferAt(fusion)); +} + TEST_F(HloAliasAnalysisTest, BitcastInterference) { // A bitcast value simultaneously live with its operand should not cause // interference. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 878442df2a2..72899ffe163 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -1178,69 +1178,49 @@ bool HloDataflowAnalysis::DoesNotUseOperandBuffer( return true; } -// Given a fusion whose root is a dynamic-update-slice op, determines whether -// the fusion's output buffer can be shared with the buffer of fusion_param, -// which must be a fused parameter of the fusion. -// -// Preconditions: -// -// - fusion's root is a dynamic-update-slice op. -// - fusion_param is a parameter within the fusion. -// -// fusion_param may point to a subelement of the actual parameter instruction if -// the param is a tuple; i.e. fusion_param->index() need not be the empty list. -// -// Returns true if: -// -// * fusion_param is used by the root of dynamic-update-slice as the "base" of -// the update, i.e. the thing being updated, AND -// * all other uses of fusion_param are dynamic-slices that slice the same -// indices as are overwritten in the dynamic-update-slice. -// -// In the case that there are no other uses of fusion_param (last bullet point -// is vacuously true) it's easy to see why an in-place DUS is safe; this is just -// the "natural" implementation of DUS. If there are other users, in-place DUS -// is safe on the assumption that the thread which writes element i of the -// output will be the only one to read element i of fusion_param (via the -// dynamic-slice ops). -static bool CanDoInPlaceDynamicUpdateSlice(HloInstruction* fusion, - const HloValue& fusion_param_value) { - auto* root = - Cast(fusion->fused_expression_root()); - auto* fusion_param = fusion_param_value.instruction(); - CHECK_EQ(fusion_param->opcode(), HloOpcode::kParameter); - CHECK_EQ(fusion_param->parent(), fusion->fused_instructions_computation()); +/*static*/ bool HloDataflowAnalysis::IsInPlaceOperation(HloOpcode opcode) { + return opcode == HloOpcode::kDynamicUpdateSlice || + opcode == HloOpcode::kScatter; +} - // fusion_param must be used by the root as the "base" of the - // dynamic-update-slice. The natural way to check this would be - // - // `if (root->operand(0) != fusion_param)` - // - // but we also have to handle the case where the fusion parameter is - // tuple-shaped and we're considering just one element of that tuple, i.e. - // fusion_param.index() != {}. - if (absl::c_count_if(fusion_param_value.uses(), [&](const HloUse& use) { - return use.instruction == root; - }) != 1) { - return false; +/*static*/ std::vector> +HloDataflowAnalysis::GetInPlaceInputOutputPairs(HloInstruction* instruction) { + if (IsInPlaceOperation(instruction->opcode())) { + return {{HloUse{instruction, 0, {}}, {}}}; + } else if (instruction->opcode() != HloOpcode::kFusion) { + return {}; } - - // All other uses of fusion_param must be dynamic-slices that slice the same - // indices as are overwritten by the dynamic-update-slice. - for (const HloUse& use : fusion_param_value.uses()) { - auto* user = use.instruction; - if (user == root) { - continue; + std::vector> input_output_pairs; + for (auto& indexed_shape : ShapeUtil::GetLeafShapes(instruction->shape())) { + const HloInstruction* hlo_generating_output = + instruction->fused_expression_root(); + for (int64 i = 0; i < indexed_shape.index.size(); ++i) { + if (hlo_generating_output->opcode() == HloOpcode::kTuple) { + hlo_generating_output = + hlo_generating_output->operand(indexed_shape.index[i]); + } else { + CHECK_EQ(i, indexed_shape.index.size() - 1); + } } - // Check that `user` is a dynamic-slice op and has the same slice indices as - // `root`. - auto* ds = DynCast(user); - if (!ds || ds->index_operands() != root->index_operands()) { - return false; + if (IsInPlaceOperation(hlo_generating_output->opcode())) { + ShapeIndex operand_index; + const HloInstruction* fusion_parameter = + hlo_generating_output->operand(0); + while (fusion_parameter->opcode() == HloOpcode::kGetTupleElement) { + operand_index.push_front(fusion_parameter->tuple_index()); + fusion_parameter = fusion_parameter->operand(0); + } + + if (fusion_parameter->opcode() == HloOpcode::kParameter) { + input_output_pairs.emplace_back( + HloUse{instruction, fusion_parameter->parameter_number(), + operand_index}, + indexed_shape.index); + } } } - return true; + return input_output_pairs; } bool HloDataflowAnalysis::CanShareOperandBufferWithUser( @@ -1261,24 +1241,17 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser( return false; } - if (user->opcode() == HloOpcode::kFusion) { - // Get the parameter associated with 'operand'; - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - - const HloValue& fusion_param_value = - GetValueDefinedAt(fusion_param, operand_index); - - // TODO(b/80315712): This code is in a bit of a weird intermediate state - // at the moment. The in-place DUS check really needs to be common to all - // backends, so it runs first. Then we run the backend-specific check if - // provided, or go through the target-independent check if not. - // Unfortunately, the notionally "target-independent" path actually contains - // some target-specific code, so we can't run all of it *in addition* to the - // target-specific function, like the interface documentation says. - if (user->fused_expression_root()->opcode() == - HloOpcode::kDynamicUpdateSlice) { - return CanDoInPlaceDynamicUpdateSlice(user, fusion_param_value); + // Must-alias relationship returns true for in-place operations (DUS and DUS + // fusions), regardless of the backend. + for (const auto& operand_and_output_index : + GetInPlaceInputOutputPairs(user)) { + if (operand_and_output_index.second != user_index) { + continue; + } + for (const HloUse& use : GetUniqueValueAt(operand, operand_index).uses()) { + if (use == operand_and_output_index.first) { + return true; + } } } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index bec592aeb20..ffa307d71dd 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -49,6 +49,9 @@ class HloDataflowAnalysis { // Infrastructure for passing may-alias hints: HLO passes can populate the // may-alias table. If an empty optional is returned, default rules are used. // + // Must-alias rules (as defined by GetInPlaceInputOutputPairs) cannot be + // overriden using backend-specific overrides. + // // The first parameter of the function should be the instruction, the // second parameter should be an operand of the instruction. The third // parameter should be the output index of the instruction. @@ -160,6 +163,15 @@ class HloDataflowAnalysis { const HloModule& module() const { return module_; } + // Returns true if the operation is an in-place operation and its operand 0 + // must alias with the output. + static bool IsInPlaceOperation(HloOpcode opcode); + + // Returns a vector consisting of the HloUse (operand number and shape index) + // and output shape index of the in-place operations within this HLO. + static std::vector> GetInPlaceInputOutputPairs( + HloInstruction* instruction); + protected: HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false, diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 551ffb52031..1fa6fe95c40 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -2324,36 +2324,6 @@ TEST_F(CanShareOperandBufferWithUserTest, dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); } -TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithDifferentIndices) { - const char* kModule = R"( - HloModule test - - fused_computation { - p0 = f32[10,20,30] parameter(0) - p1 = s32[] parameter(1) - p2 = s32[] parameter(2) - p3 = s32[] parameter(3) - slice = f32[1,1,30] dynamic-slice(p0, p1, p2, p3), dynamic_slice_sizes={1,1,30} - ROOT dus = f32[10,20,30] dynamic-update-slice(p0, slice, p1, p3, p2) - } - - ENTRY test { - p0 = f32[10,20,30] parameter(0) - p1 = s32[] parameter(1) - p2 = s32[] parameter(2) - p3 = s32[] parameter(3) - ROOT fusion = f32[10,20,30] fusion(p0, p1, p2, p3), kind=kLoop, calls=fused_computation - } - )"; - TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule)); - auto* fusion = module_->entry_computation()->root_instruction(); - auto* param = module_->entry_computation()->parameter_instruction(0); - - RunAnalysis(); - EXPECT_FALSE( - dataflow_analysis_->CanShareOperandBufferWithUser(param, {}, fusion, {})); -} - TEST_F(CanShareOperandBufferWithUserTest, DUSWithSliceWithSameIndices) { const char* kModule = R"( HloModule test diff --git a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc index 896e214858f..5af61eac5d1 100644 --- a/tensorflow/compiler/xla/service/memory_space_assignment_test.cc +++ b/tensorflow/compiler/xla/service/memory_space_assignment_test.cc @@ -232,6 +232,24 @@ class MemorySpaceAssignmentTest : public HloTestBase, return copies; } + int64 GetAlternateMemoryOffset(const PresetAssignments& preset_assignments, + const HloInstruction* instruction, + const ShapeIndex& index = {}) const { + // Returns the offset of the assignment, -1 if it's not in the alternate + // memory. + const HloModule* module = instruction->parent()->parent(); + auto alias_analysis = HloAliasAnalysis::Run(module).ValueOrDie(); + HloBuffer& buffer = alias_analysis->GetUniqueBufferAt(instruction, index); + for (auto& pos_and_chunk : preset_assignments.chunks()) { + for (auto& value : buffer.values()) { + if (pos_and_chunk.first == value->defining_position()) { + return pos_and_chunk.second.offset; + } + } + } + return -1; + } + std::unique_ptr CreateEvictAndPrefetchModule() { HloComputation::Builder builder(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); @@ -4415,6 +4433,47 @@ TEST_P(MemorySpaceAssignmentTest, Determinism) { } } +TEST_P(MemorySpaceAssignmentTest, InPlaceOp) { + // Tests that in-place ops like DynamicUpdateSlice get the same allocation as + // its input. + absl::string_view hlo_string = R"( +HloModule Module, is_scheduled=true + +fused_computation { + param0 = f32[2,3] parameter(0) + constant.1 = f32[] constant(0) + broadcast = f32[2,1] broadcast(constant.1), dimensions={} + constant.3 = s32[] constant(0) + ROOT dynamic-update-slice.5 = f32[2,3] dynamic-update-slice(param0, broadcast, constant.3, constant.3) +} + +ENTRY main { + param = f32[2,3] parameter(0) + negate = f32[2,3] negate(param) + fusion = f32[2,3] fusion(negate), kind=kLoop, calls=fused_computation + ROOT add = f32[2,3] add(fusion, fusion) +} + )"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, + ParseAndReturnVerifiedModule(hlo_string)); + auto preset_assignments = AssignMemorySpace(module.get()); + HloInstruction* negate_instruction = + module->entry_computation()->GetInstructionWithName("negate"); + int64 negate_offset = + GetAlternateMemoryOffset(*preset_assignments, negate_instruction); + HloInstruction* fusion_instruction = + module->entry_computation()->GetInstructionWithName("fusion"); + int64 fusion_offset = + GetAlternateMemoryOffset(*preset_assignments, fusion_instruction); + // We expect negate and fusion to get the same offsets. + EXPECT_EQ(negate_offset, fusion_offset); + const bool allocate_across_sequential_calls = GetParam(); + if (allocate_across_sequential_calls) { + EXPECT_NE(negate_offset, -1); + } +} + INSTANTIATE_TEST_SUITE_P(MemorySpaceAssignmentInstantiation, MemorySpaceAssignmentTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc index a21cec538d1..c5c2d081686 100644 --- a/tensorflow/compiler/xla/service/multi_output_fusion.cc +++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc @@ -17,6 +17,7 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -338,6 +339,21 @@ bool MultiOutputFusion::LegalToFuseMainConstraints(HloInstruction* instr1, if (!ShapesCompatibleForFusion(instr1, instr2)) { return false; } + + // If both nodes are in-place operations and they use a common in-place + // operand, we can't fuse these two. + for (const auto& operand_and_output_index1 : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr1)) { + const HloInstruction* operand = + instr1->operand(operand_and_output_index1.first.operand_number); + for (const auto& operand_and_output_index2 : + HloDataflowAnalysis::GetInPlaceInputOutputPairs(instr2)) { + if (operand == + instr2->operand(operand_and_output_index2.first.operand_number)) { + return false; + } + } + } return true; }