[XLA/GPU] Address review comments.

This commit is contained in:
Trent Lo 2020-09-09 13:25:52 -07:00
parent 6774af43c2
commit ea0b5fa33f
5 changed files with 24 additions and 31 deletions

View File

@ -1814,10 +1814,10 @@ cc_library(
deps = [
":gpu_fusible",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_creation_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/container:flat_hash_set",
@ -1827,22 +1827,17 @@ cc_library(
tf_cc_test(
name = "horizontal_input_fusion_test",
srcs = ["horizontal_input_fusion_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":fusion_merger",
":horizontal_input_fusion",
":instruction_fusion",
":multi_output_fusion",
"//tensorflow/compiler/jit:xla_gpu_jit",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_dce",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],

View File

@ -143,7 +143,7 @@ bool IsInputFusibleReduction(const HloInstruction& instr) {
IsReductionFromOrToContiguousDimensions(instr);
}
const HloInstruction* GetMajorNodeForMultiOutputFusion(
const HloInstruction* GetRealHeroForMultiOutputFusion(
const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) {
return &instr;
@ -152,8 +152,8 @@ const HloInstruction* GetMajorNodeForMultiOutputFusion(
if (!instr.IsMultiOutputFusion()) {
return fused_expression_root;
}
// If possible, we want to pick a reduction-to-vector operand of the
// fusion root, because it has the most constraints.
// If possible, we want to pick a reduction-from-or-to-contiguous-dims
// operand of the fusion root, because it has the most constraints.
for (const auto* inst : fused_expression_root->operands()) {
if (IsReductionFromOrToContiguousDimensions(*inst)) {
return inst;
@ -179,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
// root ops should have equal output shapes. An exception are
// reduction-to-vector ops. Here the input shapes of the reduction (first
// operand shape) and the reduction dimensions need to match.
auto* instr_1 = GetMajorNodeForMultiOutputFusion(instr1);
auto* instr_2 = GetMajorNodeForMultiOutputFusion(instr2);
auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
IsReductionFromOrToContiguousDimensions(*instr_2) &&
!AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
@ -528,16 +528,16 @@ bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
if (user->opcode() == HloOpcode::kGetTupleElement) {
// Skip GTE.
return IsConsumerTheOnlyNonRootUser(*user, consumer);
} else if (user == &consumer) {
}
if (user == &consumer) {
// `user` is `consumer`.
return true;
} else if (user == user->parent()->root_instruction()) {
// Consumed by ROOT is always fine, since it is impossible to create
// cycles through ROOT.
return true;
} else {
return false;
}
if (user == user->parent()->root_instruction()) {
// Consumed by ROOT.
return true;
}
return false;
});
}

View File

@ -73,7 +73,7 @@ bool CreatesNestedLoop(const HloInstruction& producer,
// Returns the instruction that determines the emitter used for lowering,
// sometimes referred to as "the real hero".
const HloInstruction* GetMajorNodeForMultiOutputFusion(
const HloInstruction* GetRealHeroForMultiOutputFusion(
const HloInstruction& instr);
// Whether instruction shapes are compatible for multi-output fusion, i.e.

View File

@ -33,8 +33,8 @@ namespace {
// Gets the representative input shape of the multi-output fusion.
Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
// Get the major node used in the emitter.
const HloInstruction* real_hero = GetMajorNodeForMultiOutputFusion(instr);
// Get the HLO that determines the emitter used for lowering.
const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
if (real_hero->operands().empty()) {
// Simply return an empty shape if the representative node has no input
// operands.
@ -118,15 +118,13 @@ StatusOr<bool> HorizontalInputFusionImpl::Run() {
HloInstruction* fused = candidates[j];
if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
!FusionWouldBeTooLarge(*fusion_anchor, *fused)) {
VLOG(3) << absl::StrCat("Fuse ", fused->ToString(), " into ",
fusion_anchor->ToString());
VLOG(3) << "Fuse " << fused->ToString() << " into " << fusion_anchor->ToString();
fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
changed = true;
} else {
// Update the `fusion_anchor_id` since `fused` is either not
// compatible or not beneficial to be fused with current fusion anchor.
VLOG(3) << absl::StrCat(j - fusion_anchor_id - 1,
" instructions are fused");
VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
fusion_anchor_id = j;
}
}

View File

@ -132,7 +132,7 @@ TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
// Verify that horizontal fusion is kicked in. Check that there are multiple
// `reduce` instructions fused into the same fusion. 6 is just a randomly
// picked number as we don't exactly know how large the fusion will be
// created.
// created due to the `FusionWouldBeTooLarge` constraint.
CompileAndVerifyIr(module->Clone(),
R"(CHECK: reduce-group-6)",
/*match_optimized_ir=*/false);