[XLA/GPU] Address review comments.
This commit is contained in:
parent
6774af43c2
commit
ea0b5fa33f
@ -1814,10 +1814,10 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":gpu_fusible",
|
":gpu_fusible",
|
||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_creation_utils",
|
"//tensorflow/compiler/xla/service:hlo_creation_utils",
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:lib_internal",
|
"//tensorflow/core:lib_internal",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
@ -1827,22 +1827,17 @@ cc_library(
|
|||||||
tf_cc_test(
|
tf_cc_test(
|
||||||
name = "horizontal_input_fusion_test",
|
name = "horizontal_input_fusion_test",
|
||||||
srcs = ["horizontal_input_fusion_test.cc"],
|
srcs = ["horizontal_input_fusion_test.cc"],
|
||||||
|
tags = tf_cuda_tests_tags(),
|
||||||
deps = [
|
deps = [
|
||||||
":fusion_merger",
|
|
||||||
":horizontal_input_fusion",
|
":horizontal_input_fusion",
|
||||||
":instruction_fusion",
|
|
||||||
":multi_output_fusion",
|
":multi_output_fusion",
|
||||||
"//tensorflow/compiler/jit:xla_gpu_jit",
|
|
||||||
"//tensorflow/compiler/xla:literal",
|
|
||||||
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo_dce",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
|
||||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:test_helpers",
|
"//tensorflow/compiler/xla:test_helpers",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||||
|
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
|
||||||
"//tensorflow/compiler/xla/tests:filecheck",
|
"//tensorflow/compiler/xla/tests:filecheck",
|
||||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
],
|
],
|
||||||
|
|||||||
@ -143,7 +143,7 @@ bool IsInputFusibleReduction(const HloInstruction& instr) {
|
|||||||
IsReductionFromOrToContiguousDimensions(instr);
|
IsReductionFromOrToContiguousDimensions(instr);
|
||||||
}
|
}
|
||||||
|
|
||||||
const HloInstruction* GetMajorNodeForMultiOutputFusion(
|
const HloInstruction* GetRealHeroForMultiOutputFusion(
|
||||||
const HloInstruction& instr) {
|
const HloInstruction& instr) {
|
||||||
if (instr.opcode() != HloOpcode::kFusion) {
|
if (instr.opcode() != HloOpcode::kFusion) {
|
||||||
return &instr;
|
return &instr;
|
||||||
@ -152,8 +152,8 @@ const HloInstruction* GetMajorNodeForMultiOutputFusion(
|
|||||||
if (!instr.IsMultiOutputFusion()) {
|
if (!instr.IsMultiOutputFusion()) {
|
||||||
return fused_expression_root;
|
return fused_expression_root;
|
||||||
}
|
}
|
||||||
// If possible, we want to pick a reduction-to-vector operand of the
|
// If possible, we want to pick a reduction-from-or-to-contiguous-dims
|
||||||
// fusion root, because it has the most constraints.
|
// operand of the fusion root, because it has the most constraints.
|
||||||
for (const auto* inst : fused_expression_root->operands()) {
|
for (const auto* inst : fused_expression_root->operands()) {
|
||||||
if (IsReductionFromOrToContiguousDimensions(*inst)) {
|
if (IsReductionFromOrToContiguousDimensions(*inst)) {
|
||||||
return inst;
|
return inst;
|
||||||
@ -179,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
|
|||||||
// root ops should have equal output shapes. An exception are
|
// root ops should have equal output shapes. An exception are
|
||||||
// reduction-to-vector ops. Here the input shapes of the reduction (first
|
// reduction-to-vector ops. Here the input shapes of the reduction (first
|
||||||
// operand shape) and the reduction dimensions need to match.
|
// operand shape) and the reduction dimensions need to match.
|
||||||
auto* instr_1 = GetMajorNodeForMultiOutputFusion(instr1);
|
auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
|
||||||
auto* instr_2 = GetMajorNodeForMultiOutputFusion(instr2);
|
auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
|
||||||
if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
|
if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
|
||||||
IsReductionFromOrToContiguousDimensions(*instr_2) &&
|
IsReductionFromOrToContiguousDimensions(*instr_2) &&
|
||||||
!AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
|
!AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
|
||||||
@ -528,16 +528,16 @@ bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
|
|||||||
if (user->opcode() == HloOpcode::kGetTupleElement) {
|
if (user->opcode() == HloOpcode::kGetTupleElement) {
|
||||||
// Skip GTE.
|
// Skip GTE.
|
||||||
return IsConsumerTheOnlyNonRootUser(*user, consumer);
|
return IsConsumerTheOnlyNonRootUser(*user, consumer);
|
||||||
} else if (user == &consumer) {
|
}
|
||||||
|
if (user == &consumer) {
|
||||||
// `user` is `consumer`.
|
// `user` is `consumer`.
|
||||||
return true;
|
return true;
|
||||||
} else if (user == user->parent()->root_instruction()) {
|
|
||||||
// Consumed by ROOT is always fine, since it is impossible to create
|
|
||||||
// cycles through ROOT.
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
if (user == user->parent()->root_instruction()) {
|
||||||
|
// Consumed by ROOT.
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -73,7 +73,7 @@ bool CreatesNestedLoop(const HloInstruction& producer,
|
|||||||
|
|
||||||
// Returns the instruction that determines the emitter used for lowering,
|
// Returns the instruction that determines the emitter used for lowering,
|
||||||
// sometimes referred to as "the real hero".
|
// sometimes referred to as "the real hero".
|
||||||
const HloInstruction* GetMajorNodeForMultiOutputFusion(
|
const HloInstruction* GetRealHeroForMultiOutputFusion(
|
||||||
const HloInstruction& instr);
|
const HloInstruction& instr);
|
||||||
|
|
||||||
// Whether instruction shapes are compatible for multi-output fusion, i.e.
|
// Whether instruction shapes are compatible for multi-output fusion, i.e.
|
||||||
|
|||||||
@ -33,8 +33,8 @@ namespace {
|
|||||||
|
|
||||||
// Gets the representative input shape of the multi-output fusion.
|
// Gets the representative input shape of the multi-output fusion.
|
||||||
Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
|
Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
|
||||||
// Get the major node used in the emitter.
|
// Get the HLO that determines the emitter used for lowering.
|
||||||
const HloInstruction* real_hero = GetMajorNodeForMultiOutputFusion(instr);
|
const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
|
||||||
if (real_hero->operands().empty()) {
|
if (real_hero->operands().empty()) {
|
||||||
// Simply return an empty shape if the representative node has no input
|
// Simply return an empty shape if the representative node has no input
|
||||||
// operands.
|
// operands.
|
||||||
@ -118,15 +118,13 @@ StatusOr<bool> HorizontalInputFusionImpl::Run() {
|
|||||||
HloInstruction* fused = candidates[j];
|
HloInstruction* fused = candidates[j];
|
||||||
if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
|
if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
|
||||||
!FusionWouldBeTooLarge(*fusion_anchor, *fused)) {
|
!FusionWouldBeTooLarge(*fusion_anchor, *fused)) {
|
||||||
VLOG(3) << absl::StrCat("Fuse ", fused->ToString(), " into ",
|
VLOG(3) << "Fuse " << fused->ToString() << " into " << fusion_anchor->ToString();
|
||||||
fusion_anchor->ToString());
|
|
||||||
fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
|
fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
|
||||||
changed = true;
|
changed = true;
|
||||||
} else {
|
} else {
|
||||||
// Update the `fusion_anchor_id` since `fused` is either not
|
// Update the `fusion_anchor_id` since `fused` is either not
|
||||||
// compatible or not beneficial to be fused with current fusion anchor.
|
// compatible or not beneficial to be fused with current fusion anchor.
|
||||||
VLOG(3) << absl::StrCat(j - fusion_anchor_id - 1,
|
VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
|
||||||
" instructions are fused");
|
|
||||||
fusion_anchor_id = j;
|
fusion_anchor_id = j;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -132,7 +132,7 @@ TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
|
|||||||
// Verify that horizontal fusion is kicked in. Check that there are multiple
|
// Verify that horizontal fusion is kicked in. Check that there are multiple
|
||||||
// `reduce` instructions fused into the same fusion. 6 is just a randomly
|
// `reduce` instructions fused into the same fusion. 6 is just a randomly
|
||||||
// picked number as we don't exactly know how large the fusion will be
|
// picked number as we don't exactly know how large the fusion will be
|
||||||
// created.
|
// created due to the `FusionWouldBeTooLarge` constraint.
|
||||||
CompileAndVerifyIr(module->Clone(),
|
CompileAndVerifyIr(module->Clone(),
|
||||||
R"(CHECK: reduce-group-6)",
|
R"(CHECK: reduce-group-6)",
|
||||||
/*match_optimized_ir=*/false);
|
/*match_optimized_ir=*/false);
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user