diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 39b32ec2d1d..39dad267acf 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1814,10 +1814,10 @@ cc_library( deps = [ ":gpu_fusible", ":ir_emission_utils", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "@com_google_absl//absl/container:flat_hash_set", @@ -1827,22 +1827,17 @@ cc_library( tf_cc_test( name = "horizontal_input_fusion_test", srcs = ["horizontal_input_fusion_test.cc"], + tags = tf_cuda_tests_tags(), deps = [ - ":fusion_merger", ":horizontal_input_fusion", - ":instruction_fusion", ":multi_output_fusion", - "//tensorflow/compiler/jit:xla_gpu_jit", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", - "//tensorflow/compiler/xla/service:hlo_dce", - "//tensorflow/compiler/xla/service:hlo_matchers", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service:hlo_pass", - "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service:hlo_pass_pipeline", + "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:xla_internal_test_main", ], diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc index 9f8c3c81ad2..b69b32c17c5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.cc @@ -143,7 +143,7 @@ bool IsInputFusibleReduction(const HloInstruction& instr) { IsReductionFromOrToContiguousDimensions(instr); } -const HloInstruction* GetMajorNodeForMultiOutputFusion( +const HloInstruction* GetRealHeroForMultiOutputFusion( const HloInstruction& instr) { if (instr.opcode() != HloOpcode::kFusion) { return &instr; @@ -152,8 +152,8 @@ const HloInstruction* GetMajorNodeForMultiOutputFusion( if (!instr.IsMultiOutputFusion()) { return fused_expression_root; } - // If possible, we want to pick a reduction-to-vector operand of the - // fusion root, because it has the most constraints. + // If possible, we want to pick a reduction-from-or-to-contiguous-dims + // operand of the fusion root, because it has the most constraints. for (const auto* inst : fused_expression_root->operands()) { if (IsReductionFromOrToContiguousDimensions(*inst)) { return inst; @@ -179,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1, // root ops should have equal output shapes. An exception are // reduction-to-vector ops. Here the input shapes of the reduction (first // operand shape) and the reduction dimensions need to match. - auto* instr_1 = GetMajorNodeForMultiOutputFusion(instr1); - auto* instr_2 = GetMajorNodeForMultiOutputFusion(instr2); + auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1); + auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2); if (IsReductionFromOrToContiguousDimensions(*instr_1) && IsReductionFromOrToContiguousDimensions(*instr_2) && !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) { @@ -528,16 +528,16 @@ bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr, if (user->opcode() == HloOpcode::kGetTupleElement) { // Skip GTE. return IsConsumerTheOnlyNonRootUser(*user, consumer); - } else if (user == &consumer) { + } + if (user == &consumer) { // `user` is `consumer`. return true; - } else if (user == user->parent()->root_instruction()) { - // Consumed by ROOT is always fine, since it is impossible to create - // cycles through ROOT. - return true; - } else { - return false; } + if (user == user->parent()->root_instruction()) { + // Consumed by ROOT. + return true; + } + return false; }); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h index 8595bb24ddf..9fa098a3394 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_fusible.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_fusible.h @@ -73,7 +73,7 @@ bool CreatesNestedLoop(const HloInstruction& producer, // Returns the instruction that determines the emitter used for lowering, // sometimes referred to as "the real hero". -const HloInstruction* GetMajorNodeForMultiOutputFusion( +const HloInstruction* GetRealHeroForMultiOutputFusion( const HloInstruction& instr); // Whether instruction shapes are compatible for multi-output fusion, i.e. diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc index 75a69611780..f25a283e4b9 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc @@ -33,8 +33,8 @@ namespace { // Gets the representative input shape of the multi-output fusion. Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) { - // Get the major node used in the emitter. - const HloInstruction* real_hero = GetMajorNodeForMultiOutputFusion(instr); + // Get the HLO that determines the emitter used for lowering. + const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr); if (real_hero->operands().empty()) { // Simply return an empty shape if the representative node has no input // operands. @@ -118,15 +118,13 @@ StatusOr HorizontalInputFusionImpl::Run() { HloInstruction* fused = candidates[j]; if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) && !FusionWouldBeTooLarge(*fusion_anchor, *fused)) { - VLOG(3) << absl::StrCat("Fuse ", fused->ToString(), " into ", - fusion_anchor->ToString()); + VLOG(3) << "Fuse " << fused->ToString() << " into " << fusion_anchor->ToString(); fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused); changed = true; } else { // Update the `fusion_anchor_id` since `fused` is either not // compatible or not beneficial to be fused with current fusion anchor. - VLOG(3) << absl::StrCat(j - fusion_anchor_id - 1, - " instructions are fused"); + VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused."; fusion_anchor_id = j; } } diff --git a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc index 035658fe55e..f27e77fad68 100644 --- a/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/horizontal_input_fusion_test.cc @@ -132,7 +132,7 @@ TEST_F(HorizontalInputFusionTest, ManyInputFusions) { // Verify that horizontal fusion is kicked in. Check that there are multiple // `reduce` instructions fused into the same fusion. 6 is just a randomly // picked number as we don't exactly know how large the fusion will be - // created. + // created due to the `FusionWouldBeTooLarge` constraint. CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)", /*match_optimized_ir=*/false);