Merge pull request #43964 from trentlo:horizontal_input_fusion_again
PiperOrigin-RevId: 337261311 Change-Id: I15498bba7ba9b77a2abf7001c3fe519408ee975c
This commit is contained in:
commit
bc8f385f4e
@ -1187,6 +1187,7 @@ cc_library(
|
||||
":gpu_layout_assignment",
|
||||
":gpu_sanitize_constant_names",
|
||||
":gpu_scatter_expander",
|
||||
":horizontal_input_fusion",
|
||||
":horizontal_loop_fusion",
|
||||
":instruction_fusion",
|
||||
":ir_emission_utils",
|
||||
@ -1770,6 +1771,7 @@ cc_library(
|
||||
srcs = ["horizontal_loop_fusion.cc"],
|
||||
hdrs = ["horizontal_loop_fusion.h"],
|
||||
deps = [
|
||||
":gpu_fusible",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_creation_utils",
|
||||
@ -1806,6 +1808,45 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "horizontal_input_fusion",
|
||||
srcs = ["horizontal_input_fusion.cc"],
|
||||
hdrs = ["horizontal_input_fusion.h"],
|
||||
deps = [
|
||||
":gpu_fusible",
|
||||
":ir_emission_utils",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_creation_utils",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "horizontal_input_fusion_test",
|
||||
srcs = ["horizontal_input_fusion_test.cc"],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":horizontal_input_fusion",
|
||||
":multi_output_fusion",
|
||||
"//tensorflow/compiler/jit:xla_gpu_jit",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla/service:hlo_matchers",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service:hlo_pass_pipeline",
|
||||
"//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla/tests:filecheck",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "reduction_degenerate_dim_remover",
|
||||
srcs = ["reduction_degenerate_dim_remover.cc"],
|
||||
|
@ -59,6 +59,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_sanitize_constant_names.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_scatter_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
@ -306,6 +307,7 @@ Status GpuCompiler::OptimizeHloModule(
|
||||
|
||||
HloPassPipeline horizontal_fusion("horizontal_fusion");
|
||||
horizontal_fusion.AddPass<GpuHorizontalLoopFusion>();
|
||||
horizontal_fusion.AddPass<GpuHorizontalInputFusion>();
|
||||
horizontal_fusion.AddPass<HloCSE>(/*is_layout_sensitive=*/true,
|
||||
/*only_fusion_computations=*/true);
|
||||
horizontal_fusion.AddPass<HloDCE>();
|
||||
|
@ -143,29 +143,27 @@ bool IsInputFusibleReduction(const HloInstruction& instr) {
|
||||
IsReductionFromOrToContiguousDimensions(instr);
|
||||
}
|
||||
|
||||
const HloInstruction* GetRealHeroForMultiOutputFusion(
|
||||
const HloInstruction& instr) {
|
||||
if (instr.opcode() != HloOpcode::kFusion) {
|
||||
return &instr;
|
||||
}
|
||||
auto fused_expression_root = instr.fused_expression_root();
|
||||
if (!instr.IsMultiOutputFusion()) {
|
||||
return fused_expression_root;
|
||||
}
|
||||
// If possible, we want to pick a reduction-from-or-to-contiguous-dims
|
||||
// operand of the fusion root, because it has the most constraints.
|
||||
for (const auto* inst : fused_expression_root->operands()) {
|
||||
if (IsReductionFromOrToContiguousDimensions(*inst)) {
|
||||
return inst;
|
||||
}
|
||||
}
|
||||
return fused_expression_root->operands()[0];
|
||||
}
|
||||
|
||||
bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
|
||||
const HloInstruction& instr2) {
|
||||
// Returns the instructions that determines the emitter used for lowering,
|
||||
// sometimes referred to as "the real hero".
|
||||
auto get_real_hero =
|
||||
[&](const HloInstruction* instr) -> const HloInstruction* {
|
||||
if (instr->opcode() != HloOpcode::kFusion) {
|
||||
return instr;
|
||||
}
|
||||
auto fused_expression_root = instr->fused_expression_root();
|
||||
if (!instr->IsMultiOutputFusion()) {
|
||||
return fused_expression_root;
|
||||
}
|
||||
// If possible, we want to pick a reduction-to-vector operand of the
|
||||
// fusion root, because it has the most constraints.
|
||||
for (const auto* inst : fused_expression_root->operands()) {
|
||||
if (IsReductionFromOrToContiguousDimensions(*inst)) {
|
||||
return inst;
|
||||
}
|
||||
}
|
||||
return fused_expression_root->operands()[0];
|
||||
};
|
||||
|
||||
// Multi-output fusion kernels share a common parallel loop. The loop
|
||||
// dimensions are determined by instruction shapes.
|
||||
auto get_loop_shape = [&](const HloInstruction* element_instr) {
|
||||
@ -181,8 +179,8 @@ bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
|
||||
// root ops should have equal output shapes. An exception are
|
||||
// reduction-to-vector ops. Here the input shapes of the reduction (first
|
||||
// operand shape) and the reduction dimensions need to match.
|
||||
auto* instr_1 = get_real_hero(&instr1);
|
||||
auto* instr_2 = get_real_hero(&instr2);
|
||||
auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
|
||||
auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
|
||||
if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
|
||||
IsReductionFromOrToContiguousDimensions(*instr_2) &&
|
||||
!AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
|
||||
@ -524,5 +522,24 @@ HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/,
|
||||
: HloInstruction::FusionKind::kLoop;
|
||||
}
|
||||
|
||||
bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
|
||||
const HloInstruction& consumer) {
|
||||
return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
|
||||
if (user->opcode() == HloOpcode::kGetTupleElement) {
|
||||
// Skip GTE.
|
||||
return IsConsumerTheOnlyNonRootUser(*user, consumer);
|
||||
}
|
||||
if (user == &consumer) {
|
||||
// `user` is `consumer`.
|
||||
return true;
|
||||
}
|
||||
if (user == user->parent()->root_instruction()) {
|
||||
// Consumed by ROOT.
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
@ -71,6 +71,11 @@ bool FusionWouldBeTooLarge(const HloInstruction& instr1,
|
||||
bool CreatesNestedLoop(const HloInstruction& producer,
|
||||
const HloInstruction& consumer);
|
||||
|
||||
// Returns the instruction that determines the emitter used for lowering,
|
||||
// sometimes referred to as "the real hero".
|
||||
const HloInstruction* GetRealHeroForMultiOutputFusion(
|
||||
const HloInstruction& instr);
|
||||
|
||||
// Whether instruction shapes are compatible for multi-output fusion, i.e.
|
||||
// whether the emitters support lowering the resulting fusion.
|
||||
// This function works for both, sibling and producer-consumer multi-output
|
||||
@ -100,6 +105,10 @@ bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr);
|
||||
HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& producer,
|
||||
const HloInstruction& consumer);
|
||||
|
||||
// Returns whether `consumer` is the only non-root user of `instr`.
|
||||
bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
|
||||
const HloInstruction& consumer);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
167
tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc
Normal file
167
tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.cc
Normal file
@ -0,0 +1,167 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
// Gets the representative input shape of the multi-output fusion.
|
||||
Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
|
||||
// Get the HLO that determines the emitter used for lowering.
|
||||
const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
|
||||
if (real_hero->operands().empty()) {
|
||||
// Simply return an empty shape if the representative node has no input
|
||||
// operands.
|
||||
return Shape();
|
||||
} else {
|
||||
return real_hero->operand(0)->shape();
|
||||
}
|
||||
}
|
||||
|
||||
class HorizontalInputFusionImpl {
|
||||
public:
|
||||
explicit HorizontalInputFusionImpl(HloComputation* computation)
|
||||
: computation_(computation) {}
|
||||
|
||||
~HorizontalInputFusionImpl() {}
|
||||
|
||||
StatusOr<bool> Run();
|
||||
|
||||
private:
|
||||
HloComputation* computation_;
|
||||
}; // HorizontalInputFusionImpl
|
||||
|
||||
// Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to
|
||||
// right.
|
||||
bool CompareShapeDimsFromLeftToRight(const Shape& shape_a,
|
||||
const Shape& shape_b) {
|
||||
if (shape_a.rank() != shape_b.rank()) {
|
||||
return shape_a.rank() < shape_b.rank();
|
||||
}
|
||||
auto dims_a = shape_a.dimensions();
|
||||
auto dims_b = shape_b.dimensions();
|
||||
for (size_t i = 0; i < dims_a.size(); ++i) {
|
||||
if (dims_a[i] != dims_b[i]) {
|
||||
return dims_a[i] < dims_b[i];
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<HloInstruction*> FindAndSortFusionCandidates(
|
||||
HloInstruction* consumer) {
|
||||
absl::flat_hash_set<HloInstruction*> fusion_instr_set;
|
||||
for (auto opnd : consumer->operands()) {
|
||||
HloInstruction* predecessor = opnd->LatestNonGteAncestor();
|
||||
// Find out the input fusion instructions whose only consumer is `consumer`.
|
||||
// This guarantees that fusing these candidates will never create cycles, as
|
||||
// there is no back edge.
|
||||
if (IsReduceInputFusion(*predecessor) &&
|
||||
IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) {
|
||||
fusion_instr_set.insert(predecessor);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<HloInstruction*> fusion_instrs;
|
||||
fusion_instrs.insert(fusion_instrs.end(), fusion_instr_set.begin(),
|
||||
fusion_instr_set.end());
|
||||
|
||||
std::sort(fusion_instrs.begin(), fusion_instrs.end(),
|
||||
[&](const HloInstruction* a, const HloInstruction* b) {
|
||||
Shape shape_a = GetInputShapeForMultiOutputFusion(*a);
|
||||
Shape shape_b = GetInputShapeForMultiOutputFusion(*b);
|
||||
if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) {
|
||||
// Sort shapes according to dimensions, so that the same input
|
||||
// shapes will be placed adjacent each other.
|
||||
return CompareShapeDimsFromLeftToRight(shape_a, shape_b);
|
||||
}
|
||||
// Sort `fusion_instrs` according to instruction counts, because
|
||||
// we'd like to fuse together computations of similar sizes.
|
||||
return a->fused_instruction_count() <
|
||||
b->fused_instruction_count();
|
||||
});
|
||||
|
||||
return fusion_instrs;
|
||||
}
|
||||
|
||||
StatusOr<bool> HorizontalInputFusionImpl::Run() {
|
||||
bool changed = false;
|
||||
XLA_VLOG_LINES(3, computation_->ToString());
|
||||
|
||||
// Using def-to-use order is sound since we do not modify users.
|
||||
std::vector<HloInstruction*> def_to_use_order =
|
||||
computation_->MakeInstructionPostOrder();
|
||||
for (auto consumer : def_to_use_order) {
|
||||
auto candidates = FindAndSortFusionCandidates(consumer);
|
||||
if (candidates.empty()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
size_t fusion_anchor_id = 0;
|
||||
for (size_t j = 1; j < candidates.size(); ++j) {
|
||||
HloInstruction* fusion_anchor = candidates[fusion_anchor_id];
|
||||
HloInstruction* fused = candidates[j];
|
||||
if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
|
||||
!FusionWouldBeTooLarge(*fusion_anchor, *fused)) {
|
||||
VLOG(3) << "Fuse " << fused->ToString() << " into "
|
||||
<< fusion_anchor->ToString();
|
||||
fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
|
||||
changed = true;
|
||||
} else {
|
||||
// Update the `fusion_anchor_id` since `fused` is either not
|
||||
// compatible or not beneficial to be fused with current fusion anchor.
|
||||
VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
|
||||
fusion_anchor_id = j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> GpuHorizontalInputFusion::RunOnComputation(
|
||||
HloComputation* computation) {
|
||||
HorizontalInputFusionImpl horizontal_fusion_impl(computation);
|
||||
return horizontal_fusion_impl.Run();
|
||||
}
|
||||
|
||||
StatusOr<bool> GpuHorizontalInputFusion::Run(HloModule* module) {
|
||||
bool changed = false;
|
||||
VLOG(2) << "Run horizontal input fusion.";
|
||||
for (auto* comp : module->MakeNonfusionComputations()) {
|
||||
TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp));
|
||||
}
|
||||
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
// This optimization pass horizontally fuses kInput fusions to both reduce the
|
||||
// kernel launch overhead and increase parallelism degree. See
|
||||
// GpuHorizontalFusion for general description and motivation about horizontal
|
||||
// fusion. GpuHorizontalFusion deals with kLoop fusions while this pass deals
|
||||
// with kInput fusions.
|
||||
//
|
||||
// Following GpuHorizontalFusion, a simple yet effective heuristic is used
|
||||
// to search the fusion candidates while avoiding creating cycles. That is,
|
||||
// we simply search for fusion candidates by looking for instructions whose
|
||||
// outputs are all consumed by the same instruction. This catches the typical
|
||||
// target cases; often, the candidate instructions are just consumed by the
|
||||
// ROOT tuple of the entry computation.
|
||||
class GpuHorizontalInputFusion : public HloModulePass {
|
||||
public:
|
||||
GpuHorizontalInputFusion() {}
|
||||
|
||||
absl::string_view name() const override {
|
||||
return "gpu_horizontal_input_fusion";
|
||||
}
|
||||
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
StatusOr<bool> RunOnComputation(HloComputation*);
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_HORIZONTAL_INPUT_FUSION_H_
|
@ -0,0 +1,216 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/tests/filecheck.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
namespace {
|
||||
|
||||
namespace op = xla::testing::opcode_matchers;
|
||||
|
||||
class HorizontalInputFusionTest : public GpuCodegenTest {};
|
||||
|
||||
TEST_F(HorizontalInputFusionTest, BasicTest) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule BasicTest
|
||||
|
||||
%add_f16 {
|
||||
%x = f16[] parameter(0)
|
||||
%y = f16[] parameter(1)
|
||||
ROOT %add = f16[] add(%x, %y)
|
||||
}
|
||||
|
||||
fused_computation.1 {
|
||||
arg.1 = f16[1024]{0} parameter(0)
|
||||
constant0 = f16[] constant(0)
|
||||
ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
|
||||
}
|
||||
|
||||
fused_computation.2 {
|
||||
arg.1 = f16[1024]{0} parameter(0)
|
||||
constant0 = f16[] constant(0)
|
||||
ROOT reduce1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
|
||||
}
|
||||
|
||||
ENTRY entry_computation {
|
||||
arg.1 = f16[1024]{0} parameter(0)
|
||||
arg.2 = f16[1024]{0} parameter(1)
|
||||
fusion.1 = f16[] fusion(arg.1), kind=kInput, calls=fused_computation.1
|
||||
fusion.2 = f16[] fusion(arg.2), kind=kInput, calls=fused_computation.2
|
||||
ROOT tuple.1 = (f16[], f16[]) tuple(fusion.1, fusion.2)
|
||||
}
|
||||
)")
|
||||
.ValueOrDie();
|
||||
|
||||
EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie());
|
||||
|
||||
const HloInstruction* entry_root =
|
||||
module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(entry_root, op::Tuple((op::GetTupleElement(op::Fusion())),
|
||||
(op::GetTupleElement(op::Fusion()))));
|
||||
|
||||
const HloInstruction* fusion = entry_root->operand(0)->operand(0);
|
||||
ASSERT_TRUE(fusion->IsMultiOutputFusion());
|
||||
EXPECT_THAT(fusion->fused_expression_root(),
|
||||
op::Tuple(op::Reduce(), op::Reduce()));
|
||||
}
|
||||
|
||||
TEST_F(HorizontalInputFusionTest, ManyInputFusions) {
|
||||
auto module = CreateNewVerifiedModule();
|
||||
|
||||
HloComputation* reduce_computation;
|
||||
{
|
||||
auto embedded_builder = HloComputation::Builder("add");
|
||||
auto lhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
0, ShapeUtil::MakeShape(F32, {}), "lhs"));
|
||||
auto rhs = embedded_builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
1, ShapeUtil::MakeShape(F32, {}), "rhs"));
|
||||
embedded_builder.AddInstruction(
|
||||
HloInstruction::CreateBinary(lhs->shape(), HloOpcode::kAdd, lhs, rhs));
|
||||
reduce_computation =
|
||||
module->AddEmbeddedComputation(embedded_builder.Build());
|
||||
}
|
||||
|
||||
HloComputation::Builder builder(TestName());
|
||||
std::vector<HloInstruction*> var_outs;
|
||||
auto input_shape = ShapeUtil::MakeShape(F32, {1024, 1024});
|
||||
auto output_shape = ShapeUtil::MakeShape(F32, {1024});
|
||||
for (int64 i = 0; i < 130; ++i) {
|
||||
// %fused_computation.3 (param_0: f32[1024,1024], param_1: f32[]) ->
|
||||
// f32[1024] {
|
||||
// %param_0 = f32[1024,1024]{1,0} parameter(0)
|
||||
// %param_1 = f32[] parameter(1)
|
||||
// %broadcast = f32[1024,1024]{1,0} broadcast(f32[] %param_1),
|
||||
// dimensions={}
|
||||
// %multiply = f32[1024,1024]{1,0}
|
||||
// multiply(f32[1024,1024]{1,0} %param_0, f32[1024,1024]{1,0}
|
||||
// %broadcast)
|
||||
// %constant0 = f32[] constant(0)
|
||||
// ROOT %reduce = f32[1024]{0}
|
||||
// reduce(f32[1024,1024]{1,0} %multiply, f32[] %constant0),
|
||||
// dimensions={1}, to_apply=%add
|
||||
// }
|
||||
HloInstruction* param_var_in = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(i * 2 + 0, input_shape, "var.in"));
|
||||
HloInstruction* param_alpha =
|
||||
builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
i * 2 + 1, ShapeUtil::MakeShape(F32, {}), "alpha"));
|
||||
auto alpha_broadcasted = builder.AddInstruction(
|
||||
HloInstruction::CreateBroadcast(input_shape, param_alpha, {}));
|
||||
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
input_shape, HloOpcode::kMultiply, param_var_in, alpha_broadcasted));
|
||||
HloInstruction* const0 = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(0)));
|
||||
auto reduce = builder.AddInstruction(HloInstruction::CreateReduce(
|
||||
output_shape, mul, const0, {1}, reduce_computation));
|
||||
var_outs.push_back(reduce);
|
||||
}
|
||||
builder.AddInstruction(HloInstruction::CreateTuple(var_outs));
|
||||
module->AddEntryComputation(builder.Build());
|
||||
|
||||
// Verify that horizontal fusion is kicked in. Check that there are multiple
|
||||
// `reduce` instructions fused into the same fusion. 6 is just a randomly
|
||||
// picked number as we don't exactly know how large the fusion will be
|
||||
// created due to the `FusionWouldBeTooLarge` constraint.
|
||||
CompileAndVerifyIr(module->Clone(), R"(CHECK: reduce-group-6)",
|
||||
/*match_optimized_ir=*/false);
|
||||
|
||||
// Testing with the entire gpu optimization pipeline.
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
TEST_F(HorizontalInputFusionTest, MultiOutputFusionTest) {
|
||||
// This tests the below pattern. One known issue is that gtes (to fusions) can
|
||||
// be removed after their producer fusions are merged. In the below case, gte2
|
||||
// and gte6 will be gone if Fusion2 is fused into Fusion1.
|
||||
//
|
||||
// Fusion1 Fusion2
|
||||
// | | | |
|
||||
// | gte1 gte2 |
|
||||
// | | | |
|
||||
// | Fusion3 |
|
||||
// | | | |
|
||||
// gte3 gte4 gte5 gte6
|
||||
// \ | | /
|
||||
// =====ROOT=====
|
||||
//
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule MultiOutputFusionTest
|
||||
|
||||
%add_f16 {
|
||||
%x = f16[] parameter(0)
|
||||
%y = f16[] parameter(1)
|
||||
ROOT %add = f16[] add(%x, %y)
|
||||
}
|
||||
|
||||
fused_computation.1 {
|
||||
arg.1 = f16[1024]{0} parameter(0)
|
||||
constant0 = f16[] constant(0)
|
||||
reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
|
||||
add.0 = f16[1024] add(arg.1, arg.1)
|
||||
ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
|
||||
}
|
||||
|
||||
fused_computation.2 {
|
||||
arg.1 = f16[1024]{0} parameter(0)
|
||||
constant0 = f16[] constant(0)
|
||||
reduce.1 = f16[] reduce(arg.1, constant0), dimensions={0}, to_apply=%add_f16
|
||||
add.0 = f16[1024] add(arg.1, arg.1)
|
||||
ROOT tuple.1 = (f16[], f16[1024]) tuple(reduce.1, add.0)
|
||||
}
|
||||
|
||||
fused_computation.3 {
|
||||
arg.0 = f16[1024]{0} parameter(0)
|
||||
arg.1 = f16[1024]{0} parameter(1)
|
||||
add.0 = f16[1024] add(arg.0, arg.1)
|
||||
mul.0 = f16[1024] multiply(arg.0, arg.1)
|
||||
ROOT tuple.1 = (f16[1024], f16[1024]) tuple(add.0, mul.0)
|
||||
}
|
||||
|
||||
ENTRY entry_computation {
|
||||
arg.1 = f16[1024]{0} parameter(0)
|
||||
arg.2 = f16[1024]{0} parameter(1)
|
||||
fusion.1 = (f16[],f16[1024]) fusion(arg.1), kind=kInput, calls=fused_computation.1
|
||||
fusion.2 = (f16[],f16[1024]) fusion(arg.2), kind=kInput, calls=fused_computation.2
|
||||
gte.3 = f16[] get-tuple-element(fusion.1), index=0
|
||||
gte.1 = f16[1024]{0} get-tuple-element(fusion.1), index=1
|
||||
gte.2 = f16[1024]{0} get-tuple-element(fusion.2), index=1
|
||||
gte.6 = f16[] get-tuple-element(fusion.2), index=0
|
||||
fusion.3 = (f16[1024],f16[1024]) fusion(gte.1, gte.2),
|
||||
kind=kLoop, calls=fused_computation.3
|
||||
gte.4 = f16[1024] get-tuple-element(fusion.3), index=0
|
||||
gte.5 = f16[1024]{0} get-tuple-element(fusion.3), index=1
|
||||
ROOT tuple.1 = (f16[], f16[1024]{0}, f16[], f16[1024]{0})
|
||||
tuple(gte.3, gte.4, gte.5, gte.6)
|
||||
}
|
||||
)")
|
||||
.ValueOrDie();
|
||||
|
||||
EXPECT_TRUE(GpuHorizontalInputFusion().Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
@ -137,25 +138,6 @@ bool IsFusionSupported(const HloInstruction& instr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
|
||||
const HloInstruction& consumer) {
|
||||
return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
|
||||
if (user->opcode() == HloOpcode::kGetTupleElement) {
|
||||
// Skip GTE.
|
||||
return IsConsumerTheOnlyNonRootUser(*user, consumer);
|
||||
} else if (user == &consumer) {
|
||||
// `user` is `consumer`.
|
||||
return true;
|
||||
} else if (user == user->parent()->root_instruction()) {
|
||||
// Consumed by ROOT is always fine, since it is impossible to create
|
||||
// cycles through ROOT.
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Returns whether `instr` is a profitable candidate to be horizontally fused.
|
||||
// Since the primary benefit of horizontal fusion comes from reducing the
|
||||
// kernel launch overhead, we want to exclude the instructions with
|
||||
|
Loading…
x
Reference in New Issue
Block a user