[XLA/GPU] Address review comments.
This commit is contained in:
		
							parent
							
								
									6774af43c2
								
							
						
					
					
						commit
						ea0b5fa33f
					
				@ -1814,10 +1814,10 @@ cc_library(
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":gpu_fusible",
 | 
			
		||||
        ":ir_emission_utils",
 | 
			
		||||
        "//tensorflow/compiler/xla:shape_util",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_creation_utils",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_pass",
 | 
			
		||||
        "//tensorflow/compiler/xla:shape_util",
 | 
			
		||||
        "//tensorflow/core:lib",
 | 
			
		||||
        "//tensorflow/core:lib_internal",
 | 
			
		||||
        "@com_google_absl//absl/container:flat_hash_set",
 | 
			
		||||
@ -1827,22 +1827,17 @@ cc_library(
 | 
			
		||||
tf_cc_test(
 | 
			
		||||
    name = "horizontal_input_fusion_test",
 | 
			
		||||
    srcs = ["horizontal_input_fusion_test.cc"],
 | 
			
		||||
    tags = tf_cuda_tests_tags(),
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":fusion_merger",
 | 
			
		||||
        ":horizontal_input_fusion",
 | 
			
		||||
        ":instruction_fusion",
 | 
			
		||||
        ":multi_output_fusion",
 | 
			
		||||
        "//tensorflow/compiler/jit:xla_gpu_jit",
 | 
			
		||||
        "//tensorflow/compiler/xla:literal",
 | 
			
		||||
        "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_dce",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_matchers",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_parser",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_pass",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
 | 
			
		||||
        "//tensorflow/compiler/xla:shape_util",
 | 
			
		||||
        "//tensorflow/compiler/xla:test",
 | 
			
		||||
        "//tensorflow/compiler/xla:test_helpers",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_matchers",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_parser",
 | 
			
		||||
        "//tensorflow/compiler/xla/service:hlo_pass_pipeline",
 | 
			
		||||
        "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
 | 
			
		||||
        "//tensorflow/compiler/xla/tests:filecheck",
 | 
			
		||||
        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
 | 
			
		||||
    ],
 | 
			
		||||
 | 
			
		||||
@ -143,7 +143,7 @@ bool IsInputFusibleReduction(const HloInstruction& instr) {
 | 
			
		||||
         IsReductionFromOrToContiguousDimensions(instr);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const HloInstruction* GetMajorNodeForMultiOutputFusion(
 | 
			
		||||
const HloInstruction* GetRealHeroForMultiOutputFusion(
 | 
			
		||||
    const HloInstruction& instr) {
 | 
			
		||||
  if (instr.opcode() != HloOpcode::kFusion) {
 | 
			
		||||
    return &instr;
 | 
			
		||||
@ -152,8 +152,8 @@ const HloInstruction* GetMajorNodeForMultiOutputFusion(
 | 
			
		||||
  if (!instr.IsMultiOutputFusion()) {
 | 
			
		||||
    return fused_expression_root;
 | 
			
		||||
  }
 | 
			
		||||
  // If possible, we want to pick a reduction-to-vector operand of the
 | 
			
		||||
  // fusion root, because it has the most constraints.
 | 
			
		||||
  // If possible, we want to pick a reduction-from-or-to-contiguous-dims
 | 
			
		||||
  // operand of the fusion root, because it has the most constraints.
 | 
			
		||||
  for (const auto* inst : fused_expression_root->operands()) {
 | 
			
		||||
    if (IsReductionFromOrToContiguousDimensions(*inst)) {
 | 
			
		||||
      return inst;
 | 
			
		||||
@ -179,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
 | 
			
		||||
  // root ops should have equal output shapes. An exception are
 | 
			
		||||
  // reduction-to-vector ops. Here the input shapes of the reduction (first
 | 
			
		||||
  // operand shape) and the reduction dimensions need to match.
 | 
			
		||||
  auto* instr_1 = GetMajorNodeForMultiOutputFusion(instr1);
 | 
			
		||||
  auto* instr_2 = GetMajorNodeForMultiOutputFusion(instr2);
 | 
			
		||||
  auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
 | 
			
		||||
  auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
 | 
			
		||||
  if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
 | 
			
		||||
      IsReductionFromOrToContiguousDimensions(*instr_2) &&
 | 
			
		||||
      !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
 | 
			
		||||
@ -528,16 +528,16 @@ bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
 | 
			
		||||
    if (user->opcode() == HloOpcode::kGetTupleElement) {
 | 
			
		||||
      // Skip GTE.
 | 
			
		||||
      return IsConsumerTheOnlyNonRootUser(*user, consumer);
 | 
			
		||||
    } else if (user == &consumer) {
 | 
			
		||||
    }
 | 
			
		||||
    if (user == &consumer) {
 | 
			
		||||
      // `user` is `consumer`.
 | 
			
		||||
      return true;
 | 
			
		||||
    } else if (user == user->parent()->root_instruction()) {
 | 
			
		||||
      // Consumed by ROOT is always fine, since it is impossible to create
 | 
			
		||||
      // cycles through ROOT.
 | 
			
		||||
      return true;
 | 
			
		||||
    } else {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
    if (user == user->parent()->root_instruction()) {
 | 
			
		||||
      // Consumed by ROOT.
 | 
			
		||||
      return true;
 | 
			
		||||
    }
 | 
			
		||||
    return false;
 | 
			
		||||
  });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -73,7 +73,7 @@ bool CreatesNestedLoop(const HloInstruction& producer,
 | 
			
		||||
 | 
			
		||||
// Returns the instruction that determines the emitter used for lowering,
 | 
			
		||||
// sometimes referred to as "the real hero".
 | 
			
		||||
const HloInstruction* GetMajorNodeForMultiOutputFusion(
 | 
			
		||||
const HloInstruction* GetRealHeroForMultiOutputFusion(
 | 
			
		||||
    const HloInstruction& instr);
 | 
			
		||||
 | 
			
		||||
// Whether instruction shapes are compatible for multi-output fusion, i.e.
 | 
			
		||||
 | 
			
		||||
@ -33,8 +33,8 @@ namespace {
 | 
			
		||||
 | 
			
		||||
// Gets the representative input shape of the multi-output fusion.
 | 
			
		||||
Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
 | 
			
		||||
  // Get the major node used in the emitter.
 | 
			
		||||
  const HloInstruction* real_hero = GetMajorNodeForMultiOutputFusion(instr);
 | 
			
		||||
  // Get the HLO that determines the emitter used for lowering.
 | 
			
		||||
  const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
 | 
			
		||||
  if (real_hero->operands().empty()) {
 | 
			
		||||
    // Simply return an empty shape if the representative node has no input
 | 
			
		||||
    // operands.
 | 
			
		||||
@ -118,15 +118,13 @@ StatusOr<bool> HorizontalInputFusionImpl::Run() {
 | 
			
		||||
      HloInstruction* fused = candidates[j];
 | 
			
		||||
      if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
 | 
			
		||||
          !FusionWouldBeTooLarge(*fusion_anchor, *fused)) {
 | 
			
		||||
        VLOG(3) << absl::StrCat("Fuse ", fused->ToString(), " into ",
 | 
			
		||||
                                fusion_anchor->ToString());
 | 
			
		||||
        VLOG(3) << "Fuse " << fused->ToString() << " into " << fusion_anchor->ToString();
 | 
			
		||||
        fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
 | 
			
		||||
        changed = true;
 | 
			
		||||
      } else {
 | 
			
		||||
        // Update the `fusion_anchor_id` since `fused` is either not
 | 
			
		||||
        // compatible or not beneficial to be fused with current fusion anchor.
 | 
			
		||||
        VLOG(3) << absl::StrCat(j - fusion_anchor_id - 1,
 | 
			
		||||
                                " instructions are fused");
 | 
			
		||||
        VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
 | 
			
		||||
        fusion_anchor_id = j;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -132,7 +132,7 @@ TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
 | 
			
		||||
  // Verify that horizontal fusion is kicked in. Check that there are multiple
 | 
			
		||||
  // `reduce` instructions fused into the same fusion. 6 is just a randomly
 | 
			
		||||
  // picked number as we don't exactly know how large the fusion will be
 | 
			
		||||
  // created.
 | 
			
		||||
  // created due to the `FusionWouldBeTooLarge` constraint.
 | 
			
		||||
  CompileAndVerifyIr(module->Clone(),
 | 
			
		||||
                     R"(CHECK: reduce-group-6)",
 | 
			
		||||
                     /*match_optimized_ir=*/false);
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user