Merge pull request #43964 from trentlo:horizontal_input_fusion_again

PiperOrigin-RevId: 337261311
Change-Id: I15498bba7ba9b77a2abf7001c3fe519408ee975c
This commit is contained in:
TensorFlower Gardener 2020-10-15 01:40:44 -07:00
commit bc8f385f4e
8 changed files with 533 additions and 42 deletions

View File

@ -1187,6 +1187,7 @@ cc_library(
":gpu_layout_assignment",
":gpu_sanitize_constant_names",
":gpu_scatter_expander",
":horizontal_input_fusion",
":horizontal_loop_fusion",
":instruction_fusion",
":ir_emission_utils",
@ -1770,6 +1771,7 @@ cc_library(
srcs = ["horizontal_loop_fusion.cc"],
hdrs = ["horizontal_loop_fusion.h"],
deps = [
":gpu_fusible",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_creation_utils",
@ -1806,6 +1808,45 @@ tf_cc_test(
],
)
cc_library(
name = "horizontal_input_fusion",
srcs = ["horizontal_input_fusion.cc"],
hdrs = ["horizontal_input_fusion.h"],
deps = [
":gpu_fusible",
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_creation_utils",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "horizontal_input_fusion_test",
srcs = ["horizontal_input_fusion_test.cc"],
tags = tf_cuda_tests_tags(),
deps = [
":horizontal_input_fusion",
":multi_output_fusion",
"//tensorflow/compiler/jit:xla_gpu_jit",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/service:hlo_matchers",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)
cc_library(
name = "reduction_degenerate_dim_remover",
srcs = ["reduction_degenerate_dim_remover.cc"],

View File

@ -59,6 +59,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@ -306,6 +307,7 @@ Status GpuCompiler::OptimizeHloModule(
HloPassPipeline horizontal_fusion("horizontal_fusion");
horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
/*only_fusion_computations=*/true);
horizontal_fusion.AddPass<HloDCE>();

View File

@ -143,29 +143,27 @@ bool IsInputFusibleReduction(const HloInstruction& instr) {
IsReductionFromOrToContiguousDimensions(instr);
}
const HloInstruction* GetRealHeroForMultiOutputFusion(
const HloInstruction& instr) {
if (instr.opcode() != HloOpcode::kFusion) {
return &instr;
}
auto fused_expression_root = instr.fused_expression_root();
if (!instr.IsMultiOutputFusion()) {
return fused_expression_root;
}
// If possible, we want to pick a reduction-from-or-to-contiguous-dims
// operand of the fusion root, because it has the most constraints.
for (const auto* inst : fused_expression_root->operands()) {
if (IsReductionFromOrToContiguousDimensions(*inst)) {
return inst;
}
}
return fused_expression_root->operands()[0];
}
bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
const HloInstruction& instr2) {
// Returns the instructions that determines the emitter used for lowering,
// sometimes referred to as "the real hero".
auto get_real_hero =
[&](const HloInstruction* instr) -> const HloInstruction* {
if (instr->opcode() != HloOpcode::kFusion) {
return instr;
}
auto fused_expression_root = instr->fused_expression_root();
if (!instr->IsMultiOutputFusion()) {
return fused_expression_root;
}
// If possible, we want to pick a reduction-to-vector operand of the
// fusion root, because it has the most constraints.
for (const auto* inst : fused_expression_root->operands()) {
if (IsReductionFromOrToContiguousDimensions(*inst)) {
return inst;
}
}
return fused_expression_root->operands()[0];
};
// Multi-output fusion kernels share a common parallel loop. The loop
// dimensions are determined by instruction shapes.
auto get_loop_shape = [&](const HloInstruction* element_instr) {
@ -181,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
// root ops should have equal output shapes. An exception are
// reduction-to-vector ops. Here the input shapes of the reduction (first
// operand shape) and the reduction dimensions need to match.
auto* instr_1 = get_real_hero(&instr1);
auto* instr_2 = get_real_hero(&instr2);
auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
IsReductionFromOrToContiguousDimensions(*instr_2) &&
!AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
@ -524,5 +522,24 @@ HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/,
: HloInstruction::FusionKind::kLoop;
}
bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
const HloInstruction& consumer) {
return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
if (user->opcode() == HloOpcode::kGetTupleElement) {
// Skip GTE.
return IsConsumerTheOnlyNonRootUser(*user, consumer);
}
if (user == &consumer) {
// `user` is `consumer`.
return true;
}
if (user == user->parent()->root_instruction()) {
// Consumed by ROOT.
return true;
}
return false;
});
}
} // namespace gpu
} // namespace xla

View File

@ -71,6 +71,11 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1,
bool CreatesNestedLoop(const HloInstruction& producer,
const HloInstruction& consumer);
// Returns the instruction that determines the emitter used for lowering,
// sometimes referred to as "the real hero".
const HloInstruction* GetRealHeroForMultiOutputFusion(
const HloInstruction& instr);
// Whether instruction shapes are compatible for multi-output fusion, i.e.
// whether the emitters support lowering the resulting fusion.
// This function works for both, sibling and producer-consumer multi-output
@ -100,6 +105,10 @@ bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr);
HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer,
const HloInstruction& consumer);
// Returns whether `consumer` is the only non-root user of `instr`.
bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
const HloInstruction& consumer);
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,167 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
#include <algorithm>
#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/core/platform/errors.h"
namespace xla {
namespace gpu {
namespace {
// Gets the representative input shape of the multi-output fusion.
Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
// Get the HLO that determines the emitter used for lowering.
const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
if (real_hero->operands().empty()) {
// Simply return an empty shape if the representative node has no input
// operands.
return Shape();
} else {
return real_hero->operand(0)->shape();
}
}
class HorizontalInputFusionImpl {
public:
explicit HorizontalInputFusionImpl(HloComputation* computation)
: computation_(computation) {}
~HorizontalInputFusionImpl() {}
StatusOr<bool> Run();
private:
HloComputation* computation_;
}; // HorizontalInputFusionImpl
// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to
// right.
bool CompareShapeDimsFromLeftToRight(const Shape& shape_a,
const Shape& shape_b) {
if (shape_a.rank() != shape_b.rank()) {
return shape_a.rank() < shape_b.rank();
}
auto dims_a = shape_a.dimensions();
auto dims_b = shape_b.dimensions();
for (size_t i = 0; i < dims_a.size(); ++i) {
if (dims_a[i] != dims_b[i]) {
return dims_a[i] < dims_b[i];
}
}
return true;
}
std::vector<HloInstruction*> FindAndSortFusionCandidates(
HloInstruction* consumer) {
absl::flat_hash_set<HloInstruction*> fusion_instr_set;
for (auto opnd : consumer->operands()) {
HloInstruction* predecessor = opnd->LatestNonGteAncestor();
// Find out the input fusion instructions whose only consumer is `consumer`.
// This guarantees that fusing these candidates will never create cycles, as
// there is no back edge.
if (IsReduceInputFusion(*predecessor) &&
IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) {
fusion_instr_set.insert(predecessor);
}
}
std::vector<HloInstruction*> fusion_instrs;
fusion_instrs.insert(fusion_instrs.end(), fusion_instr_set.begin(),
fusion_instr_set.end());
std::sort(fusion_instrs.begin(), fusion_instrs.end(),
[&](const HloInstruction* a, const HloInstruction* b) {
Shape shape_a = GetInputShapeForMultiOutputFusion(*a);
Shape shape_b = GetInputShapeForMultiOutputFusion(*b);
if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) {
// Sort shapes according to dimensions, so that the same input
// shapes will be placed adjacent each other.
return CompareShapeDimsFromLeftToRight(shape_a, shape_b);
}
// Sort `fusion_instrs` according to instruction counts, because
// we'd like to fuse together computations of similar sizes.
return a->fused_instruction_count() <
b->fused_instruction_count();
});
return fusion_instrs;
}
StatusOr<bool> HorizontalInputFusionImpl::Run() {
bool changed = false;
XLA_VLOG_LINES(3, computation_->ToString());
// Using def-to-use order is sound since we do not modify users.
std::vector<HloInstruction*> def_to_use_order =
computation_->MakeInstructionPostOrder();
for (auto consumer : def_to_use_order) {
auto candidates = FindAndSortFusionCandidates(consumer);
if (candidates.empty()) {
continue;
}
size_t fusion_anchor_id = 0;
for (size_t j = 1; j < candidates.size(); ++j) {
HloInstruction* fusion_anchor = candidates[fusion_anchor_id];
HloInstruction* fused = candidates[j];
if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
!FusionWouldBeTooLarge(*fusion_anchor, *fused)) {
VLOG(3) << "Fuse " << fused->ToString() << " into "
<< fusion_anchor->ToString();
fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
changed = true;
} else {
// Update the `fusion_anchor_id` since `fused` is either not
// compatible or not beneficial to be fused with current fusion anchor.
VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
fusion_anchor_id = j;
}
}
}
return changed;
}
} // namespace
StatusOr<bool> GpuHorizontalInputFusion::RunOnComputation(
HloComputation* computation) {
HorizontalInputFusionImpl horizontal_fusion_impl(computation);
return horizontal_fusion_impl.Run();
}
StatusOr<bool> GpuHorizontalInputFusion::Run(HloModule* module) {
bool changed = false;
VLOG(2) << "Run horizontal input fusion.";
for (auto* comp : module->MakeNonfusionComputations()) {
TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp));
}
return changed;
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,57 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
namespace gpu {
// This optimization pass horizontally fuses kInput fusions to both reduce the
// kernel launch overhead and increase parallelism degree. See
// GpuHorizontalFusion for general description and motivation about horizontal
// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals
// with kInput fusions.
//
// Following GpuHorizontalFusion, a simple yet effective heuristic is used
// to search the fusion candidates while avoiding creating cycles. That is,
// we simply search for fusion candidates by looking for instructions whose
// outputs are all consumed by the same instruction. This catches the typical
// target cases; often, the candidate instructions are just consumed by the
// ROOT tuple of the entry computation.
class GpuHorizontalInputFusion : public HloModulePass {
public:
GpuHorizontalInputFusion() {}
absl::string_view name() const override {
return "gpu_horizontal_input_fusion";
}
StatusOr<bool> Run(HloModule* module) override;
private:
StatusOr<bool> RunOnComputation(HloComputation*);
};
} // namespace gpu
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_

View File

@ -0,0 +1,216 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
namespace xla {
namespace gpu {
namespace {
namespace op = xla::testing::opcode_matchers;
class HorizontalInputFusionTest : public GpuCodegenTest {};
TEST_F(HorizontalInputFusionTest, BasicTest) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule BasicTest
%add_f16 {
%x = f16[] parameter(0)
%y = f16[] parameter(1)
ROOT %add = f16[] add(%x, %y)
}
fused_computation.1 {
arg.1 = f16[1024]{0} parameter(0)
constant0 = f16[] constant(0)
ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
}
fused_computation.2 {
arg.1 = f16[1024]{0} parameter(0)
constant0 = f16[] constant(0)
ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
}
ENTRY entry_computation {
arg.1 = f16[1024]{0} parameter(0)
arg.2 = f16[1024]{0} parameter(1)
fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1
fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2
ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2)
}
)")
.ValueOrDie();
EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie());
const HloInstruction* entry_root =
module->entry_computation()->root_instruction();
EXPECT_THAT(entry_root, op::Tuple((op::GetTupleElement(op::Fusion())),
(op::GetTupleElement(op::Fusion()))));
const HloInstruction* fusion = entry_root->operand(0)->operand(0);
ASSERT_TRUE(fusion->IsMultiOutputFusion());
EXPECT_THAT(fusion->fused_expression_root(),
op::Tuple(op::Reduce(), op::Reduce()));
}
TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
auto module = CreateNewVerifiedModule();
HloComputation* reduce_computation;
{
auto embedded_builder = HloComputation::Builder("add");
auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "lhs"));
auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {}), "rhs"));
embedded_builder.AddInstruction(
HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
reduce_computation =
module->AddEmbeddedComputation(embedded_builder.Build());
}
HloComputation::Builder builder(TestName());
std::vector<HloInstruction*> var_outs;
auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024});
auto output_shape = ShapeUtil::MakeShape(F32, {1024});
for (int64 i = 0; i < 130; ++i) {
// %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) ->
// f32[1024] {
// %param_0 = f32[1024,1024]{1,0} parameter(0)
// %param_1 = f32[] parameter(1)
// %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1),
// dimensions={}
// %multiply = f32[1024,1024]{1,0}
// multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0}
// %broadcast)
// %constant0 = f32[] constant(0)
// ROOT %reduce = f32[1024]{0}
// reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0),
// dimensions={1}, to_apply=%add
// }
HloInstruction* param_var_in = builder.AddInstruction(
HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in"));
HloInstruction* param_alpha =
builder.AddInstruction(HloInstruction::CreateParameter(
i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
auto alpha_broadcasted = builder.AddInstruction(
HloInstruction::CreateBroadcast(input_shape, param_alpha, {}));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted));
HloInstruction* const0 = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
output_shape, mul, const0, {1}, reduce_computation));
var_outs.push_back(reduce);
}
builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
module->AddEntryComputation(builder.Build());
// Verify that horizontal fusion is kicked in. Check that there are multiple
// `reduce` instructions fused into the same fusion. 6 is just a randomly
// picked number as we don't exactly know how large the fusion will be
// created due to the `FusionWouldBeTooLarge` constraint.
CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)",
/*match_optimized_ir=*/false);
// Testing with the entire gpu optimization pipeline.
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5}));
}
TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) {
// This tests the below pattern. One known issue is that gtes (to fusions) can
// be removed after their producer fusions are merged. In the below case, gte2
// and gte6 will be gone if Fusion2 is fused into Fusion1.
//
// Fusion1 Fusion2
// | | | |
// | gte1 gte2 |
// | | | |
// | Fusion3 |
// | | | |
// gte3 gte4 gte5 gte6
// \ | | /
// =====ROOT=====
//
auto module = ParseAndReturnVerifiedModule(R"(
HloModule MultiOutputFusionTest
%add_f16 {
%x = f16[] parameter(0)
%y = f16[] parameter(1)
ROOT %add = f16[] add(%x, %y)
}
fused_computation.1 {
arg.1 = f16[1024]{0} parameter(0)
constant0 = f16[] constant(0)
reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
add.0 = f16[1024] add(arg.1, arg.1)
ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
}
fused_computation.2 {
arg.1 = f16[1024]{0} parameter(0)
constant0 = f16[] constant(0)
reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
add.0 = f16[1024] add(arg.1, arg.1)
ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
}
fused_computation.3 {
arg.0 = f16[1024]{0} parameter(0)
arg.1 = f16[1024]{0} parameter(1)
add.0 = f16[1024] add(arg.0, arg.1)
mul.0 = f16[1024] multiply(arg.0, arg.1)
ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0)
}
ENTRY entry_computation {
arg.1 = f16[1024]{0} parameter(0)
arg.2 = f16[1024]{0} parameter(1)
fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1
fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2
gte.3 = f16[] get-tuple-element(fusion.1), index=0
gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1
gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1
gte.6 = f16[] get-tuple-element(fusion.2), index=0
fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2),
kind=kLoop, calls=fused_computation.3
gte.4 = f16[1024] get-tuple-element(fusion.3), index=0
gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1
ROOT tuple.1 = (f16[], f16[1024]{0}, f16[], f16[1024]{0})
tuple(gte.3, gte.4, gte.5, gte.6)
}
)")
.ValueOrDie();
EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie());
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/util/env_var.h"
@ -137,25 +138,6 @@ bool IsFusionSupported(const HloInstruction& instr) {
return true;
}
bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
const HloInstruction& consumer) {
return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
if (user->opcode() == HloOpcode::kGetTupleElement) {
// Skip GTE.
return IsConsumerTheOnlyNonRootUser(*user, consumer);
} else if (user == &consumer) {
// `user` is `consumer`.
return true;
} else if (user == user->parent()->root_instruction()) {
// Consumed by ROOT is always fine, since it is impossible to create
// cycles through ROOT.
return true;
} else {
return false;
}
});
}
// Returns whether `instr` is a profitable candidate to be horizontally fused.
// Since the primary benefit of horizontal fusion comes from reducing the
// kernel launch overhead, we want to exclude the instructions with