[XLA/GPU] Address review comments.

This commit is contained in:
Trent Lo 2020-09-09 13:25:52 -07:00
parent 6774af43c2
commit ea0b5fa33f
5 changed files with 24 additions and 31 deletions

View File

@ -1814,10 +1814,10 @@ cc_library(
deps = [ deps = [
":gpu_fusible", ":gpu_fusible",
":ir_emission_utils", ":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_creation_utils", "//tensorflow/compiler/xla/service:hlo_creation_utils",
"//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
@ -1827,22 +1827,17 @@ cc_library(
tf_cc_test( tf_cc_test(
name = "horizontal_input_fusion_test", name = "horizontal_input_fusion_test",
srcs = ["horizontal_input_fusion_test.cc"], srcs = ["horizontal_input_fusion_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [ deps = [
":fusion_merger",
":horizontal_input_fusion", ":horizontal_input_fusion",
":instruction_fusion",
":multi_output_fusion", ":multi_output_fusion",
"//tensorflow/compiler/jit:xla_gpu_jit",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
"//tensorflow/compiler/xla/service:hlo_dce",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
"//tensorflow/compiler/xla/tests:filecheck", "//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/compiler/xla/tests:xla_internal_test_main",
], ],

View File

@ -143,7 +143,7 @@ bool IsInputFusibleReduction(const HloInstruction& instr) {
IsReductionFromOrToContiguousDimensions(instr); IsReductionFromOrToContiguousDimensions(instr);
} }
const HloInstruction* GetMajorNodeForMultiOutputFusion( const HloInstruction* GetRealHeroForMultiOutputFusion(
const HloInstruction& instr) { const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) { if (instr.opcode() != HloOpcode::kFusion) {
return &instr; return &instr;
@ -152,8 +152,8 @@ const HloInstruction* GetMajorNodeForMultiOutputFusion(
if (!instr.IsMultiOutputFusion()) { if (!instr.IsMultiOutputFusion()) {
return fused_expression_root; return fused_expression_root;
} }
// If possible, we want to pick a reduction-to-vector operand of the // If possible, we want to pick a reduction-from-or-to-contiguous-dims
// fusion root, because it has the most constraints. // operand of the fusion root, because it has the most constraints.
for (const auto* inst : fused_expression_root->operands()) { for (const auto* inst : fused_expression_root->operands()) {
if (IsReductionFromOrToContiguousDimensions(*inst)) { if (IsReductionFromOrToContiguousDimensions(*inst)) {
return inst; return inst;
@ -179,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
// root ops should have equal output shapes. An exception are // root ops should have equal output shapes. An exception are
// reduction-to-vector ops. Here the input shapes of the reduction (first // reduction-to-vector ops. Here the input shapes of the reduction (first
// operand shape) and the reduction dimensions need to match. // operand shape) and the reduction dimensions need to match.
auto* instr_1 = GetMajorNodeForMultiOutputFusion(instr1); auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
auto* instr_2 = GetMajorNodeForMultiOutputFusion(instr2); auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
if (IsReductionFromOrToContiguousDimensions(*instr_1) && if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
IsReductionFromOrToContiguousDimensions(*instr_2) && IsReductionFromOrToContiguousDimensions(*instr_2) &&
!AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) { !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
@ -528,16 +528,16 @@ bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
if (user->opcode() == HloOpcode::kGetTupleElement) { if (user->opcode() == HloOpcode::kGetTupleElement) {
// Skip GTE. // Skip GTE.
return IsConsumerTheOnlyNonRootUser(*user, consumer); return IsConsumerTheOnlyNonRootUser(*user, consumer);
} else if (user == &consumer) { }
if (user == &consumer) {
// `user` is `consumer`. // `user` is `consumer`.
return true; return true;
} else if (user == user->parent()->root_instruction()) {
// Consumed by ROOT is always fine, since it is impossible to create
// cycles through ROOT.
return true;
} else {
return false;
} }
if (user == user->parent()->root_instruction()) {
// Consumed by ROOT.
return true;
}
return false;
}); });
} }

View File

@ -73,7 +73,7 @@ bool CreatesNestedLoop(const HloInstruction& producer,
// Returns the instruction that determines the emitter used for lowering, // Returns the instruction that determines the emitter used for lowering,
// sometimes referred to as "the real hero". // sometimes referred to as "the real hero".
const HloInstruction* GetMajorNodeForMultiOutputFusion( const HloInstruction* GetRealHeroForMultiOutputFusion(
const HloInstruction& instr); const HloInstruction& instr);
// Whether instruction shapes are compatible for multi-output fusion, i.e. // Whether instruction shapes are compatible for multi-output fusion, i.e.

View File

@ -33,8 +33,8 @@ namespace {
// Gets the representative input shape of the multi-output fusion. // Gets the representative input shape of the multi-output fusion.
Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) { Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
// Get the major node used in the emitter. // Get the HLO that determines the emitter used for lowering.
const HloInstruction* real_hero = GetMajorNodeForMultiOutputFusion(instr); const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
if (real_hero->operands().empty()) { if (real_hero->operands().empty()) {
// Simply return an empty shape if the representative node has no input // Simply return an empty shape if the representative node has no input
// operands. // operands.
@ -118,15 +118,13 @@ StatusOr<bool> HorizontalInputFusionImpl::Run() {
HloInstruction* fused = candidates[j]; HloInstruction* fused = candidates[j];
if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) && if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
!FusionWouldBeTooLarge(*fusion_anchor, *fused)) { !FusionWouldBeTooLarge(*fusion_anchor, *fused)) {
VLOG(3) << absl::StrCat("Fuse ", fused->ToString(), " into ", VLOG(3) << "Fuse " << fused->ToString() << " into " << fusion_anchor->ToString();
fusion_anchor->ToString());
fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused); fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
changed = true; changed = true;
} else { } else {
// Update the `fusion_anchor_id` since `fused` is either not // Update the `fusion_anchor_id` since `fused` is either not
// compatible or not beneficial to be fused with current fusion anchor. // compatible or not beneficial to be fused with current fusion anchor.
VLOG(3) << absl::StrCat(j - fusion_anchor_id - 1, VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
" instructions are fused");
fusion_anchor_id = j; fusion_anchor_id = j;
} }
} }

View File

@ -132,7 +132,7 @@ TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
// Verify that horizontal fusion is kicked in. Check that there are multiple // Verify that horizontal fusion is kicked in. Check that there are multiple
// `reduce` instructions fused into the same fusion. 6 is just a randomly // `reduce` instructions fused into the same fusion. 6 is just a randomly
// picked number as we don't exactly know how large the fusion will be // picked number as we don't exactly know how large the fusion will be
// created. // created due to the `FusionWouldBeTooLarge` constraint.
CompileAndVerifyIr(module->Clone(), CompileAndVerifyIr(module->Clone(),
R"(CHECK: reduce-group-6)", R"(CHECK: reduce-group-6)",
/*match_optimized_ir=*/false); /*match_optimized_ir=*/false);