[XLA/GPU] Address review comments.
This commit is contained in:
parent
6774af43c2
commit
ea0b5fa33f
@ -1814,10 +1814,10 @@ cc_library(
|
||||
deps = [
|
||||
":gpu_fusible",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_creation_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
@ -1827,22 +1827,17 @@ cc_library(
|
||||
tf_cc_test(
|
||||
name = "horizontal_input_fusion_test",
|
||||
srcs = ["horizontal_input_fusion_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":fusion_merger",
|
||||
":horizontal_input_fusion",
|
||||
":instruction_fusion",
|
||||
":multi_output_fusion",
|
||||
"//tensorflow/compiler/jit:xla_gpu_jit",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/service:hlo_dce",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/tests:filecheck",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
|
@ -143,7 +143,7 @@ bool IsInputFusibleReduction(const HloInstruction& instr) {
|
||||
IsReductionFromOrToContiguousDimensions(instr);
|
||||
}
|
||||
|
||||
const HloInstruction* GetMajorNodeForMultiOutputFusion(
|
||||
const HloInstruction* GetRealHeroForMultiOutputFusion(
|
||||
const HloInstruction& instr) {
|
||||
if (instr.opcode() != HloOpcode::kFusion) {
|
||||
return &instr;
|
||||
@ -152,8 +152,8 @@ const HloInstruction* GetMajorNodeForMultiOutputFusion(
|
||||
if (!instr.IsMultiOutputFusion()) {
|
||||
return fused_expression_root;
|
||||
}
|
||||
// If possible, we want to pick a reduction-to-vector operand of the
|
||||
// fusion root, because it has the most constraints.
|
||||
// If possible, we want to pick a reduction-from-or-to-contiguous-dims
|
||||
// operand of the fusion root, because it has the most constraints.
|
||||
for (const auto* inst : fused_expression_root->operands()) {
|
||||
if (IsReductionFromOrToContiguousDimensions(*inst)) {
|
||||
return inst;
|
||||
@ -179,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
|
||||
// root ops should have equal output shapes. An exception are
|
||||
// reduction-to-vector ops. Here the input shapes of the reduction (first
|
||||
// operand shape) and the reduction dimensions need to match.
|
||||
auto* instr_1 = GetMajorNodeForMultiOutputFusion(instr1);
|
||||
auto* instr_2 = GetMajorNodeForMultiOutputFusion(instr2);
|
||||
auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
|
||||
auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
|
||||
if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
|
||||
IsReductionFromOrToContiguousDimensions(*instr_2) &&
|
||||
!AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
|
||||
@ -528,16 +528,16 @@ bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
|
||||
if (user->opcode() == HloOpcode::kGetTupleElement) {
|
||||
// Skip GTE.
|
||||
return IsConsumerTheOnlyNonRootUser(*user, consumer);
|
||||
} else if (user == &consumer) {
|
||||
}
|
||||
if (user == &consumer) {
|
||||
// `user` is `consumer`.
|
||||
return true;
|
||||
} else if (user == user->parent()->root_instruction()) {
|
||||
// Consumed by ROOT is always fine, since it is impossible to create
|
||||
// cycles through ROOT.
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
if (user == user->parent()->root_instruction()) {
|
||||
// Consumed by ROOT.
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -73,7 +73,7 @@ bool CreatesNestedLoop(const HloInstruction& producer,
|
||||
|
||||
// Returns the instruction that determines the emitter used for lowering,
|
||||
// sometimes referred to as "the real hero".
|
||||
const HloInstruction* GetMajorNodeForMultiOutputFusion(
|
||||
const HloInstruction* GetRealHeroForMultiOutputFusion(
|
||||
const HloInstruction& instr);
|
||||
|
||||
// Whether instruction shapes are compatible for multi-output fusion, i.e.
|
||||
|
@ -33,8 +33,8 @@ namespace {
|
||||
|
||||
// Gets the representative input shape of the multi-output fusion.
|
||||
Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
|
||||
// Get the major node used in the emitter.
|
||||
const HloInstruction* real_hero = GetMajorNodeForMultiOutputFusion(instr);
|
||||
// Get the HLO that determines the emitter used for lowering.
|
||||
const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
|
||||
if (real_hero->operands().empty()) {
|
||||
// Simply return an empty shape if the representative node has no input
|
||||
// operands.
|
||||
@ -118,15 +118,13 @@ StatusOr<bool> HorizontalInputFusionImpl::Run() {
|
||||
HloInstruction* fused = candidates[j];
|
||||
if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
|
||||
!FusionWouldBeTooLarge(*fusion_anchor, *fused)) {
|
||||
VLOG(3) << absl::StrCat("Fuse ", fused->ToString(), " into ",
|
||||
fusion_anchor->ToString());
|
||||
VLOG(3) << "Fuse " << fused->ToString() << " into " << fusion_anchor->ToString();
|
||||
fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
|
||||
changed = true;
|
||||
} else {
|
||||
// Update the `fusion_anchor_id` since `fused` is either not
|
||||
// compatible or not beneficial to be fused with current fusion anchor.
|
||||
VLOG(3) << absl::StrCat(j - fusion_anchor_id - 1,
|
||||
" instructions are fused");
|
||||
VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
|
||||
fusion_anchor_id = j;
|
||||
}
|
||||
}
|
||||
|
@ -132,7 +132,7 @@ TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
|
||||
// Verify that horizontal fusion is kicked in. Check that there are multiple
|
||||
// `reduce` instructions fused into the same fusion. 6 is just a randomly
|
||||
// picked number as we don't exactly know how large the fusion will be
|
||||
// created.
|
||||
// created due to the `FusionWouldBeTooLarge` constraint.
|
||||
CompileAndVerifyIr(module->Clone(),
|
||||
R"(CHECK: reduce-group-6)",
|
||||
/*match_optimized_ir=*/false);
|
||||
|
Loading…
x
Reference in New Issue
Block a user